157 lines
4.7 KiB
Python
157 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
|
|
def test_require_missing_raises():
|
|
from argus.ray.models import _require
|
|
|
|
with pytest.raises(ValueError, match="missing required field: x"):
|
|
_require({}, "x")
|
|
with pytest.raises(ValueError, match="missing required field: x"):
|
|
_require({"x": ""}, "x")
|
|
|
|
|
|
def test_ray_config_from_dict_new_format_and_defaults():
|
|
from argus.ray.models import RayConfig
|
|
|
|
cfg = RayConfig.from_dict(
|
|
{
|
|
"ray": {
|
|
"address": "http://127.0.0.1:8265",
|
|
"shared_root": "/private",
|
|
"entrypoint_resources": {"worker_node": 1},
|
|
"runtime_env": {"env_vars": {"HF_ENDPOINT": "x"}},
|
|
}
|
|
}
|
|
)
|
|
assert cfg.address.endswith("8265")
|
|
assert cfg.shared_root == "/private"
|
|
assert cfg.entrypoint_num_cpus == 1.0
|
|
assert cfg.entrypoint_resources["worker_node"] == 1.0
|
|
assert cfg.runtime_env_env_vars["HF_ENDPOINT"] == "x"
|
|
assert cfg.user_code_path == "/private/user/code"
|
|
assert cfg.verl_code_path == "/private/common/code/verl/verl_repo"
|
|
|
|
public = cfg.to_public_dict()
|
|
assert public["runtime_env"]["env_vars"]["HF_ENDPOINT"] == "x"
|
|
|
|
|
|
def test_ray_config_from_dict_requires_mappings():
|
|
from argus.ray.models import RayConfig
|
|
|
|
with pytest.raises(ValueError, match="runtime_env\\.env_vars must be a mapping"):
|
|
RayConfig.from_dict(
|
|
{
|
|
"address": "x",
|
|
"shared_root": "/p",
|
|
"entrypoint_resources": {},
|
|
"runtime_env": {"env_vars": ["nope"]},
|
|
}
|
|
)
|
|
with pytest.raises(ValueError, match="entrypoint_resources must be a mapping"):
|
|
RayConfig.from_dict(
|
|
{
|
|
"address": "x",
|
|
"shared_root": "/p",
|
|
"entrypoint_resources": ["nope"],
|
|
}
|
|
)
|
|
|
|
|
|
def test_jobspec_validation_and_null_coercion():
|
|
from argus.ray.models import JobSpec
|
|
|
|
spec = JobSpec.from_dict(
|
|
{
|
|
"workload": "ppo",
|
|
"code_path": "/code",
|
|
"model_id": "m",
|
|
"train_file": "train.jsonl",
|
|
"val_file": "val.jsonl",
|
|
"test_freq": "",
|
|
}
|
|
)
|
|
assert spec.workload == "ppo"
|
|
assert spec.val_file == "val.jsonl"
|
|
assert spec.test_freq is None
|
|
assert spec.total_training_steps is None
|
|
assert spec.nnodes == 2
|
|
assert spec.n_gpus_per_node == 4
|
|
|
|
pub = spec.to_public_dict()
|
|
assert pub["submission_id"] == ""
|
|
assert "trainer_device" not in pub
|
|
|
|
|
|
def test_jobspec_sft_adds_trainer_device_default():
|
|
from argus.ray.models import JobSpec
|
|
|
|
spec = JobSpec.from_dict(
|
|
{
|
|
"workload": "sft",
|
|
"code_path": "/code",
|
|
"model_id": "m",
|
|
"train_file": "train.jsonl",
|
|
"val_file": "val.jsonl",
|
|
}
|
|
)
|
|
pub = spec.to_public_dict()
|
|
assert pub["trainer_device"] == "cpu"
|
|
|
|
|
|
def test_jobspec_unsupported_workload():
|
|
from argus.ray.models import JobSpec
|
|
|
|
with pytest.raises(ValueError, match="unsupported workload"):
|
|
JobSpec.from_dict(
|
|
{"workload": "nope", "code_path": "x", "model_id": "m", "train_file": "t", "val_file": "v"}
|
|
)
|
|
|
|
|
|
def test_parse_taskspec_basic_no_kind_compat():
|
|
from argus.ray.models import JobSpec, parse_taskspec
|
|
|
|
got = parse_taskspec(
|
|
{
|
|
"workload": "ppo",
|
|
"code_path": "/code",
|
|
"model_id": "m",
|
|
"train_file": "train.jsonl",
|
|
"val_file": "val.jsonl",
|
|
}
|
|
)
|
|
assert isinstance(got, JobSpec)
|
|
assert got.workload == "ppo"
|
|
|
|
|
|
def test_parse_taskspec_advanced_smoke():
|
|
from argus.ray.models import AdvancedTaskSpec, parse_taskspec
|
|
|
|
got = parse_taskspec(
|
|
{
|
|
"kind": "advanced",
|
|
"nnodes": 2,
|
|
"n_gpus_per_node": 4,
|
|
"command": "python3 -m verl.trainer.main_ppo +ray_kwargs.ray_init.address=auto",
|
|
}
|
|
)
|
|
assert isinstance(got, AdvancedTaskSpec)
|
|
assert got.kind == "advanced"
|
|
assert got.workload == "advanced"
|
|
assert got.nnodes == 2
|
|
assert got.n_gpus_per_node == 4
|
|
|
|
|
|
def test_parse_taskspec_advanced_requires_command_nnodes_gpus():
|
|
from argus.ray.models import parse_taskspec
|
|
|
|
with pytest.raises(ValueError, match="missing required field: command"):
|
|
parse_taskspec({"kind": "advanced", "nnodes": 1, "n_gpus_per_node": 1})
|
|
|
|
with pytest.raises(ValueError, match="missing required field: nnodes"):
|
|
parse_taskspec({"kind": "advanced", "n_gpus_per_node": 1, "command": "python3 -m verl.trainer.main_ppo"})
|
|
|
|
with pytest.raises(ValueError, match="missing required field: n_gpus_per_node"):
|
|
parse_taskspec({"kind": "advanced", "nnodes": 1, "command": "python3 -m verl.trainer.main_ppo"})
|