v2.5 功能测试通过,待加上docker镜像

This commit is contained in:
yuyr 2025-12-29 14:33:05 +08:00
parent ce8c2128b4
commit f02059126e
30 changed files with 1794 additions and 66 deletions

View File

@ -11,4 +11,4 @@ v2.5 的核心变化:
- `specs/mvp/v2.5/v2.5_design.md`总体架构、关键机制head IP file / watchdog / 用户隔离 / 任务流)。
- `specs/mvp/v2.5/v2.5_api.md`API 设计(用户、任务、队列、日志)与鉴权约定。
- `specs/mvp/v2.5/v2.5_acceptance.md`:开发/部署/验收流程与可验证标准。
- `specs/mvp/v2.5/v2.5_summary.md`v2.5 已实现内容总结(本次迭代做了什么、验收结果、已知限制)。

View File

@ -0,0 +1,3 @@
# 记录问题
1. task 、 submission id 里加上 user name
2. 补全端到端测试用例,各种正常和异常用例,边界情况测试

View File

@ -0,0 +1,229 @@
# MVP v2.5 开发计划TDD 驱动)
本文是 v2.5 的**工程化开发计划**强调“先写测试再写实现”TDD并将每个里程碑拆成**可独立验收**的小闭环。
输入依据:
- 路线图:`specs/mvp/mvp_roadmap_v2.md`
- v2.5 设计:`specs/mvp/v2.5/v2.5_design.md`
- v2.5 API 草案:`specs/mvp/v2.5/v2.5_api.md`
- v2.5 验收:`specs/mvp/v2.5/v2.5_acceptance.md`
v2.5 约束(已确认):
- **不扩展 TaskSpec**:沿用 v2.0/v2.1 的 YAML 结构化字段与语义。
- **不支持自定义 reward function / 不支持用户自定义 verl 代码**
- 训练输入verl 代码、HF cache、datasets统一使用 `/private/common/...`
- 多用户隔离 v2.5 **先只隔离 jobs 输出目录**`/private/users/<uid>/jobs/<ray_submission_id>/...`
---
## 0. TDD 规范(所有功能都遵循)
### 0.1 测试分层
1) **单元测试fast**
- 纯 Python 逻辑DB、鉴权、ID、路径派生、head.json 解析/TTL、watchdog 决策逻辑。
- 目标:不依赖真实 Ray、不依赖 docker、不依赖网络。
2) **组件测试(中等)**
- FastAPI 路由:使用 `fastapi.testclient.TestClient`(现有 v2.0 已采用)。
- 目标:验证 auth/权限隔离、API 行为、状态机。
3) **端到端(慢/手工或脚本)**
- 在 `argus@h1` 上通过 scripts/compose 跑一次“head publish → worker auto-connect → API submit”闭环。
- 目标:验证无状态 worker + watchdog 的真实行为。
### 0.2 测试约定
- 测试目录:`src/mvp/py/tests/`
- 新增功能必须先补齐测试用例,并让其在未实现时失败(红)。
- 实现最小改动让测试变绿(绿)。
- 重构/去重复(重构)。
> 注:现有测试通过 `src/mvp/py/tests/conftest.py` 注入 ray stub确保单测不依赖真实 ray 包v2.5 新增模块也应复用此模式。
---
## 1. 里程碑拆分v2.5 = 4 个可验证闭环)
### M1User 表/Token 表 + 基础鉴权(不影响现有内部 token 兼容)
**目标**
- 引入 user/token 的持久化与鉴权映射token → user_id
- 兼容现有 `Authorization: Bearer <MVP_INTERNAL_TOKEN>` 的“单租户模式”,避免一次性破坏 v2.0 用法:
- v2.5 可以先支持两种 token 模式:
- legacy环境变量 `MVP_INTERNAL_TOKEN`(全局单租户);
- user tokenDB 内签发 token多用户
- admin 能创建用户、签发 token、禁用用户。
**TDD 用例(先写测试)**
单测:
- `test_user_db_create_disable()`
- 创建用户 ACTIVE禁用后状态变为 DISABLED重复创建返回冲突或幂等按最终约定
- `test_token_hashing()`
- 签发 token 时 DB 中只保存 hash不保存明文。
API 测试TestClient
- `test_admin_create_user_and_issue_token()`
- admin token 可创建用户并签发 token明文 token 只返回一次)。
- `test_disabled_user_token_rejected()`
- 用户被禁用后,使用旧 token 调用 API 返回 401/403。
**实现落点(建议模块)**
- `argus.service.auth`token 校验与 user_id 解析(兼容 legacy 模式)
- `argus.service.db`:新增 `users``api_tokens` 表与 CRUD
- `argus.service.app`:新增 user 管理 endpointsadmin scope
- `configs/dev.yaml`:补充 admin token/env 相关配置(保持 YAML 风格)
**验收点**
- `v2.5_acceptance.md`U1 可通过自动化 API 测试覆盖。
---
### M2Task 绑定 user_id + API 可见性隔离(仍不改 TaskSpec
**目标**
- 提交 task 时由 token 推导 `user_id`,写入 `tasks.user_id`
- task 查询/取消/日志默认只允许 owner他人访问返回 404避免泄露存在性
- queue 默认只返回当前用户队列admin 可查询全局队列(可选)。
**TDD 用例(先写测试)**
单测:
- `test_tasks_table_has_user_id()`:创建任务必须落 `user_id`,且 `list_queue(user_id=...)` 只返回该用户任务。
API 测试:
- `test_task_visibility_isolated()`
- user A 创建 taskuser B 查询 `/api/v2/tasks/{id}` 返回 404
- user B cancel/logs 也返回 404。
- `test_queue_isolated()`
- A/B 各自创建 task`GET /api/v2/queue` 只看到自己的。
**实现落点**
- `argus.service.app`:为 task endpoints 增加 user scope
- `argus.service.db`tasks 表增加 user_id 字段、索引、按 user 过滤的查询方法
- `argus.service.scheduler`pick_next_runnable_task 等仍按“全局 FIFO”或“按 user FIFO”
- v2.5 先保持“全局 FIFO”最简单但 API queue 视角是按 user 过滤)。
**验收点**
- `v2.5_acceptance.md`U2 可通过 API 测试覆盖。
---
### M3Jobs 输出目录按 user 隔离(只改输出,不改输入)
**目标**
- Ray Job 的 job_root 目录由服务端统一计算到:
- `/private/users/<uid>/jobs/<ray_submission_id>/...`
- TaskSpec 内与输入相关的路径字段必须是 `/private/common/...`v2.5 输入统一 common
- 任何用户无法通过 TaskSpec 指定输出写到非 user jobs 目录(避免越权写)。
**TDD 用例(先写测试)**
单测:
- `test_job_root_derivation_per_user()`
- 给定 user_id 与 ray_submission_id派生 job_root 固定且正确。
- `test_reject_non_common_inputs()`
- TaskSpec 中 train_file / val_file / code_path / hf 路径等若不以 `/private/common/` 开头则拒绝HTTP 400
API 测试:
- `test_job_dir_written_under_user_jobs()`
- 提交 task 后,在 DB 或 submit payload 中能看到 job_root 在 user jobs 下(可通过 mock RayJobTool.submit 捕获 spec
**实现落点(建议最小侵入)**
- 在 service 层派生 `job_root` 并注入到 RayJobTool/builders而不是让用户从 TaskSpec 指定)。
- RayJobTool `_job_dir()` 改为接收“job_root 生成器”或直接接收 `job_root` 参数(由服务层提供)。
- 目标:保持 RayJobTool 的职责清晰:提交 Ray job路径策略由 service 决定。
**验收点**
- `v2.5_acceptance.md`U3/U4 可通过 API/单测覆盖。
---
### M4Stateless Ray Node Poolhead.json + worker watchdog+ 端到端脚本验证
**目标**
- head 启动后持续写入 `/private/ray/discovery/<cluster_name>/head.json`(包含 TTL
- worker 容器内运行 watchdog或启动脚本 + watchdog无需平台显式传 head 地址:
- 读取 head.json存在且未过期`ray start --address=<head_ip>:<gcs_port>`
- head.json 变化 → `ray stop` + `ray start` 重连
- 在 dev 环境docker compose提供一键脚本复现e2e
**TDD 用例(先写测试)**
单测(不跑真实 ray
- `test_head_json_read_validate_ttl()`
- 文件不存在/过期 → 返回“不可用”
- 未过期 → 返回 head 地址
- `test_watchdog_decision_on_change()`
- head_ip 变化 → 触发重连动作
- only updated_at 变化(地址不变)→ 不重连(或按策略重连,需确定)
组件/脚本级测试(可选):
- 如果 watchdog 用 Python 实现,可对“执行命令”层做 stub不真正跑 `ray start`),只验证会调用什么命令。
端到端脚本(手工/慢):
- 提供脚本 `scripts/run_all_v25_stateless.sh`(命名示例):
1) 起 headRay head + API
2) 启动 head publisher写 head.json
3) 起 2 个 worker每个 4 GPUworker 只跑 watchdog不传 head 地址
4) `ray status` 显示 1 head + 2 worker 且 GPU=8
5) 通过 API 创建用户/签发 token提交 PPO/GRPO/SFT
6) 重启 head或更新 head.json 指向新地址)验证 worker 自动重连
**实现落点(建议实现策略)**
为了可测试性TDD推荐把“读 head.json/判定 TTL/生成 ray start 命令”做成 Python 模块:
- `argus.ray.discovery`read/write head.json原子写、TTL
- `argus.ray.worker_watchdog`watch looppolling + change detection执行命令可注入便于单测 stub
脚本层保持薄:
- `scripts/` 负责 docker exec / compose 编排与进程守护;
- watchdog 进程由容器内 python 模块运行(更可测、更易移植到生产平台的 entrypoint/command
**验收点**
- `v2.5_acceptance.md`A1/A2/A3 主要通过 e2e 脚本 + dashboard/日志验证。
---
## 2. 回归策略(确保 v2.0 不被破坏)
在 v2.5 过程中保留并持续回归以下用例(至少单测覆盖):
- 旧的内部 token 模式仍可访问 `GET /api/v2/queue` 与提交 task若决定保留兼容
- scheduler 的“资源不足 → PENDING_RESOURCES → 延迟重试”行为不变(现有 `test_scheduler.py` 覆盖)。
- `ray entrypoint_resources` 强制 driver 落 worker继续使用 `worker_node` 自定义资源)。
---
## 3. 交付清单(代码/脚本/文档)
### 3.1 代码
- user/tokensDB schema + auth + API endpoints
- tasks绑定 user_id + 权限隔离
- job_root按 user jobs 输出目录派生(输入仍 common
- discovery/watchdoghead.json + worker 自愈
### 3.2 scriptsdev e2e
- head启动 Ray head + head publisher
- workers以无状态方式启动不传 head addr+ watchdog
- `run_all`:一键跑通(含 API submit + 查询 + cancel + 观察队列)
### 3.3 文档
- 更新 `specs/mvp/v2.5/*`(设计/API/验收/开发计划)
- 补充 `src/mvp/README.md` 的 v2.5 使用方式(如需要)
---
## 4. 关键待确认点(开始实现前必须定稿)
1) **legacy token 是否继续兼容**
- 方案 A保留 `MVP_INTERNAL_TOKEN`(单租户)+ 新增 user token多租户
- 方案 Bv2.5 直接切换到 user token破坏兼容但更清晰
2) **调度公平性**
- v2.5 先全局 FIFO简单后续 v3 再引入 per-user 公平调度/配额。
3) **head.json 的生产写入者**
- 方案 A与 API 同进程线程(最少组件)
- 方案 B独立进程更独立、易运维

