163 lines
5.0 KiB
Python
163 lines
5.0 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from argus.ray.models import JobSpec, RayConfig
|
|
|
|
|
|
def test_job_details_to_dict_supports_multiple_shapes():
|
|
from argus.ray.ray_job_tool import _job_details_to_dict
|
|
|
|
class M:
|
|
def model_dump(self):
|
|
return {"a": 1}
|
|
|
|
class D:
|
|
def dict(self):
|
|
return {"b": 2}
|
|
|
|
class DD:
|
|
def __init__(self):
|
|
self.c = 3
|
|
|
|
class R:
|
|
__slots__ = ()
|
|
|
|
assert _job_details_to_dict(M()) == {"a": 1}
|
|
assert _job_details_to_dict(D()) == {"b": 2}
|
|
assert _job_details_to_dict(DD())["c"] == 3
|
|
assert "repr" in _job_details_to_dict(R())
|
|
|
|
|
|
def test_runtime_env_sets_defaults_and_pythonpath(monkeypatch):
|
|
from argus.ray.ray_job_tool import RayJobTool
|
|
|
|
cfg = RayConfig.from_dict(
|
|
{
|
|
"address": "http://127.0.0.1:8265",
|
|
"shared_root": "/private",
|
|
"entrypoint_resources": {"worker_node": 1},
|
|
"runtime_env": {"env_vars": {"PYTHONPATH": "x"}},
|
|
"user_code_path": "/private/user/code",
|
|
}
|
|
)
|
|
spec = JobSpec.from_dict(
|
|
{"workload": "sft", "code_path": "/c", "model_id": "m", "train_file": "t"}
|
|
)
|
|
monkeypatch.setenv("MVP_TOOL_CODE_PATH", "/tool")
|
|
|
|
tool = RayJobTool(cfg)
|
|
env = tool._runtime_env(spec)["env_vars"]
|
|
assert env["HF_HOME"].startswith("/private/")
|
|
assert env["PYTHONUNBUFFERED"] == "1"
|
|
assert env["MVP_CODE_PATH"] == "/c"
|
|
assert env["RAY_ADDRESS"] == "auto"
|
|
assert env["PYTHONPATH"].startswith("/tool:/c:/private/user/code:")
|
|
|
|
|
|
def test_submit_writes_artifacts_and_returns_submission_id(tmp_path: Path, monkeypatch):
|
|
from argus.ray import ray_job_tool as mod
|
|
|
|
class _FakeClient:
|
|
def __init__(self, address: str):
|
|
self.address = address
|
|
self.last_submit_kwargs = None
|
|
|
|
def submit_job(self, **kwargs):
|
|
self.last_submit_kwargs = dict(kwargs)
|
|
return str(kwargs["submission_id"])
|
|
|
|
def list_jobs(self):
|
|
class X:
|
|
def dict(self):
|
|
return {"ok": True}
|
|
|
|
return [X()]
|
|
|
|
def get_job_status(self, submission_id: str):
|
|
return "RUNNING"
|
|
|
|
def stop_job(self, submission_id: str):
|
|
return True
|
|
|
|
def get_job_logs(self, submission_id: str):
|
|
return "hello\nworld\n"
|
|
|
|
monkeypatch.setattr(mod, "JobSubmissionClient", _FakeClient)
|
|
monkeypatch.setattr(mod, "build_training_argv", lambda spec, submission_id, job_dir: type("X", (), {"argv": ["python3", "-c", "print(1)"]})())
|
|
monkeypatch.setattr(mod.ray, "init", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("no ray")))
|
|
|
|
cfg = RayConfig.from_dict(
|
|
{
|
|
"address": "http://127.0.0.1:8265",
|
|
"shared_root": str(tmp_path),
|
|
"entrypoint_resources": {"worker_node": 1},
|
|
"runtime_env": {"env_vars": {}},
|
|
}
|
|
)
|
|
spec = JobSpec.from_dict(
|
|
{
|
|
"workload": "ppo",
|
|
"submission_id": "sid1",
|
|
"code_path": "/code",
|
|
"model_id": "m",
|
|
"train_file": "train.jsonl",
|
|
}
|
|
)
|
|
|
|
tool = mod.RayJobTool(cfg)
|
|
submitted = tool.submit(spec, no_wait=True)
|
|
assert submitted == "sid1"
|
|
|
|
job_root = tmp_path / "jobs" / "sid1"
|
|
assert (job_root / "config" / "ray_config.yaml").exists()
|
|
assert (job_root / "config" / "jobspec.yaml").exists()
|
|
assert (job_root / "config" / "ray_submission_id.txt").read_text(encoding="utf-8").strip() == "sid1"
|
|
|
|
payload = json.loads((job_root / "config" / "submit_payload.json").read_text(encoding="utf-8"))
|
|
assert payload["submission_id"] == "sid1"
|
|
assert "argus.ray.driver_entrypoint" in payload["entrypoint"]
|
|
|
|
assert (job_root / "debug" / "ray_resources_pre.error.txt").exists()
|
|
assert (job_root / "debug" / "ray_job_list_post.json").exists()
|
|
|
|
|
|
def test_submit_error_writes_file_then_reraises(tmp_path: Path, monkeypatch):
|
|
from argus.ray import ray_job_tool as mod
|
|
|
|
class _FakeClient:
|
|
def __init__(self, address: str):
|
|
self.address = address
|
|
|
|
def submit_job(self, **kwargs):
|
|
raise RuntimeError("boom")
|
|
|
|
def list_jobs(self):
|
|
return []
|
|
|
|
monkeypatch.setattr(mod, "JobSubmissionClient", _FakeClient)
|
|
monkeypatch.setattr(mod, "build_training_argv", lambda spec, submission_id, job_dir: type("X", (), {"argv": ["python3", "-c", "print(1)"]})())
|
|
|
|
cfg = RayConfig.from_dict(
|
|
{
|
|
"address": "http://127.0.0.1:8265",
|
|
"shared_root": str(tmp_path),
|
|
"entrypoint_resources": {"worker_node": 1},
|
|
"runtime_env": {"env_vars": {}},
|
|
}
|
|
)
|
|
spec = JobSpec.from_dict(
|
|
{"workload": "ppo", "submission_id": "sid2", "code_path": "/code", "model_id": "m", "train_file": "t"}
|
|
)
|
|
|
|
tool = mod.RayJobTool(cfg)
|
|
with pytest.raises(RuntimeError, match="boom"):
|
|
tool.submit(spec, no_wait=True)
|
|
|
|
err = tmp_path / "jobs" / "sid2" / "logs" / "submit.error.txt"
|
|
assert err.exists()
|