v2.5 功能测试通过,待加上docker镜像

This commit is contained in:
yuyr 2025-12-29 14:33:05 +08:00
parent ce8c2128b4
commit f02059126e
30 changed files with 1794 additions and 66 deletions

View File

@ -11,4 +11,4 @@ v2.5 的核心变化:
- `specs/mvp/v2.5/v2.5_design.md`总体架构、关键机制head IP file / watchdog / 用户隔离 / 任务流)。 - `specs/mvp/v2.5/v2.5_design.md`总体架构、关键机制head IP file / watchdog / 用户隔离 / 任务流)。
- `specs/mvp/v2.5/v2.5_api.md`API 设计(用户、任务、队列、日志)与鉴权约定。 - `specs/mvp/v2.5/v2.5_api.md`API 设计(用户、任务、队列、日志)与鉴权约定。
- `specs/mvp/v2.5/v2.5_acceptance.md`:开发/部署/验收流程与可验证标准。 - `specs/mvp/v2.5/v2.5_acceptance.md`:开发/部署/验收流程与可验证标准。
- `specs/mvp/v2.5/v2.5_summary.md`v2.5 已实现内容总结(本次迭代做了什么、验收结果、已知限制)。

View File

@ -0,0 +1,3 @@
# 记录问题
1. task 、 submission id 里加上 user name
2. 补全端到端测试用例,各种正常和异常用例,边界情况测试

View File

@ -0,0 +1,229 @@
# MVP v2.5 开发计划TDD 驱动)
本文是 v2.5 的**工程化开发计划**强调“先写测试再写实现”TDD并将每个里程碑拆成**可独立验收**的小闭环。
输入依据:
- 路线图:`specs/mvp/mvp_roadmap_v2.md`
- v2.5 设计:`specs/mvp/v2.5/v2.5_design.md`
- v2.5 API 草案:`specs/mvp/v2.5/v2.5_api.md`
- v2.5 验收:`specs/mvp/v2.5/v2.5_acceptance.md`
v2.5 约束(已确认):
- **不扩展 TaskSpec**:沿用 v2.0/v2.1 的 YAML 结构化字段与语义。
- **不支持自定义 reward function / 不支持用户自定义 verl 代码**
- 训练输入verl 代码、HF cache、datasets统一使用 `/private/common/...`
- 多用户隔离 v2.5 **先只隔离 jobs 输出目录**`/private/users/<uid>/jobs/<ray_submission_id>/...`
---
## 0. TDD 规范(所有功能都遵循)
### 0.1 测试分层
1) **单元测试fast**
- 纯 Python 逻辑DB、鉴权、ID、路径派生、head.json 解析/TTL、watchdog 决策逻辑。
- 目标:不依赖真实 Ray、不依赖 docker、不依赖网络。
2) **组件测试(中等)**
- FastAPI 路由:使用 `fastapi.testclient.TestClient`(现有 v2.0 已采用)。
- 目标:验证 auth/权限隔离、API 行为、状态机。
3) **端到端(慢/手工或脚本)**
- 在 `argus@h1` 上通过 scripts/compose 跑一次“head publish → worker auto-connect → API submit”闭环。
- 目标:验证无状态 worker + watchdog 的真实行为。
### 0.2 测试约定
- 测试目录:`src/mvp/py/tests/`
- 新增功能必须先补齐测试用例,并让其在未实现时失败(红)。
- 实现最小改动让测试变绿(绿)。
- 重构/去重复(重构)。
> 注:现有测试通过 `src/mvp/py/tests/conftest.py` 注入 ray stub确保单测不依赖真实 ray 包v2.5 新增模块也应复用此模式。
---
## 1. 里程碑拆分v2.5 = 4 个可验证闭环)
### M1User 表/Token 表 + 基础鉴权(不影响现有内部 token 兼容)
**目标**
- 引入 user/token 的持久化与鉴权映射token → user_id
- 兼容现有 `Authorization: Bearer <MVP_INTERNAL_TOKEN>` 的“单租户模式”,避免一次性破坏 v2.0 用法:
- v2.5 可以先支持两种 token 模式:
- legacy环境变量 `MVP_INTERNAL_TOKEN`(全局单租户);
- user tokenDB 内签发 token多用户
- admin 能创建用户、签发 token、禁用用户。
**TDD 用例(先写测试)**
单测:
- `test_user_db_create_disable()`
- 创建用户 ACTIVE禁用后状态变为 DISABLED重复创建返回冲突或幂等按最终约定
- `test_token_hashing()`
- 签发 token 时 DB 中只保存 hash不保存明文。
API 测试TestClient
- `test_admin_create_user_and_issue_token()`
- admin token 可创建用户并签发 token明文 token 只返回一次)。
- `test_disabled_user_token_rejected()`
- 用户被禁用后,使用旧 token 调用 API 返回 401/403。
**实现落点(建议模块)**
- `argus.service.auth`token 校验与 user_id 解析(兼容 legacy 模式)
- `argus.service.db`:新增 `users``api_tokens` 表与 CRUD
- `argus.service.app`:新增 user 管理 endpointsadmin scope
- `configs/dev.yaml`:补充 admin token/env 相关配置(保持 YAML 风格)
**验收点**
- `v2.5_acceptance.md`U1 可通过自动化 API 测试覆盖。
---
### M2Task 绑定 user_id + API 可见性隔离(仍不改 TaskSpec
**目标**
- 提交 task 时由 token 推导 `user_id`,写入 `tasks.user_id`
- task 查询/取消/日志默认只允许 owner他人访问返回 404避免泄露存在性
- queue 默认只返回当前用户队列admin 可查询全局队列(可选)。
**TDD 用例(先写测试)**
单测:
- `test_tasks_table_has_user_id()`:创建任务必须落 `user_id`,且 `list_queue(user_id=...)` 只返回该用户任务。
API 测试:
- `test_task_visibility_isolated()`
- user A 创建 taskuser B 查询 `/api/v2/tasks/{id}` 返回 404
- user B cancel/logs 也返回 404。
- `test_queue_isolated()`
- A/B 各自创建 task`GET /api/v2/queue` 只看到自己的。
**实现落点**
- `argus.service.app`:为 task endpoints 增加 user scope
- `argus.service.db`tasks 表增加 user_id 字段、索引、按 user 过滤的查询方法
- `argus.service.scheduler`pick_next_runnable_task 等仍按“全局 FIFO”或“按 user FIFO”
- v2.5 先保持“全局 FIFO”最简单但 API queue 视角是按 user 过滤)。
**验收点**
- `v2.5_acceptance.md`U2 可通过 API 测试覆盖。
---
### M3Jobs 输出目录按 user 隔离(只改输出,不改输入)
**目标**
- Ray Job 的 job_root 目录由服务端统一计算到:
- `/private/users/<uid>/jobs/<ray_submission_id>/...`
- TaskSpec 内与输入相关的路径字段必须是 `/private/common/...`v2.5 输入统一 common
- 任何用户无法通过 TaskSpec 指定输出写到非 user jobs 目录(避免越权写)。
**TDD 用例(先写测试)**
单测:
- `test_job_root_derivation_per_user()`
- 给定 user_id 与 ray_submission_id派生 job_root 固定且正确。
- `test_reject_non_common_inputs()`
- TaskSpec 中 train_file / val_file / code_path / hf 路径等若不以 `/private/common/` 开头则拒绝HTTP 400
API 测试:
- `test_job_dir_written_under_user_jobs()`
- 提交 task 后,在 DB 或 submit payload 中能看到 job_root 在 user jobs 下(可通过 mock RayJobTool.submit 捕获 spec
**实现落点(建议最小侵入)**
- 在 service 层派生 `job_root` 并注入到 RayJobTool/builders而不是让用户从 TaskSpec 指定)。
- RayJobTool `_job_dir()` 改为接收“job_root 生成器”或直接接收 `job_root` 参数(由服务层提供)。
- 目标:保持 RayJobTool 的职责清晰:提交 Ray job路径策略由 service 决定。
**验收点**
- `v2.5_acceptance.md`U3/U4 可通过 API/单测覆盖。
---
### M4Stateless Ray Node Poolhead.json + worker watchdog+ 端到端脚本验证
**目标**
- head 启动后持续写入 `/private/ray/discovery/<cluster_name>/head.json`(包含 TTL
- worker 容器内运行 watchdog或启动脚本 + watchdog无需平台显式传 head 地址:
- 读取 head.json存在且未过期`ray start --address=<head_ip>:<gcs_port>`
- head.json 变化 → `ray stop` + `ray start` 重连
- 在 dev 环境docker compose提供一键脚本复现e2e
**TDD 用例(先写测试)**
单测(不跑真实 ray
- `test_head_json_read_validate_ttl()`
- 文件不存在/过期 → 返回“不可用”
- 未过期 → 返回 head 地址
- `test_watchdog_decision_on_change()`
- head_ip 变化 → 触发重连动作
- only updated_at 变化(地址不变)→ 不重连(或按策略重连,需确定)
组件/脚本级测试(可选):
- 如果 watchdog 用 Python 实现,可对“执行命令”层做 stub不真正跑 `ray start`),只验证会调用什么命令。
端到端脚本(手工/慢):
- 提供脚本 `scripts/run_all_v25_stateless.sh`(命名示例):
1) 起 headRay head + API
2) 启动 head publisher写 head.json
3) 起 2 个 worker每个 4 GPUworker 只跑 watchdog不传 head 地址
4) `ray status` 显示 1 head + 2 worker 且 GPU=8
5) 通过 API 创建用户/签发 token提交 PPO/GRPO/SFT
6) 重启 head或更新 head.json 指向新地址)验证 worker 自动重连
**实现落点(建议实现策略)**
为了可测试性TDD推荐把“读 head.json/判定 TTL/生成 ray start 命令”做成 Python 模块:
- `argus.ray.discovery`read/write head.json原子写、TTL
- `argus.ray.worker_watchdog`watch looppolling + change detection执行命令可注入便于单测 stub
脚本层保持薄:
- `scripts/` 负责 docker exec / compose 编排与进程守护;
- watchdog 进程由容器内 python 模块运行(更可测、更易移植到生产平台的 entrypoint/command
**验收点**
- `v2.5_acceptance.md`A1/A2/A3 主要通过 e2e 脚本 + dashboard/日志验证。
---
## 2. 回归策略(确保 v2.0 不被破坏)
在 v2.5 过程中保留并持续回归以下用例(至少单测覆盖):
- 旧的内部 token 模式仍可访问 `GET /api/v2/queue` 与提交 task若决定保留兼容
- scheduler 的“资源不足 → PENDING_RESOURCES → 延迟重试”行为不变(现有 `test_scheduler.py` 覆盖)。
- `ray entrypoint_resources` 强制 driver 落 worker继续使用 `worker_node` 自定义资源)。
---
## 3. 交付清单(代码/脚本/文档)
### 3.1 代码
- user/tokensDB schema + auth + API endpoints
- tasks绑定 user_id + 权限隔离
- job_root按 user jobs 输出目录派生(输入仍 common
- discovery/watchdoghead.json + worker 自愈
### 3.2 scriptsdev e2e
- head启动 Ray head + head publisher
- workers以无状态方式启动不传 head addr+ watchdog
- `run_all`:一键跑通(含 API submit + 查询 + cancel + 观察队列)
### 3.3 文档
- 更新 `specs/mvp/v2.5/*`(设计/API/验收/开发计划)
- 补充 `src/mvp/README.md` 的 v2.5 使用方式(如需要)
---
## 4. 关键待确认点(开始实现前必须定稿)
1) **legacy token 是否继续兼容**
- 方案 A保留 `MVP_INTERNAL_TOKEN`(单租户)+ 新增 user token多租户
- 方案 Bv2.5 直接切换到 user token破坏兼容但更清晰
2) **调度公平性**
- v2.5 先全局 FIFO简单后续 v3 再引入 per-user 公平调度/配额。
3) **head.json 的生产写入者**
- 方案 A与 API 同进程线程(最少组件)
- 方案 B独立进程更独立、易运维