View File

@ -0,0 +1,132 @@
# MVP v2.5 端到端测试用例(正常/异常/边界)
本用例集目标:覆盖 v2.5 的关键能力与边界条件User + jobs 隔离 + stateless node pool + API 队列调度)。
约束v2.5 已确认):
- TaskSpec 不扩展;不支持 reward_fn不支持用户自定义 verl 代码。
- 输入统一 `/private/common/...`;用户隔离先只隔离 `/private/users/<uid>/jobs/...` 输出。
---
## 0. 环境前置
远端目录示例:
- `argus@h1:/home2/argus/infra/mvp/src/mvp/`
共享目录(宿主机):
- `/home2/argus/infra/mvp/shared/`
容器内路径约定:
- `/private` 为共享存储挂载点
需要:
- GPU 0-7 可用
- 3 容器head无 GPU+ 2 worker各 4 GPU
---
## 1. 正常用例Happy Path
### HP-1v2.5 全链路PPO → GRPO → SFT串行
步骤:
1) `cd /home2/argus/infra/mvp/src/mvp/scripts`
2) `MVP_INTERNAL_TOKEN=<admin_token> RESET_DB=1 ./run_all_v25_api.sh`
期望:
- Ray dashboard 显示 3 nodeshead+2 workersGPU 总数 8。
- 3 个 task 最终为 `SUCCEEDED`
- 输出目录存在且按用户隔离:
- `/private/users/<uid>/jobs/<ray_submission_id>/{config,logs,checkpoints,debug}`
### HP-2Driver 不在 head 跑
验证点(任选一种):
- Ray job 的 driver node IP 不等于 head 容器 IP
- 或日志/调度信息显示 entrypoint_resources 生效driver 在 worker
---
## 2. 异常用例Error Cases
### E-Auth-1缺 token
请求:
- `GET /api/v2/queue` 不带 `Authorization`
期望:
- 返回 401missing bearer token
### E-Auth-2无效 token
请求:
- `Authorization: Bearer <random>`
期望:
- 返回 401invalid token
### E-Auth-3用户禁用后拒绝访问
步骤:
1) admin 创建用户 `bob` 并签发 token
2) admin 禁用 `bob`
3) 用 bob token 请求 `/api/v2/queue`
期望:
- 返回 403user disabled
### E-Isolation-1跨用户访问 task 资源(不泄露存在性)
步骤:
1) alice 提交 task 得到 `task_id`
2) bob 查询 `/api/v2/tasks/{task_id}`
期望:
- 返回 404task not found
### E-Input-1输入路径不在 /private/commonv2.5 约束)
请求:
- 提交 taskspec 但 `train_file``code_path` 不以 `/private/common/` 开头
期望:
- 返回 400并给出具体字段错误例如 `train_file must start with /private/common/`)。
---
## 3. 边界用例Boundary
### B-Queue-1资源不足时不提交 RayPENDING_RESOURCES
步骤:
1) 构造任务需求 `nnodes=3``n_gpus_per_node=4`total 12 GPU
2) 提交后轮询状态
期望:
- task 进入 `PENDING_RESOURCES`(服务侧 pending不向 Ray submit
- 具备 `next_run_at`
### B-Cancel-1任务取消QUEUED/RUNNING
步骤:
1) 提交一个较长 steps 的任务(确保有机会 RUNNING
2) 调用 `POST /api/v2/tasks/{task_id}:cancel`
期望:
- task state 为 `CANCELED`
- attempt 中 `ray_status` 最终为 `STOPPED`(或 Ray 侧停止)
---
## 4. 可执行回归脚本
见:
- `src/mvp/scripts/run_e2e_v25_cases.sh`
脚本覆盖:
- HP-1
- E-Auth-1/E-Auth-2/E-Input-1
- E-Isolation-1
- B-Queue-1
- B-Cancel-1best-effort

View File

