from __future__ import annotations import pytest def test_require_missing_raises(): from argus.ray.models import _require with pytest.raises(ValueError, match="missing required field: x"): _require({}, "x") with pytest.raises(ValueError, match="missing required field: x"): _require({"x": ""}, "x") def test_ray_config_from_dict_new_format_and_defaults(): from argus.ray.models import RayConfig cfg = RayConfig.from_dict( { "ray": { "address": "http://127.0.0.1:8265", "shared_root": "/private", "entrypoint_resources": {"worker_node": 1}, "runtime_env": {"env_vars": {"HF_ENDPOINT": "x"}}, } } ) assert cfg.address.endswith("8265") assert cfg.shared_root == "/private" assert cfg.entrypoint_num_cpus == 1.0 assert cfg.entrypoint_resources["worker_node"] == 1.0 assert cfg.runtime_env_env_vars["HF_ENDPOINT"] == "x" assert cfg.user_code_path == "/private/user/code" public = cfg.to_public_dict() assert public["runtime_env"]["env_vars"]["HF_ENDPOINT"] == "x" def test_ray_config_from_dict_requires_mappings(): from argus.ray.models import RayConfig with pytest.raises(ValueError, match="runtime_env\\.env_vars must be a mapping"): RayConfig.from_dict( { "address": "x", "shared_root": "/p", "entrypoint_resources": {}, "runtime_env": {"env_vars": ["nope"]}, } ) with pytest.raises(ValueError, match="entrypoint_resources must be a mapping"): RayConfig.from_dict( { "address": "x", "shared_root": "/p", "entrypoint_resources": ["nope"], } ) def test_jobspec_validation_and_null_coercion(): from argus.ray.models import JobSpec spec = JobSpec.from_dict( { "workload": "ppo", "code_path": "/code", "model_id": "m", "train_file": "train.jsonl", "val_file": "val.jsonl", "test_freq": "", } ) assert spec.workload == "ppo" assert spec.val_file == "val.jsonl" assert spec.test_freq is None assert spec.total_training_steps is None assert spec.nnodes == 2 assert spec.n_gpus_per_node == 4 pub = spec.to_public_dict() assert pub["submission_id"] == "" assert "trainer_device" not in pub def test_jobspec_sft_adds_trainer_device_default(): from argus.ray.models import JobSpec spec = JobSpec.from_dict( { "workload": "sft", "code_path": "/code", "model_id": "m", "train_file": "train.jsonl", "val_file": "val.jsonl", } ) pub = spec.to_public_dict() assert pub["trainer_device"] == "cpu" def test_jobspec_unsupported_workload(): from argus.ray.models import JobSpec with pytest.raises(ValueError, match="unsupported workload"): JobSpec.from_dict( {"workload": "nope", "code_path": "x", "model_id": "m", "train_file": "t", "val_file": "v"} )