122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
|
|
def _require(d: dict[str, Any], key: str) -> Any:
|
|
if key not in d or d[key] in (None, ""):
|
|
raise ValueError(f"missing required field: {key}")
|
|
return d[key]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RayConfig:
|
|
address: str
|
|
shared_root: str
|
|
entrypoint_num_cpus: float
|
|
entrypoint_resources: dict[str, float]
|
|
runtime_env_env_vars: dict[str, str]
|
|
user_code_path: str
|
|
|
|
@staticmethod
|
|
def from_dict(d: dict[str, Any]) -> "RayConfig":
|
|
runtime_env = d.get("runtime_env") or {}
|
|
env_vars = (runtime_env.get("env_vars") or {}) if isinstance(runtime_env, dict) else {}
|
|
if not isinstance(env_vars, dict):
|
|
raise ValueError("runtime_env.env_vars must be a mapping")
|
|
|
|
entrypoint_resources = d.get("entrypoint_resources") or {}
|
|
if not isinstance(entrypoint_resources, dict):
|
|
raise ValueError("entrypoint_resources must be a mapping")
|
|
|
|
return RayConfig(
|
|
address=str(_require(d, "address")),
|
|
shared_root=str(_require(d, "shared_root")),
|
|
entrypoint_num_cpus=float(d.get("entrypoint_num_cpus", 1)),
|
|
entrypoint_resources={str(k): float(v) for k, v in entrypoint_resources.items()},
|
|
runtime_env_env_vars={str(k): str(v) for k, v in env_vars.items()},
|
|
user_code_path=str(d.get("user_code_path", f"{_require(d, 'shared_root')}/user/code")),
|
|
)
|
|
|
|
def to_public_dict(self) -> dict[str, Any]:
|
|
return {
|
|
"address": self.address,
|
|
"shared_root": self.shared_root,
|
|
"entrypoint_num_cpus": self.entrypoint_num_cpus,
|
|
"entrypoint_resources": self.entrypoint_resources,
|
|
"runtime_env": {"env_vars": self.runtime_env_env_vars},
|
|
"user_code_path": self.user_code_path,
|
|
}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class JobSpec:
|
|
workload: str # ppo|grpo|sft
|
|
submission_id: str | None
|
|
code_path: str
|
|
model_id: str
|
|
|
|
train_file: str
|
|
val_file: str | None
|
|
|
|
nnodes: int
|
|
n_gpus_per_node: int
|
|
|
|
total_epochs: int
|
|
total_training_steps: int
|
|
|
|
save_freq: int
|
|
test_freq: int | None
|
|
|
|
trainer_device: str | None # only for sft (driver-side device)
|
|
|
|
@staticmethod
|
|
def from_dict(d: dict[str, Any]) -> "JobSpec":
|
|
workload = str(_require(d, "workload"))
|
|
if workload not in ("ppo", "grpo", "sft"):
|
|
raise ValueError(f"unsupported workload: {workload}")
|
|
|
|
val_file = d.get("val_file", None)
|
|
if val_file in ("", "null"):
|
|
val_file = None
|
|
|
|
test_freq = d.get("test_freq", None)
|
|
if test_freq in ("", "null"):
|
|
test_freq = None
|
|
|
|
return JobSpec(
|
|
workload=workload,
|
|
submission_id=(str(d["submission_id"]) if d.get("submission_id") else None),
|
|
code_path=str(_require(d, "code_path")),
|
|
model_id=str(_require(d, "model_id")),
|
|
train_file=str(_require(d, "train_file")),
|
|
val_file=(str(val_file) if val_file is not None else None),
|
|
nnodes=int(d.get("nnodes", 2)),
|
|
n_gpus_per_node=int(d.get("n_gpus_per_node", 4)),
|
|
total_epochs=int(d.get("total_epochs", 1)),
|
|
total_training_steps=int(d.get("total_training_steps", 10)),
|
|
save_freq=int(d.get("save_freq", 10)),
|
|
test_freq=(int(test_freq) if test_freq is not None else None),
|
|
trainer_device=(str(d.get("trainer_device")) if d.get("trainer_device") else None),
|
|
)
|
|
|
|
def to_public_dict(self) -> dict[str, Any]:
|
|
out: dict[str, Any] = {
|
|
"workload": self.workload,
|
|
"submission_id": self.submission_id or "",
|
|
"code_path": self.code_path,
|
|
"model_id": self.model_id,
|
|
"train_file": self.train_file,
|
|
"val_file": self.val_file,
|
|
"nnodes": self.nnodes,
|
|
"n_gpus_per_node": self.n_gpus_per_node,
|
|
"total_epochs": self.total_epochs,
|
|
"total_training_steps": self.total_training_steps,
|
|
"save_freq": self.save_freq,
|
|
"test_freq": self.test_freq,
|
|
}
|
|
if self.workload == "sft":
|
|
out["trainer_device"] = self.trainer_device or "cpu"
|
|
return out
|