V3.0 优化webui界面
This commit is contained in:
parent
6d3fefc7a6
commit
e3dcfe526f
@ -21,14 +21,16 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
|
||||
algo_overrides.append("algorithm.adv_estimator=grpo")
|
||||
|
||||
test_freq = spec.test_freq if spec.test_freq is not None else -1
|
||||
val_file = spec.val_file if spec.val_file is not None else "null"
|
||||
maybe_total_steps: list[str] = []
|
||||
if spec.total_training_steps is not None:
|
||||
maybe_total_steps.append(f"trainer.total_training_steps={spec.total_training_steps}")
|
||||
|
||||
argv = [
|
||||
"python3",
|
||||
"-m",
|
||||
"verl.trainer.main_ppo",
|
||||
f"data.train_files={spec.train_file}",
|
||||
f"data.val_files={val_file}",
|
||||
f"data.val_files={spec.val_file}",
|
||||
"data.train_batch_size=256",
|
||||
"data.max_prompt_length=512",
|
||||
"data.max_response_length=512",
|
||||
@ -53,7 +55,7 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
|
||||
f"trainer.save_freq={spec.save_freq}",
|
||||
f"trainer.test_freq={test_freq}",
|
||||
f"trainer.total_epochs={spec.total_epochs}",
|
||||
f"trainer.total_training_steps={spec.total_training_steps}",
|
||||
*maybe_total_steps,
|
||||
"trainer.resume_mode=disable",
|
||||
f"trainer.default_local_dir={job_dir}/checkpoints",
|
||||
"+ray_kwargs.ray_init.address=auto",
|
||||
@ -62,8 +64,10 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
|
||||
return BuiltCommand(argv=argv)
|
||||
|
||||
if spec.workload == "sft":
|
||||
val_override = "null" if spec.val_file is None else spec.val_file
|
||||
trainer_device = spec.trainer_device or "cpu"
|
||||
maybe_total_steps: list[str] = []
|
||||
if spec.total_training_steps is not None:
|
||||
maybe_total_steps.append(f"trainer.total_training_steps={spec.total_training_steps}")
|
||||
|
||||
argv = [
|
||||
"python3",
|
||||
@ -71,7 +75,7 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
|
||||
"verl.trainer.sft_trainer_ray",
|
||||
f"model.path={spec.model_id}",
|
||||
f"data.train_files={spec.train_file}",
|
||||
f"data.val_files={val_override}",
|
||||
f"data.val_files={spec.val_file}",
|
||||
"data.train_batch_size=64",
|
||||
"data.micro_batch_size_per_gpu=1",
|
||||
"data.max_token_len_per_gpu=2048",
|
||||
@ -80,7 +84,7 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
|
||||
"trainer.project_name=mvp11-sft",
|
||||
f"trainer.experiment_name={submission_id}",
|
||||
f"trainer.total_epochs={spec.total_epochs}",
|
||||
f"trainer.total_training_steps={spec.total_training_steps}",
|
||||
*maybe_total_steps,
|
||||
f"trainer.save_freq={spec.save_freq}",
|
||||
"trainer.test_freq=-1",
|
||||
"trainer.resume_mode=disable",
|
||||
@ -93,4 +97,3 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
|
||||
return BuiltCommand(argv=argv)
|
||||
|
||||
raise ValueError(f"unsupported workload: {spec.workload}")
|
||||
|
||||
|
||||
@ -63,13 +63,13 @@ class JobSpec:
|
||||
model_id: str
|
||||
|
||||
train_file: str
|
||||
val_file: str | None
|
||||
val_file: str
|
||||
|
||||
nnodes: int
|
||||
n_gpus_per_node: int
|
||||
|
||||
total_epochs: int
|
||||
total_training_steps: int
|
||||
total_training_steps: int | None
|
||||
|
||||
save_freq: int
|
||||
test_freq: int | None
|
||||
@ -82,25 +82,27 @@ class JobSpec:
|
||||
if workload not in ("ppo", "grpo", "sft"):
|
||||
raise ValueError(f"unsupported workload: {workload}")
|
||||
|
||||
val_file = d.get("val_file", None)
|
||||
if val_file in ("", "null"):
|
||||
val_file = None
|
||||
|
||||
test_freq = d.get("test_freq", None)
|
||||
if test_freq in ("", "null"):
|
||||
test_freq = None
|
||||
|
||||
total_training_steps = d.get("total_training_steps", None)
|
||||
if total_training_steps in ("", "null"):
|
||||
total_training_steps = None
|
||||
|
||||
return JobSpec(
|
||||
workload=workload,
|
||||
submission_id=(str(d["submission_id"]) if d.get("submission_id") else None),
|
||||
code_path=str(_require(d, "code_path")),
|
||||
model_id=str(_require(d, "model_id")),
|
||||
train_file=str(_require(d, "train_file")),
|
||||
val_file=(str(val_file) if val_file is not None else None),
|
||||
val_file=str(_require(d, "val_file")),
|
||||
nnodes=int(d.get("nnodes", 2)),
|
||||
n_gpus_per_node=int(d.get("n_gpus_per_node", 4)),
|
||||
total_epochs=int(d.get("total_epochs", 1)),
|
||||
total_training_steps=int(d.get("total_training_steps", 10)),
|
||||
# If omitted or explicitly null: do not override trainer.total_training_steps
|
||||
# (let VERL derive it from epochs and dataset length).
|
||||
total_training_steps=(int(total_training_steps) if total_training_steps is not None else None),
|
||||
save_freq=int(d.get("save_freq", 10)),
|
||||
test_freq=(int(test_freq) if test_freq is not None else None),
|
||||
trainer_device=(str(d.get("trainer_device")) if d.get("trainer_device") else None),
|
||||
|
||||
@ -408,6 +408,38 @@ def create_app(config_path: str) -> FastAPI:
|
||||
}
|
||||
return out
|
||||
|
||||
@app.get("/api/v2/tasks/{task_id}/spec")
|
||||
async def get_task_spec(task_id: str, req: Request) -> Response:
|
||||
"""
|
||||
Returns the TaskSpec YAML for this task.
|
||||
|
||||
Note: We prefer returning a "resolved" TaskSpec (includes default values) so the
|
||||
UI can show a complete spec even if the user submitted a minimal YAML.
|
||||
"""
|
||||
subject = _auth(req)
|
||||
row = db.get_task(task_id)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="task not found")
|
||||
if not subject.get("is_admin"):
|
||||
if str(row.get("user_id") or "") != str(subject["user_id"]):
|
||||
raise HTTPException(status_code=404, detail="task not found")
|
||||
raw = str(row.get("jobspec_yaml") or "")
|
||||
try:
|
||||
obj = yaml.safe_load(raw) or {}
|
||||
if not isinstance(obj, dict):
|
||||
raise ValueError("jobspec must be a YAML mapping")
|
||||
spec = JobSpec.from_dict(obj)
|
||||
resolved = spec.to_public_dict()
|
||||
attempts = db.list_attempts(task_id)
|
||||
if attempts:
|
||||
resolved["submission_id"] = str(attempts[-1].get("ray_submission_id") or resolved.get("submission_id") or "")
|
||||
text = yaml.safe_dump(resolved, sort_keys=False)
|
||||
except Exception:
|
||||
# Backward compatibility / best-effort: if old rows contain invalid YAML,
|
||||
# fall back to the original content.
|
||||
text = raw
|
||||
return Response(content=text, media_type="text/plain")
|
||||
|
||||
@app.get("/api/v2/tasks/{task_id}/attempts")
|
||||
async def get_attempts(task_id: str, req: Request) -> dict[str, Any]:
|
||||
subject = _auth(req)
|
||||
|
||||
@ -279,31 +279,56 @@ refresh();
|
||||
@app.get("/ui/tasks/new")
|
||||
async def ui_new_task() -> HTMLResponse:
|
||||
ppo = """# PPO TaskSpec (YAML)
|
||||
workload: ppo
|
||||
nnodes: 2
|
||||
n_gpus_per_node: 4
|
||||
code_path: /private/common/code/verl/verl_repo
|
||||
train_file: /private/common/datasets/gsm8k/train.parquet
|
||||
val_file: /private/common/datasets/gsm8k/test.parquet
|
||||
model_id: Qwen/Qwen2.5-0.5B-Instruct
|
||||
workload: ppo # 任务类型(必填):ppo|grpo|sft
|
||||
code_path: /private/common/code/verl/verl_repo # 代码路径(必填):v3.0 固定使用 common 下的 verl 快照(不支持用户自定义代码)
|
||||
model_id: Qwen/Qwen2.5-0.5B-Instruct # 基础模型(必填):HuggingFace 模型 ID 或 /private/... 本地模型路径
|
||||
train_file: /private/common/datasets/gsm8k/train.parquet # 训练数据(必填):parquet 文件路径(支持 /private/common/datasets 或 /private/users/<user>/datasets)
|
||||
val_file: /private/common/datasets/gsm8k/test.parquet # 验证数据(必填):parquet 文件路径(VERL 侧会用来构建 val dataset,不能为 null)
|
||||
|
||||
# nnodes: 2 # 训练节点数(可选,默认:2)
|
||||
# n_gpus_per_node: 4 # 每节点 GPU 数(可选,默认:4)
|
||||
|
||||
# total_epochs: 1 # 总训练 epoch(可选,默认:1)
|
||||
# total_training_steps: null # 总训练 step(可选,默认:null;不传则让 VERL 按 epochs 和数据长度自动推导)
|
||||
# save_freq: 10 # checkpoint 保存频率(step)(可选,默认:10)
|
||||
# test_freq: null # 验证频率(step)(可选,默认:null;训练端会当成 -1=不验证)
|
||||
|
||||
# submission_id: "" # Ray submission_id(可选,默认空;通常由服务自动生成,无需填写)
|
||||
""".strip()
|
||||
grpo = """# GRPO TaskSpec (YAML)
|
||||
workload: grpo
|
||||
nnodes: 2
|
||||
n_gpus_per_node: 4
|
||||
code_path: /private/common/code/verl/verl_repo
|
||||
train_file: /private/common/datasets/gsm8k/train.parquet
|
||||
val_file: /private/common/datasets/gsm8k/test.parquet
|
||||
model_id: Qwen/Qwen2.5-0.5B-Instruct
|
||||
workload: grpo # 任务类型(必填):ppo|grpo|sft(grpo 会自动启用对应的算法配置)
|
||||
code_path: /private/common/code/verl/verl_repo # 代码路径(必填):v3.0 固定使用 common 下的 verl 快照(不支持用户自定义代码)
|
||||
model_id: Qwen/Qwen2.5-0.5B-Instruct # 基础模型(必填):HuggingFace 模型 ID 或 /private/... 本地模型路径
|
||||
train_file: /private/common/datasets/gsm8k/train.parquet # 训练数据(必填):parquet 文件路径(支持 /private/common/datasets 或 /private/users/<user>/datasets)
|
||||
val_file: /private/common/datasets/gsm8k/test.parquet # 验证数据(必填):parquet 文件路径(VERL 侧会用来构建 val dataset,不能为 null)
|
||||
|
||||
# nnodes: 2 # 训练节点数(可选,默认:2)
|
||||
# n_gpus_per_node: 4 # 每节点 GPU 数(可选,默认:4)
|
||||
|
||||
# total_epochs: 1 # 总训练 epoch(可选,默认:1)
|
||||
# total_training_steps: null # 总训练 step(可选,默认:null;不传则让 VERL 按 epochs 和数据长度自动推导)
|
||||
# save_freq: 10 # checkpoint 保存频率(step)(可选,默认:10)
|
||||
# test_freq: null # 验证频率(step)(可选,默认:null;训练端会当成 -1=不验证)
|
||||
|
||||
# submission_id: "" # Ray submission_id(可选,默认空;通常由服务自动生成,无需填写)
|
||||
""".strip()
|
||||
sft = """# SFT TaskSpec (YAML)
|
||||
workload: sft
|
||||
nnodes: 1
|
||||
n_gpus_per_node: 1
|
||||
code_path: /private/common/code/verl/verl_repo
|
||||
train_file: /private/common/datasets/gsm8k_sft/train.parquet
|
||||
val_file: /private/common/datasets/gsm8k_sft/test.parquet
|
||||
model_id: Qwen/Qwen2.5-0.5B-Instruct
|
||||
workload: sft # 任务类型(必填):ppo|grpo|sft
|
||||
code_path: /private/common/code/verl/verl_repo # 代码路径(必填):v3.0 固定使用 common 下的 verl 快照(不支持用户自定义代码)
|
||||
model_id: Qwen/Qwen2.5-0.5B-Instruct # 基础模型(必填):HuggingFace 模型 ID 或 /private/... 本地模型路径
|
||||
train_file: /private/common/datasets/gsm8k_sft/train.parquet # 训练数据(必填):parquet 文件路径(支持 /private/common/datasets 或 /private/users/<user>/datasets)
|
||||
val_file: /private/common/datasets/gsm8k_sft/test.parquet # 验证数据(必填):parquet 文件路径(VERL 侧会用来构建 val dataset,不能为 null)
|
||||
|
||||
# nnodes: 2 # 训练节点数(可选,默认:2;单机可设 1)
|
||||
# n_gpus_per_node: 4 # 每节点 GPU 数(可选,默认:4;单卡可设 1)
|
||||
|
||||
# total_epochs: 1 # 总训练 epoch(可选,默认:1)
|
||||
# total_training_steps: null # 总训练 step(可选,默认:null;不传则让 VERL 按 epochs 和数据长度自动推导)
|
||||
# save_freq: 10 # checkpoint 保存频率(step)(可选,默认:10)
|
||||
# test_freq: null # 验证频率(step)(可选,默认:null;训练端会当成 -1=不验证)
|
||||
|
||||
# trainer_device: cpu # 仅 SFT 生效:driver 侧 device(可选,默认:cpu)
|
||||
# submission_id: "" # Ray submission_id(可选,默认空;通常由服务自动生成,无需填写)
|
||||
""".strip()
|
||||
body = f"""
|
||||
<h1>New Task</h1>
|
||||
@ -371,16 +396,29 @@ document.getElementById("submit").onclick = async () => {
|
||||
<div style="height:10px"></div>
|
||||
<pre id="out" class="muted">Loading...</pre>
|
||||
</div>
|
||||
<div style="height:12px"></div>
|
||||
<div class="card">
|
||||
<h3 style="margin-top:0">TaskSpec (YAML)</h3>
|
||||
<div class="muted">Resolved TaskSpec (includes default values; submission_id reflects latest attempt when available).</div>
|
||||
<div style="height:10px"></div>
|
||||
<pre id="spec" class="muted">Loading...</pre>
|
||||
</div>
|
||||
""".strip()
|
||||
script = f"""
|
||||
document.getElementById("nav-ray-dashboard").href = curOriginWithPort(8265);
|
||||
const out = document.getElementById("out");
|
||||
const spec = document.getElementById("spec");
|
||||
async function refresh() {{
|
||||
out.textContent = "Loading...";
|
||||
spec.textContent = "Loading...";
|
||||
const resp = await apiFetch("/api/v2/tasks/{task_id}");
|
||||
const text = await resp.text();
|
||||
if (!resp.ok) {{ out.textContent = "Error: " + resp.status + "\\n" + text; return; }}
|
||||
out.textContent = fmtJson(JSON.parse(text));
|
||||
|
||||
const resp2 = await apiFetch("/api/v2/tasks/{task_id}/spec");
|
||||
const text2 = await resp2.text();
|
||||
spec.textContent = resp2.ok ? text2 : ("Error: " + resp2.status + "\\n" + text2);
|
||||
}}
|
||||
document.getElementById("refresh").onclick = refresh;
|
||||
document.getElementById("cancel").onclick = async () => {{
|
||||
|
||||
@ -87,7 +87,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
|
||||
r = c.post(
|
||||
"/api/v2/tasks",
|
||||
headers=headers,
|
||||
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
|
||||
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["task_id"] == "tid1"
|
||||
@ -95,6 +95,27 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
|
||||
r2 = c.get("/api/v2/tasks/tid1", headers=headers)
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["desired_resources"]["total_gpus"] == 8
|
||||
r2s = c.get("/api/v2/tasks/tid1/spec", headers=headers)
|
||||
assert r2s.status_code == 200
|
||||
assert "workload: ppo" in r2s.text
|
||||
assert "code_path:" in r2s.text
|
||||
# Spec endpoint returns a resolved TaskSpec, i.e. includes default values.
|
||||
assert "nnodes: 2" in r2s.text
|
||||
assert "n_gpus_per_node: 4" in r2s.text
|
||||
|
||||
# Seed an attempt and ensure submission_id is reflected in the resolved spec.
|
||||
from argus.service.db import Db
|
||||
from argus.service.config import V2Config
|
||||
from argus.ray.models import RayConfig
|
||||
|
||||
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
|
||||
ray_cfg = RayConfig.from_dict(root)
|
||||
v2_cfg = V2Config.from_root_dict(root)
|
||||
db = Db(v2_cfg.sqlite.db_path)
|
||||
db.create_attempt(task_id="tid1", attempt_no=1, ray_submission_id="sid1")
|
||||
r2s2 = c.get("/api/v2/tasks/tid1/spec", headers=headers)
|
||||
assert r2s2.status_code == 200
|
||||
assert "submission_id: sid1" in r2s2.text
|
||||
|
||||
r3 = c.get("/api/v2/queue", headers=headers)
|
||||
assert r3.status_code == 200
|
||||
@ -116,10 +137,6 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
|
||||
assert r4.json()["state"] == "CANCELED"
|
||||
|
||||
# Seed an attempt then fetch logs
|
||||
from argus.service.db import Db
|
||||
from argus.service.config import V2Config
|
||||
from argus.ray.models import RayConfig
|
||||
|
||||
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
|
||||
ray_cfg = RayConfig.from_dict(root)
|
||||
v2_cfg = V2Config.from_root_dict(root)
|
||||
@ -127,7 +144,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
|
||||
db.create_task(
|
||||
task_id="tid2",
|
||||
workload="ppo",
|
||||
jobspec_yaml="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
|
||||
jobspec_yaml="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
|
||||
@ -6,18 +6,18 @@ from argus.ray.builders import build_training_argv
|
||||
from argus.ray.models import JobSpec
|
||||
|
||||
|
||||
def _base_spec(workload: str) -> JobSpec:
|
||||
def _base_spec(workload: str, *, total_training_steps: int | None = 10) -> JobSpec:
|
||||
return JobSpec(
|
||||
workload=workload,
|
||||
submission_id=None,
|
||||
code_path="/code",
|
||||
model_id="m",
|
||||
train_file="train.jsonl",
|
||||
val_file=None,
|
||||
val_file="val.jsonl",
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
total_epochs=1,
|
||||
total_training_steps=10,
|
||||
total_training_steps=total_training_steps,
|
||||
save_freq=10,
|
||||
test_freq=None,
|
||||
trainer_device=None,
|
||||
@ -28,7 +28,7 @@ def test_build_training_argv_ppo_smoke():
|
||||
spec = _base_spec("ppo")
|
||||
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
|
||||
assert built.argv[:3] == ["python3", "-m", "verl.trainer.main_ppo"]
|
||||
assert "data.val_files=null" in built.argv
|
||||
assert "data.val_files=val.jsonl" in built.argv
|
||||
assert "trainer.test_freq=-1" in built.argv
|
||||
|
||||
|
||||
@ -43,11 +43,16 @@ def test_build_training_argv_sft_smoke():
|
||||
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
|
||||
assert built.argv[:3] == ["python3", "-m", "verl.trainer.sft_trainer_ray"]
|
||||
assert "trainer.device=cpu" in built.argv
|
||||
assert "data.val_files=null" in built.argv
|
||||
assert "data.val_files=val.jsonl" in built.argv
|
||||
|
||||
|
||||
def test_build_training_argv_omits_total_training_steps_when_unset():
|
||||
spec = _base_spec("ppo", total_training_steps=None)
|
||||
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
|
||||
assert not any(a.startswith("trainer.total_training_steps=") for a in built.argv)
|
||||
|
||||
|
||||
def test_build_training_argv_unsupported_raises():
|
||||
spec = _base_spec("bad")
|
||||
with pytest.raises(ValueError, match="unsupported workload"):
|
||||
build_training_argv(spec, submission_id="sid", job_dir="/job")
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ def test_cli_submit_status_logs_list(monkeypatch, tmp_path: Path, capsys):
|
||||
encoding="utf-8",
|
||||
)
|
||||
spec = tmp_path / "spec.yaml"
|
||||
spec.write_text("workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n", encoding="utf-8")
|
||||
spec.write_text("workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\nval_file: v\n", encoding="utf-8")
|
||||
|
||||
from argus.cli.run import main
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ def test_db_lifecycle_and_basic_queries(tmp_path: Path):
|
||||
t = db.create_task(
|
||||
task_id="t1",
|
||||
workload="ppo",
|
||||
jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
|
||||
jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\nval_file: v\n",
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
|
||||
@ -67,13 +67,14 @@ def test_jobspec_validation_and_null_coercion():
|
||||
"code_path": "/code",
|
||||
"model_id": "m",
|
||||
"train_file": "train.jsonl",
|
||||
"val_file": "null",
|
||||
"val_file": "val.jsonl",
|
||||
"test_freq": "",
|
||||
}
|
||||
)
|
||||
assert spec.workload == "ppo"
|
||||
assert spec.val_file is None
|
||||
assert spec.val_file == "val.jsonl"
|
||||
assert spec.test_freq is None
|
||||
assert spec.total_training_steps is None
|
||||
assert spec.nnodes == 2
|
||||
assert spec.n_gpus_per_node == 4
|
||||
|
||||
@ -91,6 +92,7 @@ def test_jobspec_sft_adds_trainer_device_default():
|
||||
"code_path": "/code",
|
||||
"model_id": "m",
|
||||
"train_file": "train.jsonl",
|
||||
"val_file": "val.jsonl",
|
||||
}
|
||||
)
|
||||
pub = spec.to_public_dict()
|
||||
@ -102,6 +104,5 @@ def test_jobspec_unsupported_workload():
|
||||
|
||||
with pytest.raises(ValueError, match="unsupported workload"):
|
||||
JobSpec.from_dict(
|
||||
{"workload": "nope", "code_path": "x", "model_id": "m", "train_file": "t"}
|
||||
{"workload": "nope", "code_path": "x", "model_id": "m", "train_file": "t", "val_file": "v"}
|
||||
)
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ def test_runtime_env_sets_defaults_and_pythonpath(monkeypatch):
|
||||
}
|
||||
)
|
||||
spec = JobSpec.from_dict(
|
||||
{"workload": "sft", "code_path": "/c", "model_id": "m", "train_file": "t"}
|
||||
{"workload": "sft", "code_path": "/c", "model_id": "m", "train_file": "t", "val_file": "v"}
|
||||
)
|
||||
monkeypatch.setenv("MVP_TOOL_CODE_PATH", "/tool")
|
||||
|
||||
@ -106,6 +106,7 @@ def test_submit_writes_artifacts_and_returns_submission_id(tmp_path: Path, monke
|
||||
"code_path": "/code",
|
||||
"model_id": "m",
|
||||
"train_file": "train.jsonl",
|
||||
"val_file": "val.jsonl",
|
||||
}
|
||||
)
|
||||
|
||||
@ -151,7 +152,7 @@ def test_submit_error_writes_file_then_reraises(tmp_path: Path, monkeypatch):
|
||||
}
|
||||
)
|
||||
spec = JobSpec.from_dict(
|
||||
{"workload": "ppo", "submission_id": "sid2", "code_path": "/code", "model_id": "m", "train_file": "t"}
|
||||
{"workload": "ppo", "submission_id": "sid2", "code_path": "/code", "model_id": "m", "train_file": "t", "val_file": "v"}
|
||||
)
|
||||
|
||||
tool = mod.RayJobTool(cfg)
|
||||
|
||||
@ -36,7 +36,15 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
|
||||
task_id="t1",
|
||||
user_id="alice",
|
||||
workload="ppo",
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
|
||||
jobspec_yaml=yaml.safe_dump(
|
||||
{
|
||||
"workload": "ppo",
|
||||
"code_path": "/private/common/code/verl",
|
||||
"model_id": "m",
|
||||
"train_file": "/private/common/datasets/t",
|
||||
"val_file": "/private/common/datasets/v",
|
||||
}
|
||||
),
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
@ -87,7 +95,15 @@ def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
|
||||
task_id="t1",
|
||||
user_id="alice",
|
||||
workload="ppo",
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
|
||||
jobspec_yaml=yaml.safe_dump(
|
||||
{
|
||||
"workload": "ppo",
|
||||
"code_path": "/private/common/code/verl",
|
||||
"model_id": "m",
|
||||
"train_file": "/private/common/datasets/t",
|
||||
"val_file": "/private/common/datasets/v",
|
||||
}
|
||||
),
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
@ -117,7 +133,15 @@ def test_sync_failed_insufficient_resources(monkeypatch, tmp_path: Path):
|
||||
task_id="t1",
|
||||
user_id="alice",
|
||||
workload="ppo",
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
|
||||
jobspec_yaml=yaml.safe_dump(
|
||||
{
|
||||
"workload": "ppo",
|
||||
"code_path": "/private/common/code/verl",
|
||||
"model_id": "m",
|
||||
"train_file": "/private/common/datasets/t",
|
||||
"val_file": "/private/common/datasets/v",
|
||||
}
|
||||
),
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
|
||||
@ -76,3 +76,5 @@ def test_ui_task_detail_shows_ids(tmp_path, monkeypatch):
|
||||
assert r.status_code == 200
|
||||
assert task_id in r.text
|
||||
assert f"/ui/tasks/{task_id}/logs" in r.text
|
||||
assert "TaskSpec (YAML)" in r.text
|
||||
assert "/api/v2/tasks/" in r.text
|
||||
|
||||
@ -132,7 +132,7 @@ def test_tasks_are_isolated_by_user(tmp_path: Path, monkeypatch):
|
||||
r1 = c.post(
|
||||
"/api/v2/tasks",
|
||||
headers=alice_headers,
|
||||
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
|
||||
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
|
||||
)
|
||||
assert r1.status_code == 200
|
||||
alice_tid = r1.json()["task_id"]
|
||||
@ -140,7 +140,7 @@ def test_tasks_are_isolated_by_user(tmp_path: Path, monkeypatch):
|
||||
r2 = c.post(
|
||||
"/api/v2/tasks",
|
||||
headers=bob_headers,
|
||||
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
|
||||
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
|
||||
)
|
||||
assert r2.status_code == 200
|
||||
bob_tid = r2.json()["task_id"]
|
||||
@ -186,7 +186,7 @@ def test_submit_rejects_non_common_inputs(tmp_path: Path, monkeypatch):
|
||||
r = c.post(
|
||||
"/api/v2/tasks",
|
||||
headers=alice_headers,
|
||||
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
|
||||
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert "code_path must start with /private/common/" in r.text
|
||||
@ -223,6 +223,7 @@ def test_submit_accepts_user_dataset_paths_and_local_model_paths(tmp_path: Path,
|
||||
"code_path: /private/common/code/verl\n"
|
||||
"model_id: Qwen/Qwen2.5-0.5B-Instruct\n"
|
||||
"train_file: /private/users/alice/datasets/t\n"
|
||||
"val_file: /private/users/alice/datasets/v\n"
|
||||
),
|
||||
)
|
||||
assert r1.status_code == 200
|
||||
@ -236,6 +237,7 @@ def test_submit_accepts_user_dataset_paths_and_local_model_paths(tmp_path: Path,
|
||||
"code_path: /private/common/code/verl\n"
|
||||
"model_id: /private/users/alice/models/m1\n"
|
||||
"train_file: /private/common/datasets/t\n"
|
||||
"val_file: /private/common/datasets/v\n"
|
||||
),
|
||||
)
|
||||
assert r2.status_code == 200
|
||||
@ -274,6 +276,7 @@ def test_submit_rejects_cross_user_paths_and_bad_local_model_dirs(tmp_path: Path
|
||||
"code_path: /private/common/code/verl\n"
|
||||
"model_id: Qwen/Qwen2.5-0.5B-Instruct\n"
|
||||
"train_file: /private/users/alice/datasets/t\n"
|
||||
"val_file: /private/users/alice/datasets/v\n"
|
||||
),
|
||||
)
|
||||
assert r1.status_code == 400
|
||||
@ -288,6 +291,7 @@ def test_submit_rejects_cross_user_paths_and_bad_local_model_dirs(tmp_path: Path
|
||||
"code_path: /private/common/code/verl\n"
|
||||
"model_id: /private/users/bob/jobs/j1/checkpoints\n"
|
||||
"train_file: /private/common/datasets/t\n"
|
||||
"val_file: /private/common/datasets/v\n"
|
||||
),
|
||||
)
|
||||
assert r2.status_code == 400
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user