235 lines
7.1 KiB
Python
235 lines
7.1 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from types import SimpleNamespace
|
|
|
|
from argus.ray.models import RayConfig
|
|
from argus.service.config import V2Config
|
|
from argus.service.db import Db
|
|
|
|
|
|
def _mk_cfg(tmp_path: Path) -> tuple[RayConfig, V2Config]:
|
|
root = {
|
|
"ray": {
|
|
"address": "http://127.0.0.1:8265",
|
|
"shared_root": "/private",
|
|
"entrypoint_resources": {"worker_node": 1},
|
|
"runtime_env": {"env_vars": {}},
|
|
},
|
|
"service": {
|
|
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
|
|
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
|
|
},
|
|
}
|
|
return RayConfig.from_dict(root), V2Config.from_root_dict(root)
|
|
|
|
|
|
def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
|
|
from argus.service import scheduler as sched_mod
|
|
|
|
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
|
db = Db(v2_cfg.sqlite.db_path)
|
|
db.init()
|
|
db.create_task_v25(
|
|
task_id="t1",
|
|
user_id="alice",
|
|
workload="ppo",
|
|
jobspec_yaml=yaml.safe_dump(
|
|
{
|
|
"workload": "ppo",
|
|
"code_path": "/private/common/code/verl",
|
|
"model_id": "m",
|
|
"train_file": "/private/common/datasets/t",
|
|
"val_file": "/private/common/datasets/v",
|
|
}
|
|
),
|
|
nnodes=2,
|
|
n_gpus_per_node=4,
|
|
)
|
|
|
|
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
|
|
monkeypatch.setattr(
|
|
sched_mod,
|
|
"get_cluster_available",
|
|
lambda: SimpleNamespace(total_available_gpus=999.0, total_available_npus=0.0),
|
|
)
|
|
|
|
class _Tool:
|
|
def __init__(self, cfg):
|
|
self.submitted = []
|
|
self.job_dirs = []
|
|
|
|
def submit(self, spec, no_wait: bool, job_dir: str | None = None):
|
|
self.submitted.append(spec.submission_id)
|
|
self.job_dirs.append(job_dir)
|
|
return str(spec.submission_id)
|
|
|
|
def status(self, submission_id: str):
|
|
return "RUNNING"
|
|
|
|
def logs(self, submission_id: str):
|
|
return ""
|
|
|
|
monkeypatch.setattr(sched_mod, "RayJobTool", _Tool)
|
|
|
|
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
|
|
s.tick()
|
|
|
|
row = db.get_task("t1")
|
|
assert row and row["state"] == "SUBMITTED"
|
|
attempts = db.list_attempts("t1")
|
|
assert len(attempts) == 1
|
|
assert attempts[0]["ray_submission_id"] == "t1--a01"
|
|
assert s.tool.job_dirs[-1] == "/private/users/alice/jobs/t1--a01"
|
|
|
|
|
|
def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
|
|
from argus.service import scheduler as sched_mod
|
|
|
|
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
|
db = Db(v2_cfg.sqlite.db_path)
|
|
db.init()
|
|
db.create_task_v25(
|
|
task_id="t1",
|
|
user_id="alice",
|
|
workload="ppo",
|
|
jobspec_yaml=yaml.safe_dump(
|
|
{
|
|
"workload": "ppo",
|
|
"code_path": "/private/common/code/verl",
|
|
"model_id": "m",
|
|
"train_file": "/private/common/datasets/t",
|
|
"val_file": "/private/common/datasets/v",
|
|
}
|
|
),
|
|
nnodes=2,
|
|
n_gpus_per_node=4,
|
|
)
|
|
|
|
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
|
|
monkeypatch.setattr(
|
|
sched_mod,
|
|
"get_cluster_available",
|
|
lambda: SimpleNamespace(total_available_gpus=0.0, total_available_npus=0.0),
|
|
)
|
|
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
|
|
|
|
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
|
|
s.tick()
|
|
row = db.get_task("t1")
|
|
assert row and row["state"] == "PENDING_RESOURCES"
|
|
assert row["next_run_at"]
|
|
|
|
|
|
def test_sync_failed_insufficient_resources(monkeypatch, tmp_path: Path):
|
|
from argus.service import scheduler as sched_mod
|
|
|
|
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
|
db = Db(v2_cfg.sqlite.db_path)
|
|
db.init()
|
|
db.create_task_v25(
|
|
task_id="t1",
|
|
user_id="alice",
|
|
workload="ppo",
|
|
jobspec_yaml=yaml.safe_dump(
|
|
{
|
|
"workload": "ppo",
|
|
"code_path": "/private/common/code/verl",
|
|
"model_id": "m",
|
|
"train_file": "/private/common/datasets/t",
|
|
"val_file": "/private/common/datasets/v",
|
|
}
|
|
),
|
|
nnodes=2,
|
|
n_gpus_per_node=4,
|
|
)
|
|
db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="t1--a01")
|
|
db.set_task_state(task_id="t1", state="RUNNING", latest_attempt_no=1)
|
|
|
|
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
|
|
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
|
|
|
|
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
|
|
|
|
class _Tool:
|
|
def status(self, sid: str):
|
|
return "FAILED"
|
|
|
|
def logs(self, sid: str):
|
|
# Match the service's regex exactly:
|
|
# it expects literal backslashes and repeats of 's'/'d' (because of double-escaping).
|
|
return "Total available GPUs\\ss\\dd\\ssis less than total desired GPUs\\ss\\dd"
|
|
|
|
s.tool = _Tool()
|
|
s.tick()
|
|
|
|
row = db.get_task("t1")
|
|
assert row and row["state"] == "PENDING_RESOURCES"
|
|
attempts = db.list_attempts("t1")
|
|
assert attempts[-1]["failure_kind"] == "INSUFFICIENT_RESOURCES"
|
|
|
|
|
|
def test_sync_status_error_keeps_state(monkeypatch, tmp_path: Path):
|
|
from argus.service import scheduler as sched_mod
|
|
|
|
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
|
db = Db(v2_cfg.sqlite.db_path)
|
|
db.init()
|
|
db.create_task_v25(
|
|
task_id="t1",
|
|
user_id="alice",
|
|
workload="ppo",
|
|
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
|
|
nnodes=2,
|
|
n_gpus_per_node=4,
|
|
)
|
|
db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="t1--a01")
|
|
db.set_task_state(task_id="t1", state="RUNNING", latest_attempt_no=1)
|
|
|
|
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
|
|
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
|
|
|
|
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
|
|
|
|
class _Tool:
|
|
def status(self, sid: str):
|
|
raise RuntimeError("boom")
|
|
|
|
s.tool = _Tool()
|
|
s.tick()
|
|
row = db.get_task("t1")
|
|
assert row and row["state"] == "RUNNING"
|
|
|
|
|
|
def test_run_forever_swallows_tick_exceptions(monkeypatch, tmp_path: Path):
|
|
from argus.service import scheduler as sched_mod
|
|
|
|
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
|
db = Db(v2_cfg.sqlite.db_path)
|
|
db.init()
|
|
|
|
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
|
|
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
|
|
|
|
calls = {"n": 0}
|
|
|
|
def _tick():
|
|
calls["n"] += 1
|
|
raise RuntimeError("boom")
|
|
|
|
monkeypatch.setattr(s, "tick", _tick)
|
|
monkeypatch.setattr(sched_mod.time, "sleep", lambda _: None)
|
|
|
|
class _Stop:
|
|
def __init__(self):
|
|
self._n = 0
|
|
|
|
def is_set(self):
|
|
self._n += 1
|
|
return self._n > 1
|
|
|
|
s.run_forever(_Stop())
|
|
assert calls["n"] == 1
|