V3.0 优化webui界面

This commit is contained in:
yuyr 2025-12-31 15:16:42 +08:00
parent 6d3fefc7a6
commit e3dcfe526f
13 changed files with 191 additions and 62 deletions

View File

@ -21,14 +21,16 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
algo_overrides.append("algorithm.adv_estimator=grpo")
test_freq = spec.test_freq if spec.test_freq is not None else -1
val_file = spec.val_file if spec.val_file is not None else "null"
maybe_total_steps: list[str] = []
if spec.total_training_steps is not None:
maybe_total_steps.append(f"trainer.total_training_steps={spec.total_training_steps}")
argv = [
"python3",
"-m",
"verl.trainer.main_ppo",
f"data.train_files={spec.train_file}",
f"data.val_files={val_file}",
f"data.val_files={spec.val_file}",
"data.train_batch_size=256",
"data.max_prompt_length=512",
"data.max_response_length=512",
@ -53,7 +55,7 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
f"trainer.save_freq={spec.save_freq}",
f"trainer.test_freq={test_freq}",
f"trainer.total_epochs={spec.total_epochs}",
f"trainer.total_training_steps={spec.total_training_steps}",
*maybe_total_steps,
"trainer.resume_mode=disable",
f"trainer.default_local_dir={job_dir}/checkpoints",
"+ray_kwargs.ray_init.address=auto",
@ -62,8 +64,10 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
return BuiltCommand(argv=argv)
if spec.workload == "sft":
val_override = "null" if spec.val_file is None else spec.val_file
trainer_device = spec.trainer_device or "cpu"
maybe_total_steps: list[str] = []
if spec.total_training_steps is not None:
maybe_total_steps.append(f"trainer.total_training_steps={spec.total_training_steps}")
argv = [
"python3",
@ -71,7 +75,7 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
"verl.trainer.sft_trainer_ray",
f"model.path={spec.model_id}",
f"data.train_files={spec.train_file}",
f"data.val_files={val_override}",
f"data.val_files={spec.val_file}",
"data.train_batch_size=64",
"data.micro_batch_size_per_gpu=1",
"data.max_token_len_per_gpu=2048",
@ -80,7 +84,7 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
"trainer.project_name=mvp11-sft",
f"trainer.experiment_name={submission_id}",
f"trainer.total_epochs={spec.total_epochs}",
f"trainer.total_training_steps={spec.total_training_steps}",
*maybe_total_steps,
f"trainer.save_freq={spec.save_freq}",
"trainer.test_freq=-1",
"trainer.resume_mode=disable",
@ -93,4 +97,3 @@ def build_training_argv(spec: JobSpec, submission_id: str, job_dir: str) -> Buil
return BuiltCommand(argv=argv)
raise ValueError(f"unsupported workload: {spec.workload}")

View File

@ -63,13 +63,13 @@ class JobSpec:
model_id: str
train_file: str
val_file: str | None
val_file: str
nnodes: int
n_gpus_per_node: int
total_epochs: int
total_training_steps: int
total_training_steps: int | None
save_freq: int
test_freq: int | None
@ -82,25 +82,27 @@ class JobSpec:
if workload not in ("ppo", "grpo", "sft"):
raise ValueError(f"unsupported workload: {workload}")
val_file = d.get("val_file", None)
if val_file in ("", "null"):
val_file = None
test_freq = d.get("test_freq", None)
if test_freq in ("", "null"):
test_freq = None
total_training_steps = d.get("total_training_steps", None)
if total_training_steps in ("", "null"):
total_training_steps = None
return JobSpec(
workload=workload,
submission_id=(str(d["submission_id"]) if d.get("submission_id") else None),
code_path=str(_require(d, "code_path")),
model_id=str(_require(d, "model_id")),
train_file=str(_require(d, "train_file")),
val_file=(str(val_file) if val_file is not None else None),
val_file=str(_require(d, "val_file")),
nnodes=int(d.get("nnodes", 2)),
n_gpus_per_node=int(d.get("n_gpus_per_node", 4)),
total_epochs=int(d.get("total_epochs", 1)),
total_training_steps=int(d.get("total_training_steps", 10)),
# If omitted or explicitly null: do not override trainer.total_training_steps
# (let VERL derive it from epochs and dataset length).
total_training_steps=(int(total_training_steps) if total_training_steps is not None else None),
save_freq=int(d.get("save_freq", 10)),
test_freq=(int(test_freq) if test_freq is not None else None),
trainer_device=(str(d.get("trainer_device")) if d.get("trainer_device") else None),

View File

@ -408,6 +408,38 @@ def create_app(config_path: str) -> FastAPI:
}
return out
@app.get("/api/v2/tasks/{task_id}/spec")
async def get_task_spec(task_id: str, req: Request) -> Response:
"""
Returns the TaskSpec YAML for this task.
Note: We prefer returning a "resolved" TaskSpec (includes default values) so the
UI can show a complete spec even if the user submitted a minimal YAML.
"""
subject = _auth(req)
row = db.get_task(task_id)
if not row:
raise HTTPException(status_code=404, detail="task not found")
if not subject.get("is_admin"):
if str(row.get("user_id") or "") != str(subject["user_id"]):
raise HTTPException(status_code=404, detail="task not found")
raw = str(row.get("jobspec_yaml") or "")
try:
obj = yaml.safe_load(raw) or {}
if not isinstance(obj, dict):
raise ValueError("jobspec must be a YAML mapping")
spec = JobSpec.from_dict(obj)
resolved = spec.to_public_dict()
attempts = db.list_attempts(task_id)
if attempts:
resolved["submission_id"] = str(attempts[-1].get("ray_submission_id") or resolved.get("submission_id") or "")
text = yaml.safe_dump(resolved, sort_keys=False)
except Exception:
# Backward compatibility / best-effort: if old rows contain invalid YAML,
# fall back to the original content.
text = raw
return Response(content=text, media_type="text/plain")
@app.get("/api/v2/tasks/{task_id}/attempts")
async def get_attempts(task_id: str, req: Request) -> dict[str, Any]:
subject = _auth(req)

View File

@ -279,31 +279,56 @@ refresh();
@app.get("/ui/tasks/new")
async def ui_new_task() -> HTMLResponse:
ppo = """# PPO TaskSpec (YAML)
workload: ppo
nnodes: 2
n_gpus_per_node: 4
code_path: /private/common/code/verl/verl_repo
train_file: /private/common/datasets/gsm8k/train.parquet
val_file: /private/common/datasets/gsm8k/test.parquet
model_id: Qwen/Qwen2.5-0.5B-Instruct
workload: ppo # 任务类型必填ppogrposft
code_path: /private/common/code/verl/verl_repo # 代码路径必填v3.0 固定使用 common 下的 verl 快照(不支持用户自定义代码)
model_id: Qwen/Qwen2.5-0.5B-Instruct # 基础模型必填HuggingFace 模型 ID 或 /private/... 本地模型路径
train_file: /private/common/datasets/gsm8k/train.parquet # 训练数据必填parquet 文件路径(支持 /private/common/datasets 或 /private/users/<user>/datasets
val_file: /private/common/datasets/gsm8k/test.parquet # 验证数据必填parquet 文件路径VERL 侧会用来构建 val dataset不能为 null
# nnodes: 2 # 训练节点数可选默认2
# n_gpus_per_node: 4 # 每节点 GPU 数可选默认4
# total_epochs: 1 # 总训练 epoch可选默认1
# total_training_steps: null # 总训练 step可选默认null不传则让 VERL 按 epochs 和数据长度自动推导)
# save_freq: 10 # checkpoint 保存频率step可选默认10
# test_freq: null # 验证频率step可选默认null训练端会当成 -1=不验证)
# submission_id: "" # Ray submission_id可选默认空通常由服务自动生成无需填写
""".strip()
grpo = """# GRPO TaskSpec (YAML)
workload: grpo
nnodes: 2
n_gpus_per_node: 4
code_path: /private/common/code/verl/verl_repo
train_file: /private/common/datasets/gsm8k/train.parquet
val_file: /private/common/datasets/gsm8k/test.parquet
model_id: Qwen/Qwen2.5-0.5B-Instruct
workload: grpo # 任务类型必填ppogrposftgrpo 会自动启用对应的算法配置)
code_path: /private/common/code/verl/verl_repo # 代码路径必填v3.0 固定使用 common 下的 verl 快照(不支持用户自定义代码)
model_id: Qwen/Qwen2.5-0.5B-Instruct # 基础模型必填HuggingFace 模型 ID 或 /private/... 本地模型路径
train_file: /private/common/datasets/gsm8k/train.parquet # 训练数据必填parquet 文件路径(支持 /private/common/datasets 或 /private/users/<user>/datasets
val_file: /private/common/datasets/gsm8k/test.parquet # 验证数据必填parquet 文件路径VERL 侧会用来构建 val dataset不能为 null
# nnodes: 2 # 训练节点数可选默认2
# n_gpus_per_node: 4 # 每节点 GPU 数可选默认4
# total_epochs: 1 # 总训练 epoch可选默认1
# total_training_steps: null # 总训练 step可选默认null不传则让 VERL 按 epochs 和数据长度自动推导)
# save_freq: 10 # checkpoint 保存频率step可选默认10
# test_freq: null # 验证频率step可选默认null训练端会当成 -1=不验证)
# submission_id: "" # Ray submission_id可选默认空通常由服务自动生成无需填写
""".strip()
sft = """# SFT TaskSpec (YAML)
workload: sft
nnodes: 1
n_gpus_per_node: 1
code_path: /private/common/code/verl/verl_repo
train_file: /private/common/datasets/gsm8k_sft/train.parquet
val_file: /private/common/datasets/gsm8k_sft/test.parquet
model_id: Qwen/Qwen2.5-0.5B-Instruct
workload: sft # 任务类型必填ppogrposft
code_path: /private/common/code/verl/verl_repo # 代码路径必填v3.0 固定使用 common 下的 verl 快照(不支持用户自定义代码)
model_id: Qwen/Qwen2.5-0.5B-Instruct # 基础模型必填HuggingFace 模型 ID 或 /private/... 本地模型路径
train_file: /private/common/datasets/gsm8k_sft/train.parquet # 训练数据必填parquet 文件路径(支持 /private/common/datasets 或 /private/users/<user>/datasets
val_file: /private/common/datasets/gsm8k_sft/test.parquet # 验证数据必填parquet 文件路径VERL 侧会用来构建 val dataset不能为 null
# nnodes: 2 # 训练节点数可选默认2单机可设 1
# n_gpus_per_node: 4 # 每节点 GPU 数可选默认4单卡可设 1
# total_epochs: 1 # 总训练 epoch可选默认1
# total_training_steps: null # 总训练 step可选默认null不传则让 VERL 按 epochs 和数据长度自动推导)
# save_freq: 10 # checkpoint 保存频率step可选默认10
# test_freq: null # 验证频率step可选默认null训练端会当成 -1=不验证)
# trainer_device: cpu # 仅 SFT 生效driver 侧 device可选默认cpu
# submission_id: "" # Ray submission_id可选默认空通常由服务自动生成无需填写
""".strip()
body = f"""
<h1>New Task</h1>
@ -371,16 +396,29 @@ document.getElementById("submit").onclick = async () => {
<div style="height:10px"></div>
<pre id="out" class="muted">Loading...</pre>
</div>
<div style="height:12px"></div>
<div class="card">
<h3 style="margin-top:0">TaskSpec (YAML)</h3>
<div class="muted">Resolved TaskSpec (includes default values; submission_id reflects latest attempt when available).</div>
<div style="height:10px"></div>
<pre id="spec" class="muted">Loading...</pre>
</div>
""".strip()
script = f"""
document.getElementById("nav-ray-dashboard").href = curOriginWithPort(8265);
const out = document.getElementById("out");
const spec = document.getElementById("spec");
async function refresh() {{
out.textContent = "Loading...";
spec.textContent = "Loading...";
const resp = await apiFetch("/api/v2/tasks/{task_id}");
const text = await resp.text();
if (!resp.ok) {{ out.textContent = "Error: " + resp.status + "\\n" + text; return; }}
out.textContent = fmtJson(JSON.parse(text));
const resp2 = await apiFetch("/api/v2/tasks/{task_id}/spec");
const text2 = await resp2.text();
spec.textContent = resp2.ok ? text2 : ("Error: " + resp2.status + "\\n" + text2);
}}
document.getElementById("refresh").onclick = refresh;
document.getElementById("cancel").onclick = async () => {{

View File

@ -87,7 +87,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
r = c.post(
"/api/v2/tasks",
headers=headers,
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
)
assert r.status_code == 200
assert r.json()["task_id"] == "tid1"
@ -95,6 +95,27 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
r2 = c.get("/api/v2/tasks/tid1", headers=headers)
assert r2.status_code == 200
assert r2.json()["desired_resources"]["total_gpus"] == 8
r2s = c.get("/api/v2/tasks/tid1/spec", headers=headers)
assert r2s.status_code == 200
assert "workload: ppo" in r2s.text
assert "code_path:" in r2s.text
# Spec endpoint returns a resolved TaskSpec, i.e. includes default values.
assert "nnodes: 2" in r2s.text
assert "n_gpus_per_node: 4" in r2s.text
# Seed an attempt and ensure submission_id is reflected in the resolved spec.
from argus.service.db import Db
from argus.service.config import V2Config
from argus.ray.models import RayConfig
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
ray_cfg = RayConfig.from_dict(root)
v2_cfg = V2Config.from_root_dict(root)
db = Db(v2_cfg.sqlite.db_path)
db.create_attempt(task_id="tid1", attempt_no=1, ray_submission_id="sid1")
r2s2 = c.get("/api/v2/tasks/tid1/spec", headers=headers)
assert r2s2.status_code == 200
assert "submission_id: sid1" in r2s2.text
r3 = c.get("/api/v2/queue", headers=headers)
assert r3.status_code == 200
@ -116,10 +137,6 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
assert r4.json()["state"] == "CANCELED"
# Seed an attempt then fetch logs
from argus.service.db import Db
from argus.service.config import V2Config
from argus.ray.models import RayConfig
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
ray_cfg = RayConfig.from_dict(root)
v2_cfg = V2Config.from_root_dict(root)
@ -127,7 +144,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
db.create_task(
task_id="tid2",
workload="ppo",
jobspec_yaml="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
jobspec_yaml="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
nnodes=2,
n_gpus_per_node=4,
)

View File

@ -6,18 +6,18 @@ from argus.ray.builders import build_training_argv
from argus.ray.models import JobSpec
def _base_spec(workload: str) -> JobSpec:
def _base_spec(workload: str, *, total_training_steps: int | None = 10) -> JobSpec:
return JobSpec(
workload=workload,
submission_id=None,
code_path="/code",
model_id="m",
train_file="train.jsonl",
val_file=None,
val_file="val.jsonl",
nnodes=2,
n_gpus_per_node=4,
total_epochs=1,
total_training_steps=10,
total_training_steps=total_training_steps,
save_freq=10,
test_freq=None,
trainer_device=None,
@ -28,7 +28,7 @@ def test_build_training_argv_ppo_smoke():
spec = _base_spec("ppo")
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
assert built.argv[:3] == ["python3", "-m", "verl.trainer.main_ppo"]
assert "data.val_files=null" in built.argv
assert "data.val_files=val.jsonl" in built.argv
assert "trainer.test_freq=-1" in built.argv
@ -43,11 +43,16 @@ def test_build_training_argv_sft_smoke():
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
assert built.argv[:3] == ["python3", "-m", "verl.trainer.sft_trainer_ray"]
assert "trainer.device=cpu" in built.argv
assert "data.val_files=null" in built.argv
assert "data.val_files=val.jsonl" in built.argv
def test_build_training_argv_omits_total_training_steps_when_unset():
spec = _base_spec("ppo", total_training_steps=None)
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
assert not any(a.startswith("trainer.total_training_steps=") for a in built.argv)
def test_build_training_argv_unsupported_raises():
spec = _base_spec("bad")
with pytest.raises(ValueError, match="unsupported workload"):
build_training_argv(spec, submission_id="sid", job_dir="/job")

View File

@ -37,7 +37,7 @@ def test_cli_submit_status_logs_list(monkeypatch, tmp_path: Path, capsys):
encoding="utf-8",
)
spec = tmp_path / "spec.yaml"
spec.write_text("workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n", encoding="utf-8")
spec.write_text("workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\nval_file: v\n", encoding="utf-8")
from argus.cli.run import main

View File

@ -12,7 +12,7 @@ def test_db_lifecycle_and_basic_queries(tmp_path: Path):
t = db.create_task(
task_id="t1",
workload="ppo",
jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\nval_file: v\n",
nnodes=2,
n_gpus_per_node=4,
)

View File

@ -67,13 +67,14 @@ def test_jobspec_validation_and_null_coercion():
"code_path": "/code",
"model_id": "m",
"train_file": "train.jsonl",
"val_file": "null",
"val_file": "val.jsonl",
"test_freq": "",
}
)
assert spec.workload == "ppo"
assert spec.val_file is None
assert spec.val_file == "val.jsonl"
assert spec.test_freq is None
assert spec.total_training_steps is None
assert spec.nnodes == 2
assert spec.n_gpus_per_node == 4
@ -91,6 +92,7 @@ def test_jobspec_sft_adds_trainer_device_default():
"code_path": "/code",
"model_id": "m",
"train_file": "train.jsonl",
"val_file": "val.jsonl",
}
)
pub = spec.to_public_dict()
@ -102,6 +104,5 @@ def test_jobspec_unsupported_workload():
with pytest.raises(ValueError, match="unsupported workload"):
JobSpec.from_dict(
{"workload": "nope", "code_path": "x", "model_id": "m", "train_file": "t"}
{"workload": "nope", "code_path": "x", "model_id": "m", "train_file": "t", "val_file": "v"}
)

View File

@ -46,7 +46,7 @@ def test_runtime_env_sets_defaults_and_pythonpath(monkeypatch):
}
)
spec = JobSpec.from_dict(
{"workload": "sft", "code_path": "/c", "model_id": "m", "train_file": "t"}
{"workload": "sft", "code_path": "/c", "model_id": "m", "train_file": "t", "val_file": "v"}
)
monkeypatch.setenv("MVP_TOOL_CODE_PATH", "/tool")
@ -106,6 +106,7 @@ def test_submit_writes_artifacts_and_returns_submission_id(tmp_path: Path, monke
"code_path": "/code",
"model_id": "m",
"train_file": "train.jsonl",
"val_file": "val.jsonl",
}
)
@ -151,7 +152,7 @@ def test_submit_error_writes_file_then_reraises(tmp_path: Path, monkeypatch):
}
)
spec = JobSpec.from_dict(
{"workload": "ppo", "submission_id": "sid2", "code_path": "/code", "model_id": "m", "train_file": "t"}
{"workload": "ppo", "submission_id": "sid2", "code_path": "/code", "model_id": "m", "train_file": "t", "val_file": "v"}
)
tool = mod.RayJobTool(cfg)

View File

@ -36,7 +36,15 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
task_id="t1",
user_id="alice",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
jobspec_yaml=yaml.safe_dump(
{
"workload": "ppo",
"code_path": "/private/common/code/verl",
"model_id": "m",
"train_file": "/private/common/datasets/t",
"val_file": "/private/common/datasets/v",
}
),
nnodes=2,
n_gpus_per_node=4,
)
@ -87,7 +95,15 @@ def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
task_id="t1",
user_id="alice",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
jobspec_yaml=yaml.safe_dump(
{
"workload": "ppo",
"code_path": "/private/common/code/verl",
"model_id": "m",
"train_file": "/private/common/datasets/t",
"val_file": "/private/common/datasets/v",
}
),
nnodes=2,
n_gpus_per_node=4,
)
@ -117,7 +133,15 @@ def test_sync_failed_insufficient_resources(monkeypatch, tmp_path: Path):
task_id="t1",
user_id="alice",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
jobspec_yaml=yaml.safe_dump(
{
"workload": "ppo",
"code_path": "/private/common/code/verl",
"model_id": "m",
"train_file": "/private/common/datasets/t",
"val_file": "/private/common/datasets/v",
}
),
nnodes=2,
n_gpus_per_node=4,
)

View File

@ -76,3 +76,5 @@ def test_ui_task_detail_shows_ids(tmp_path, monkeypatch):
assert r.status_code == 200
assert task_id in r.text
assert f"/ui/tasks/{task_id}/logs" in r.text
assert "TaskSpec (YAML)" in r.text
assert "/api/v2/tasks/" in r.text

View File

@ -132,7 +132,7 @@ def test_tasks_are_isolated_by_user(tmp_path: Path, monkeypatch):
r1 = c.post(
"/api/v2/tasks",
headers=alice_headers,
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
)
assert r1.status_code == 200
alice_tid = r1.json()["task_id"]
@ -140,7 +140,7 @@ def test_tasks_are_isolated_by_user(tmp_path: Path, monkeypatch):
r2 = c.post(
"/api/v2/tasks",
headers=bob_headers,
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
)
assert r2.status_code == 200
bob_tid = r2.json()["task_id"]
@ -186,7 +186,7 @@ def test_submit_rejects_non_common_inputs(tmp_path: Path, monkeypatch):
r = c.post(
"/api/v2/tasks",
headers=alice_headers,
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
)
assert r.status_code == 400
assert "code_path must start with /private/common/" in r.text
@ -223,6 +223,7 @@ def test_submit_accepts_user_dataset_paths_and_local_model_paths(tmp_path: Path,
"code_path: /private/common/code/verl\n"
"model_id: Qwen/Qwen2.5-0.5B-Instruct\n"
"train_file: /private/users/alice/datasets/t\n"
"val_file: /private/users/alice/datasets/v\n"
),
)
assert r1.status_code == 200
@ -236,6 +237,7 @@ def test_submit_accepts_user_dataset_paths_and_local_model_paths(tmp_path: Path,
"code_path: /private/common/code/verl\n"
"model_id: /private/users/alice/models/m1\n"
"train_file: /private/common/datasets/t\n"
"val_file: /private/common/datasets/v\n"
),
)
assert r2.status_code == 200
@ -274,6 +276,7 @@ def test_submit_rejects_cross_user_paths_and_bad_local_model_dirs(tmp_path: Path
"code_path: /private/common/code/verl\n"
"model_id: Qwen/Qwen2.5-0.5B-Instruct\n"
"train_file: /private/users/alice/datasets/t\n"
"val_file: /private/users/alice/datasets/v\n"
),
)
assert r1.status_code == 400
@ -288,6 +291,7 @@ def test_submit_rejects_cross_user_paths_and_bad_local_model_dirs(tmp_path: Path
"code_path: /private/common/code/verl\n"
"model_id: /private/users/bob/jobs/j1/checkpoints\n"
"train_file: /private/common/datasets/t\n"
"val_file: /private/common/datasets/v\n"
),
)
assert r2.status_code == 400