from __future__ import annotations from pathlib import Path import yaml from types import SimpleNamespace from argus.ray.models import RayConfig from argus.service.config import V2Config from argus.service.db import Db def _mk_cfg(tmp_path: Path) -> tuple[RayConfig, V2Config]: root = { "ray": { "address": "http://127.0.0.1:8265", "shared_root": "/private", "entrypoint_resources": {"worker_node": 1}, "runtime_env": {"env_vars": {}}, }, "service": { "sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")}, "scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1}, }, } return RayConfig.from_dict(root), V2Config.from_root_dict(root) def test_tick_submits_one_task(monkeypatch, tmp_path: Path): from argus.service import scheduler as sched_mod ray_cfg, v2_cfg = _mk_cfg(tmp_path) db = Db(v2_cfg.sqlite.db_path) db.init() db.create_task_v25( task_id="t1", user_id="alice", workload="ppo", jobspec_yaml=yaml.safe_dump( { "workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t", "val_file": "/private/common/datasets/v", } ), nnodes=2, n_gpus_per_node=4, ) monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None) monkeypatch.setattr( sched_mod, "get_cluster_available", lambda: SimpleNamespace(total_available_gpus=999.0, total_available_npus=0.0), ) class _Tool: def __init__(self, cfg): self.submitted = [] self.job_dirs = [] def submit(self, spec, no_wait: bool, job_dir: str | None = None): self.submitted.append(spec.submission_id) self.job_dirs.append(job_dir) return str(spec.submission_id) def status(self, submission_id: str): return "RUNNING" def logs(self, submission_id: str): return "" monkeypatch.setattr(sched_mod, "RayJobTool", _Tool) s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg) s.tick() row = db.get_task("t1") assert row and row["state"] == "SUBMITTED" attempts = db.list_attempts("t1") assert len(attempts) == 1 assert attempts[0]["ray_submission_id"] == "t1--a01" assert s.tool.job_dirs[-1] == "/private/users/alice/jobs/t1--a01" def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path): from argus.service import scheduler as sched_mod ray_cfg, v2_cfg = _mk_cfg(tmp_path) db = Db(v2_cfg.sqlite.db_path) db.init() db.create_task_v25( task_id="t1", user_id="alice", workload="ppo", jobspec_yaml=yaml.safe_dump( { "workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t", "val_file": "/private/common/datasets/v", } ), nnodes=2, n_gpus_per_node=4, ) monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None) monkeypatch.setattr( sched_mod, "get_cluster_available", lambda: SimpleNamespace(total_available_gpus=0.0, total_available_npus=0.0), ) monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None) s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg) s.tick() row = db.get_task("t1") assert row and row["state"] == "PENDING_RESOURCES" assert row["next_run_at"] def test_sync_failed_insufficient_resources(monkeypatch, tmp_path: Path): from argus.service import scheduler as sched_mod ray_cfg, v2_cfg = _mk_cfg(tmp_path) db = Db(v2_cfg.sqlite.db_path) db.init() db.create_task_v25( task_id="t1", user_id="alice", workload="ppo", jobspec_yaml=yaml.safe_dump( { "workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t", "val_file": "/private/common/datasets/v", } ), nnodes=2, n_gpus_per_node=4, ) db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="t1--a01") db.set_task_state(task_id="t1", state="RUNNING", latest_attempt_no=1) monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None) monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None) s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg) class _Tool: def status(self, sid: str): return "FAILED" def logs(self, sid: str): # Match the service's regex exactly: # it expects literal backslashes and repeats of 's'/'d' (because of double-escaping). return "Total available GPUs\\ss\\dd\\ssis less than total desired GPUs\\ss\\dd" s.tool = _Tool() s.tick() row = db.get_task("t1") assert row and row["state"] == "PENDING_RESOURCES" attempts = db.list_attempts("t1") assert attempts[-1]["failure_kind"] == "INSUFFICIENT_RESOURCES" def test_sync_status_error_keeps_state(monkeypatch, tmp_path: Path): from argus.service import scheduler as sched_mod ray_cfg, v2_cfg = _mk_cfg(tmp_path) db = Db(v2_cfg.sqlite.db_path) db.init() db.create_task_v25( task_id="t1", user_id="alice", workload="ppo", jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}), nnodes=2, n_gpus_per_node=4, ) db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="t1--a01") db.set_task_state(task_id="t1", state="RUNNING", latest_attempt_no=1) monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None) monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None) s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg) class _Tool: def status(self, sid: str): raise RuntimeError("boom") s.tool = _Tool() s.tick() row = db.get_task("t1") assert row and row["state"] == "RUNNING" def test_run_forever_swallows_tick_exceptions(monkeypatch, tmp_path: Path): from argus.service import scheduler as sched_mod ray_cfg, v2_cfg = _mk_cfg(tmp_path) db = Db(v2_cfg.sqlite.db_path) db.init() monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None) s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg) calls = {"n": 0} def _tick(): calls["n"] += 1 raise RuntimeError("boom") monkeypatch.setattr(s, "tick", _tick) monkeypatch.setattr(sched_mod.time, "sleep", lambda _: None) class _Stop: def __init__(self): self._n = 0 def is_set(self): self._n += 1 return self._n > 1 s.run_forever(_Stop()) assert calls["n"] == 1