argus-cluster/src/mvp/py/tests/test_ray_job_tool.py

163 lines
5.0 KiB
Python

from __future__ import annotations
import json
import os
from pathlib import Path
import pytest
from argus.ray.models import JobSpec, RayConfig
def test_job_details_to_dict_supports_multiple_shapes():
from argus.ray.ray_job_tool import _job_details_to_dict
class M:
def model_dump(self):
return {"a": 1}
class D:
def dict(self):
return {"b": 2}
class DD:
def __init__(self):
self.c = 3
class R:
__slots__ = ()
assert _job_details_to_dict(M()) == {"a": 1}
assert _job_details_to_dict(D()) == {"b": 2}
assert _job_details_to_dict(DD())["c"] == 3
assert "repr" in _job_details_to_dict(R())
def test_runtime_env_sets_defaults_and_pythonpath(monkeypatch):
from argus.ray.ray_job_tool import RayJobTool
cfg = RayConfig.from_dict(
{
"address": "http://127.0.0.1:8265",
"shared_root": "/private",
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {"PYTHONPATH": "x"}},
"user_code_path": "/private/user/code",
}
)
spec = JobSpec.from_dict(
{"workload": "sft", "code_path": "/c", "model_id": "m", "train_file": "t"}
)
monkeypatch.setenv("MVP_TOOL_CODE_PATH", "/tool")
tool = RayJobTool(cfg)
env = tool._runtime_env(spec)["env_vars"]
assert env["HF_HOME"].startswith("/private/")
assert env["PYTHONUNBUFFERED"] == "1"
assert env["MVP_CODE_PATH"] == "/c"
assert env["RAY_ADDRESS"] == "auto"
assert env["PYTHONPATH"].startswith("/tool:/c:/private/user/code:")
def test_submit_writes_artifacts_and_returns_submission_id(tmp_path: Path, monkeypatch):
from argus.ray import ray_job_tool as mod
class _FakeClient:
def __init__(self, address: str):
self.address = address
self.last_submit_kwargs = None
def submit_job(self, **kwargs):
self.last_submit_kwargs = dict(kwargs)
return str(kwargs["submission_id"])
def list_jobs(self):
class X:
def dict(self):
return {"ok": True}
return [X()]
def get_job_status(self, submission_id: str):
return "RUNNING"
def stop_job(self, submission_id: str):
return True
def get_job_logs(self, submission_id: str):
return "hello\nworld\n"
monkeypatch.setattr(mod, "JobSubmissionClient", _FakeClient)
monkeypatch.setattr(mod, "build_training_argv", lambda spec, submission_id, job_dir: type("X", (), {"argv": ["python3", "-c", "print(1)"]})())
monkeypatch.setattr(mod.ray, "init", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("no ray")))
cfg = RayConfig.from_dict(
{
"address": "http://127.0.0.1:8265",
"shared_root": str(tmp_path),
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
}
)
spec = JobSpec.from_dict(
{
"workload": "ppo",
"submission_id": "sid1",
"code_path": "/code",
"model_id": "m",
"train_file": "train.jsonl",
}
)
tool = mod.RayJobTool(cfg)
submitted = tool.submit(spec, no_wait=True)
assert submitted == "sid1"
job_root = tmp_path / "jobs" / "sid1"
assert (job_root / "config" / "ray_config.yaml").exists()
assert (job_root / "config" / "jobspec.yaml").exists()
assert (job_root / "config" / "ray_submission_id.txt").read_text(encoding="utf-8").strip() == "sid1"
payload = json.loads((job_root / "config" / "submit_payload.json").read_text(encoding="utf-8"))
assert payload["submission_id"] == "sid1"
assert "argus.ray.driver_entrypoint" in payload["entrypoint"]
assert (job_root / "debug" / "ray_resources_pre.error.txt").exists()
assert (job_root / "debug" / "ray_job_list_post.json").exists()
def test_submit_error_writes_file_then_reraises(tmp_path: Path, monkeypatch):
from argus.ray import ray_job_tool as mod
class _FakeClient:
def __init__(self, address: str):
self.address = address
def submit_job(self, **kwargs):
raise RuntimeError("boom")
def list_jobs(self):
return []
monkeypatch.setattr(mod, "JobSubmissionClient", _FakeClient)
monkeypatch.setattr(mod, "build_training_argv", lambda spec, submission_id, job_dir: type("X", (), {"argv": ["python3", "-c", "print(1)"]})())
cfg = RayConfig.from_dict(
{
"address": "http://127.0.0.1:8265",
"shared_root": str(tmp_path),
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
}
)
spec = JobSpec.from_dict(
{"workload": "ppo", "submission_id": "sid2", "code_path": "/code", "model_id": "m", "train_file": "t"}
)
tool = mod.RayJobTool(cfg)
with pytest.raises(RuntimeError, match="boom"):
tool.submit(spec, no_wait=True)
err = tmp_path / "jobs" / "sid2" / "logs" / "submit.error.txt"
assert err.exists()