109 lines
3.1 KiB
Python
109 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
|
|
def test_require_missing_raises():
|
|
from argus.ray.models import _require
|
|
|
|
with pytest.raises(ValueError, match="missing required field: x"):
|
|
_require({}, "x")
|
|
with pytest.raises(ValueError, match="missing required field: x"):
|
|
_require({"x": ""}, "x")
|
|
|
|
|
|
def test_ray_config_from_dict_new_format_and_defaults():
|
|
from argus.ray.models import RayConfig
|
|
|
|
cfg = RayConfig.from_dict(
|
|
{
|
|
"ray": {
|
|
"address": "http://127.0.0.1:8265",
|
|
"shared_root": "/private",
|
|
"entrypoint_resources": {"worker_node": 1},
|
|
"runtime_env": {"env_vars": {"HF_ENDPOINT": "x"}},
|
|
}
|
|
}
|
|
)
|
|
assert cfg.address.endswith("8265")
|
|
assert cfg.shared_root == "/private"
|
|
assert cfg.entrypoint_num_cpus == 1.0
|
|
assert cfg.entrypoint_resources["worker_node"] == 1.0
|
|
assert cfg.runtime_env_env_vars["HF_ENDPOINT"] == "x"
|
|
assert cfg.user_code_path == "/private/user/code"
|
|
|
|
public = cfg.to_public_dict()
|
|
assert public["runtime_env"]["env_vars"]["HF_ENDPOINT"] == "x"
|
|
|
|
|
|
def test_ray_config_from_dict_requires_mappings():
|
|
from argus.ray.models import RayConfig
|
|
|
|
with pytest.raises(ValueError, match="runtime_env\\.env_vars must be a mapping"):
|
|
RayConfig.from_dict(
|
|
{
|
|
"address": "x",
|
|
"shared_root": "/p",
|
|
"entrypoint_resources": {},
|
|
"runtime_env": {"env_vars": ["nope"]},
|
|
}
|
|
)
|
|
with pytest.raises(ValueError, match="entrypoint_resources must be a mapping"):
|
|
RayConfig.from_dict(
|
|
{
|
|
"address": "x",
|
|
"shared_root": "/p",
|
|
"entrypoint_resources": ["nope"],
|
|
}
|
|
)
|
|
|
|
|
|
def test_jobspec_validation_and_null_coercion():
|
|
from argus.ray.models import JobSpec
|
|
|
|
spec = JobSpec.from_dict(
|
|
{
|
|
"workload": "ppo",
|
|
"code_path": "/code",
|
|
"model_id": "m",
|
|
"train_file": "train.jsonl",
|
|
"val_file": "val.jsonl",
|
|
"test_freq": "",
|
|
}
|
|
)
|
|
assert spec.workload == "ppo"
|
|
assert spec.val_file == "val.jsonl"
|
|
assert spec.test_freq is None
|
|
assert spec.total_training_steps is None
|
|
assert spec.nnodes == 2
|
|
assert spec.n_gpus_per_node == 4
|
|
|
|
pub = spec.to_public_dict()
|
|
assert pub["submission_id"] == ""
|
|
assert "trainer_device" not in pub
|
|
|
|
|
|
def test_jobspec_sft_adds_trainer_device_default():
|
|
from argus.ray.models import JobSpec
|
|
|
|
spec = JobSpec.from_dict(
|
|
{
|
|
"workload": "sft",
|
|
"code_path": "/code",
|
|
"model_id": "m",
|
|
"train_file": "train.jsonl",
|
|
"val_file": "val.jsonl",
|
|
}
|
|
)
|
|
pub = spec.to_public_dict()
|
|
assert pub["trainer_device"] == "cpu"
|
|
|
|
|
|
def test_jobspec_unsupported_workload():
|
|
from argus.ray.models import JobSpec
|
|
|
|
with pytest.raises(ValueError, match="unsupported workload"):
|
|
JobSpec.from_dict(
|
|
{"workload": "nope", "code_path": "x", "model_id": "m", "train_file": "t", "val_file": "v"}
|
|
)
|