@ -0,0 +1,92 @@
# MVP v2.5 迭代总结(已落地)
本文档总结 v2.5 在 v2.0/v2.1/v2.2…基础上完成的能力、实现点、验收方式与已知限制,便于回顾与后续版本迭代对齐。
## 目标与边界
v2.5 的核心目标:
- 引入 **User Management用户管理**:基于 token 的鉴权与任务级隔离(“只隔离 jobs”
- 引入 **Stateless Ray Node Pool无状态 Ray worker 池)**worker 不依赖平台下发 head 地址,自动发现并连接/自愈。
- 保持 **TaskSpecv1.1 同款 YAML 格式)不扩展**:本迭代不支持 reward function、自定义 verl 代码等。
明确不做v2.5 约束):
- 不支持 TaskSpec 扩展(例如 `reward_fn_path` 等)。
- 不支持用户自定义 verl/hf/dataset 的隔离或自定义路径:**统一使用 `/private/common/...`** 的公共资源。
- 用户隔离仅覆盖 **任务与产物目录**jobs不覆盖 HF cache、datasets 等公共缓存。
## 关键能力(对外表现)
### 1) 多用户鉴权与任务隔离
- API 仍使用内部 `Authorization: Bearer <token>` 方式:
- 管理员 token 来自环境变量 `MVP_INTERNAL_TOKEN`admin
- 业务用户 token 由管理员通过 API 下发并持久化到 SQLite。
- 用户隔离策略:
- 非管理员用户只能查询/取消/拉取日志 **自己的 task**;跨用户访问返回 404不泄露存在性
- 训练产物落盘隔离Ray job 目录统一写入 `/private/users/<user_id>/jobs/<ray_submission_id>/...`
### 2) task_id / submission_id 带用户名
- 新任务 ID 规则:`mvp2-<user>-<workload>-<YYYYMMDD-HHMMSS>-<suffix>`
- Ray submission idattempt规则`<task_id>--aNN`,因此自然包含用户名。
- 作用Dashboard/日志/落盘目录可读性更强,便于按用户追踪和审计。
### 3) “无状态 worker 池”与 head 地址发现
- Head 在共享存储写入 **head 地址文件**(例如 `head.json`worker 通过 watchdog
- 轮询发现 head 地址
- 自动 `ray start --address ...` 加入集群
- 掉线后自动重连watchdog 自愈)
- 达成效果:在生产环境中,即使 worker 容器由算力平台创建(只提供 SSH 纳管),也能通过共享存储实现连接与自愈。
### 4) 任务调度:队列 + Ray Job 提交 + 状态回传
- API 提交任务后进入 SQLite 队列,由后台 scheduler 逐个提交到 Ray默认 `max_running_tasks=1`)。
- Scheduler 持续轮询 Ray job 状态并回写任务状态RUNNING/SUCCEEDED/FAILED/CANCELED
- 资源不足的“可重试失败”处理:
- 针对 VERL 的 fail-fast`Total available GPUs ... is less than total desired GPUs ...`)或集群资源不足,
任务进入 `PENDING_RESOURCES` 并设置 `next_run_at`,按 `retry_interval_s` 周期重试。
## 关键实现点(工程化落地)
### 存储与目录约定(容器内视角)
- 共享根路径统一为 `/private`(对齐生产挂载)。
- v2.5 强约束TaskSpec 的以下字段必须以 `/private/common/` 开头:
- `code_path` / `train_file` / `val_file`
- 公共目录(示例):
- `/private/common/hf`HF 缓存
- `/private/common/datasets`:训练数据(必要时通过 symlink 指向已有缓存目录复用下载)
- `/private/common/db/mvp.sqlite3`队列与用户信息SQLite
- `/private/common/logs`API / watchdog 日志
- `/private/users/<uid>/jobs/...`:用户作业产物(隔离)
### Ray 拓扑与“head 不跑训练”
- Head 启动为管理节点CPU/GPU=0避免训练任务落到 head。
- Worker 节点具备 GPU示例2 个 worker * 每个 4 GPU
- driver 通过 `entrypoint_resources`(例如 `worker_node: 1`)强制落 worker。
### 部署脚本与可重复执行
提供完整脚本链路,覆盖:
- 清理 legacy 环境、起停容器、启动 Ray head
- head discovery publisher、worker watchdog 启动与状态检查
- 数据/模型/代码准备(幂等、可复用已有下载)
- 启动 API server并支持 RESET_DB
- API 方式连续提交 PPO/GRPO/SFT 并等待完成
代表性脚本:
- `src/mvp/scripts/run_all_v25_api.sh`v2.5 happy-path 端到端(含重建集群、准备资源、起 API、提交 3 类任务)
- `src/mvp/scripts/run_e2e_v25_cases.sh`:在 happy-path 基础上增加鉴权/隔离/输入校验/资源不足/取消等用例
## 验收与测试(已通过)
### 单元测试(本机 venv
- `.venv/bin/python -m pytest`
- 覆盖率阈值:>= 90%
### 远端端到端h1
- 在 `argus@h1:/home2/argus/infra/mvp/src/mvp/scripts` 执行:
- `MVP_INTERNAL_TOKEN=mvp-dev-token RESET_DB=1 ./run_e2e_v25_cases.sh`
- 结果happy-pathPPO/GRPO/SFT完成且异常/边界用例验证通过(鉴权、跨用户隔离、输入校验、资源不足转 PENDING_RESOURCES、取消任务等
## 已知问题与后续建议
- `max_running_tasks=1` 会让队列中的任务在前序 RUNNING 时保持 QUEUED这在“资源不足”边界测试里需要显式清空/取消前序任务,或接受该行为作为设计的一部分。
- 当前仍是 SQLite 单点;后续若要 HA/水平扩展,可在 v2.6+ 引入更强的持久化与多副本(例如 Postgres/etcd
- API server / watchdog 目前以脚本方式守护;后续可进一步统一为 systemd/supervisor或平台侧守护并补齐健康检查与告警。

View File

@ -15,3 +15,4 @@
快速开始:
- CLI 提交流程:`scripts/run_all_cli.sh`
- API 提交流程:`scripts/run_all_api.sh`
- v2.5Stateless worker + user 隔离 jobsE2E`scripts/run_all_v25_api.sh`

View File

@ -2,14 +2,27 @@ from __future__ import annotations
import secrets
from datetime import datetime
import re
def new_task_id(workload: str) -> str:
_USER_SAFE_RE = re.compile(r"[^a-z0-9_-]+")
def _normalize_user_id(user_id: str) -> str:
s = (user_id or "").strip().lower()
s = _USER_SAFE_RE.sub("-", s)
s = re.sub(r"-{2,}", "-", s).strip("-")
return s[:24] if s else ""
def new_task_id(workload: str, *, user_id: str | None = None) -> str:
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
suffix = secrets.token_hex(2)
u = _normalize_user_id(user_id or "")
if u:
return f"mvp2-{u}-{workload}-{ts}-{suffix}"
return f"mvp2-{workload}-{ts}-{suffix}"
def attempt_submission_id(task_id: str, attempt_no: int) -> str:
return f"{task_id}--a{attempt_no:02d}"

View File

@ -0,0 +1,115 @@
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _parse_utc(ts: str) -> datetime:
# Accept ISO8601 with trailing Z
s = ts.strip()
if s.endswith("Z"):
s = s[:-1] + "+00:00"
return datetime.fromisoformat(s).astimezone(timezone.utc)
def _iso_z(dt: datetime) -> str:
return dt.astimezone(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
@dataclass(frozen=True)
class HeadRecord:
cluster_name: str
head_ip: str
gcs_port: int
dashboard_port: int
job_server_url: str
updated_at: str
expires_at: str
@staticmethod
def from_dict(d: dict[str, Any]) -> "HeadRecord":
return HeadRecord(
cluster_name=str(d["cluster_name"]),
head_ip=str(d["head_ip"]),
gcs_port=int(d.get("gcs_port", 6379)),
dashboard_port=int(d.get("dashboard_port", 8265)),
job_server_url=str(d.get("job_server_url") or f"http://{d['head_ip']}:{int(d.get('dashboard_port', 8265))}"),
updated_at=str(d["updated_at"]),
expires_at=str(d["expires_at"]),
)
def to_dict(self) -> dict[str, Any]:
return {
"cluster_name": self.cluster_name,
"head_ip": self.head_ip,
"gcs_port": self.gcs_port,
"dashboard_port": self.dashboard_port,
"job_server_url": self.job_server_url,
"updated_at": self.updated_at,
"expires_at": self.expires_at,
}
def head_addr(self) -> str:
return f"{self.head_ip}:{self.gcs_port}"
def build_head_record(
*,
cluster_name: str,
head_ip: str,
gcs_port: int = 6379,
dashboard_port: int = 8265,
ttl_s: int = 60,
now: datetime | None = None,
) -> HeadRecord:
now_dt = now or _utc_now()
expires = now_dt + timedelta(seconds=int(ttl_s))
updated_at = _iso_z(now_dt)
expires_at = _iso_z(expires)
return HeadRecord(
cluster_name=cluster_name,
head_ip=head_ip,
gcs_port=int(gcs_port),
dashboard_port=int(dashboard_port),
job_server_url=f"http://{head_ip}:{int(dashboard_port)}",
updated_at=updated_at,
expires_at=expires_at,
)
def load_head_record(path: str, *, now: datetime | None = None) -> HeadRecord | None:
p = Path(path)
if not p.exists():
return None
try:
obj = json.loads(p.read_text(encoding="utf-8"))
except Exception:
return None
if not isinstance(obj, dict):
return None
try:
rec = HeadRecord.from_dict(obj)
expires = _parse_utc(rec.expires_at)
except Exception:
return None
now_dt = now or _utc_now()
if expires <= now_dt:
return None
return rec
def write_head_record_atomic(path: str, rec: HeadRecord) -> None:
p = Path(path)
os.makedirs(p.parent, exist_ok=True)
tmp = p.with_suffix(p.suffix + ".tmp")
tmp.write_text(json.dumps(rec.to_dict(), indent=2, sort_keys=True) + "\n", encoding="utf-8")
tmp.replace(p)

View File

@ -0,0 +1,65 @@
from __future__ import annotations
import argparse
import time
from .discovery import build_head_record, write_head_record_atomic
def publish_once(
*,
cluster_name: str,
head_ip_file: str,
head_ip: str,
gcs_port: int,
dashboard_port: int,
ttl_s: int,
) -> None:
rec = build_head_record(
cluster_name=cluster_name,
head_ip=head_ip,
gcs_port=gcs_port,
dashboard_port=dashboard_port,
ttl_s=ttl_s,
)
write_head_record_atomic(head_ip_file, rec)
def main(argv: list[str] | None = None) -> int:
ap = argparse.ArgumentParser(description="Publish Ray head address to shared storage (head.json).")
ap.add_argument("--cluster-name", required=True)
ap.add_argument("--head-ip-file", required=True)
ap.add_argument("--head-ip", required=True)
ap.add_argument("--gcs-port", type=int, default=6379)
ap.add_argument("--dashboard-port", type=int, default=8265)
ap.add_argument("--ttl-s", type=int, default=60)
ap.add_argument("--refresh-s", type=int, default=10)
ap.add_argument("--once", action="store_true", help="Write once then exit (for testing/debug).")
args = ap.parse_args(argv)
if args.once:
publish_once(
cluster_name=args.cluster_name,
head_ip_file=args.head_ip_file,
head_ip=args.head_ip,
gcs_port=args.gcs_port,
dashboard_port=args.dashboard_port,
ttl_s=args.ttl_s,
)
return 0
refresh_s = max(1, int(args.refresh_s))
while True:
publish_once(
cluster_name=args.cluster_name,
head_ip_file=args.head_ip_file,
head_ip=args.head_ip,
gcs_port=args.gcs_port,
dashboard_port=args.dashboard_port,
ttl_s=args.ttl_s,
)
time.sleep(refresh_s)
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -73,9 +73,9 @@ class RayJobTool:
return {"env_vars": env_vars}
def submit(self, spec: JobSpec, no_wait: bool) -> str:
def submit(self, spec: JobSpec, no_wait: bool, job_dir: str | None = None) -> str:
submission_id = spec.submission_id or f"mvp11_{spec.workload}_{_ts()}_{os.getpid()}"
job_dir = self._job_dir(submission_id)
job_dir = job_dir or self._job_dir(submission_id)
built = build_training_argv(spec, submission_id=submission_id, job_dir=job_dir)
entrypoint_argv = [

View File

@ -0,0 +1,110 @@
from __future__ import annotations
import argparse
import json
import subprocess
import time
from dataclasses import dataclass
from typing import Callable
from .discovery import HeadRecord, load_head_record
Runner = Callable[[list[str]], int]
def _default_runner(argv: list[str]) -> int:
proc = subprocess.run(argv, check=False)
return int(proc.returncode)
def _ray_stop_cmd() -> list[str]:
return ["ray", "stop", "--force"]
def _ray_start_cmd(*, head_addr: str, node_ip: str | None, resources_json: str) -> list[str]:
argv = ["ray", "start", f"--address={head_addr}", f"--resources={resources_json}", "--disable-usage-stats"]
if node_ip:
argv.append(f"--node-ip-address={node_ip}")
return argv
@dataclass
class Watchdog:
head_ip_file: str
node_ip: str | None
resources_json: str
poll_s: int
runner: Runner
_current_head_addr: str | None = None
def _restart_ray(self, desired: str) -> None:
# Best-effort stop then start.
self.runner(_ray_stop_cmd())
self.runner(_ray_start_cmd(head_addr=desired, node_ip=self.node_ip, resources_json=self.resources_json))
self._current_head_addr = desired
def tick_once(self) -> HeadRecord | None:
rec = load_head_record(self.head_ip_file)
if rec is None:
return None
desired = rec.head_addr()
if self._current_head_addr != desired:
self._restart_ray(desired)
return rec
def run_forever(self) -> None:
while True:
try:
self.tick_once()
except Exception:
# keep watchdog alive
pass
time.sleep(max(1, int(self.poll_s)))
def main(argv: list[str] | None = None) -> int:
ap = argparse.ArgumentParser(description="Stateless Ray worker watchdog (auto-connect + reconnect).")
ap.add_argument("--head-ip-file", required=True)
ap.add_argument("--node-ip", default=None)
ap.add_argument(
"--resources-kv",
action="append",
default=[],
help='Repeatable k=v resources, e.g. --resources-kv worker_node=100',
)
ap.add_argument("--poll-s", type=int, default=5)
ap.add_argument("--once", action="store_true", help="Run one tick then exit (for testing/debug).")
args = ap.parse_args(argv)
resources: dict[str, float] = {"worker_node": 100.0}
if args.resources_kv:
resources = {}
for item in args.resources_kv:
if "=" not in item:
raise SystemExit(f"invalid --resources-kv (expected k=v): {item!r}")
k, v = item.split("=", 1)
k = k.strip()
v = v.strip()
if not k:
raise SystemExit(f"invalid --resources-kv key: {item!r}")
resources[k] = float(v)
wd = Watchdog(
head_ip_file=args.head_ip_file,
node_ip=args.node_ip,
resources_json=json.dumps(resources),
poll_s=int(args.poll_s),
runner=_default_runner,
)
if args.once:
wd.tick_once()
return 0
wd.run_forever()
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -43,19 +43,32 @@ def create_app(config_path: str) -> FastAPI:
app = FastAPI(title="mvp-v2", version="2.0")
def _require_token(req: Request) -> None:
def _auth(req: Request) -> dict[str, Any]:
token_env = v2_cfg.auth.token_env
expected = os.environ.get(token_env, "")
if not expected:
# Misconfigured service; treat as server error.
admin_token = os.environ.get(token_env, "")
if not admin_token:
raise HTTPException(status_code=500, detail=f"missing token env: {token_env}")
auth = req.headers.get("authorization") or ""
if not auth.startswith("Bearer "):
raise HTTPException(status_code=401, detail="missing bearer token")
got = auth.removeprefix("Bearer ").strip()
if got != expected:
if got == admin_token:
return {"user_id": "admin", "is_admin": True}
info = db.resolve_token(got)
if not info:
raise HTTPException(status_code=401, detail="invalid token")
if info.get("state") != "ACTIVE":
raise HTTPException(status_code=403, detail="user disabled")
return {"user_id": info["user_id"], "is_admin": False}
def _require_admin(req: Request) -> dict[str, Any]:
subject = _auth(req)
if not subject.get("is_admin"):
raise HTTPException(status_code=403, detail="admin required")
return subject
@app.on_event("startup")
def _startup() -> None:
@ -66,9 +79,43 @@ def create_app(config_path: str) -> FastAPI:
def _shutdown() -> None:
stop_flag.set()
@app.post("/api/v2/users")
async def create_user(req: Request) -> dict[str, Any]:
_require_admin(req)
obj = await req.json()
if not isinstance(obj, dict):
raise HTTPException(status_code=400, detail="body must be a JSON object")
user_id = str(obj.get("user_id") or "").strip()
if not user_id:
raise HTTPException(status_code=400, detail="missing user_id")
display_name = obj.get("display_name")
try:
row = db.create_user(user_id=user_id, display_name=str(display_name) if display_name is not None else None)
except Exception as e:
raise HTTPException(status_code=409, detail=f"user create failed: {e!r}")
return {"user_id": row.get("user_id", user_id), "state": row.get("state", "ACTIVE")}
@app.post("/api/v2/users/{user_id}/tokens")
async def issue_token(user_id: str, req: Request) -> dict[str, Any]:
_require_admin(req)
u = db.get_user(user_id)
if not u:
raise HTTPException(status_code=404, detail="user not found")
token = db.issue_token(user_id=user_id)
return {"user_id": user_id, "token": token}
@app.post("/api/v2/users/{user_id}:disable")
async def disable_user(user_id: str, req: Request) -> dict[str, Any]:
_require_admin(req)
u = db.get_user(user_id)
if not u:
raise HTTPException(status_code=404, detail="user not found")
db.disable_user(user_id=user_id)
return {"user_id": user_id, "state": "DISABLED"}
@app.post("/api/v2/tasks")
async def submit_task(req: Request) -> dict[str, Any]:
_require_token(req)
subject = _auth(req)
body = (await req.body()).decode("utf-8")
obj = yaml.safe_load(body) or {}
if not isinstance(obj, dict):
@ -79,9 +126,18 @@ def create_app(config_path: str) -> FastAPI:
except Exception as e:
raise HTTPException(status_code=400, detail=f"invalid jobspec: {e!r}")
task_id = new_task_id(spec.workload)
db.create_task(
# v2.5 constraint: training inputs must come from /private/common (dev/prod统一)。
common_prefix = ray_cfg.shared_root.rstrip("/") + "/common/"
for k, v in (("code_path", spec.code_path), ("train_file", spec.train_file), ("val_file", spec.val_file)):
if v is None:
continue
if not str(v).startswith(common_prefix):
raise HTTPException(status_code=400, detail=f"{k} must start with {common_prefix}")
task_id = new_task_id(spec.workload, user_id=str(subject["user_id"]))
db.create_task_v25(
task_id=task_id,
user_id=str(subject["user_id"]),
workload=spec.workload,
jobspec_yaml=body,
nnodes=spec.nnodes,
@ -91,10 +147,13 @@ def create_app(config_path: str) -> FastAPI:
@app.get("/api/v2/tasks/{task_id}")
async def get_task(task_id: str, req: Request) -> dict[str, Any]:
_require_token(req)
subject = _auth(req)
row = db.get_task(task_id)
if not row:
raise HTTPException(status_code=404, detail="task not found")
if not subject.get("is_admin"):
if str(row.get("user_id") or "") != str(subject["user_id"]):
raise HTTPException(status_code=404, detail="task not found")
attempts = db.list_attempts(task_id)
latest_attempt = attempts[-1] if attempts else None
desired = {
@ -125,18 +184,24 @@ def create_app(config_path: str) -> FastAPI:
@app.get("/api/v2/tasks/{task_id}/attempts")
async def get_attempts(task_id: str, req: Request) -> dict[str, Any]:
_require_token(req)
subject = _auth(req)
row = db.get_task(task_id)
if not row:
raise HTTPException(status_code=404, detail="task not found")
if not subject.get("is_admin"):
if str(row.get("user_id") or "") != str(subject["user_id"]):
raise HTTPException(status_code=404, detail="task not found")
return {"task_id": task_id, "attempts": db.list_attempts(task_id)}
@app.post("/api/v2/tasks/{task_id}:cancel")
async def cancel(task_id: str, req: Request) -> dict[str, Any]:
_require_token(req)
subject = _auth(req)
row = db.get_task(task_id)
if not row:
raise HTTPException(status_code=404, detail="task not found")
if not subject.get("is_admin"):
if str(row.get("user_id") or "") != str(subject["user_id"]):
raise HTTPException(status_code=404, detail="task not found")
state = str(row["state"])
if state in ("SUCCEEDED", "FAILED", "CANCELED"):
@ -165,10 +230,13 @@ def create_app(config_path: str) -> FastAPI:
@app.get("/api/v2/tasks/{task_id}/logs")
async def logs(task_id: str, req: Request, tail: int = 2000, attempt: str = "latest") -> Response:
_require_token(req)
subject = _auth(req)
row = db.get_task(task_id)
if not row:
raise HTTPException(status_code=404, detail="task not found")
if not subject.get("is_admin"):
if str(row.get("user_id") or "") != str(subject["user_id"]):
raise HTTPException(status_code=404, detail="task not found")
attempts = db.list_attempts(task_id)
if not attempts:
raise HTTPException(status_code=404, detail="no attempts yet")
@ -184,7 +252,9 @@ def create_app(config_path: str) -> FastAPI:
@app.get("/api/v2/queue")
async def queue(req: Request) -> dict[str, Any]:
_require_token(req)
return db.list_queue()
subject = _auth(req)
if subject.get("is_admin"):
return db.list_queue()
return db.list_queue(user_id=str(subject["user_id"]))
return app

View File

@ -4,6 +4,8 @@ import os
import sqlite3
from contextlib import contextmanager
from dataclasses import dataclass
import hashlib
import secrets
from typing import Any, Iterator
@ -18,6 +20,10 @@ def _utc_now_iso() -> str:
class Db:
db_path: str
def _hash_token(self, token: str) -> str:
# Internal tokens only; store hash to avoid plaintext at rest.
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _connect(self) -> sqlite3.Connection:
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
conn = sqlite3.connect(self.db_path, timeout=30, isolation_level=None)
@ -28,6 +34,27 @@ class Db:
def init(self) -> None:
with self._connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS users (
user_id TEXT PRIMARY KEY,
display_name TEXT,
state TEXT NOT NULL,
created_at TEXT NOT NULL
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS api_tokens (
token_hash TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
created_at TEXT NOT NULL,
last_used_at TEXT,
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS tasks (
@ -45,6 +72,12 @@ class Db:
)
"""
)
# Best-effort: add user_id column for forward compatibility (v2.5+).
# v2.0 tables may exist without it; SQLite supports ADD COLUMN.
try:
conn.execute("ALTER TABLE tasks ADD COLUMN user_id TEXT")
except sqlite3.OperationalError:
pass
conn.execute(
"""
CREATE TABLE IF NOT EXISTS attempts (
@ -92,10 +125,10 @@ class Db:
with self.tx() as conn:
conn.execute(
"""
INSERT INTO tasks (task_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at)
VALUES (?, ?, 'QUEUED', ?, ?, ?, ?, ?)
INSERT INTO tasks (task_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at, user_id)
VALUES (?, ?, 'QUEUED', ?, ?, ?, ?, ?, ?)
""",
(task_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now),
(task_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now, None),
)
conn.execute(
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (?, ?, 'TASK_CREATED', ?)",
@ -104,6 +137,97 @@ class Db:
row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
return dict(row) if row else {}
def create_task_v25(
self,
*,
task_id: str,
user_id: str,
workload: str,
jobspec_yaml: str,
nnodes: int,
n_gpus_per_node: int,
) -> dict[str, Any]:
now = _utc_now_iso()
with self.tx() as conn:
conn.execute(
"""
INSERT INTO tasks (task_id, user_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at)
VALUES (?, ?, ?, 'QUEUED', ?, ?, ?, ?, ?)
""",
(task_id, user_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now),
)
conn.execute(
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (?, ?, 'TASK_CREATED', ?)",
(task_id, now, None),
)
row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
return dict(row) if row else {}
def create_user(self, *, user_id: str, display_name: str | None = None) -> dict[str, Any]:
now = _utc_now_iso()
with self.tx() as conn:
conn.execute(
"""
INSERT INTO users (user_id, display_name, state, created_at)
VALUES (?, ?, 'ACTIVE', ?)
""",
(user_id, display_name, now),
)
conn.execute(
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'USER_CREATED', ?)",
(now, user_id),
)
row = conn.execute("SELECT * FROM users WHERE user_id = ?", (user_id,)).fetchone()
return dict(row) if row else {}
def disable_user(self, *, user_id: str) -> None:
now = _utc_now_iso()
with self.tx() as conn:
conn.execute("UPDATE users SET state = 'DISABLED' WHERE user_id = ?", (user_id,))
conn.execute(
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'USER_DISABLED', ?)",
(now, user_id),
)
def get_user(self, user_id: str) -> dict[str, Any] | None:
with self._connect() as conn:
row = conn.execute("SELECT * FROM users WHERE user_id = ?", (user_id,)).fetchone()
return dict(row) if row else None
def issue_token(self, *, user_id: str) -> str:
# Returns plaintext token once; stores hash only.
now = _utc_now_iso()
token = f"mvp_u_{user_id}_{secrets.token_urlsafe(18)}"
token_hash = self._hash_token(token)
with self.tx() as conn:
conn.execute(
"INSERT INTO api_tokens (token_hash, user_id, created_at) VALUES (?, ?, ?)",
(token_hash, user_id, now),
)
conn.execute(
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'TOKEN_ISSUED', ?)",
(now, user_id),
)
return token
def resolve_token(self, token: str) -> dict[str, Any] | None:
token_hash = self._hash_token(token)
with self.tx() as conn:
row = conn.execute(
"""
SELECT t.user_id, u.state
FROM api_tokens t
JOIN users u ON u.user_id = t.user_id
WHERE t.token_hash = ?
""",
(token_hash,),
).fetchone()
if not row:
return None
now = _utc_now_iso()
conn.execute("UPDATE api_tokens SET last_used_at = ? WHERE token_hash = ?", (now, token_hash))
return {"user_id": row["user_id"], "state": row["state"]}
def get_task(self, task_id: str) -> dict[str, Any] | None:
with self._connect() as conn:
row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
@ -116,26 +240,33 @@ class Db:
).fetchall()
return [dict(r) for r in rows]
def list_queue(self) -> dict[str, list[dict[str, Any]]]:
def list_queue(self, *, user_id: str | None = None) -> dict[str, list[dict[str, Any]]]:
with self._connect() as conn:
pending = conn.execute(
"""
SELECT task_id, workload, state, nnodes, n_gpus_per_node, next_run_at, created_at, updated_at
FROM tasks
WHERE state IN ('QUEUED','PENDING_RESOURCES')
ORDER BY created_at ASC
LIMIT 200
"""
).fetchall()
running = conn.execute(
"""
SELECT task_id, workload, state, nnodes, n_gpus_per_node, latest_attempt_no, created_at, updated_at
FROM tasks
WHERE state IN ('SUBMITTING','SUBMITTED','RUNNING')
ORDER BY updated_at ASC
LIMIT 200
"""
).fetchall()
params: list[Any] = []
user_filter_sql = ""
if user_id is not None:
user_filter_sql = " AND user_id = ?"
params = [user_id]
pending_sql = (
"SELECT task_id, workload, state, nnodes, n_gpus_per_node, next_run_at, created_at, updated_at "
"FROM tasks "
"WHERE state IN ('QUEUED','PENDING_RESOURCES')"
f"{user_filter_sql} "
"ORDER BY created_at ASC "
"LIMIT 200"
)
running_sql = (
"SELECT task_id, workload, state, nnodes, n_gpus_per_node, latest_attempt_no, created_at, updated_at "
"FROM tasks "
"WHERE state IN ('SUBMITTING','SUBMITTED','RUNNING')"
f"{user_filter_sql} "
"ORDER BY updated_at ASC "
"LIMIT 200"
)
pending = conn.execute(pending_sql, tuple(params)).fetchall()
running = conn.execute(running_sql, tuple(params)).fetchall()
return {"pending": [dict(r) for r in pending], "running": [dict(r) for r in running]}
def count_running(self) -> int:

View File

@ -37,6 +37,12 @@ class Scheduler:
def __post_init__(self) -> None:
self.tool = RayJobTool(self.ray_cfg)
def _job_dir_for_task(self, *, user_id: str | None, ray_submission_id: str) -> str:
root = self.ray_cfg.shared_root.rstrip("/")
if user_id:
return f"{root}/users/{user_id}/jobs/{ray_submission_id}"
return f"{root}/jobs/{ray_submission_id}"
def _resources_sufficient(self, *, nnodes: int, n_gpus_per_node: int) -> bool:
avail = get_cluster_available()
required = float(nnodes * n_gpus_per_node)
@ -51,10 +57,13 @@ class Scheduler:
def _submit_one(self, task_row: dict[str, Any]) -> None:
task_id = str(task_row["task_id"])
jobspec_yaml = str(task_row["jobspec_yaml"])
user_id = task_row.get("user_id")
user_id_s = str(user_id) if user_id not in (None, "") else None
spec = self._parse_jobspec(jobspec_yaml)
attempt_no = int(task_row.get("latest_attempt_no", 0)) + 1
ray_sid = attempt_submission_id(task_id, attempt_no)
job_dir = self._job_dir_for_task(user_id=user_id_s, ray_submission_id=ray_sid)
# Record attempt first so that we can surface it even if submit crashes.
self.db.create_attempt(task_id=task_id, attempt_no=attempt_no, ray_submission_id=ray_sid)
@ -66,7 +75,7 @@ class Scheduler:
spec2 = JobSpec.from_dict(d)
try:
submitted = self.tool.submit(spec2, no_wait=True)
submitted = self.tool.submit(spec2, no_wait=True, job_dir=job_dir)
# submitted should equal ray_sid; keep as source of truth.
self.db.update_attempt(task_id=task_id, attempt_no=attempt_no, ray_status="SUBMITTED")
self.db.set_task_state(task_id=task_id, state="SUBMITTED")

View File

@ -12,7 +12,7 @@ def _write_config(tmp_path: Path) -> Path:
cfg = {
"ray": {
"address": "http://127.0.0.1:8265",
"shared_root": str(tmp_path),
"shared_root": "/private",
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
},
@ -54,7 +54,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
monkeypatch.setattr(app_mod, "new_task_id", lambda workload: "tid1")
monkeypatch.setattr(app_mod, "new_task_id", lambda workload, **kwargs: "tid1")
class _Tool:
def __init__(self):
@ -82,7 +82,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
r = c.post(
"/api/v2/tasks",
headers=headers,
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
)
assert r.status_code == 200
assert r.json()["task_id"] == "tid1"
@ -111,7 +111,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
db.create_task(
task_id="tid2",
workload="ppo",
jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
jobspec_yaml="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
nnodes=2,
n_gpus_per_node=4,
)
@ -163,4 +163,3 @@ def test_submit_rejects_invalid_jobspec(tmp_path: Path, monkeypatch):
with TestClient(app) as c:
r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="workload: nope\n")
assert r.status_code == 400

View File

@ -0,0 +1,50 @@
from __future__ import annotations
import json
from datetime import datetime, timedelta, timezone
from pathlib import Path
from argus.ray.discovery import build_head_record, load_head_record
def test_load_head_record_rejects_missing(tmp_path: Path):
assert load_head_record(str(tmp_path / "nope.json")) is None
def test_load_head_record_rejects_expired(tmp_path: Path):
p = tmp_path / "head.json"
now = datetime.now(timezone.utc).replace(microsecond=0)
rec = build_head_record(cluster_name="c", head_ip="1.2.3.4", ttl_s=1, now=now)
p.write_text(json.dumps(rec.to_dict()), encoding="utf-8")
assert load_head_record(str(p), now=now + timedelta(seconds=2)) is None
def test_load_head_record_accepts_fresh(tmp_path: Path):
p = tmp_path / "head.json"
now = datetime.now(timezone.utc).replace(microsecond=0)
rec = build_head_record(cluster_name="c", head_ip="1.2.3.4", ttl_s=60, now=now)
p.write_text(json.dumps(rec.to_dict()), encoding="utf-8")
got = load_head_record(str(p), now=now + timedelta(seconds=1))
assert got is not None
assert got.head_addr() == "1.2.3.4:6379"
def test_head_publisher_main_once_writes_file(tmp_path: Path):
from argus.ray import head_publisher as mod
p = tmp_path / "head.json"
rc = mod.main(
[
"--cluster-name",
"c",
"--head-ip-file",
str(p),
"--head-ip",
"9.9.9.9",
"--ttl-s",
"60",
"--once",
]
)
assert rc == 0
assert "9.9.9.9" in p.read_text(encoding="utf-8")

View File

@ -20,9 +20,27 @@ def test_new_task_id_is_deterministic_with_patches(monkeypatch):
assert ids.new_task_id("ppo") == "mvp2-ppo-20250101-010203-abcd"
def test_new_task_id_includes_user_when_provided(monkeypatch):
import argus.core.ids as ids
class _FakeDatetime:
@staticmethod
def now():
class _DT:
def strftime(self, fmt: str) -> str:
assert fmt == "%Y%m%d-%H%M%S"
return "20250101-010203"
return _DT()
monkeypatch.setattr(ids, "datetime", _FakeDatetime)
monkeypatch.setattr(ids.secrets, "token_hex", lambda n: "abcd")
assert ids.new_task_id("ppo", user_id="Alice_01") == "mvp2-alice_01-ppo-20250101-010203-abcd"
def test_attempt_submission_id_format():
from argus.core.ids import attempt_submission_id
assert attempt_submission_id("t", 1) == "t--a01"
assert attempt_submission_id("t", 12) == "t--a12"

View File

@ -14,7 +14,7 @@ def _mk_cfg(tmp_path: Path) -> tuple[RayConfig, V2Config]:
root = {
"ray": {
"address": "http://127.0.0.1:8265",
"shared_root": str(tmp_path),
"shared_root": "/private",
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
},
@ -32,10 +32,11 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
db.create_task_v25(
task_id="t1",
user_id="alice",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
nnodes=2,
n_gpus_per_node=4,
)
@ -50,9 +51,11 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
class _Tool:
def __init__(self, cfg):
self.submitted = []
self.job_dirs = []
def submit(self, spec, no_wait: bool):
def submit(self, spec, no_wait: bool, job_dir: str | None = None):
self.submitted.append(spec.submission_id)
self.job_dirs.append(job_dir)
return str(spec.submission_id)
def status(self, submission_id: str):
@ -71,6 +74,7 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
attempts = db.list_attempts("t1")
assert len(attempts) == 1
assert attempts[0]["ray_submission_id"] == "t1--a01"
assert s.tool.job_dirs[-1] == "/private/users/alice/jobs/t1--a01"
def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
@ -79,10 +83,11 @@ def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
db.create_task_v25(
task_id="t1",
user_id="alice",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
nnodes=2,
n_gpus_per_node=4,
)
@ -108,10 +113,11 @@ def test_sync_failed_insufficient_resources(monkeypatch, tmp_path: Path):
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
db.create_task_v25(
task_id="t1",
user_id="alice",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
nnodes=2,
n_gpus_per_node=4,
)
@ -147,10 +153,11 @@ def test_sync_status_error_keeps_state(monkeypatch, tmp_path: Path):
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
db.create_task_v25(
task_id="t1",
user_id="alice",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
nnodes=2,
n_gpus_per_node=4,
)

View File

@ -0,0 +1,179 @@
from __future__ import annotations
from pathlib import Path
import yaml
from fastapi.testclient import TestClient
def _write_config(tmp_path: Path) -> Path:
cfg = {
"ray": {
"address": "http://127.0.0.1:8265",
"shared_root": "/private",
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
},
"service": {
"api": {"host": "127.0.0.1", "port": 0},
"auth": {"token_env": "MVP_INTERNAL_TOKEN"},
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
},
}
p = tmp_path / "cfg.yaml"
p.write_text(yaml.safe_dump(cfg), encoding="utf-8")
return p
def test_db_user_token_lifecycle(tmp_path: Path):
from argus.service.db import Db
db = Db(str(tmp_path / "mvp.sqlite3"))
db.init()
db.create_user(user_id="alice", display_name="Alice")
tok = db.issue_token(user_id="alice")
info = db.resolve_token(tok)
assert info and info["user_id"] == "alice"
db.disable_user(user_id="alice")
info2 = db.resolve_token(tok)
assert info2 and info2["state"] == "DISABLED"
def test_admin_create_user_issue_token_and_disabled_rejected(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
r1 = c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice", "display_name": "Alice"})
assert r1.status_code == 200
assert r1.json()["user_id"] == "alice"
r2 = c.post("/api/v2/users/alice/tokens", headers=admin_headers)
assert r2.status_code == 200
token = r2.json()["token"]
assert token
# non-admin token can access regular endpoints
r3 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"})
assert r3.status_code == 200
r4 = c.post("/api/v2/users/alice:disable", headers=admin_headers)
assert r4.status_code == 200
r5 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"})
assert r5.status_code == 403
def test_tasks_are_isolated_by_user(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
# Deterministic task ids
counter = {"n": 0}
def _new_id(workload: str, **kwargs) -> str:
counter["n"] += 1
return f"{workload}_t{counter['n']}"
monkeypatch.setattr(app_mod, "new_task_id", _new_id)
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
# Create users and tokens
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "bob"}).status_code == 200
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
bob_tok = c.post("/api/v2/users/bob/tokens", headers=admin_headers).json()["token"]
alice_headers = {"authorization": f"Bearer {alice_tok}"}
bob_headers = {"authorization": f"Bearer {bob_tok}"}
# Each user submits one task.
r1 = c.post(
"/api/v2/tasks",
headers=alice_headers,
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
)
assert r1.status_code == 200
alice_tid = r1.json()["task_id"]
r2 = c.post(
"/api/v2/tasks",
headers=bob_headers,
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
)
assert r2.status_code == 200
bob_tid = r2.json()["task_id"]
# Queue is scoped to user.
qa = c.get("/api/v2/queue", headers=alice_headers).json()
qb = c.get("/api/v2/queue", headers=bob_headers).json()
assert {t["task_id"] for t in qa["pending"]} == {alice_tid}
assert {t["task_id"] for t in qb["pending"]} == {bob_tid}
# Cross-user access returns 404 (no existence leak).
assert c.get(f"/api/v2/tasks/{bob_tid}", headers=alice_headers).status_code == 404
assert c.post(f"/api/v2/tasks/{bob_tid}:cancel", headers=alice_headers).status_code == 404
assert c.get(f"/api/v2/tasks/{bob_tid}/attempts", headers=alice_headers).status_code == 404
# Admin can see global queue.
qadm = c.get("/api/v2/queue", headers=admin_headers).json()
assert {t["task_id"] for t in qadm["pending"]} == {alice_tid, bob_tid}
def test_submit_rejects_non_common_inputs(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
alice_headers = {"authorization": f"Bearer {alice_tok}"}
r = c.post(
"/api/v2/tasks",
headers=alice_headers,
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
)
assert r.status_code == 400
assert "code_path must start with /private/common/" in r.text

View File

@ -0,0 +1,58 @@
from __future__ import annotations
import json
from pathlib import Path
from argus.ray.discovery import build_head_record, write_head_record_atomic
from argus.ray.worker_watchdog import Watchdog
def test_watchdog_restarts_on_first_seen_and_on_change(tmp_path: Path):
head_file = tmp_path / "head.json"
calls: list[list[str]] = []
def runner(argv: list[str]) -> int:
calls.append(argv)
return 0
wd = Watchdog(
head_ip_file=str(head_file),
node_ip="10.0.0.2",
resources_json=json.dumps({"worker_node": 100}),
poll_s=1,
runner=runner,
)
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1"))
assert wd.tick_once() is not None
assert any("ray" in c[0] and c[1] == "start" for c in calls)
calls.clear()
# Same address -> no restart
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1"))
wd.tick_once()
assert calls == []
# Address change -> restart
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="2.2.2.2"))
wd.tick_once()
assert any(c[1] == "stop" for c in calls)
assert any(c[1] == "start" for c in calls)
def test_watchdog_main_once_invokes_runner(monkeypatch, tmp_path: Path):
from argus.ray import worker_watchdog as mod
head_file = tmp_path / "head.json"
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1"))
calls: list[list[str]] = []
def runner(argv: list[str]) -> int:
calls.append(argv)
return 0
monkeypatch.setattr(mod, "_default_runner", runner)
rc = mod.main(["--head-ip-file", str(head_file), "--node-ip", "10.0.0.2", "--resources-kv", "worker_node=100", "--once"])
assert rc == 0
assert any(c[1] == "start" for c in calls)

View File

@ -0,0 +1,24 @@
#!/usr/bin/env bash
set -euo pipefail
echo "[host] cleanup v2 legacy containers/processes (best-effort)"
# Known historical container names from v1.1/v2 era
LEGACY=(
mvp11-ray-head
mvp11-ray-worker-0
mvp11-ray-worker-1
mvp2-ray-head
mvp2-ray-worker-0
mvp2-ray-worker-1
)
for c in "${LEGACY[@]}"; do
if docker ps -a --format '{{.Names}}' | grep -qx "${c}"; then
echo "[host] removing legacy container: ${c}"
docker rm -f "${c}" >/dev/null 2>&1 || true
fi
done
echo "[host] legacy v2 cleanup done"

View File

@ -0,0 +1,44 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# shellcheck source=lib.sh
source "${SCRIPT_DIR}/lib.sh"
CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}"
HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}"
TTL_S="${TTL_S:-60}"
REFRESH_S="${REFRESH_S:-10}"
# PID file must be container-local to avoid conflicts if /private is shared across containers.
PID_PATH="${PID_PATH:-/tmp/argus_head_publisher.pid}"
LOG_PATH="${LOG_PATH:-${SHARED_ROOT}/common/logs/argus_head_publisher.log}"
HEAD_IP="$(container_ip "${HEAD_CONTAINER}")"
echo "[head] start discovery publisher (supervised): ${HEAD_IP_FILE} (head_ip=${HEAD_IP})"
# stop existing (best-effort)
dexec "${HEAD_CONTAINER}" bash -lc "if test -f '${PID_PATH}'; then pid=\$(cat '${PID_PATH}'); if kill -0 \"\${pid}\" >/dev/null 2>&1; then kill \"\${pid}\" || true; fi; rm -f '${PID_PATH}'; fi"
dexec "${HEAD_CONTAINER}" bash -lc "mkdir -p \"$(dirname "${LOG_PATH}")\" \"$(dirname "${HEAD_IP_FILE}")\""
# Supervisor loop: restart publisher if it exits.
docker exec -d "${HEAD_CONTAINER}" bash -lc "
nohup bash -lc '
export PYTHONPATH=/workspace/mvp/py
while true; do
python3 -m argus.ray.head_publisher \
--cluster-name \"${CLUSTER_NAME}\" \
--head-ip-file \"${HEAD_IP_FILE}\" \
--head-ip \"${HEAD_IP}\" \
--ttl-s \"${TTL_S}\" \
--refresh-s \"${REFRESH_S}\"
sleep 2
done
' >>'${LOG_PATH}' 2>&1 & echo \$! >'${PID_PATH}'
"
echo "[head] publisher pid stored in ${PID_PATH} (container-local)"
echo "[head] logs: ${LOG_PATH}"

View File

@ -0,0 +1,52 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# shellcheck source=lib.sh
source "${SCRIPT_DIR}/lib.sh"
CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}"
HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}"
POLL_S="${POLL_S:-5}"
WORKER_NODE_RESOURCE="${WORKER_NODE_RESOURCE:-100}"
start_one() {
local worker="$1"
local ip
ip="$(container_ip "${worker}")"
local pid_path="/tmp/argus_worker_watchdog.pid"
local log_path="${SHARED_ROOT}/common/logs/argus_worker_watchdog.${worker}.log"
echo "[${worker}] start stateless watchdog (supervised): head_file=${HEAD_IP_FILE} node_ip=${ip}"
# stop existing watchdog (best-effort)
dexec "${worker}" bash -lc "if test -f '${pid_path}'; then pid=\$(cat '${pid_path}'); if kill -0 \"\${pid}\" >/dev/null 2>&1; then kill \"\${pid}\" || true; fi; rm -f '${pid_path}'; fi"
# stop any legacy ray process to avoid split-brain
dexec "${worker}" bash -lc "ray stop --force || true"
dexec "${worker}" bash -lc "mkdir -p \"$(dirname "${log_path}")\""
docker exec -d "${worker}" bash -lc "
nohup bash -lc '
export PYTHONPATH=/workspace/mvp/py
while true; do
python3 -m argus.ray.worker_watchdog \
--head-ip-file \"${HEAD_IP_FILE}\" \
--node-ip \"${ip}\" \
--resources-kv \"worker_node=${WORKER_NODE_RESOURCE}\" \
--poll-s \"${POLL_S}\"
sleep 2
done
' >>'${log_path}' 2>&1 & echo \$! >'${pid_path}'
"
echo "[${worker}] watchdog pid stored in ${pid_path} (container-local)"
echo "[${worker}] logs: ${log_path}"
}
start_one "${WORKER0_CONTAINER}"
start_one "${WORKER1_CONTAINER}"
echo "[head] ray status"
dexec "${HEAD_CONTAINER}" bash -lc "ray status || true"

View File

@ -0,0 +1,17 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# shellcheck source=lib.sh
source "${SCRIPT_DIR}/lib.sh"
CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}"
HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}"
echo "[head] head.json (best-effort): ${HEAD_IP_FILE}"
dexec "${HEAD_CONTAINER}" bash -lc "if test -f '${HEAD_IP_FILE}'; then cat '${HEAD_IP_FILE}'; else echo 'missing'; fi"
echo
echo "[head] ray status"
dexec "${HEAD_CONTAINER}" bash -lc "ray status || true"

View File

@ -11,10 +11,30 @@ PPO_DATA_DIR="${SHARED_ROOT}/datasets/gsm8k"
SFT_DATA_DIR="${SHARED_ROOT}/datasets/gsm8k_sft"
CODE_SNAPSHOT_DIR="${SHARED_ROOT}/common/code/verl/verl_repo"
COMMON_DIR="${SHARED_ROOT}/common"
COMMON_DATASETS_DIR="${SHARED_ROOT}/common/datasets"
COMMON_HF_DIR="${SHARED_ROOT}/common/hf"
echo "[head] ensure dataset dirs exist"
dexec "${HEAD_CONTAINER}" bash -lc "mkdir -p '${PPO_DATA_DIR}' '${SFT_DATA_DIR}'"
echo "[head] ensure v2.5 common links (idempotent)"
# In existing deployments, /private/common/{datasets,hf} may already exist as directories (not symlinks).
# For v2.5 taskspecs, we only require:
# /private/common/datasets/gsm8k -> /private/datasets/gsm8k
# /private/common/datasets/gsm8k_sft -> /private/datasets/gsm8k_sft
# /private/common/hf/{hub,transformers} -> /private/hf/{hub,transformers} (best-effort)
dexec "${HEAD_CONTAINER}" bash -lc "
set -euo pipefail
mkdir -p '${COMMON_DIR}' '${COMMON_DATASETS_DIR}' '${COMMON_HF_DIR}'
ln -sfn '${SHARED_ROOT}/datasets/gsm8k' '${COMMON_DATASETS_DIR}/gsm8k'
ln -sfn '${SHARED_ROOT}/datasets/gsm8k_sft' '${COMMON_DATASETS_DIR}/gsm8k_sft'
mkdir -p '${SHARED_ROOT}/hf/hub' '${SHARED_ROOT}/hf/transformers'
ln -sfn '${SHARED_ROOT}/hf/hub' '${COMMON_HF_DIR}/hub'
ln -sfn '${SHARED_ROOT}/hf/transformers' '${COMMON_HF_DIR}/transformers'
echo 'common_links_ok'
"
echo "[head] prepare PPO dataset (gsm8k RL parquet) -> ${PPO_DATA_DIR}"
dexec "${HEAD_CONTAINER}" bash -lc "if [[ -f '${PPO_DATA_DIR}/train.parquet' && -f '${PPO_DATA_DIR}/test.parquet' ]]; then echo 'ppo_dataset_exists: skip'; else python3 /workspace/verl/examples/data_preprocess/gsm8k.py --local_save_dir '${PPO_DATA_DIR}'; fi"

View File

@ -0,0 +1,137 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# shellcheck source=lib.sh
source "${SCRIPT_DIR}/lib.sh"
# E2E v2.5:
# - Clean legacy env
# - Start containers
# - Start ray head + discovery publisher
# - Start stateless worker watchdogs (auto-connect)
# - Prepare data/model/code (reuses existing downloads)
# - Start API
# - Create user + issue token
# - Submit PPO/GRPO/SFT via API and wait
API_ADDR="${API_ADDR:-http://127.0.0.1:8080}"
ADMIN_TOKEN="${MVP_INTERNAL_TOKEN:-}"
USER_ID="${USER_ID:-alice}"
RESET_DB="${RESET_DB:-1}"
if [[ -z "${ADMIN_TOKEN}" ]]; then
echo "ERROR: MVP_INTERNAL_TOKEN must be set in host env (admin token)" >&2
exit 1
fi
api_curl_admin() {
curl -sS -H "Authorization: Bearer ${ADMIN_TOKEN}" "$@"
}
api_wait_ready() {
local tries="${1:-60}"
for i in $(seq 1 "${tries}"); do
if curl -sS -m 2 "${API_ADDR}/docs" >/dev/null 2>&1; then
echo "[host] api_ready: ${API_ADDR}"
return 0
fi
echo "[host] waiting api... (${i}/${tries})"
sleep 2
done
echo "ERROR: api not ready: ${API_ADDR}" >&2
return 1
}
submit_taskspec() {
local token="$1"
local taskspec_path="$2"
echo "[host] submit via API (user=${USER_ID}): ${taskspec_path}" >&2
local resp
resp="$(curl -sS -H "Authorization: Bearer ${token}" -H "Content-Type: application/yaml" --data-binary @"${taskspec_path}" "${API_ADDR}/api/v2/tasks")"
echo "[host] submit_resp: ${resp}" >&2
printf '%s' "${resp}" | python3 -c 'import sys,json; print(json.load(sys.stdin)["task_id"])'
}
wait_task() {
local token="$1"
local task_id="$2"
while true; do
local body state
body="$(curl -sS -H "Authorization: Bearer ${token}" "${API_ADDR}/api/v2/tasks/${task_id}")"
state="$(printf '%s' "${body}" | python3 -c 'import sys,json; print(json.load(sys.stdin)["state"])')"
echo "[host] task ${task_id}: ${state}"
if [[ "${state}" == "SUCCEEDED" ]]; then
return 0
fi
if [[ "${state}" == "FAILED" || "${state}" == "CANCELED" ]]; then
echo "[host] terminal=${state}; tail logs (best-effort):" >&2
curl -sS -H "Authorization: Bearer ${token}" "${API_ADDR}/api/v2/tasks/${task_id}/logs?tail=200" >&2 || true
return 1
fi
sleep 10
done
}
echo "[host] ===== run_all_v25_api.sh begin ====="
"${SCRIPT_DIR}/00_prereq_check.sh"
"${SCRIPT_DIR}/03_cleanup_v1_legacy.sh"
"${SCRIPT_DIR}/04_cleanup_v2_legacy.sh"
echo "[host] bring down existing containers (best-effort)"
"${SCRIPT_DIR}/02_down.sh" || true
echo "[host] (re)create containers"
"${SCRIPT_DIR}/01_up.sh"
echo "[host] restart ray head (no compute on head)"
"${SCRIPT_DIR}/20_start_head.sh"
echo "[host] start head discovery publisher"
"${SCRIPT_DIR}/22_start_head_discovery.sh"
echo "[host] start stateless workers (watchdog auto-connect)"
"${SCRIPT_DIR}/23_start_workers_stateless.sh"
echo "[host] prepare data/model/code snapshot (idempotent; reuse cache)"
"${SCRIPT_DIR}/30_prepare_data_and_model.sh"
echo "[host] install api deps in head container"
"${SCRIPT_DIR}/12_install_api_deps.sh"
echo "[host] stop api (best-effort)"
"${SCRIPT_DIR}/61_stop_api.sh" || true
if [[ "${RESET_DB}" == "1" ]]; then
echo "[host] reset api sqlite db in container (best-effort)"
docker exec -i "${HEAD_CONTAINER}" bash -lc "rm -f /private/common/db/mvp.sqlite3 /private/common/db/mvp.sqlite3-wal /private/common/db/mvp.sqlite3-shm || true"
else
echo "[host] keep existing api sqlite db (RESET_DB=${RESET_DB})"
fi
echo "[host] start api (admin token via env)"
MVP_INTERNAL_TOKEN="${ADMIN_TOKEN}" "${SCRIPT_DIR}/60_start_api.sh"
api_wait_ready 60
echo "[host] ensure user exists + issue token"
api_curl_admin -H "Content-Type: application/json" -d "{\"user_id\":\"${USER_ID}\"}" "${API_ADDR}/api/v2/users" >/dev/null 2>&1 || true
USER_TOKEN="$(api_curl_admin -X POST "${API_ADDR}/api/v2/users/${USER_ID}/tokens" | python3 -c 'import sys,json; print(json.load(sys.stdin)["token"])')"
echo "[host] user_token_issued: user=${USER_ID}"
PPO_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/ppo.yaml")"
GRPO_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/grpo.yaml")"
SFT_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/sft.yaml")"
echo "[host] submitted task ids:"
echo " ppo=${PPO_TASK_ID}"
echo " grpo=${GRPO_TASK_ID}"
echo " sft=${SFT_TASK_ID}"
echo "[host] wait for tasks (in submission order)"
wait_task "${USER_TOKEN}" "${PPO_TASK_ID}"
wait_task "${USER_TOKEN}" "${GRPO_TASK_ID}"
wait_task "${USER_TOKEN}" "${SFT_TASK_ID}"
echo "[host] ===== run_all_v25_api.sh done ====="

View File

@ -0,0 +1,156 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# shellcheck source=lib.sh
source "${SCRIPT_DIR}/lib.sh"
API_ADDR="${API_ADDR:-http://127.0.0.1:8080}"
ADMIN_TOKEN="${MVP_INTERNAL_TOKEN:-}"
if [[ -z "${ADMIN_TOKEN}" ]]; then
echo "ERROR: MVP_INTERNAL_TOKEN must be set (admin token)" >&2
exit 1
fi
require_cmd jq
require_cmd python3
api_admin() { curl -sS -H "Authorization: Bearer ${ADMIN_TOKEN}" "$@"; }
api_user() { local tok="$1"; shift; curl -sS -H "Authorization: Bearer ${tok}" "$@"; }
wait_state() {
local tok="$1"
local task_id="$2"
local want="$3"
local tries="${4:-120}"
for i in $(seq 1 "${tries}"); do
local st
st="$(api_user "${tok}" "${API_ADDR}/api/v2/tasks/${task_id}" | jq -r .state)"
echo "[case] ${task_id} state=${st} (want=${want})"
if [[ "${st}" == "${want}" ]]; then
return 0
fi
if [[ "${st}" == "FAILED" || "${st}" == "CANCELED" || "${st}" == "SUCCEEDED" ]]; then
return 1
fi
sleep 5
done
return 2
}
issue_token() {
local user_id="$1"
api_admin -H "Content-Type: application/json" -d "{\"user_id\":\"${user_id}\"}" "${API_ADDR}/api/v2/users" >/dev/null 2>&1 || true
api_admin -X POST "${API_ADDR}/api/v2/users/${user_id}/tokens" | jq -r .token
}
submit_yaml() {
local tok="$1"
local yaml_path="$2"
api_user "${tok}" -H "Content-Type: application/yaml" --data-binary @"${yaml_path}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id
}
echo "[case] ===== v2.5 e2e cases ====="
echo "[case] HP-1: run_all_v25_api.sh (happy path)"
RESET_DB="${RESET_DB:-1}"
MVP_INTERNAL_TOKEN="${ADMIN_TOKEN}" RESET_DB="${RESET_DB}" "${SCRIPT_DIR}/run_all_v25_api.sh"
echo "[case] E-Auth-1: missing bearer token"
code="$(curl -sS -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/queue" || true)"
[[ "${code}" == "401" ]] || { echo "expected 401 got ${code}" >&2; exit 1; }
echo "[case] E-Auth-2: invalid token"
code="$(curl -sS -o /dev/null -w '%{http_code}' -H 'Authorization: Bearer nope' "${API_ADDR}/api/v2/queue" || true)"
[[ "${code}" == "401" ]] || { echo "expected 401 got ${code}" >&2; exit 1; }
echo "[case] setup users alice/bob"
ALICE_TOKEN="$(issue_token alice)"
BOB_TOKEN="$(issue_token bob)"
echo "[case] E-Isolation-1: cross-user task visibility (404)"
TID="$(submit_yaml "${ALICE_TOKEN}" "${ROOT_DIR}/taskspecs/ppo.yaml")"
code="$(api_user "${BOB_TOKEN}" -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/tasks/${TID}" || true)"
[[ "${code}" == "404" ]] || { echo "expected 404 got ${code}" >&2; exit 1; }
# Keep the queue clean for subsequent cases (avoid consuming scheduler capacity).
api_user "${ALICE_TOKEN}" -X POST "${API_ADDR}/api/v2/tasks/${TID}:cancel" >/dev/null 2>&1 || true
wait_state "${ALICE_TOKEN}" "${TID}" "CANCELED" 30 || true
echo "[case] E-Input-1: reject non-common inputs"
bad_yaml="$(mktemp)"
cat >"${bad_yaml}" <<'YAML'
workload: "ppo"
submission_id: ""
code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/datasets/gsm8k/train.parquet"
val_file: "/private/common/datasets/gsm8k/test.parquet"
nnodes: 2
n_gpus_per_node: 4
total_epochs: 1
total_training_steps: 10
save_freq: 10
test_freq: -1
YAML
code="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${bad_yaml}" -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/tasks" || true)"
rm -f "${bad_yaml}"
[[ "${code}" == "400" ]] || { echo "expected 400 got ${code}" >&2; exit 1; }
echo "[case] B-Queue-1: pending_resources when demand > cluster capacity"
big_yaml="$(mktemp)"
cat >"${big_yaml}" <<'YAML'
workload: "ppo"
submission_id: ""
code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/common/datasets/gsm8k/train.parquet"
val_file: "/private/common/datasets/gsm8k/test.parquet"
nnodes: 3
n_gpus_per_node: 4
total_epochs: 1
total_training_steps: 10
save_freq: 10
test_freq: -1
YAML
BIG_TID="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${big_yaml}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id)"
rm -f "${big_yaml}"
wait_state "${ALICE_TOKEN}" "${BIG_TID}" "PENDING_RESOURCES" 60
echo "[case] B-Cancel-1: cancel a running task (best-effort)"
long_yaml="$(mktemp)"
cat >"${long_yaml}" <<'YAML'
workload: "ppo"
submission_id: ""
code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/common/datasets/gsm8k/train.parquet"
val_file: "/private/common/datasets/gsm8k/test.parquet"
nnodes: 2
n_gpus_per_node: 4
total_epochs: 1
total_training_steps: 200
save_freq: 10
test_freq: -1
YAML
LONG_TID="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${long_yaml}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id)"
rm -f "${long_yaml}"
for i in $(seq 1 60); do
st="$(api_user "${ALICE_TOKEN}" "${API_ADDR}/api/v2/tasks/${LONG_TID}" | jq -r .state)"
echo "[case] wait running: ${LONG_TID} state=${st}"
if [[ "${st}" == "RUNNING" ]]; then
break
fi
if [[ "${st}" == "FAILED" || "${st}" == "SUCCEEDED" ]]; then
break
fi
sleep 5
done
api_user "${ALICE_TOKEN}" -X POST "${API_ADDR}/api/v2/tasks/${LONG_TID}:cancel" | jq .
st="$(api_admin "${API_ADDR}/api/v2/tasks/${LONG_TID}" | jq -r .state)"
[[ "${st}" == "CANCELED" ]] || { echo "expected CANCELED got ${st}" >&2; exit 1; }
echo "[case] ===== all v2.5 e2e cases passed ====="

