from __future__ import annotations from dataclasses import dataclass from typing import Any def _require(d: dict[str, Any], key: str) -> Any: if key not in d or d[key] in (None, ""): raise ValueError(f"missing required field: {key}") return d[key] @dataclass(frozen=True) class RayConfig: address: str shared_root: str entrypoint_num_cpus: float entrypoint_resources: dict[str, float] runtime_env_env_vars: dict[str, str] user_code_path: str @staticmethod def from_dict(d: dict[str, Any]) -> "RayConfig": runtime_env = d.get("runtime_env") or {} env_vars = (runtime_env.get("env_vars") or {}) if isinstance(runtime_env, dict) else {} if not isinstance(env_vars, dict): raise ValueError("runtime_env.env_vars must be a mapping") entrypoint_resources = d.get("entrypoint_resources") or {} if not isinstance(entrypoint_resources, dict): raise ValueError("entrypoint_resources must be a mapping") return RayConfig( address=str(_require(d, "address")), shared_root=str(_require(d, "shared_root")), entrypoint_num_cpus=float(d.get("entrypoint_num_cpus", 1)), entrypoint_resources={str(k): float(v) for k, v in entrypoint_resources.items()}, runtime_env_env_vars={str(k): str(v) for k, v in env_vars.items()}, user_code_path=str(d.get("user_code_path", f"{_require(d, 'shared_root')}/user/code")), ) def to_public_dict(self) -> dict[str, Any]: return { "address": self.address, "shared_root": self.shared_root, "entrypoint_num_cpus": self.entrypoint_num_cpus, "entrypoint_resources": self.entrypoint_resources, "runtime_env": {"env_vars": self.runtime_env_env_vars}, "user_code_path": self.user_code_path, } @dataclass(frozen=True) class JobSpec: workload: str # ppo|grpo|sft submission_id: str | None code_path: str model_id: str train_file: str val_file: str | None nnodes: int n_gpus_per_node: int total_epochs: int total_training_steps: int save_freq: int test_freq: int | None trainer_device: str | None # only for sft (driver-side device) @staticmethod def from_dict(d: dict[str, Any]) -> "JobSpec": workload = str(_require(d, "workload")) if workload not in ("ppo", "grpo", "sft"): raise ValueError(f"unsupported workload: {workload}") val_file = d.get("val_file", None) if val_file in ("", "null"): val_file = None test_freq = d.get("test_freq", None) if test_freq in ("", "null"): test_freq = None return JobSpec( workload=workload, submission_id=(str(d["submission_id"]) if d.get("submission_id") else None), code_path=str(_require(d, "code_path")), model_id=str(_require(d, "model_id")), train_file=str(_require(d, "train_file")), val_file=(str(val_file) if val_file is not None else None), nnodes=int(d.get("nnodes", 2)), n_gpus_per_node=int(d.get("n_gpus_per_node", 4)), total_epochs=int(d.get("total_epochs", 1)), total_training_steps=int(d.get("total_training_steps", 10)), save_freq=int(d.get("save_freq", 10)), test_freq=(int(test_freq) if test_freq is not None else None), trainer_device=(str(d.get("trainer_device")) if d.get("trainer_device") else None), ) def to_public_dict(self) -> dict[str, Any]: out: dict[str, Any] = { "workload": self.workload, "submission_id": self.submission_id or "", "code_path": self.code_path, "model_id": self.model_id, "train_file": self.train_file, "val_file": self.val_file, "nnodes": self.nnodes, "n_gpus_per_node": self.n_gpus_per_node, "total_epochs": self.total_epochs, "total_training_steps": self.total_training_steps, "save_freq": self.save_freq, "test_freq": self.test_freq, } if self.workload == "sft": out["trainer_device"] = self.trainer_device or "cpu" return out