2025-12-23 14:22:15 +08:00

122 lines
4.2 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any
def _require(d: dict[str, Any], key: str) -> Any:
if key not in d or d[key] in (None, ""):
raise ValueError(f"missing required field: {key}")
return d[key]
@dataclass(frozen=True)
class RayConfig:
address: str
shared_root: str
entrypoint_num_cpus: float
entrypoint_resources: dict[str, float]
runtime_env_env_vars: dict[str, str]
user_code_path: str
@staticmethod
def from_dict(d: dict[str, Any]) -> "RayConfig":
runtime_env = d.get("runtime_env") or {}
env_vars = (runtime_env.get("env_vars") or {}) if isinstance(runtime_env, dict) else {}
if not isinstance(env_vars, dict):
raise ValueError("runtime_env.env_vars must be a mapping")
entrypoint_resources = d.get("entrypoint_resources") or {}
if not isinstance(entrypoint_resources, dict):
raise ValueError("entrypoint_resources must be a mapping")
return RayConfig(
address=str(_require(d, "address")),
shared_root=str(_require(d, "shared_root")),
entrypoint_num_cpus=float(d.get("entrypoint_num_cpus", 1)),
entrypoint_resources={str(k): float(v) for k, v in entrypoint_resources.items()},
runtime_env_env_vars={str(k): str(v) for k, v in env_vars.items()},
user_code_path=str(d.get("user_code_path", f"{_require(d, 'shared_root')}/user/code")),
)
def to_public_dict(self) -> dict[str, Any]:
return {
"address": self.address,
"shared_root": self.shared_root,
"entrypoint_num_cpus": self.entrypoint_num_cpus,
"entrypoint_resources": self.entrypoint_resources,
"runtime_env": {"env_vars": self.runtime_env_env_vars},
"user_code_path": self.user_code_path,
}
@dataclass(frozen=True)
class JobSpec:
workload: str # ppo|grpo|sft
submission_id: str | None
code_path: str
model_id: str
train_file: str
val_file: str | None
nnodes: int
n_gpus_per_node: int
total_epochs: int
total_training_steps: int
save_freq: int
test_freq: int | None
trainer_device: str | None # only for sft (driver-side device)
@staticmethod
def from_dict(d: dict[str, Any]) -> "JobSpec":
workload = str(_require(d, "workload"))
if workload not in ("ppo", "grpo", "sft"):
raise ValueError(f"unsupported workload: {workload}")
val_file = d.get("val_file", None)
if val_file in ("", "null"):
val_file = None
test_freq = d.get("test_freq", None)
if test_freq in ("", "null"):
test_freq = None
return JobSpec(
workload=workload,
submission_id=(str(d["submission_id"]) if d.get("submission_id") else None),
code_path=str(_require(d, "code_path")),
model_id=str(_require(d, "model_id")),
train_file=str(_require(d, "train_file")),
val_file=(str(val_file) if val_file is not None else None),
nnodes=int(d.get("nnodes", 2)),
n_gpus_per_node=int(d.get("n_gpus_per_node", 4)),
total_epochs=int(d.get("total_epochs", 1)),
total_training_steps=int(d.get("total_training_steps", 10)),
save_freq=int(d.get("save_freq", 10)),
test_freq=(int(test_freq) if test_freq is not None else None),
trainer_device=(str(d.get("trainer_device")) if d.get("trainer_device") else None),
)
def to_public_dict(self) -> dict[str, Any]:
out: dict[str, Any] = {
"workload": self.workload,
"submission_id": self.submission_id or "",
"code_path": self.code_path,
"model_id": self.model_id,
"train_file": self.train_file,
"val_file": self.val_file,
"nnodes": self.nnodes,
"n_gpus_per_node": self.n_gpus_per_node,
"total_epochs": self.total_epochs,
"total_training_steps": self.total_training_steps,
"save_freq": self.save_freq,
"test_freq": self.test_freq,
}
if self.workload == "sft":
out["trainer_device"] = self.trainer_device or "cpu"
return out