View File

@ -6,8 +6,8 @@ code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/datasets/gsm8k/train.parquet"
val_file: "/private/datasets/gsm8k/test.parquet"
train_file: "/private/common/datasets/gsm8k/train.parquet"
val_file: "/private/common/datasets/gsm8k/test.parquet"
nnodes: 2
n_gpus_per_node: 4
@ -17,4 +17,3 @@ total_training_steps: 10
save_freq: 10
test_freq: -1

View File

@ -8,8 +8,8 @@ code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/datasets/gsm8k/train.parquet"
val_file: "/private/datasets/gsm8k/test.parquet"
train_file: "/private/common/datasets/gsm8k/train.parquet"
val_file: "/private/common/datasets/gsm8k/test.parquet"
nnodes: 2
n_gpus_per_node: 4
@ -19,4 +19,3 @@ total_training_steps: 10
save_freq: 10
test_freq: -1

View File

@ -6,7 +6,7 @@ code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/datasets/gsm8k_sft/train.parquet"
train_file: "/private/common/datasets/gsm8k_sft/train.parquet"
val_file: null
nnodes: 2
@ -19,4 +19,3 @@ save_freq: 10
# SFT driver 默认不分配 GPUray job entrypoint 不指定 entrypoint_num_gpus因此 driver 侧不要依赖 CUDA
trainer_device: "cpu"