View File

@ -0,0 +1,132 @@
# MVP v2.5 端到端测试用例(正常/异常/边界)
本用例集目标:覆盖 v2.5 的关键能力与边界条件User + jobs 隔离 + stateless node pool + API 队列调度)。
约束v2.5 已确认):
- TaskSpec 不扩展;不支持 reward_fn不支持用户自定义 verl 代码。
- 输入统一 `/private/common/...`;用户隔离先只隔离 `/private/users/<uid>/jobs/...` 输出。
---
## 0. 环境前置
远端目录示例:
- `argus@h1:/home2/argus/infra/mvp/src/mvp/`
共享目录(宿主机):
- `/home2/argus/infra/mvp/shared/`
容器内路径约定:
- `/private` 为共享存储挂载点
需要:
- GPU 0-7 可用
- 3 容器head无 GPU+ 2 worker各 4 GPU
---
## 1. 正常用例Happy Path
### HP-1v2.5 全链路PPO → GRPO → SFT串行
步骤:
1) `cd /home2/argus/infra/mvp/src/mvp/scripts`
2) `MVP_INTERNAL_TOKEN=<admin_token> RESET_DB=1 ./run_all_v25_api.sh`
期望:
- Ray dashboard 显示 3 nodeshead+2 workersGPU 总数 8。
- 3 个 task 最终为 `SUCCEEDED`
- 输出目录存在且按用户隔离:
- `/private/users/<uid>/jobs/<ray_submission_id>/{config,logs,checkpoints,debug}`
### HP-2Driver 不在 head 跑
验证点(任选一种):
- Ray job 的 driver node IP 不等于 head 容器 IP
- 或日志/调度信息显示 entrypoint_resources 生效driver 在 worker
---
## 2. 异常用例Error Cases
### E-Auth-1缺 token
请求:
- `GET /api/v2/queue` 不带 `Authorization`
期望:
- 返回 401missing bearer token
### E-Auth-2无效 token
请求:
- `Authorization: Bearer <random>`
期望:
- 返回 401invalid token
### E-Auth-3用户禁用后拒绝访问
步骤:
1) admin 创建用户 `bob` 并签发 token
2) admin 禁用 `bob`
3) 用 bob token 请求 `/api/v2/queue`
期望:
- 返回 403user disabled
### E-Isolation-1跨用户访问 task 资源(不泄露存在性)
步骤:
1) alice 提交 task 得到 `task_id`
2) bob 查询 `/api/v2/tasks/{task_id}`
期望:
- 返回 404task not found
### E-Input-1输入路径不在 /private/commonv2.5 约束)
请求:
- 提交 taskspec 但 `train_file``code_path` 不以 `/private/common/` 开头
期望:
- 返回 400并给出具体字段错误例如 `train_file must start with /private/common/`)。
---
## 3. 边界用例Boundary
### B-Queue-1资源不足时不提交 RayPENDING_RESOURCES
步骤:
1) 构造任务需求 `nnodes=3``n_gpus_per_node=4`total 12 GPU
2) 提交后轮询状态
期望:
- task 进入 `PENDING_RESOURCES`(服务侧 pending不向 Ray submit
- 具备 `next_run_at`
### B-Cancel-1任务取消QUEUED/RUNNING
步骤:
1) 提交一个较长 steps 的任务(确保有机会 RUNNING
2) 调用 `POST /api/v2/tasks/{task_id}:cancel`
期望:
- task state 为 `CANCELED`
- attempt 中 `ray_status` 最终为 `STOPPED`(或 Ray 侧停止)
---
## 4. 可执行回归脚本
见:
- `src/mvp/scripts/run_e2e_v25_cases.sh`
脚本覆盖:
- HP-1
- E-Auth-1/E-Auth-2/E-Input-1
- E-Isolation-1
- B-Queue-1
- B-Cancel-1best-effort

View File

