argus-cluster/src/mvp/py/tests/test_models.py
2025-12-31 15:16:42 +08:00

109 lines
3.1 KiB
Python

from __future__ import annotations
import pytest
def test_require_missing_raises():
from argus.ray.models import _require
with pytest.raises(ValueError, match="missing required field: x"):
_require({}, "x")
with pytest.raises(ValueError, match="missing required field: x"):
_require({"x": ""}, "x")
def test_ray_config_from_dict_new_format_and_defaults():
from argus.ray.models import RayConfig
cfg = RayConfig.from_dict(
{
"ray": {
"address": "http://127.0.0.1:8265",
"shared_root": "/private",
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {"HF_ENDPOINT": "x"}},
}
}
)
assert cfg.address.endswith("8265")
assert cfg.shared_root == "/private"
assert cfg.entrypoint_num_cpus == 1.0
assert cfg.entrypoint_resources["worker_node"] == 1.0
assert cfg.runtime_env_env_vars["HF_ENDPOINT"] == "x"
assert cfg.user_code_path == "/private/user/code"
public = cfg.to_public_dict()
assert public["runtime_env"]["env_vars"]["HF_ENDPOINT"] == "x"
def test_ray_config_from_dict_requires_mappings():
from argus.ray.models import RayConfig
with pytest.raises(ValueError, match="runtime_env\\.env_vars must be a mapping"):
RayConfig.from_dict(
{
"address": "x",
"shared_root": "/p",
"entrypoint_resources": {},
"runtime_env": {"env_vars": ["nope"]},
}
)
with pytest.raises(ValueError, match="entrypoint_resources must be a mapping"):
RayConfig.from_dict(
{
"address": "x",
"shared_root": "/p",
"entrypoint_resources": ["nope"],
}
)
def test_jobspec_validation_and_null_coercion():
from argus.ray.models import JobSpec
spec = JobSpec.from_dict(
{
"workload": "ppo",
"code_path": "/code",
"model_id": "m",
"train_file": "train.jsonl",
"val_file": "val.jsonl",
"test_freq": "",
}
)
assert spec.workload == "ppo"
assert spec.val_file == "val.jsonl"
assert spec.test_freq is None
assert spec.total_training_steps is None
assert spec.nnodes == 2
assert spec.n_gpus_per_node == 4
pub = spec.to_public_dict()
assert pub["submission_id"] == ""
assert "trainer_device" not in pub
def test_jobspec_sft_adds_trainer_device_default():
from argus.ray.models import JobSpec
spec = JobSpec.from_dict(
{
"workload": "sft",
"code_path": "/code",
"model_id": "m",
"train_file": "train.jsonl",
"val_file": "val.jsonl",
}
)
pub = spec.to_public_dict()
assert pub["trainer_device"] == "cpu"
def test_jobspec_unsupported_workload():
from argus.ray.models import JobSpec
with pytest.raises(ValueError, match="unsupported workload"):
JobSpec.from_dict(
{"workload": "nope", "code_path": "x", "model_id": "m", "train_file": "t", "val_file": "v"}
)