argus-cluster/src/mvp/py/tests/test_scheduler.py

204 lines
6.1 KiB
Python

from __future__ import annotations
from pathlib import Path
import yaml
from types import SimpleNamespace
from argus.ray.models import RayConfig
from argus.service.config import V2Config
from argus.service.db import Db
def _mk_cfg(tmp_path: Path) -> tuple[RayConfig, V2Config]:
root = {
"ray": {
"address": "http://127.0.0.1:8265",
"shared_root": str(tmp_path),
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
},
"service": {
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
},
}
return RayConfig.from_dict(root), V2Config.from_root_dict(root)
def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
from argus.service import scheduler as sched_mod
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
task_id="t1",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
nnodes=2,
n_gpus_per_node=4,
)
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
monkeypatch.setattr(
sched_mod,
"get_cluster_available",
lambda: SimpleNamespace(total_available_gpus=999.0, total_available_npus=0.0),
)
class _Tool:
def __init__(self, cfg):
self.submitted = []
def submit(self, spec, no_wait: bool):
self.submitted.append(spec.submission_id)
return str(spec.submission_id)
def status(self, submission_id: str):
return "RUNNING"
def logs(self, submission_id: str):
return ""
monkeypatch.setattr(sched_mod, "RayJobTool", _Tool)
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
s.tick()
row = db.get_task("t1")
assert row and row["state"] == "SUBMITTED"
attempts = db.list_attempts("t1")
assert len(attempts) == 1
assert attempts[0]["ray_submission_id"] == "t1--a01"
def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
from argus.service import scheduler as sched_mod
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
task_id="t1",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
nnodes=2,
n_gpus_per_node=4,
)
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
monkeypatch.setattr(
sched_mod,
"get_cluster_available",
lambda: SimpleNamespace(total_available_gpus=0.0, total_available_npus=0.0),
)
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
s.tick()
row = db.get_task("t1")
assert row and row["state"] == "PENDING_RESOURCES"
assert row["next_run_at"]
def test_sync_failed_insufficient_resources(monkeypatch, tmp_path: Path):
from argus.service import scheduler as sched_mod
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
task_id="t1",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
nnodes=2,
n_gpus_per_node=4,
)
db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="t1--a01")
db.set_task_state(task_id="t1", state="RUNNING", latest_attempt_no=1)
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
class _Tool:
def status(self, sid: str):
return "FAILED"
def logs(self, sid: str):
# Match the service's regex exactly:
# it expects literal backslashes and repeats of 's'/'d' (because of double-escaping).
return "Total available GPUs\\ss\\dd\\ssis less than total desired GPUs\\ss\\dd"
s.tool = _Tool()
s.tick()
row = db.get_task("t1")
assert row and row["state"] == "PENDING_RESOURCES"
attempts = db.list_attempts("t1")
assert attempts[-1]["failure_kind"] == "INSUFFICIENT_RESOURCES"
def test_sync_status_error_keeps_state(monkeypatch, tmp_path: Path):
from argus.service import scheduler as sched_mod
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
task_id="t1",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
nnodes=2,
n_gpus_per_node=4,
)
db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="t1--a01")
db.set_task_state(task_id="t1", state="RUNNING", latest_attempt_no=1)
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
class _Tool:
def status(self, sid: str):
raise RuntimeError("boom")
s.tool = _Tool()
s.tick()
row = db.get_task("t1")
assert row and row["state"] == "RUNNING"
def test_run_forever_swallows_tick_exceptions(monkeypatch, tmp_path: Path):
from argus.service import scheduler as sched_mod
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
calls = {"n": 0}
def _tick():
calls["n"] += 1
raise RuntimeError("boom")
monkeypatch.setattr(s, "tick", _tick)
monkeypatch.setattr(sched_mod.time, "sleep", lambda _: None)
class _Stop:
def __init__(self):
self._n = 0
def is_set(self):
self._n += 1
return self._n > 1
s.run_forever(_Stop())
assert calls["n"] == 1