from __future__ import annotations import json import os from pathlib import Path import pytest from argus.ray.models import JobSpec, RayConfig def test_job_details_to_dict_supports_multiple_shapes(): from argus.ray.ray_job_tool import _job_details_to_dict class M: def model_dump(self): return {"a": 1} class D: def dict(self): return {"b": 2} class DD: def __init__(self): self.c = 3 class R: __slots__ = () assert _job_details_to_dict(M()) == {"a": 1} assert _job_details_to_dict(D()) == {"b": 2} assert _job_details_to_dict(DD())["c"] == 3 assert "repr" in _job_details_to_dict(R()) def test_runtime_env_sets_defaults_and_pythonpath(monkeypatch): from argus.ray.ray_job_tool import RayJobTool cfg = RayConfig.from_dict( { "address": "http://127.0.0.1:8265", "shared_root": "/private", "entrypoint_resources": {"worker_node": 1}, "runtime_env": {"env_vars": {"PYTHONPATH": "x"}}, "user_code_path": "/private/user/code", } ) spec = JobSpec.from_dict( {"workload": "sft", "code_path": "/c", "model_id": "m", "train_file": "t"} ) monkeypatch.setenv("MVP_TOOL_CODE_PATH", "/tool") tool = RayJobTool(cfg) env = tool._runtime_env(spec)["env_vars"] assert env["HF_HOME"].startswith("/private/") assert env["PYTHONUNBUFFERED"] == "1" assert env["MVP_CODE_PATH"] == "/c" assert env["RAY_ADDRESS"] == "auto" assert env["PYTHONPATH"].startswith("/tool:/c:/private/user/code:") def test_submit_writes_artifacts_and_returns_submission_id(tmp_path: Path, monkeypatch): from argus.ray import ray_job_tool as mod class _FakeClient: def __init__(self, address: str): self.address = address self.last_submit_kwargs = None def submit_job(self, **kwargs): self.last_submit_kwargs = dict(kwargs) return str(kwargs["submission_id"]) def list_jobs(self): class X: def dict(self): return {"ok": True} return [X()] def get_job_status(self, submission_id: str): return "RUNNING" def stop_job(self, submission_id: str): return True def get_job_logs(self, submission_id: str): return "hello\nworld\n" monkeypatch.setattr(mod, "JobSubmissionClient", _FakeClient) monkeypatch.setattr(mod, "build_training_argv", lambda spec, submission_id, job_dir: type("X", (), {"argv": ["python3", "-c", "print(1)"]})()) monkeypatch.setattr(mod.ray, "init", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("no ray"))) cfg = RayConfig.from_dict( { "address": "http://127.0.0.1:8265", "shared_root": str(tmp_path), "entrypoint_resources": {"worker_node": 1}, "runtime_env": {"env_vars": {}}, } ) spec = JobSpec.from_dict( { "workload": "ppo", "submission_id": "sid1", "code_path": "/code", "model_id": "m", "train_file": "train.jsonl", } ) tool = mod.RayJobTool(cfg) submitted = tool.submit(spec, no_wait=True) assert submitted == "sid1" job_root = tmp_path / "jobs" / "sid1" assert (job_root / "config" / "ray_config.yaml").exists() assert (job_root / "config" / "jobspec.yaml").exists() assert (job_root / "config" / "ray_submission_id.txt").read_text(encoding="utf-8").strip() == "sid1" payload = json.loads((job_root / "config" / "submit_payload.json").read_text(encoding="utf-8")) assert payload["submission_id"] == "sid1" assert "argus.ray.driver_entrypoint" in payload["entrypoint"] assert (job_root / "debug" / "ray_resources_pre.error.txt").exists() assert (job_root / "debug" / "ray_job_list_post.json").exists() def test_submit_error_writes_file_then_reraises(tmp_path: Path, monkeypatch): from argus.ray import ray_job_tool as mod class _FakeClient: def __init__(self, address: str): self.address = address def submit_job(self, **kwargs): raise RuntimeError("boom") def list_jobs(self): return [] monkeypatch.setattr(mod, "JobSubmissionClient", _FakeClient) monkeypatch.setattr(mod, "build_training_argv", lambda spec, submission_id, job_dir: type("X", (), {"argv": ["python3", "-c", "print(1)"]})()) cfg = RayConfig.from_dict( { "address": "http://127.0.0.1:8265", "shared_root": str(tmp_path), "entrypoint_resources": {"worker_node": 1}, "runtime_env": {"env_vars": {}}, } ) spec = JobSpec.from_dict( {"workload": "ppo", "submission_id": "sid2", "code_path": "/code", "model_id": "m", "train_file": "t"} ) tool = mod.RayJobTool(cfg) with pytest.raises(RuntimeError, match="boom"): tool.submit(spec, no_wait=True) err = tmp_path / "jobs" / "sid2" / "logs" / "submit.error.txt" assert err.exists()