@ -0,0 +1,92 @@
# MVP v2.5 迭代总结(已落地)
本文档总结 v2.5 在 v2.0/v2.1/v2.2…基础上完成的能力、实现点、验收方式与已知限制,便于回顾与后续版本迭代对齐。
## 目标与边界
v2.5 的核心目标:
- 引入 **User Management用户管理**:基于 token 的鉴权与任务级隔离(“只隔离 jobs”
- 引入 **Stateless Ray Node Pool无状态 Ray worker 池)**worker 不依赖平台下发 head 地址,自动发现并连接/自愈。
- 保持 **TaskSpecv1.1 同款 YAML 格式)不扩展**:本迭代不支持 reward function、自定义 verl 代码等。
明确不做v2.5 约束):
- 不支持 TaskSpec 扩展(例如 `reward_fn_path` 等)。
- 不支持用户自定义 verl/hf/dataset 的隔离或自定义路径:**统一使用 `/private/common/...`** 的公共资源。
- 用户隔离仅覆盖 **任务与产物目录**jobs不覆盖 HF cache、datasets 等公共缓存。
## 关键能力(对外表现)
### 1) 多用户鉴权与任务隔离
- API 仍使用内部 `Authorization: Bearer <token>` 方式:
- 管理员 token 来自环境变量 `MVP_INTERNAL_TOKEN`admin
- 业务用户 token 由管理员通过 API 下发并持久化到 SQLite。
- 用户隔离策略:
- 非管理员用户只能查询/取消/拉取日志 **自己的 task**;跨用户访问返回 404不泄露存在性
- 训练产物落盘隔离Ray job 目录统一写入 `/private/users/<user_id>/jobs/<ray_submission_id>/...`
### 2) task_id / submission_id 带用户名
- 新任务 ID 规则:`mvp2-<user>-<workload>-<YYYYMMDD-HHMMSS>-<suffix>`
- Ray submission idattempt规则`<task_id>--aNN`,因此自然包含用户名。
- 作用Dashboard/日志/落盘目录可读性更强,便于按用户追踪和审计。
### 3) “无状态 worker 池”与 head 地址发现
- Head 在共享存储写入 **head 地址文件**(例如 `head.json`worker 通过 watchdog
- 轮询发现 head 地址
- 自动 `ray start --address ...` 加入集群
- 掉线后自动重连watchdog 自愈)
- 达成效果:在生产环境中,即使 worker 容器由算力平台创建(只提供 SSH 纳管),也能通过共享存储实现连接与自愈。
### 4) 任务调度:队列 + Ray Job 提交 + 状态回传
- API 提交任务后进入 SQLite 队列,由后台 scheduler 逐个提交到 Ray默认 `max_running_tasks=1`)。
- Scheduler 持续轮询 Ray job 状态并回写任务状态RUNNING/SUCCEEDED/FAILED/CANCELED
- 资源不足的“可重试失败”处理:
- 针对 VERL 的 fail-fast`Total available GPUs ... is less than total desired GPUs ...`)或集群资源不足,
任务进入 `PENDING_RESOURCES` 并设置 `next_run_at`,按 `retry_interval_s` 周期重试。
## 关键实现点(工程化落地)
### 存储与目录约定(容器内视角)
- 共享根路径统一为 `/private`(对齐生产挂载)。
- v2.5 强约束TaskSpec 的以下字段必须以 `/private/common/` 开头:
- `code_path` / `train_file` / `val_file`
- 公共目录(示例):
- `/private/common/hf`HF 缓存
- `/private/common/datasets`:训练数据(必要时通过 symlink 指向已有缓存目录复用下载)
- `/private/common/db/mvp.sqlite3`队列与用户信息SQLite
- `/private/common/logs`API / watchdog 日志
- `/private/users/<uid>/jobs/...`:用户作业产物(隔离)
### Ray 拓扑与“head 不跑训练”
- Head 启动为管理节点CPU/GPU=0避免训练任务落到 head。
- Worker 节点具备 GPU示例2 个 worker * 每个 4 GPU
- driver 通过 `entrypoint_resources`(例如 `worker_node: 1`)强制落 worker。
### 部署脚本与可重复执行
提供完整脚本链路,覆盖:
- 清理 legacy 环境、起停容器、启动 Ray head
- head discovery publisher、worker watchdog 启动与状态检查
- 数据/模型/代码准备(幂等、可复用已有下载)
- 启动 API server并支持 RESET_DB
- API 方式连续提交 PPO/GRPO/SFT 并等待完成
代表性脚本:
- `src/mvp/scripts/run_all_v25_api.sh`v2.5 happy-path 端到端(含重建集群、准备资源、起 API、提交 3 类任务)
- `src/mvp/scripts/run_e2e_v25_cases.sh`:在 happy-path 基础上增加鉴权/隔离/输入校验/资源不足/取消等用例
## 验收与测试(已通过)
### 单元测试(本机 venv
- `.venv/bin/python -m pytest`
- 覆盖率阈值:>= 90%
### 远端端到端h1
- 在 `argus@h1:/home2/argus/infra/mvp/src/mvp/scripts` 执行:
- `MVP_INTERNAL_TOKEN=mvp-dev-token RESET_DB=1 ./run_e2e_v25_cases.sh`
- 结果happy-pathPPO/GRPO/SFT完成且异常/边界用例验证通过(鉴权、跨用户隔离、输入校验、资源不足转 PENDING_RESOURCES、取消任务等
## 已知问题与后续建议
- `max_running_tasks=1` 会让队列中的任务在前序 RUNNING 时保持 QUEUED这在“资源不足”边界测试里需要显式清空/取消前序任务,或接受该行为作为设计的一部分。
- 当前仍是 SQLite 单点;后续若要 HA/水平扩展,可在 v2.6+ 引入更强的持久化与多副本(例如 Postgres/etcd
- API server / watchdog 目前以脚本方式守护;后续可进一步统一为 systemd/supervisor或平台侧守护并补齐健康检查与告警。

View File

@ -15,3 +15,4 @@
快速开始: 快速开始:
- CLI 提交流程:`scripts/run_all_cli.sh` - CLI 提交流程:`scripts/run_all_cli.sh`
- API 提交流程:`scripts/run_all_api.sh` - API 提交流程:`scripts/run_all_api.sh`
- v2.5Stateless worker + user 隔离 jobsE2E`scripts/run_all_v25_api.sh`

View File

@ -2,14 +2,27 @@ from __future__ import annotations
import secrets import secrets
from datetime import datetime from datetime import datetime
import re
def new_task_id(workload: str) -> str: _USER_SAFE_RE = re.compile(r"[^a-z0-9_-]+")
def _normalize_user_id(user_id: str) -> str:
s = (user_id or "").strip().lower()
s = _USER_SAFE_RE.sub("-", s)
s = re.sub(r"-{2,}", "-", s).strip("-")
return s[:24] if s else ""
def new_task_id(workload: str, *, user_id: str | None = None) -> str:
ts = datetime.now().strftime("%Y%m%d-%H%M%S") ts = datetime.now().strftime("%Y%m%d-%H%M%S")
suffix = secrets.token_hex(2) suffix = secrets.token_hex(2)
u = _normalize_user_id(user_id or "")
if u:
return f"mvp2-{u}-{workload}-{ts}-{suffix}"
return f"mvp2-{workload}-{ts}-{suffix}" return f"mvp2-{workload}-{ts}-{suffix}"
def attempt_submission_id(task_id: str, attempt_no: int) -> str: def attempt_submission_id(task_id: str, attempt_no: int) -> str:
return f"{task_id}--a{attempt_no:02d}" return f"{task_id}--a{attempt_no:02d}"

View File

@ -0,0 +1,115 @@
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _parse_utc(ts: str) -> datetime:
# Accept ISO8601 with trailing Z
s = ts.strip()
if s.endswith("Z"):
s = s[:-1] + "+00:00"
return datetime.fromisoformat(s).astimezone(timezone.utc)
def _iso_z(dt: datetime) -> str:
return dt.astimezone(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
@dataclass(frozen=True)
class HeadRecord:
cluster_name: str
head_ip: str
gcs_port: int
dashboard_port: int
job_server_url: str
updated_at: str
expires_at: str
@staticmethod
def from_dict(d: dict[str, Any]) -> "HeadRecord":
return HeadRecord(
cluster_name=str(d["cluster_name"]),
head_ip=str(d["head_ip"]),
gcs_port=int(d.get("gcs_port", 6379)),
dashboard_port=int(d.get("dashboard_port", 8265)),
job_server_url=str(d.get("job_server_url") or f"http://{d['head_ip']}:{int(d.get('dashboard_port', 8265))}"),
updated_at=str(d["updated_at"]),
expires_at=str(d["expires_at"]),
)
def to_dict(self) -> dict[str, Any]:
return {
"cluster_name": self.cluster_name,
"head_ip": self.head_ip,
"gcs_port": self.gcs_port,
"dashboard_port": self.dashboard_port,
"job_server_url": self.job_server_url,
"updated_at": self.updated_at,
"expires_at": self.expires_at,
}
def head_addr(self) -> str:
return f"{self.head_ip}:{self.gcs_port}"
def build_head_record(
*,
cluster_name: str,
head_ip: str,
gcs_port: int = 6379,
dashboard_port: int = 8265,
ttl_s: int = 60,
now: datetime | None = None,
) -> HeadRecord:
now_dt = now or _utc_now()
expires = now_dt + timedelta(seconds=int(ttl_s))
updated_at = _iso_z(now_dt)
expires_at = _iso_z(expires)
return HeadRecord(
cluster_name=cluster_name,
head_ip=head_ip,
gcs_port=int(gcs_port),
dashboard_port=int(dashboard_port),
job_server_url=f"http://{head_ip}:{int(dashboard_port)}",
updated_at=updated_at,
expires_at=expires_at,
)
def load_head_record(path: str, *, now: datetime | None = None) -> HeadRecord | None:
p = Path(path)
if not p.exists():
return None
try:
obj = json.loads(p.read_text(encoding="utf-8"))
except Exception:
return None
if not isinstance(obj, dict):
return None
try:
rec = HeadRecord.from_dict(obj)
expires = _parse_utc(rec.expires_at)
except Exception:
return None
now_dt = now or _utc_now()
if expires <= now_dt:
return None
return rec
def write_head_record_atomic(path: str, rec: HeadRecord) -> None:
p = Path(path)
os.makedirs(p.parent, exist_ok=True)
tmp = p.with_suffix(p.suffix + ".tmp")
tmp.write_text(json.dumps(rec.to_dict(), indent=2, sort_keys=True) + "\n", encoding="utf-8")
tmp.replace(p)

View File

@ -0,0 +1,65 @@
from __future__ import annotations
import argparse
import time
from .discovery import build_head_record, write_head_record_atomic
def publish_once(
*,
cluster_name: str,
head_ip_file: str,
head_ip: str,
gcs_port: int,
dashboard_port: int,
ttl_s: int,
) -> None:
rec = build_head_record(
cluster_name=cluster_name,
head_ip=head_ip,
gcs_port=gcs_port,
dashboard_port=dashboard_port,
ttl_s=ttl_s,
)
write_head_record_atomic(head_ip_file, rec)
def main(argv: list[str] | None = None) -> int:
ap = argparse.ArgumentParser(description="Publish Ray head address to shared storage (head.json).")
ap.add_argument("--cluster-name", required=True)
ap.add_argument("--head-ip-file", required=True)
ap.add_argument("--head-ip", required=True)
ap.add_argument("--gcs-port", type=int, default=6379)
ap.add_argument("--dashboard-port", type=int, default=8265)
ap.add_argument("--ttl-s", type=int, default=60)
ap.add_argument("--refresh-s", type=int, default=10)
ap.add_argument("--once", action="store_true", help="Write once then exit (for testing/debug).")
args = ap.parse_args(argv)
if args.once:
publish_once(
cluster_name=args.cluster_name,
head_ip_file=args.head_ip_file,
head_ip=args.head_ip,
gcs_port=args.gcs_port,
dashboard_port=args.dashboard_port,
ttl_s=args.ttl_s,
)
return 0
refresh_s = max(1, int(args.refresh_s))
while True:
publish_once(
cluster_name=args.cluster_name,
head_ip_file=args.head_ip_file,
head_ip=args.head_ip,
gcs_port=args.gcs_port,
dashboard_port=args.dashboard_port,
ttl_s=args.ttl_s,
)
time.sleep(refresh_s)
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -73,9 +73,9 @@ class RayJobTool:
return {"env_vars": env_vars} return {"env_vars": env_vars}
def submit(self, spec: JobSpec, no_wait: bool) -> str: def submit(self, spec: JobSpec, no_wait: bool, job_dir: str | None = None) -> str:
submission_id = spec.submission_id or f"mvp11_{spec.workload}_{_ts()}_{os.getpid()}" submission_id = spec.submission_id or f"mvp11_{spec.workload}_{_ts()}_{os.getpid()}"
job_dir = self._job_dir(submission_id) job_dir = job_dir or self._job_dir(submission_id)
built = build_training_argv(spec, submission_id=submission_id, job_dir=job_dir) built = build_training_argv(spec, submission_id=submission_id, job_dir=job_dir)
entrypoint_argv = [ entrypoint_argv = [

View File

@ -0,0 +1,110 @@
from __future__ import annotations
import argparse
import json
import subprocess
import time
from dataclasses import dataclass
from typing import Callable
from .discovery import HeadRecord, load_head_record
Runner = Callable[[list[str]], int]
def _default_runner(argv: list[str]) -> int:
proc = subprocess.run(argv, check=False)
return int(proc.returncode)
def _ray_stop_cmd() -> list[str]:
return ["ray", "stop", "--force"]
def _ray_start_cmd(*, head_addr: str, node_ip: str | None, resources_json: str) -> list[str]:
argv = ["ray", "start", f"--address={head_addr}", f"--resources={resources_json}", "--disable-usage-stats"]
if node_ip:
argv.append(f"--node-ip-address={node_ip}")
return argv
@dataclass
class Watchdog:
head_ip_file: str
node_ip: str | None
resources_json: str
poll_s: int
runner: Runner
_current_head_addr: str | None = None
def _restart_ray(self, desired: str) -> None:
# Best-effort stop then start.
self.runner(_ray_stop_cmd())
self.runner(_ray_start_cmd(head_addr=desired, node_ip=self.node_ip, resources_json=self.resources_json))
self._current_head_addr = desired
def tick_once(self) -> HeadRecord | None:
rec = load_head_record(self.head_ip_file)
if rec is None:
return None
desired = rec.head_addr()
if self._current_head_addr != desired:
self._restart_ray(desired)
return rec
def run_forever(self) -> None:
while True:
try:
self.tick_once()
except Exception:
# keep watchdog alive
pass
time.sleep(max(1, int(self.poll_s)))
def main(argv: list[str] | None = None) -> int:
ap = argparse.ArgumentParser(description="Stateless Ray worker watchdog (auto-connect + reconnect).")
ap.add_argument("--head-ip-file", required=True)
ap.add_argument("--node-ip", default=None)
ap.add_argument(
"--resources-kv",
action="append",
default=[],
help='Repeatable k=v resources, e.g. --resources-kv worker_node=100',
)
ap.add_argument("--poll-s", type=int, default=5)
ap.add_argument("--once", action="store_true", help="Run one tick then exit (for testing/debug).")
args = ap.parse_args(argv)
resources: dict[str, float] = {"worker_node": 100.0}
if args.resources_kv:
resources = {}
for item in args.resources_kv:
if "=" not in item:
raise SystemExit(f"invalid --resources-kv (expected k=v): {item!r}")
k, v = item.split("=", 1)
k = k.strip()
v = v.strip()
if not k:
raise SystemExit(f"invalid --resources-kv key: {item!r}")
resources[k] = float(v)
wd = Watchdog(
head_ip_file=args.head_ip_file,
node_ip=args.node_ip,
resources_json=json.dumps(resources),
poll_s=int(args.poll_s),
runner=_default_runner,
)
if args.once:
wd.tick_once()
return 0
wd.run_forever()
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -43,19 +43,32 @@ def create_app(config_path: str) -> FastAPI:
app = FastAPI(title="mvp-v2", version="2.0") app = FastAPI(title="mvp-v2", version="2.0")
def _require_token(req: Request) -> None: def _auth(req: Request) -> dict[str, Any]:
token_env = v2_cfg.auth.token_env token_env = v2_cfg.auth.token_env
expected = os.environ.get(token_env, "") admin_token = os.environ.get(token_env, "")
if not expected: if not admin_token:
# Misconfigured service; treat as server error.
raise HTTPException(status_code=500, detail=f"missing token env: {token_env}") raise HTTPException(status_code=500, detail=f"missing token env: {token_env}")
auth = req.headers.get("authorization") or "" auth = req.headers.get("authorization") or ""
if not auth.startswith("Bearer "): if not auth.startswith("Bearer "):
raise HTTPException(status_code=401, detail="missing bearer token") raise HTTPException(status_code=401, detail="missing bearer token")
got = auth.removeprefix("Bearer ").strip() got = auth.removeprefix("Bearer ").strip()
if got != expected:
if got == admin_token:
return {"user_id": "admin", "is_admin": True}
info = db.resolve_token(got)
if not info:
raise HTTPException(status_code=401, detail="invalid token") raise HTTPException(status_code=401, detail="invalid token")
if info.get("state") != "ACTIVE":
raise HTTPException(status_code=403, detail="user disabled")
return {"user_id": info["user_id"], "is_admin": False}
def _require_admin(req: Request) -> dict[str, Any]:
subject = _auth(req)
if not subject.get("is_admin"):
raise HTTPException(status_code=403, detail="admin required")
return subject
@app.on_event("startup") @app.on_event("startup")
def _startup() -> None: def _startup() -> None:
@ -66,9 +79,43 @@ def create_app(config_path: str) -> FastAPI:
def _shutdown() -> None: def _shutdown() -> None:
stop_flag.set() stop_flag.set()
@app.post("/api/v2/users")
async def create_user(req: Request) -> dict[str, Any]:
_require_admin(req)
obj = await req.json()
if not isinstance(obj, dict):
raise HTTPException(status_code=400, detail="body must be a JSON object")
user_id = str(obj.get("user_id") or "").strip()
if not user_id:
raise HTTPException(status_code=400, detail="missing user_id")
display_name = obj.get("display_name")
try:
row = db.create_user(user_id=user_id, display_name=str(display_name) if display_name is not None else None)
except Exception as e:
raise HTTPException(status_code=409, detail=f"user create failed: {e!r}")
return {"user_id": row.get("user_id", user_id), "state": row.get("state", "ACTIVE")}
@app.post("/api/v2/users/{user_id}/tokens")
async def issue_token(user_id: str, req: Request) -> dict[str, Any]:
_require_admin(req)
u = db.get_user(user_id)
if not u:
raise HTTPException(status_code=404, detail="user not found")
token = db.issue_token(user_id=user_id)
return {"user_id": user_id, "token": token}
@app.post("/api/v2/users/{user_id}:disable")
async def disable_user(user_id: str, req: Request) -> dict[str, Any]:
_require_admin(req)
u = db.get_user(user_id)
if not u:
raise HTTPException(status_code=404, detail="user not found")
db.disable_user(user_id=user_id)
return {"user_id": user_id, "state": "DISABLED"}
@app.post("/api/v2/tasks") @app.post("/api/v2/tasks")
async def submit_task(req: Request) -> dict[str, Any]: async def submit_task(req: Request) -> dict[str, Any]:
_require_token(req) subject = _auth(req)
body = (await req.body()).decode("utf-8") body = (await req.body()).decode("utf-8")
obj = yaml.safe_load(body) or {} obj = yaml.safe_load(body) or {}
if not isinstance(obj, dict): if not isinstance(obj, dict):
@ -79,9 +126,18 @@ def create_app(config_path: str) -> FastAPI:
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"invalid jobspec: {e!r}") raise HTTPException(status_code=400, detail=f"invalid jobspec: {e!r}")
task_id = new_task_id(spec.workload) # v2.5 constraint: training inputs must come from /private/common (dev/prod统一)。
db.create_task( common_prefix = ray_cfg.shared_root.rstrip("/") + "/common/"
for k, v in (("code_path", spec.code_path), ("train_file", spec.train_file), ("val_file", spec.val_file)):
if v is None:
continue
if not str(v).startswith(common_prefix):
raise HTTPException(status_code=400, detail=f"{k} must start with {common_prefix}")
task_id = new_task_id(spec.workload, user_id=str(subject["user_id"]))
db.create_task_v25(
task_id=task_id, task_id=task_id,
user_id=str(subject["user_id"]),
workload=spec.workload, workload=spec.workload,
jobspec_yaml=body, jobspec_yaml=body,
nnodes=spec.nnodes, nnodes=spec.nnodes,
@ -91,10 +147,13 @@ def create_app(config_path: str) -> FastAPI:
@app.get("/api/v2/tasks/{task_id}") @app.get("/api/v2/tasks/{task_id}")
async def get_task(task_id: str, req: Request) -> dict[str, Any]: async def get_task(task_id: str, req: Request) -> dict[str, Any]:
_require_token(req) subject = _auth(req)
row = db.get_task(task_id) row = db.get_task(task_id)
if not row: if not row:
raise HTTPException(status_code=404, detail="task not found") raise HTTPException(status_code=404, detail="task not found")
if not subject.get("is_admin"):
if str(row.get("user_id") or "") != str(subject["user_id"]):
raise HTTPException(status_code=404, detail="task not found")
attempts = db.list_attempts(task_id) attempts = db.list_attempts(task_id)
latest_attempt = attempts[-1] if attempts else None latest_attempt = attempts[-1] if attempts else None
desired = { desired = {
@ -125,18 +184,24 @@ def create_app(config_path: str) -> FastAPI:
@app.get("/api/v2/tasks/{task_id}/attempts") @app.get("/api/v2/tasks/{task_id}/attempts")
async def get_attempts(task_id: str, req: Request) -> dict[str, Any]: async def get_attempts(task_id: str, req: Request) -> dict[str, Any]:
_require_token(req) subject = _auth(req)
row = db.get_task(task_id) row = db.get_task(task_id)
if not row: if not row:
raise HTTPException(status_code=404, detail="task not found") raise HTTPException(status_code=404, detail="task not found")
if not subject.get("is_admin"):
if str(row.get("user_id") or "") != str(subject["user_id"]):
raise HTTPException(status_code=404, detail="task not found")
return {"task_id": task_id, "attempts": db.list_attempts(task_id)} return {"task_id": task_id, "attempts": db.list_attempts(task_id)}
@app.post("/api/v2/tasks/{task_id}:cancel") @app.post("/api/v2/tasks/{task_id}:cancel")
async def cancel(task_id: str, req: Request) -> dict[str, Any]: async def cancel(task_id: str, req: Request) -> dict[str, Any]:
_require_token(req) subject = _auth(req)
row = db.get_task(task_id) row = db.get_task(task_id)
if not row: if not row:
raise HTTPException(status_code=404, detail="task not found") raise HTTPException(status_code=404, detail="task not found")
if not subject.get("is_admin"):
if str(row.get("user_id") or "") != str(subject["user_id"]):
raise HTTPException(status_code=404, detail="task not found")
state = str(row["state"]) state = str(row["state"])
if state in ("SUCCEEDED", "FAILED", "CANCELED"): if state in ("SUCCEEDED", "FAILED", "CANCELED"):
@ -165,10 +230,13 @@ def create_app(config_path: str) -> FastAPI:
@app.get("/api/v2/tasks/{task_id}/logs") @app.get("/api/v2/tasks/{task_id}/logs")
async def logs(task_id: str, req: Request, tail: int = 2000, attempt: str = "latest") -> Response: async def logs(task_id: str, req: Request, tail: int = 2000, attempt: str = "latest") -> Response:
_require_token(req) subject = _auth(req)
row = db.get_task(task_id) row = db.get_task(task_id)
if not row: if not row:
raise HTTPException(status_code=404, detail="task not found") raise HTTPException(status_code=404, detail="task not found")
if not subject.get("is_admin"):
if str(row.get("user_id") or "") != str(subject["user_id"]):
raise HTTPException(status_code=404, detail="task not found")
attempts = db.list_attempts(task_id) attempts = db.list_attempts(task_id)
if not attempts: if not attempts:
raise HTTPException(status_code=404, detail="no attempts yet") raise HTTPException(status_code=404, detail="no attempts yet")
@ -184,7 +252,9 @@ def create_app(config_path: str) -> FastAPI:
@app.get("/api/v2/queue") @app.get("/api/v2/queue")
async def queue(req: Request) -> dict[str, Any]: async def queue(req: Request) -> dict[str, Any]:
_require_token(req) subject = _auth(req)
if subject.get("is_admin"):
return db.list_queue() return db.list_queue()
return db.list_queue(user_id=str(subject["user_id"]))
return app return app

View File

@ -4,6 +4,8 @@ import os
import sqlite3 import sqlite3
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
import hashlib
import secrets
from typing import Any, Iterator from typing import Any, Iterator
@ -18,6 +20,10 @@ def _utc_now_iso() -> str:
class Db: class Db:
db_path: str db_path: str
def _hash_token(self, token: str) -> str:
# Internal tokens only; store hash to avoid plaintext at rest.
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _connect(self) -> sqlite3.Connection: def _connect(self) -> sqlite3.Connection:
os.makedirs(os.path.dirname(self.db_path), exist_ok=True) os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
conn = sqlite3.connect(self.db_path, timeout=30, isolation_level=None) conn = sqlite3.connect(self.db_path, timeout=30, isolation_level=None)
@ -28,6 +34,27 @@ class Db:
def init(self) -> None: def init(self) -> None:
with self._connect() as conn: with self._connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS users (
user_id TEXT PRIMARY KEY,
display_name TEXT,
state TEXT NOT NULL,
created_at TEXT NOT NULL
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS api_tokens (
token_hash TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
created_at TEXT NOT NULL,
last_used_at TEXT,
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
)
"""
)
conn.execute( conn.execute(
""" """
CREATE TABLE IF NOT EXISTS tasks ( CREATE TABLE IF NOT EXISTS tasks (
@ -45,6 +72,12 @@ class Db:
) )
""" """
) )
# Best-effort: add user_id column for forward compatibility (v2.5+).
# v2.0 tables may exist without it; SQLite supports ADD COLUMN.
try:
conn.execute("ALTER TABLE tasks ADD COLUMN user_id TEXT")
except sqlite3.OperationalError:
pass
conn.execute( conn.execute(
""" """
CREATE TABLE IF NOT EXISTS attempts ( CREATE TABLE IF NOT EXISTS attempts (
@ -92,10 +125,10 @@ class Db:
with self.tx() as conn: with self.tx() as conn:
conn.execute( conn.execute(
""" """
INSERT INTO tasks (task_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at) INSERT INTO tasks (task_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at, user_id)
VALUES (?, ?, 'QUEUED', ?, ?, ?, ?, ?) VALUES (?, ?, 'QUEUED', ?, ?, ?, ?, ?, ?)
""", """,
(task_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now), (task_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now, None),
) )
conn.execute( conn.execute(
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (?, ?, 'TASK_CREATED', ?)", "INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (?, ?, 'TASK_CREATED', ?)",
@ -104,6 +137,97 @@ class Db:
row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone() row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
return dict(row) if row else {} return dict(row) if row else {}
def create_task_v25(
self,
*,
task_id: str,
user_id: str,
workload: str,
jobspec_yaml: str,
nnodes: int,
n_gpus_per_node: int,
) -> dict[str, Any]:
now = _utc_now_iso()
with self.tx() as conn:
conn.execute(
"""
INSERT INTO tasks (task_id, user_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at)
VALUES (?, ?, ?, 'QUEUED', ?, ?, ?, ?, ?)
""",
(task_id, user_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now),
)
conn.execute(
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (?, ?, 'TASK_CREATED', ?)",
(task_id, now, None),
)
row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
return dict(row) if row else {}
def create_user(self, *, user_id: str, display_name: str | None = None) -> dict[str, Any]:
now = _utc_now_iso()
with self.tx() as conn:
conn.execute(
"""
INSERT INTO users (user_id, display_name, state, created_at)
VALUES (?, ?, 'ACTIVE', ?)
""",
(user_id, display_name, now),
)
conn.execute(
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'USER_CREATED', ?)",
(now, user_id),
)
row = conn.execute("SELECT * FROM users WHERE user_id = ?", (user_id,)).fetchone()
return dict(row) if row else {}
def disable_user(self, *, user_id: str) -> None:
now = _utc_now_iso()
with self.tx() as conn:
conn.execute("UPDATE users SET state = 'DISABLED' WHERE user_id = ?", (user_id,))
conn.execute(
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'USER_DISABLED', ?)",
(now, user_id),
)
def get_user(self, user_id: str) -> dict[str, Any] | None:
with self._connect() as conn:
row = conn.execute("SELECT * FROM users WHERE user_id = ?", (user_id,)).fetchone()
return dict(row) if row else None
def issue_token(self, *, user_id: str) -> str:
# Returns plaintext token once; stores hash only.
now = _utc_now_iso()
token = f"mvp_u_{user_id}_{secrets.token_urlsafe(18)}"
token_hash = self._hash_token(token)
with self.tx() as conn:
conn.execute(
"INSERT INTO api_tokens (token_hash, user_id, created_at) VALUES (?, ?, ?)",
(token_hash, user_id, now),
)
conn.execute(
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'TOKEN_ISSUED', ?)",
(now, user_id),
)
return token
def resolve_token(self, token: str) -> dict[str, Any] | None:
token_hash = self._hash_token(token)
with self.tx() as conn:
row = conn.execute(
"""
SELECT t.user_id, u.state
FROM api_tokens t
JOIN users u ON u.user_id = t.user_id
WHERE t.token_hash = ?
""",
(token_hash,),
).fetchone()
if not row:
return None
now = _utc_now_iso()
conn.execute("UPDATE api_tokens SET last_used_at = ? WHERE token_hash = ?", (now, token_hash))
return {"user_id": row["user_id"], "state": row["state"]}
def get_task(self, task_id: str) -> dict[str, Any] | None: def get_task(self, task_id: str) -> dict[str, Any] | None:
with self._connect() as conn: with self._connect() as conn:
row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone() row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
@ -116,26 +240,33 @@ class Db:
).fetchall() ).fetchall()
return [dict(r) for r in rows] return [dict(r) for r in rows]
def list_queue(self) -> dict[str, list[dict[str, Any]]]: def list_queue(self, *, user_id: str | None = None) -> dict[str, list[dict[str, Any]]]:
with self._connect() as conn: with self._connect() as conn:
pending = conn.execute( params: list[Any] = []
""" user_filter_sql = ""
SELECT task_id, workload, state, nnodes, n_gpus_per_node, next_run_at, created_at, updated_at if user_id is not None:
FROM tasks user_filter_sql = " AND user_id = ?"
WHERE state IN ('QUEUED','PENDING_RESOURCES') params = [user_id]
ORDER BY created_at ASC
LIMIT 200 pending_sql = (
""" "SELECT task_id, workload, state, nnodes, n_gpus_per_node, next_run_at, created_at, updated_at "
).fetchall() "FROM tasks "
running = conn.execute( "WHERE state IN ('QUEUED','PENDING_RESOURCES')"
""" f"{user_filter_sql} "
SELECT task_id, workload, state, nnodes, n_gpus_per_node, latest_attempt_no, created_at, updated_at "ORDER BY created_at ASC "
FROM tasks "LIMIT 200"
WHERE state IN ('SUBMITTING','SUBMITTED','RUNNING') )
ORDER BY updated_at ASC running_sql = (
LIMIT 200 "SELECT task_id, workload, state, nnodes, n_gpus_per_node, latest_attempt_no, created_at, updated_at "
""" "FROM tasks "
).fetchall() "WHERE state IN ('SUBMITTING','SUBMITTED','RUNNING')"
f"{user_filter_sql} "
"ORDER BY updated_at ASC "
"LIMIT 200"
)
pending = conn.execute(pending_sql, tuple(params)).fetchall()
running = conn.execute(running_sql, tuple(params)).fetchall()
return {"pending": [dict(r) for r in pending], "running": [dict(r) for r in running]} return {"pending": [dict(r) for r in pending], "running": [dict(r) for r in running]}
def count_running(self) -> int: def count_running(self) -> int:

View File

@ -37,6 +37,12 @@ class Scheduler:
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.tool = RayJobTool(self.ray_cfg) self.tool = RayJobTool(self.ray_cfg)
def _job_dir_for_task(self, *, user_id: str | None, ray_submission_id: str) -> str:
root = self.ray_cfg.shared_root.rstrip("/")
if user_id:
return f"{root}/users/{user_id}/jobs/{ray_submission_id}"
return f"{root}/jobs/{ray_submission_id}"
def _resources_sufficient(self, *, nnodes: int, n_gpus_per_node: int) -> bool: def _resources_sufficient(self, *, nnodes: int, n_gpus_per_node: int) -> bool:
avail = get_cluster_available() avail = get_cluster_available()
required = float(nnodes * n_gpus_per_node) required = float(nnodes * n_gpus_per_node)
@ -51,10 +57,13 @@ class Scheduler:
def _submit_one(self, task_row: dict[str, Any]) -> None: def _submit_one(self, task_row: dict[str, Any]) -> None:
task_id = str(task_row["task_id"]) task_id = str(task_row["task_id"])
jobspec_yaml = str(task_row["jobspec_yaml"]) jobspec_yaml = str(task_row["jobspec_yaml"])
user_id = task_row.get("user_id")
user_id_s = str(user_id) if user_id not in (None, "") else None
spec = self._parse_jobspec(jobspec_yaml) spec = self._parse_jobspec(jobspec_yaml)
attempt_no = int(task_row.get("latest_attempt_no", 0)) + 1 attempt_no = int(task_row.get("latest_attempt_no", 0)) + 1
ray_sid = attempt_submission_id(task_id, attempt_no) ray_sid = attempt_submission_id(task_id, attempt_no)
job_dir = self._job_dir_for_task(user_id=user_id_s, ray_submission_id=ray_sid)
# Record attempt first so that we can surface it even if submit crashes. # Record attempt first so that we can surface it even if submit crashes.
self.db.create_attempt(task_id=task_id, attempt_no=attempt_no, ray_submission_id=ray_sid) self.db.create_attempt(task_id=task_id, attempt_no=attempt_no, ray_submission_id=ray_sid)
@ -66,7 +75,7 @@ class Scheduler:
spec2 = JobSpec.from_dict(d) spec2 = JobSpec.from_dict(d)
try: try:
submitted = self.tool.submit(spec2, no_wait=True) submitted = self.tool.submit(spec2, no_wait=True, job_dir=job_dir)
# submitted should equal ray_sid; keep as source of truth. # submitted should equal ray_sid; keep as source of truth.
self.db.update_attempt(task_id=task_id, attempt_no=attempt_no, ray_status="SUBMITTED") self.db.update_attempt(task_id=task_id, attempt_no=attempt_no, ray_status="SUBMITTED")
self.db.set_task_state(task_id=task_id, state="SUBMITTED") self.db.set_task_state(task_id=task_id, state="SUBMITTED")

View File

@ -12,7 +12,7 @@ def _write_config(tmp_path: Path) -> Path:
cfg = { cfg = {
"ray": { "ray": {
"address": "http://127.0.0.1:8265", "address": "http://127.0.0.1:8265",
"shared_root": str(tmp_path), "shared_root": "/private",
"entrypoint_resources": {"worker_node": 1}, "entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}}, "runtime_env": {"env_vars": {}},
}, },
@ -54,7 +54,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
cfg_path = _write_config(tmp_path) cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1") monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
monkeypatch.setattr(app_mod, "new_task_id", lambda workload: "tid1") monkeypatch.setattr(app_mod, "new_task_id", lambda workload, **kwargs: "tid1")
class _Tool: class _Tool:
def __init__(self): def __init__(self):
@ -82,7 +82,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
r = c.post( r = c.post(
"/api/v2/tasks", "/api/v2/tasks",
headers=headers, headers=headers,
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n", data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
) )
assert r.status_code == 200 assert r.status_code == 200
assert r.json()["task_id"] == "tid1" assert r.json()["task_id"] == "tid1"
@ -111,7 +111,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
db.create_task( db.create_task(
task_id="tid2", task_id="tid2",
workload="ppo", workload="ppo",
jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n", jobspec_yaml="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
nnodes=2, nnodes=2,
n_gpus_per_node=4, n_gpus_per_node=4,
) )
@ -163,4 +163,3 @@ def test_submit_rejects_invalid_jobspec(tmp_path: Path, monkeypatch):
with TestClient(app) as c: with TestClient(app) as c:
r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="workload: nope\n") r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="workload: nope\n")
assert r.status_code == 400 assert r.status_code == 400

View File

@ -0,0 +1,50 @@
from __future__ import annotations
import json
from datetime import datetime, timedelta, timezone
from pathlib import Path
from argus.ray.discovery import build_head_record, load_head_record
def test_load_head_record_rejects_missing(tmp_path: Path):
assert load_head_record(str(tmp_path / "nope.json")) is None
def test_load_head_record_rejects_expired(tmp_path: Path):
p = tmp_path / "head.json"
now = datetime.now(timezone.utc).replace(microsecond=0)
rec = build_head_record(cluster_name="c", head_ip="1.2.3.4", ttl_s=1, now=now)
p.write_text(json.dumps(rec.to_dict()), encoding="utf-8")
assert load_head_record(str(p), now=now + timedelta(seconds=2)) is None
def test_load_head_record_accepts_fresh(tmp_path: Path):
p = tmp_path / "head.json"
now = datetime.now(timezone.utc).replace(microsecond=0)
rec = build_head_record(cluster_name="c", head_ip="1.2.3.4", ttl_s=60, now=now)
p.write_text(json.dumps(rec.to_dict()), encoding="utf-8")
got = load_head_record(str(p), now=now + timedelta(seconds=1))
assert got is not None
assert got.head_addr() == "1.2.3.4:6379"
def test_head_publisher_main_once_writes_file(tmp_path: Path):
from argus.ray import head_publisher as mod
p = tmp_path / "head.json"
rc = mod.main(
[
"--cluster-name",
"c",
"--head-ip-file",
str(p),
"--head-ip",
"9.9.9.9",
"--ttl-s",
"60",
"--once",
]
)
assert rc == 0
assert "9.9.9.9" in p.read_text(encoding="utf-8")

View File

@ -20,9 +20,27 @@ def test_new_task_id_is_deterministic_with_patches(monkeypatch):
assert ids.new_task_id("ppo") == "mvp2-ppo-20250101-010203-abcd" assert ids.new_task_id("ppo") == "mvp2-ppo-20250101-010203-abcd"
def test_new_task_id_includes_user_when_provided(monkeypatch):
import argus.core.ids as ids
class _FakeDatetime:
@staticmethod
def now():
class _DT:
def strftime(self, fmt: str) -> str:
assert fmt == "%Y%m%d-%H%M%S"
return "20250101-010203"
return _DT()
monkeypatch.setattr(ids, "datetime", _FakeDatetime)
monkeypatch.setattr(ids.secrets, "token_hex", lambda n: "abcd")
assert ids.new_task_id("ppo", user_id="Alice_01") == "mvp2-alice_01-ppo-20250101-010203-abcd"
def test_attempt_submission_id_format(): def test_attempt_submission_id_format():
from argus.core.ids import attempt_submission_id from argus.core.ids import attempt_submission_id
assert attempt_submission_id("t", 1) == "t--a01" assert attempt_submission_id("t", 1) == "t--a01"
assert attempt_submission_id("t", 12) == "t--a12" assert attempt_submission_id("t", 12) == "t--a12"

View File

@ -14,7 +14,7 @@ def _mk_cfg(tmp_path: Path) -> tuple[RayConfig, V2Config]:
root = { root = {
"ray": { "ray": {
"address": "http://127.0.0.1:8265", "address": "http://127.0.0.1:8265",
"shared_root": str(tmp_path), "shared_root": "/private",
"entrypoint_resources": {"worker_node": 1}, "entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}}, "runtime_env": {"env_vars": {}},
}, },
@ -32,10 +32,11 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
ray_cfg, v2_cfg = _mk_cfg(tmp_path) ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path) db = Db(v2_cfg.sqlite.db_path)
db.init() db.init()
db.create_task( db.create_task_v25(
task_id="t1", task_id="t1",
user_id="alice",
workload="ppo", workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}), jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
nnodes=2, nnodes=2,
n_gpus_per_node=4, n_gpus_per_node=4,
) )
@ -50,9 +51,11 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
class _Tool: class _Tool:
def __init__(self, cfg): def __init__(self, cfg):
self.submitted = [] self.submitted = []
self.job_dirs = []
def submit(self, spec, no_wait: bool): def submit(self, spec, no_wait: bool, job_dir: str | None = None):
self.submitted.append(spec.submission_id) self.submitted.append(spec.submission_id)
self.job_dirs.append(job_dir)
return str(spec.submission_id) return str(spec.submission_id)
def status(self, submission_id: str): def status(self, submission_id: str):
@ -71,6 +74,7 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
attempts = db.list_attempts("t1") attempts = db.list_attempts("t1")
assert len(attempts) == 1 assert len(attempts) == 1
assert attempts[0]["ray_submission_id"] == "t1--a01" assert attempts[0]["ray_submission_id"] == "t1--a01"
assert s.tool.job_dirs[-1] == "/private/users/alice/jobs/t1--a01"
def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path): def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
@ -79,10 +83,11 @@ def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
ray_cfg, v2_cfg = _mk_cfg(tmp_path) ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path) db = Db(v2_cfg.sqlite.db_path)
db.init() db.init()
db.create_task( db.create_task_v25(
task_id="t1", task_id="t1",
user_id="alice",
workload="ppo", workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}), jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
nnodes=2, nnodes=2,
n_gpus_per_node=4, n_gpus_per_node=4,
) )
@ -108,10 +113,11 @@ def test_sync_failed_insufficient_resources(monkeypatch, tmp_path: Path):
ray_cfg, v2_cfg = _mk_cfg(tmp_path) ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path) db = Db(v2_cfg.sqlite.db_path)
db.init() db.init()
db.create_task( db.create_task_v25(
task_id="t1", task_id="t1",
user_id="alice",
workload="ppo", workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}), jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
nnodes=2, nnodes=2,
n_gpus_per_node=4, n_gpus_per_node=4,
) )
@ -147,10 +153,11 @@ def test_sync_status_error_keeps_state(monkeypatch, tmp_path: Path):
ray_cfg, v2_cfg = _mk_cfg(tmp_path) ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path) db = Db(v2_cfg.sqlite.db_path)
db.init() db.init()
db.create_task( db.create_task_v25(
task_id="t1", task_id="t1",
user_id="alice",
workload="ppo", workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}), jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
nnodes=2, nnodes=2,
n_gpus_per_node=4, n_gpus_per_node=4,
) )

View File

@ -0,0 +1,179 @@
from __future__ import annotations
from pathlib import Path
import yaml
from fastapi.testclient import TestClient
def _write_config(tmp_path: Path) -> Path:
cfg = {
"ray": {
"address": "http://127.0.0.1:8265",
"shared_root": "/private",
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
},
"service": {
"api": {"host": "127.0.0.1", "port": 0},
"auth": {"token_env": "MVP_INTERNAL_TOKEN"},
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
},
}
p = tmp_path / "cfg.yaml"
p.write_text(yaml.safe_dump(cfg), encoding="utf-8")
return p
def test_db_user_token_lifecycle(tmp_path: Path):
from argus.service.db import Db
db = Db(str(tmp_path / "mvp.sqlite3"))
db.init()
db.create_user(user_id="alice", display_name="Alice")
tok = db.issue_token(user_id="alice")
info = db.resolve_token(tok)
assert info and info["user_id"] == "alice"
db.disable_user(user_id="alice")
info2 = db.resolve_token(tok)
assert info2 and info2["state"] == "DISABLED"
def test_admin_create_user_issue_token_and_disabled_rejected(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
r1 = c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice", "display_name": "Alice"})
assert r1.status_code == 200
assert r1.json()["user_id"] == "alice"
r2 = c.post("/api/v2/users/alice/tokens", headers=admin_headers)
assert r2.status_code == 200
token = r2.json()["token"]
assert token
# non-admin token can access regular endpoints
r3 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"})
assert r3.status_code == 200
r4 = c.post("/api/v2/users/alice:disable", headers=admin_headers)
assert r4.status_code == 200
r5 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"})
assert r5.status_code == 403
def test_tasks_are_isolated_by_user(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
# Deterministic task ids
counter = {"n": 0}
def _new_id(workload: str, **kwargs) -> str:
counter["n"] += 1
return f"{workload}_t{counter['n']}"
monkeypatch.setattr(app_mod, "new_task_id", _new_id)
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
# Create users and tokens
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "bob"}).status_code == 200
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
bob_tok = c.post("/api/v2/users/bob/tokens", headers=admin_headers).json()["token"]
alice_headers = {"authorization": f"Bearer {alice_tok}"}
bob_headers = {"authorization": f"Bearer {bob_tok}"}
# Each user submits one task.
r1 = c.post(
"/api/v2/tasks",
headers=alice_headers,
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
)
assert r1.status_code == 200
alice_tid = r1.json()["task_id"]
r2 = c.post(
"/api/v2/tasks",
headers=bob_headers,
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
)
assert r2.status_code == 200
bob_tid = r2.json()["task_id"]
# Queue is scoped to user.
qa = c.get("/api/v2/queue", headers=alice_headers).json()
qb = c.get("/api/v2/queue", headers=bob_headers).json()
assert {t["task_id"] for t in qa["pending"]} == {alice_tid}
assert {t["task_id"] for t in qb["pending"]} == {bob_tid}
# Cross-user access returns 404 (no existence leak).
assert c.get(f"/api/v2/tasks/{bob_tid}", headers=alice_headers).status_code == 404
assert c.post(f"/api/v2/tasks/{bob_tid}:cancel", headers=alice_headers).status_code == 404
assert c.get(f"/api/v2/tasks/{bob_tid}/attempts", headers=alice_headers).status_code == 404
# Admin can see global queue.
qadm = c.get("/api/v2/queue", headers=admin_headers).json()
assert {t["task_id"] for t in qadm["pending"]} == {alice_tid, bob_tid}
def test_submit_rejects_non_common_inputs(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
alice_headers = {"authorization": f"Bearer {alice_tok}"}
r = c.post(
"/api/v2/tasks",
headers=alice_headers,
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
)
assert r.status_code == 400
assert "code_path must start with /private/common/" in r.text

View File

@ -0,0 +1,58 @@
from __future__ import annotations
import json
from pathlib import Path
from argus.ray.discovery import build_head_record, write_head_record_atomic
from argus.ray.worker_watchdog import Watchdog
def test_watchdog_restarts_on_first_seen_and_on_change(tmp_path: Path):
head_file = tmp_path / "head.json"
calls: list[list[str]] = []
def runner(argv: list[str]) -> int:
calls.append(argv)
return 0
wd = Watchdog(
head_ip_file=str(head_file),
node_ip="10.0.0.2",
resources_json=json.dumps({"worker_node": 100}),
poll_s=1,
runner=runner,
)
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1"))
assert wd.tick_once() is not None
assert any("ray" in c[0] and c[1] == "start" for c in calls)
calls.clear()
# Same address -> no restart
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1"))
wd.tick_once()
assert calls == []
# Address change -> restart
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="2.2.2.2"))
wd.tick_once()
assert any(c[1] == "stop" for c in calls)
assert any(c[1] == "start" for c in calls)
def test_watchdog_main_once_invokes_runner(monkeypatch, tmp_path: Path):
from argus.ray import worker_watchdog as mod
head_file = tmp_path / "head.json"
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1"))
calls: list[list[str]] = []
def runner(argv: list[str]) -> int:
calls.append(argv)
return 0
monkeypatch.setattr(mod, "_default_runner", runner)
rc = mod.main(["--head-ip-file", str(head_file), "--node-ip", "10.0.0.2", "--resources-kv", "worker_node=100", "--once"])
assert rc == 0
assert any(c[1] == "start" for c in calls)

View File

@ -0,0 +1,24 @@
#!/usr/bin/env bash
set -euo pipefail
echo "[host] cleanup v2 legacy containers/processes (best-effort)"
# Known historical container names from v1.1/v2 era
LEGACY=(
mvp11-ray-head
mvp11-ray-worker-0
mvp11-ray-worker-1
mvp2-ray-head
mvp2-ray-worker-0
mvp2-ray-worker-1
)
for c in "${LEGACY[@]}"; do
if docker ps -a --format '{{.Names}}' | grep -qx "${c}"; then
echo "[host] removing legacy container: ${c}"
docker rm -f "${c}" >/dev/null 2>&1 || true
fi
done
echo "[host] legacy v2 cleanup done"

View File

@ -0,0 +1,44 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# shellcheck source=lib.sh
source "${SCRIPT_DIR}/lib.sh"
CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}"
HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}"
TTL_S="${TTL_S:-60}"
REFRESH_S="${REFRESH_S:-10}"
# PID file must be container-local to avoid conflicts if /private is shared across containers.
PID_PATH="${PID_PATH:-/tmp/argus_head_publisher.pid}"
LOG_PATH="${LOG_PATH:-${SHARED_ROOT}/common/logs/argus_head_publisher.log}"
HEAD_IP="$(container_ip "${HEAD_CONTAINER}")"
echo "[head] start discovery publisher (supervised): ${HEAD_IP_FILE} (head_ip=${HEAD_IP})"
# stop existing (best-effort)
dexec "${HEAD_CONTAINER}" bash -lc "if test -f '${PID_PATH}'; then pid=\$(cat '${PID_PATH}'); if kill -0 \"\${pid}\" >/dev/null 2>&1; then kill \"\${pid}\" || true; fi; rm -f '${PID_PATH}'; fi"
dexec "${HEAD_CONTAINER}" bash -lc "mkdir -p \"$(dirname "${LOG_PATH}")\" \"$(dirname "${HEAD_IP_FILE}")\""
# Supervisor loop: restart publisher if it exits.
docker exec -d "${HEAD_CONTAINER}" bash -lc "
nohup bash -lc '
export PYTHONPATH=/workspace/mvp/py
while true; do
python3 -m argus.ray.head_publisher \
--cluster-name \"${CLUSTER_NAME}\" \
--head-ip-file \"${HEAD_IP_FILE}\" \
--head-ip \"${HEAD_IP}\" \
--ttl-s \"${TTL_S}\" \
--refresh-s \"${REFRESH_S}\"
sleep 2
done
' >>'${LOG_PATH}' 2>&1 & echo \$! >'${PID_PATH}'
"
echo "[head] publisher pid stored in ${PID_PATH} (container-local)"
echo "[head] logs: ${LOG_PATH}"

View File

@ -0,0 +1,52 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# shellcheck source=lib.sh
source "${SCRIPT_DIR}/lib.sh"
CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}"
HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}"
POLL_S="${POLL_S:-5}"
WORKER_NODE_RESOURCE="${WORKER_NODE_RESOURCE:-100}"
start_one() {
local worker="$1"
local ip
ip="$(container_ip "${worker}")"
local pid_path="/tmp/argus_worker_watchdog.pid"
local log_path="${SHARED_ROOT}/common/logs/argus_worker_watchdog.${worker}.log"
echo "[${worker}] start stateless watchdog (supervised): head_file=${HEAD_IP_FILE} node_ip=${ip}"
# stop existing watchdog (best-effort)
dexec "${worker}" bash -lc "if test -f '${pid_path}'; then pid=\$(cat '${pid_path}'); if kill -0 \"\${pid}\" >/dev/null 2>&1; then kill \"\${pid}\" || true; fi; rm -f '${pid_path}'; fi"
# stop any legacy ray process to avoid split-brain
dexec "${worker}" bash -lc "ray stop --force || true"
dexec "${worker}" bash -lc "mkdir -p \"$(dirname "${log_path}")\""
docker exec -d "${worker}" bash -lc "
nohup bash -lc '
export PYTHONPATH=/workspace/mvp/py
while true; do
python3 -m argus.ray.worker_watchdog \
--head-ip-file \"${HEAD_IP_FILE}\" \
--node-ip \"${ip}\" \
--resources-kv \"worker_node=${WORKER_NODE_RESOURCE}\" \
--poll-s \"${POLL_S}\"
sleep 2
done
' >>'${log_path}' 2>&1 & echo \$! >'${pid_path}'
"
echo "[${worker}] watchdog pid stored in ${pid_path} (container-local)"
echo "[${worker}] logs: ${log_path}"
}
start_one "${WORKER0_CONTAINER}"
start_one "${WORKER1_CONTAINER}"
echo "[head] ray status"
dexec "${HEAD_CONTAINER}" bash -lc "ray status || true"

View File

@ -0,0 +1,17 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# shellcheck source=lib.sh
source "${SCRIPT_DIR}/lib.sh"
CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}"
HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}"
echo "[head] head.json (best-effort): ${HEAD_IP_FILE}"
dexec "${HEAD_CONTAINER}" bash -lc "if test -f '${HEAD_IP_FILE}'; then cat '${HEAD_IP_FILE}'; else echo 'missing'; fi"
echo
echo "[head] ray status"
dexec "${HEAD_CONTAINER}" bash -lc "ray status || true"

View File

@ -11,10 +11,30 @@ PPO_DATA_DIR="${SHARED_ROOT}/datasets/gsm8k"
SFT_DATA_DIR="${SHARED_ROOT}/datasets/gsm8k_sft" SFT_DATA_DIR="${SHARED_ROOT}/datasets/gsm8k_sft"
CODE_SNAPSHOT_DIR="${SHARED_ROOT}/common/code/verl/verl_repo" CODE_SNAPSHOT_DIR="${SHARED_ROOT}/common/code/verl/verl_repo"
COMMON_DIR="${SHARED_ROOT}/common"
COMMON_DATASETS_DIR="${SHARED_ROOT}/common/datasets"
COMMON_HF_DIR="${SHARED_ROOT}/common/hf"
echo "[head] ensure dataset dirs exist" echo "[head] ensure dataset dirs exist"
dexec "${HEAD_CONTAINER}" bash -lc "mkdir -p '${PPO_DATA_DIR}' '${SFT_DATA_DIR}'" dexec "${HEAD_CONTAINER}" bash -lc "mkdir -p '${PPO_DATA_DIR}' '${SFT_DATA_DIR}'"
echo "[head] ensure v2.5 common links (idempotent)"
# In existing deployments, /private/common/{datasets,hf} may already exist as directories (not symlinks).
# For v2.5 taskspecs, we only require:
# /private/common/datasets/gsm8k -> /private/datasets/gsm8k
# /private/common/datasets/gsm8k_sft -> /private/datasets/gsm8k_sft
# /private/common/hf/{hub,transformers} -> /private/hf/{hub,transformers} (best-effort)
dexec "${HEAD_CONTAINER}" bash -lc "
set -euo pipefail
mkdir -p '${COMMON_DIR}' '${COMMON_DATASETS_DIR}' '${COMMON_HF_DIR}'
ln -sfn '${SHARED_ROOT}/datasets/gsm8k' '${COMMON_DATASETS_DIR}/gsm8k'
ln -sfn '${SHARED_ROOT}/datasets/gsm8k_sft' '${COMMON_DATASETS_DIR}/gsm8k_sft'
mkdir -p '${SHARED_ROOT}/hf/hub' '${SHARED_ROOT}/hf/transformers'
ln -sfn '${SHARED_ROOT}/hf/hub' '${COMMON_HF_DIR}/hub'
ln -sfn '${SHARED_ROOT}/hf/transformers' '${COMMON_HF_DIR}/transformers'
echo 'common_links_ok'
"
echo "[head] prepare PPO dataset (gsm8k RL parquet) -> ${PPO_DATA_DIR}" echo "[head] prepare PPO dataset (gsm8k RL parquet) -> ${PPO_DATA_DIR}"
dexec "${HEAD_CONTAINER}" bash -lc "if [[ -f '${PPO_DATA_DIR}/train.parquet' && -f '${PPO_DATA_DIR}/test.parquet' ]]; then echo 'ppo_dataset_exists: skip'; else python3 /workspace/verl/examples/data_preprocess/gsm8k.py --local_save_dir '${PPO_DATA_DIR}'; fi" dexec "${HEAD_CONTAINER}" bash -lc "if [[ -f '${PPO_DATA_DIR}/train.parquet' && -f '${PPO_DATA_DIR}/test.parquet' ]]; then echo 'ppo_dataset_exists: skip'; else python3 /workspace/verl/examples/data_preprocess/gsm8k.py --local_save_dir '${PPO_DATA_DIR}'; fi"

View File

@ -0,0 +1,137 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# shellcheck source=lib.sh
source "${SCRIPT_DIR}/lib.sh"
# E2E v2.5:
# - Clean legacy env
# - Start containers
# - Start ray head + discovery publisher
# - Start stateless worker watchdogs (auto-connect)
# - Prepare data/model/code (reuses existing downloads)
# - Start API
# - Create user + issue token
# - Submit PPO/GRPO/SFT via API and wait
API_ADDR="${API_ADDR:-http://127.0.0.1:8080}"
ADMIN_TOKEN="${MVP_INTERNAL_TOKEN:-}"
USER_ID="${USER_ID:-alice}"
RESET_DB="${RESET_DB:-1}"
if [[ -z "${ADMIN_TOKEN}" ]]; then
echo "ERROR: MVP_INTERNAL_TOKEN must be set in host env (admin token)" >&2
exit 1
fi
api_curl_admin() {
curl -sS -H "Authorization: Bearer ${ADMIN_TOKEN}" "$@"
}
api_wait_ready() {
local tries="${1:-60}"
for i in $(seq 1 "${tries}"); do
if curl -sS -m 2 "${API_ADDR}/docs" >/dev/null 2>&1; then
echo "[host] api_ready: ${API_ADDR}"
return 0
fi
echo "[host] waiting api... (${i}/${tries})"
sleep 2
done
echo "ERROR: api not ready: ${API_ADDR}" >&2
return 1
}
submit_taskspec() {
local token="$1"
local taskspec_path="$2"
echo "[host] submit via API (user=${USER_ID}): ${taskspec_path}" >&2
local resp
resp="$(curl -sS -H "Authorization: Bearer ${token}" -H "Content-Type: application/yaml" --data-binary @"${taskspec_path}" "${API_ADDR}/api/v2/tasks")"
echo "[host] submit_resp: ${resp}" >&2
printf '%s' "${resp}" | python3 -c 'import sys,json; print(json.load(sys.stdin)["task_id"])'
}
wait_task() {
local token="$1"
local task_id="$2"
while true; do
local body state
body="$(curl -sS -H "Authorization: Bearer ${token}" "${API_ADDR}/api/v2/tasks/${task_id}")"
state="$(printf '%s' "${body}" | python3 -c 'import sys,json; print(json.load(sys.stdin)["state"])')"
echo "[host] task ${task_id}: ${state}"
if [[ "${state}" == "SUCCEEDED" ]]; then
return 0
fi
if [[ "${state}" == "FAILED" || "${state}" == "CANCELED" ]]; then
echo "[host] terminal=${state}; tail logs (best-effort):" >&2
curl -sS -H "Authorization: Bearer ${token}" "${API_ADDR}/api/v2/tasks/${task_id}/logs?tail=200" >&2 || true
return 1
fi
sleep 10
done
}
echo "[host] ===== run_all_v25_api.sh begin ====="
"${SCRIPT_DIR}/00_prereq_check.sh"
"${SCRIPT_DIR}/03_cleanup_v1_legacy.sh"
"${SCRIPT_DIR}/04_cleanup_v2_legacy.sh"
echo "[host] bring down existing containers (best-effort)"
"${SCRIPT_DIR}/02_down.sh" || true
echo "[host] (re)create containers"
"${SCRIPT_DIR}/01_up.sh"
echo "[host] restart ray head (no compute on head)"
"${SCRIPT_DIR}/20_start_head.sh"
echo "[host] start head discovery publisher"
"${SCRIPT_DIR}/22_start_head_discovery.sh"
echo "[host] start stateless workers (watchdog auto-connect)"
"${SCRIPT_DIR}/23_start_workers_stateless.sh"
echo "[host] prepare data/model/code snapshot (idempotent; reuse cache)"
"${SCRIPT_DIR}/30_prepare_data_and_model.sh"
echo "[host] install api deps in head container"
"${SCRIPT_DIR}/12_install_api_deps.sh"
echo "[host] stop api (best-effort)"
"${SCRIPT_DIR}/61_stop_api.sh" || true
if [[ "${RESET_DB}" == "1" ]]; then
echo "[host] reset api sqlite db in container (best-effort)"
docker exec -i "${HEAD_CONTAINER}" bash -lc "rm -f /private/common/db/mvp.sqlite3 /private/common/db/mvp.sqlite3-wal /private/common/db/mvp.sqlite3-shm || true"
else
echo "[host] keep existing api sqlite db (RESET_DB=${RESET_DB})"
fi
echo "[host] start api (admin token via env)"
MVP_INTERNAL_TOKEN="${ADMIN_TOKEN}" "${SCRIPT_DIR}/60_start_api.sh"
api_wait_ready 60
echo "[host] ensure user exists + issue token"
api_curl_admin -H "Content-Type: application/json" -d "{\"user_id\":\"${USER_ID}\"}" "${API_ADDR}/api/v2/users" >/dev/null 2>&1 || true
USER_TOKEN="$(api_curl_admin -X POST "${API_ADDR}/api/v2/users/${USER_ID}/tokens" | python3 -c 'import sys,json; print(json.load(sys.stdin)["token"])')"
echo "[host] user_token_issued: user=${USER_ID}"
PPO_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/ppo.yaml")"
GRPO_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/grpo.yaml")"
SFT_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/sft.yaml")"
echo "[host] submitted task ids:"
echo " ppo=${PPO_TASK_ID}"
echo " grpo=${GRPO_TASK_ID}"
echo " sft=${SFT_TASK_ID}"
echo "[host] wait for tasks (in submission order)"
wait_task "${USER_TOKEN}" "${PPO_TASK_ID}"
wait_task "${USER_TOKEN}" "${GRPO_TASK_ID}"
wait_task "${USER_TOKEN}" "${SFT_TASK_ID}"
echo "[host] ===== run_all_v25_api.sh done ====="

View File

@ -0,0 +1,156 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# shellcheck source=lib.sh
source "${SCRIPT_DIR}/lib.sh"
API_ADDR="${API_ADDR:-http://127.0.0.1:8080}"
ADMIN_TOKEN="${MVP_INTERNAL_TOKEN:-}"
if [[ -z "${ADMIN_TOKEN}" ]]; then
echo "ERROR: MVP_INTERNAL_TOKEN must be set (admin token)" >&2
exit 1
fi
require_cmd jq
require_cmd python3
api_admin() { curl -sS -H "Authorization: Bearer ${ADMIN_TOKEN}" "$@"; }
api_user() { local tok="$1"; shift; curl -sS -H "Authorization: Bearer ${tok}" "$@"; }
wait_state() {
local tok="$1"
local task_id="$2"
local want="$3"
local tries="${4:-120}"
for i in $(seq 1 "${tries}"); do
local st
st="$(api_user "${tok}" "${API_ADDR}/api/v2/tasks/${task_id}" | jq -r .state)"
echo "[case] ${task_id} state=${st} (want=${want})"
if [[ "${st}" == "${want}" ]]; then
return 0
fi
if [[ "${st}" == "FAILED" || "${st}" == "CANCELED" || "${st}" == "SUCCEEDED" ]]; then
return 1
fi
sleep 5
done
return 2
}
issue_token() {
local user_id="$1"
api_admin -H "Content-Type: application/json" -d "{\"user_id\":\"${user_id}\"}" "${API_ADDR}/api/v2/users" >/dev/null 2>&1 || true
api_admin -X POST "${API_ADDR}/api/v2/users/${user_id}/tokens" | jq -r .token
}
submit_yaml() {
local tok="$1"
local yaml_path="$2"
api_user "${tok}" -H "Content-Type: application/yaml" --data-binary @"${yaml_path}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id
}
echo "[case] ===== v2.5 e2e cases ====="
echo "[case] HP-1: run_all_v25_api.sh (happy path)"
RESET_DB="${RESET_DB:-1}"
MVP_INTERNAL_TOKEN="${ADMIN_TOKEN}" RESET_DB="${RESET_DB}" "${SCRIPT_DIR}/run_all_v25_api.sh"
echo "[case] E-Auth-1: missing bearer token"
code="$(curl -sS -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/queue" || true)"
[[ "${code}" == "401" ]] || { echo "expected 401 got ${code}" >&2; exit 1; }
echo "[case] E-Auth-2: invalid token"
code="$(curl -sS -o /dev/null -w '%{http_code}' -H 'Authorization: Bearer nope' "${API_ADDR}/api/v2/queue" || true)"
[[ "${code}" == "401" ]] || { echo "expected 401 got ${code}" >&2; exit 1; }
echo "[case] setup users alice/bob"
ALICE_TOKEN="$(issue_token alice)"
BOB_TOKEN="$(issue_token bob)"
echo "[case] E-Isolation-1: cross-user task visibility (404)"
TID="$(submit_yaml "${ALICE_TOKEN}" "${ROOT_DIR}/taskspecs/ppo.yaml")"
code="$(api_user "${BOB_TOKEN}" -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/tasks/${TID}" || true)"
[[ "${code}" == "404" ]] || { echo "expected 404 got ${code}" >&2; exit 1; }
# Keep the queue clean for subsequent cases (avoid consuming scheduler capacity).
api_user "${ALICE_TOKEN}" -X POST "${API_ADDR}/api/v2/tasks/${TID}:cancel" >/dev/null 2>&1 || true
wait_state "${ALICE_TOKEN}" "${TID}" "CANCELED" 30 || true
echo "[case] E-Input-1: reject non-common inputs"
bad_yaml="$(mktemp)"
cat >"${bad_yaml}" <<'YAML'
workload: "ppo"
submission_id: ""
code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/datasets/gsm8k/train.parquet"
val_file: "/private/common/datasets/gsm8k/test.parquet"
nnodes: 2
n_gpus_per_node: 4
total_epochs: 1
total_training_steps: 10
save_freq: 10
test_freq: -1
YAML
code="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${bad_yaml}" -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/tasks" || true)"
rm -f "${bad_yaml}"
[[ "${code}" == "400" ]] || { echo "expected 400 got ${code}" >&2; exit 1; }
echo "[case] B-Queue-1: pending_resources when demand > cluster capacity"
big_yaml="$(mktemp)"
cat >"${big_yaml}" <<'YAML'
workload: "ppo"
submission_id: ""
code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/common/datasets/gsm8k/train.parquet"
val_file: "/private/common/datasets/gsm8k/test.parquet"
nnodes: 3
n_gpus_per_node: 4
total_epochs: 1
total_training_steps: 10
save_freq: 10
test_freq: -1
YAML
BIG_TID="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${big_yaml}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id)"
rm -f "${big_yaml}"
wait_state "${ALICE_TOKEN}" "${BIG_TID}" "PENDING_RESOURCES" 60
echo "[case] B-Cancel-1: cancel a running task (best-effort)"
long_yaml="$(mktemp)"
cat >"${long_yaml}" <<'YAML'
workload: "ppo"
submission_id: ""
code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/common/datasets/gsm8k/train.parquet"
val_file: "/private/common/datasets/gsm8k/test.parquet"
nnodes: 2
n_gpus_per_node: 4
total_epochs: 1
total_training_steps: 200
save_freq: 10
test_freq: -1
YAML
LONG_TID="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${long_yaml}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id)"
rm -f "${long_yaml}"
for i in $(seq 1 60); do
st="$(api_user "${ALICE_TOKEN}" "${API_ADDR}/api/v2/tasks/${LONG_TID}" | jq -r .state)"
echo "[case] wait running: ${LONG_TID} state=${st}"
if [[ "${st}" == "RUNNING" ]]; then
break
fi
if [[ "${st}" == "FAILED" || "${st}" == "SUCCEEDED" ]]; then
break
fi
sleep 5
done
api_user "${ALICE_TOKEN}" -X POST "${API_ADDR}/api/v2/tasks/${LONG_TID}:cancel" | jq .
st="$(api_admin "${API_ADDR}/api/v2/tasks/${LONG_TID}" | jq -r .state)"
[[ "${st}" == "CANCELED" ]] || { echo "expected CANCELED got ${st}" >&2; exit 1; }
echo "[case] ===== all v2.5 e2e cases passed ====="

View File

@ -6,8 +6,8 @@ code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct" model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/datasets/gsm8k/train.parquet" train_file: "/private/common/datasets/gsm8k/train.parquet"
val_file: "/private/datasets/gsm8k/test.parquet" val_file: "/private/common/datasets/gsm8k/test.parquet"
nnodes: 2 nnodes: 2
n_gpus_per_node: 4 n_gpus_per_node: 4
@ -17,4 +17,3 @@ total_training_steps: 10
save_freq: 10 save_freq: 10
test_freq: -1 test_freq: -1

View File

@ -8,8 +8,8 @@ code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct" model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/datasets/gsm8k/train.parquet" train_file: "/private/common/datasets/gsm8k/train.parquet"
val_file: "/private/datasets/gsm8k/test.parquet" val_file: "/private/common/datasets/gsm8k/test.parquet"
nnodes: 2 nnodes: 2
n_gpus_per_node: 4 n_gpus_per_node: 4
@ -19,4 +19,3 @@ total_training_steps: 10
save_freq: 10 save_freq: 10
test_freq: -1 test_freq: -1

View File

@ -6,7 +6,7 @@ code_path: "/private/common/code/verl/verl_repo"
model_id: "Qwen/Qwen2.5-0.5B-Instruct" model_id: "Qwen/Qwen2.5-0.5B-Instruct"
train_file: "/private/datasets/gsm8k_sft/train.parquet" train_file: "/private/common/datasets/gsm8k_sft/train.parquet"
val_file: null val_file: null
nnodes: 2 nnodes: 2
@ -19,4 +19,3 @@ save_freq: 10
# SFT driver 默认不分配 GPUray job entrypoint 不指定 entrypoint_num_gpus因此 driver 侧不要依赖 CUDA # SFT driver 默认不分配 GPUray job entrypoint 不指定 entrypoint_num_gpus因此 driver 侧不要依赖 CUDA
trainer_device: "cpu" trainer_device: "cpu"