v2.5 功能测试通过,待加上docker镜像
This commit is contained in:
parent
ce8c2128b4
commit
f02059126e
@ -11,4 +11,4 @@ v2.5 的核心变化:
|
||||
- `specs/mvp/v2.5/v2.5_design.md`:总体架构、关键机制(head IP file / watchdog / 用户隔离 / 任务流)。
|
||||
- `specs/mvp/v2.5/v2.5_api.md`:API 设计(用户、任务、队列、日志)与鉴权约定。
|
||||
- `specs/mvp/v2.5/v2.5_acceptance.md`:开发/部署/验收流程与可验证标准。
|
||||
|
||||
- `specs/mvp/v2.5/v2.5_summary.md`:v2.5 已实现内容总结(本次迭代做了什么、验收结果、已知限制)。
|
||||
|
||||
3
specs/mvp/v2.5/notices.md
Normal file
3
specs/mvp/v2.5/notices.md
Normal file
@ -0,0 +1,3 @@
|
||||
# 记录问题
|
||||
1. task 、 submission id 里加上 user name
|
||||
2. 补全端到端测试用例,各种正常和异常用例,边界情况测试
|
||||
229
specs/mvp/v2.5/v2.5_dev_plan.md
Normal file
229
specs/mvp/v2.5/v2.5_dev_plan.md
Normal file
@ -0,0 +1,229 @@
|
||||
# MVP v2.5 开发计划(TDD 驱动)
|
||||
|
||||
本文是 v2.5 的**工程化开发计划**,强调“先写测试,再写实现”(TDD),并将每个里程碑拆成**可独立验收**的小闭环。
|
||||
|
||||
输入依据:
|
||||
- 路线图:`specs/mvp/mvp_roadmap_v2.md`
|
||||
- v2.5 设计:`specs/mvp/v2.5/v2.5_design.md`
|
||||
- v2.5 API 草案:`specs/mvp/v2.5/v2.5_api.md`
|
||||
- v2.5 验收:`specs/mvp/v2.5/v2.5_acceptance.md`
|
||||
|
||||
v2.5 约束(已确认):
|
||||
- **不扩展 TaskSpec**:沿用 v2.0/v2.1 的 YAML 结构化字段与语义。
|
||||
- **不支持自定义 reward function / 不支持用户自定义 verl 代码**。
|
||||
- 训练输入(verl 代码、HF cache、datasets)统一使用 `/private/common/...`。
|
||||
- 多用户隔离 v2.5 **先只隔离 jobs 输出目录**:`/private/users/<uid>/jobs/<ray_submission_id>/...`。
|
||||
|
||||
---
|
||||
|
||||
## 0. TDD 规范(所有功能都遵循)
|
||||
|
||||
### 0.1 测试分层
|
||||
|
||||
1) **单元测试(fast)**
|
||||
- 纯 Python 逻辑:DB、鉴权、ID、路径派生、head.json 解析/TTL、watchdog 决策逻辑。
|
||||
- 目标:不依赖真实 Ray、不依赖 docker、不依赖网络。
|
||||
|
||||
2) **组件测试(中等)**
|
||||
- FastAPI 路由:使用 `fastapi.testclient.TestClient`(现有 v2.0 已采用)。
|
||||
- 目标:验证 auth/权限隔离、API 行为、状态机。
|
||||
|
||||
3) **端到端(慢/手工或脚本)**
|
||||
- 在 `argus@h1` 上通过 scripts/compose 跑一次“head publish → worker auto-connect → API submit”闭环。
|
||||
- 目标:验证无状态 worker + watchdog 的真实行为。
|
||||
|
||||
### 0.2 测试约定
|
||||
|
||||
- 测试目录:`src/mvp/py/tests/`
|
||||
- 新增功能必须先补齐测试用例,并让其在未实现时失败(红)。
|
||||
- 实现最小改动让测试变绿(绿)。
|
||||
- 重构/去重复(重构)。
|
||||
|
||||
> 注:现有测试通过 `src/mvp/py/tests/conftest.py` 注入 ray stub,确保单测不依赖真实 ray 包;v2.5 新增模块也应复用此模式。
|
||||
|
||||
---
|
||||
|
||||
## 1. 里程碑拆分(v2.5 = 4 个可验证闭环)
|
||||
|
||||
### M1:User 表/Token 表 + 基础鉴权(不影响现有内部 token 兼容)
|
||||
|
||||
**目标**
|
||||
- 引入 user/token 的持久化与鉴权映射(token → user_id)。
|
||||
- 兼容现有 `Authorization: Bearer <MVP_INTERNAL_TOKEN>` 的“单租户模式”,避免一次性破坏 v2.0 用法:
|
||||
- v2.5 可以先支持两种 token 模式:
|
||||
- legacy:环境变量 `MVP_INTERNAL_TOKEN`(全局单租户);
|
||||
- user token:DB 内签发 token(多用户)。
|
||||
- admin 能创建用户、签发 token、禁用用户。
|
||||
|
||||
**TDD 用例(先写测试)**
|
||||
|
||||
单测:
|
||||
- `test_user_db_create_disable()`
|
||||
- 创建用户 ACTIVE;禁用后状态变为 DISABLED;重复创建返回冲突或幂等(按最终约定)。
|
||||
- `test_token_hashing()`
|
||||
- 签发 token 时 DB 中只保存 hash,不保存明文。
|
||||
|
||||
API 测试(TestClient):
|
||||
- `test_admin_create_user_and_issue_token()`
|
||||
- admin token 可创建用户并签发 token(明文 token 只返回一次)。
|
||||
- `test_disabled_user_token_rejected()`
|
||||
- 用户被禁用后,使用旧 token 调用 API 返回 401/403。
|
||||
|
||||
**实现落点(建议模块)**
|
||||
- `argus.service.auth`:token 校验与 user_id 解析(兼容 legacy 模式)
|
||||
- `argus.service.db`:新增 `users`、`api_tokens` 表与 CRUD
|
||||
- `argus.service.app`:新增 user 管理 endpoints(admin scope)
|
||||
- `configs/dev.yaml`:补充 admin token/env 相关配置(保持 YAML 风格)
|
||||
|
||||
**验收点**
|
||||
- `v2.5_acceptance.md`:U1 可通过自动化 API 测试覆盖。
|
||||
|
||||
---
|
||||
|
||||
### M2:Task 绑定 user_id + API 可见性隔离(仍不改 TaskSpec)
|
||||
|
||||
**目标**
|
||||
- 提交 task 时由 token 推导 `user_id`,写入 `tasks.user_id`。
|
||||
- task 查询/取消/日志默认只允许 owner;他人访问返回 404(避免泄露存在性)。
|
||||
- queue 默认只返回当前用户队列;admin 可查询全局队列(可选)。
|
||||
|
||||
**TDD 用例(先写测试)**
|
||||
|
||||
单测:
|
||||
- `test_tasks_table_has_user_id()`:创建任务必须落 `user_id`,且 `list_queue(user_id=...)` 只返回该用户任务。
|
||||
|
||||
API 测试:
|
||||
- `test_task_visibility_isolated()`
|
||||
- user A 创建 task;user B 查询 `/api/v2/tasks/{id}` 返回 404;
|
||||
- user B cancel/logs 也返回 404。
|
||||
- `test_queue_isolated()`
|
||||
- A/B 各自创建 task;`GET /api/v2/queue` 只看到自己的。
|
||||
|
||||
**实现落点**
|
||||
- `argus.service.app`:为 task endpoints 增加 user scope
|
||||
- `argus.service.db`:tasks 表增加 user_id 字段、索引、按 user 过滤的查询方法
|
||||
- `argus.service.scheduler`:pick_next_runnable_task 等仍按“全局 FIFO”或“按 user FIFO”
|
||||
- v2.5 先保持“全局 FIFO”最简单(但 API queue 视角是按 user 过滤)。
|
||||
|
||||
**验收点**
|
||||
- `v2.5_acceptance.md`:U2 可通过 API 测试覆盖。
|
||||
|
||||
---
|
||||
|
||||
### M3:Jobs 输出目录按 user 隔离(只改输出,不改输入)
|
||||
|
||||
**目标**
|
||||
- Ray Job 的 job_root 目录由服务端统一计算到:
|
||||
- `/private/users/<uid>/jobs/<ray_submission_id>/...`
|
||||
- TaskSpec 内与输入相关的路径字段必须是 `/private/common/...`(v2.5 输入统一 common)。
|
||||
- 任何用户无法通过 TaskSpec 指定输出写到非 user jobs 目录(避免越权写)。
|
||||
|
||||
**TDD 用例(先写测试)**
|
||||
|
||||
单测:
|
||||
- `test_job_root_derivation_per_user()`
|
||||
- 给定 user_id 与 ray_submission_id,派生 job_root 固定且正确。
|
||||
- `test_reject_non_common_inputs()`
|
||||
- TaskSpec 中 train_file / val_file / code_path / hf 路径等若不以 `/private/common/` 开头则拒绝(HTTP 400)。
|
||||
|
||||
API 测试:
|
||||
- `test_job_dir_written_under_user_jobs()`
|
||||
- 提交 task 后,在 DB 或 submit payload 中能看到 job_root 在 user jobs 下(可通过 mock RayJobTool.submit 捕获 spec)。
|
||||
|
||||
**实现落点(建议最小侵入)**
|
||||
- 在 service 层派生 `job_root` 并注入到 RayJobTool/builders(而不是让用户从 TaskSpec 指定)。
|
||||
- RayJobTool `_job_dir()` 改为接收“job_root 生成器”或直接接收 `job_root` 参数(由服务层提供)。
|
||||
- 目标:保持 RayJobTool 的职责清晰:提交 Ray job;路径策略由 service 决定。
|
||||
|
||||
**验收点**
|
||||
- `v2.5_acceptance.md`:U3/U4 可通过 API/单测覆盖。
|
||||
|
||||
---
|
||||
|
||||
### M4:Stateless Ray Node Pool(head.json + worker watchdog)+ 端到端脚本验证
|
||||
|
||||
**目标**
|
||||
- head 启动后持续写入 `/private/ray/discovery/<cluster_name>/head.json`(包含 TTL)。
|
||||
- worker 容器内运行 watchdog(或启动脚本 + watchdog),无需平台显式传 head 地址:
|
||||
- 读取 head.json(存在且未过期)→ `ray start --address=<head_ip>:<gcs_port>`
|
||||
- head.json 变化 → `ray stop` + `ray start` 重连
|
||||
- 在 dev 环境(docker compose)提供一键脚本复现(e2e)。
|
||||
|
||||
**TDD 用例(先写测试)**
|
||||
|
||||
单测(不跑真实 ray):
|
||||
- `test_head_json_read_validate_ttl()`
|
||||
- 文件不存在/过期 → 返回“不可用”
|
||||
- 未过期 → 返回 head 地址
|
||||
- `test_watchdog_decision_on_change()`
|
||||
- head_ip 变化 → 触发重连动作
|
||||
- only updated_at 变化(地址不变)→ 不重连(或按策略重连,需确定)
|
||||
|
||||
组件/脚本级测试(可选):
|
||||
- 如果 watchdog 用 Python 实现,可对“执行命令”层做 stub(不真正跑 `ray start`),只验证会调用什么命令。
|
||||
|
||||
端到端脚本(手工/慢):
|
||||
- 提供脚本 `scripts/run_all_v25_stateless.sh`(命名示例):
|
||||
1) 起 head(Ray head + API)
|
||||
2) 启动 head publisher(写 head.json)
|
||||
3) 起 2 个 worker(每个 4 GPU),worker 只跑 watchdog,不传 head 地址
|
||||
4) `ray status` 显示 1 head + 2 worker 且 GPU=8
|
||||
5) 通过 API 创建用户/签发 token,提交 PPO/GRPO/SFT
|
||||
6) 重启 head(或更新 head.json 指向新地址)验证 worker 自动重连
|
||||
|
||||
**实现落点(建议实现策略)**
|
||||
|
||||
为了可测试性(TDD),推荐把“读 head.json/判定 TTL/生成 ray start 命令”做成 Python 模块:
|
||||
- `argus.ray.discovery`:read/write head.json(原子写、TTL)
|
||||
- `argus.ray.worker_watchdog`:watch loop(polling + change detection),执行命令可注入(便于单测 stub)
|
||||
|
||||
脚本层保持薄:
|
||||
- `scripts/` 负责 docker exec / compose 编排与进程守护;
|
||||
- watchdog 进程由容器内 python 模块运行(更可测、更易移植到生产平台的 entrypoint/command)。
|
||||
|
||||
**验收点**
|
||||
- `v2.5_acceptance.md`:A1/A2/A3 主要通过 e2e 脚本 + dashboard/日志验证。
|
||||
|
||||
---
|
||||
|
||||
## 2. 回归策略(确保 v2.0 不被破坏)
|
||||
|
||||
在 v2.5 过程中保留并持续回归以下用例(至少单测覆盖):
|
||||
- 旧的内部 token 模式仍可访问 `GET /api/v2/queue` 与提交 task(若决定保留兼容)。
|
||||
- scheduler 的“资源不足 → PENDING_RESOURCES → 延迟重试”行为不变(现有 `test_scheduler.py` 覆盖)。
|
||||
- `ray entrypoint_resources` 强制 driver 落 worker(继续使用 `worker_node` 自定义资源)。
|
||||
|
||||
---
|
||||
|
||||
## 3. 交付清单(代码/脚本/文档)
|
||||
|
||||
### 3.1 代码
|
||||
- user/tokens:DB schema + auth + API endpoints
|
||||
- tasks:绑定 user_id + 权限隔离
|
||||
- job_root:按 user jobs 输出目录派生(输入仍 common)
|
||||
- discovery/watchdog:head.json + worker 自愈
|
||||
|
||||
### 3.2 scripts(dev e2e)
|
||||
- head:启动 Ray head + head publisher
|
||||
- workers:以无状态方式启动(不传 head addr)+ watchdog
|
||||
- `run_all`:一键跑通(含 API submit + 查询 + cancel + 观察队列)
|
||||
|
||||
### 3.3 文档
|
||||
- 更新 `specs/mvp/v2.5/*`(设计/API/验收/开发计划)
|
||||
- 补充 `src/mvp/README.md` 的 v2.5 使用方式(如需要)
|
||||
|
||||
---
|
||||
|
||||
## 4. 关键待确认点(开始实现前必须定稿)
|
||||
|
||||
1) **legacy token 是否继续兼容**
|
||||
- 方案 A:保留 `MVP_INTERNAL_TOKEN`(单租户)+ 新增 user token(多租户)
|
||||
- 方案 B:v2.5 直接切换到 user token(破坏兼容,但更清晰)
|
||||
|
||||
2) **调度公平性**
|
||||
- v2.5 先全局 FIFO(简单);后续 v3 再引入 per-user 公平调度/配额。
|
||||
|
||||
3) **head.json 的生产写入者**
|
||||
- 方案 A:与 API 同进程线程(最少组件)
|
||||
- 方案 B:独立进程(更独立、易运维)
|
||||
|
||||
132
specs/mvp/v2.5/v2.5_e2e_test_cases.md
Normal file
132
specs/mvp/v2.5/v2.5_e2e_test_cases.md
Normal file
@ -0,0 +1,132 @@
|
||||
# MVP v2.5 端到端测试用例(正常/异常/边界)
|
||||
|
||||
本用例集目标:覆盖 v2.5 的关键能力与边界条件(User + jobs 隔离 + stateless node pool + API 队列调度)。
|
||||
|
||||
约束(v2.5 已确认):
|
||||
- TaskSpec 不扩展;不支持 reward_fn;不支持用户自定义 verl 代码。
|
||||
- 输入统一 `/private/common/...`;用户隔离先只隔离 `/private/users/<uid>/jobs/...` 输出。
|
||||
|
||||
---
|
||||
|
||||
## 0. 环境前置
|
||||
|
||||
远端目录示例:
|
||||
- `argus@h1:/home2/argus/infra/mvp/src/mvp/`
|
||||
|
||||
共享目录(宿主机):
|
||||
- `/home2/argus/infra/mvp/shared/`
|
||||
|
||||
容器内路径约定:
|
||||
- `/private` 为共享存储挂载点
|
||||
|
||||
需要:
|
||||
- GPU 0-7 可用
|
||||
- 3 容器:head(无 GPU)+ 2 worker(各 4 GPU)
|
||||
|
||||
---
|
||||
|
||||
## 1. 正常用例(Happy Path)
|
||||
|
||||
### HP-1:v2.5 全链路(PPO → GRPO → SFT,串行)
|
||||
|
||||
步骤:
|
||||
1) `cd /home2/argus/infra/mvp/src/mvp/scripts`
|
||||
2) `MVP_INTERNAL_TOKEN=<admin_token> RESET_DB=1 ./run_all_v25_api.sh`
|
||||
|
||||
期望:
|
||||
- Ray dashboard 显示 3 nodes(head+2 workers),GPU 总数 8。
|
||||
- 3 个 task 最终为 `SUCCEEDED`。
|
||||
- 输出目录存在且按用户隔离:
|
||||
- `/private/users/<uid>/jobs/<ray_submission_id>/{config,logs,checkpoints,debug}`
|
||||
|
||||
### HP-2:Driver 不在 head 跑
|
||||
|
||||
验证点(任选一种):
|
||||
- Ray job 的 driver node IP 不等于 head 容器 IP;
|
||||
- 或日志/调度信息显示 entrypoint_resources 生效(driver 在 worker)。
|
||||
|
||||
---
|
||||
|
||||
## 2. 异常用例(Error Cases)
|
||||
|
||||
### E-Auth-1:缺 token
|
||||
|
||||
请求:
|
||||
- `GET /api/v2/queue` 不带 `Authorization` 头
|
||||
|
||||
期望:
|
||||
- 返回 401(missing bearer token)
|
||||
|
||||
### E-Auth-2:无效 token
|
||||
|
||||
请求:
|
||||
- `Authorization: Bearer <random>`
|
||||
|
||||
期望:
|
||||
- 返回 401(invalid token)
|
||||
|
||||
### E-Auth-3:用户禁用后拒绝访问
|
||||
|
||||
步骤:
|
||||
1) admin 创建用户 `bob` 并签发 token
|
||||
2) admin 禁用 `bob`
|
||||
3) 用 bob token 请求 `/api/v2/queue`
|
||||
|
||||
期望:
|
||||
- 返回 403(user disabled)
|
||||
|
||||
### E-Isolation-1:跨用户访问 task 资源(不泄露存在性)
|
||||
|
||||
步骤:
|
||||
1) alice 提交 task 得到 `task_id`
|
||||
2) bob 查询 `/api/v2/tasks/{task_id}`
|
||||
|
||||
期望:
|
||||
- 返回 404(task not found)
|
||||
|
||||
### E-Input-1:输入路径不在 /private/common(v2.5 约束)
|
||||
|
||||
请求:
|
||||
- 提交 taskspec 但 `train_file` 或 `code_path` 不以 `/private/common/` 开头
|
||||
|
||||
期望:
|
||||
- 返回 400,并给出具体字段错误(例如 `train_file must start with /private/common/`)。
|
||||
|
||||
---
|
||||
|
||||
## 3. 边界用例(Boundary)
|
||||
|
||||
### B-Queue-1:资源不足时不提交 Ray(PENDING_RESOURCES)
|
||||
|
||||
步骤:
|
||||
1) 构造任务需求 `nnodes=3` 且 `n_gpus_per_node=4`(total 12 GPU)
|
||||
2) 提交后轮询状态
|
||||
|
||||
期望:
|
||||
- task 进入 `PENDING_RESOURCES`(服务侧 pending,不向 Ray submit)
|
||||
- 具备 `next_run_at`
|
||||
|
||||
### B-Cancel-1:任务取消(QUEUED/RUNNING)
|
||||
|
||||
步骤:
|
||||
1) 提交一个较长 steps 的任务(确保有机会 RUNNING)
|
||||
2) 调用 `POST /api/v2/tasks/{task_id}:cancel`
|
||||
|
||||
期望:
|
||||
- task state 为 `CANCELED`
|
||||
- attempt 中 `ray_status` 最终为 `STOPPED`(或 Ray 侧停止)
|
||||
|
||||
---
|
||||
|
||||
## 4. 可执行回归脚本
|
||||
|
||||
见:
|
||||
- `src/mvp/scripts/run_e2e_v25_cases.sh`
|
||||
|
||||
脚本覆盖:
|
||||
- HP-1
|
||||
- E-Auth-1/E-Auth-2/E-Input-1
|
||||
- E-Isolation-1
|
||||
- B-Queue-1
|
||||
- B-Cancel-1(best-effort)
|
||||
|
||||
92
specs/mvp/v2.5/v2.5_summary.md
Normal file
92
specs/mvp/v2.5/v2.5_summary.md
Normal file
@ -0,0 +1,92 @@
|
||||
# MVP v2.5 迭代总结(已落地)
|
||||
|
||||
本文档总结 v2.5 在 v2.0/v2.1/v2.2…基础上完成的能力、实现点、验收方式与已知限制,便于回顾与后续版本迭代对齐。
|
||||
|
||||
## 目标与边界
|
||||
|
||||
v2.5 的核心目标:
|
||||
- 引入 **User Management(用户管理)**:基于 token 的鉴权与任务级隔离(“只隔离 jobs”)。
|
||||
- 引入 **Stateless Ray Node Pool(无状态 Ray worker 池)**:worker 不依赖平台下发 head 地址,自动发现并连接/自愈。
|
||||
- 保持 **TaskSpec(v1.1 同款 YAML 格式)不扩展**:本迭代不支持 reward function、自定义 verl 代码等。
|
||||
|
||||
明确不做(v2.5 约束):
|
||||
- 不支持 TaskSpec 扩展(例如 `reward_fn_path` 等)。
|
||||
- 不支持用户自定义 verl/hf/dataset 的隔离或自定义路径:**统一使用 `/private/common/...`** 的公共资源。
|
||||
- 用户隔离仅覆盖 **任务与产物目录**(jobs),不覆盖 HF cache、datasets 等公共缓存。
|
||||
|
||||
## 关键能力(对外表现)
|
||||
|
||||
### 1) 多用户鉴权与任务隔离
|
||||
- API 仍使用内部 `Authorization: Bearer <token>` 方式:
|
||||
- 管理员 token 来自环境变量 `MVP_INTERNAL_TOKEN`(admin)。
|
||||
- 业务用户 token 由管理员通过 API 下发并持久化到 SQLite。
|
||||
- 用户隔离策略:
|
||||
- 非管理员用户只能查询/取消/拉取日志 **自己的 task**;跨用户访问返回 404(不泄露存在性)。
|
||||
- 训练产物落盘隔离:Ray job 目录统一写入 `/private/users/<user_id>/jobs/<ray_submission_id>/...`。
|
||||
|
||||
### 2) task_id / submission_id 带用户名
|
||||
- 新任务 ID 规则:`mvp2-<user>-<workload>-<YYYYMMDD-HHMMSS>-<suffix>`
|
||||
- Ray submission id(attempt)规则:`<task_id>--aNN`,因此自然包含用户名。
|
||||
- 作用:Dashboard/日志/落盘目录可读性更强,便于按用户追踪和审计。
|
||||
|
||||
### 3) “无状态 worker 池”与 head 地址发现
|
||||
- Head 在共享存储写入 **head 地址文件**(例如 `head.json`),worker 通过 watchdog:
|
||||
- 轮询发现 head 地址
|
||||
- 自动 `ray start --address ...` 加入集群
|
||||
- 掉线后自动重连(watchdog 自愈)
|
||||
- 达成效果:在生产环境中,即使 worker 容器由算力平台创建(只提供 SSH 纳管),也能通过共享存储实现连接与自愈。
|
||||
|
||||
### 4) 任务调度:队列 + Ray Job 提交 + 状态回传
|
||||
- API 提交任务后进入 SQLite 队列,由后台 scheduler 逐个提交到 Ray(默认 `max_running_tasks=1`)。
|
||||
- Scheduler 持续轮询 Ray job 状态并回写任务状态(RUNNING/SUCCEEDED/FAILED/CANCELED)。
|
||||
- 资源不足的“可重试失败”处理:
|
||||
- 针对 VERL 的 fail-fast(`Total available GPUs ... is less than total desired GPUs ...`)或集群资源不足,
|
||||
任务进入 `PENDING_RESOURCES` 并设置 `next_run_at`,按 `retry_interval_s` 周期重试。
|
||||
|
||||
## 关键实现点(工程化落地)
|
||||
|
||||
### 存储与目录约定(容器内视角)
|
||||
- 共享根路径统一为 `/private`(对齐生产挂载)。
|
||||
- v2.5 强约束:TaskSpec 的以下字段必须以 `/private/common/` 开头:
|
||||
- `code_path` / `train_file` / `val_file`
|
||||
- 公共目录(示例):
|
||||
- `/private/common/hf`:HF 缓存
|
||||
- `/private/common/datasets`:训练数据(必要时通过 symlink 指向已有缓存目录复用下载)
|
||||
- `/private/common/db/mvp.sqlite3`:队列与用户信息(SQLite)
|
||||
- `/private/common/logs`:API / watchdog 日志
|
||||
- `/private/users/<uid>/jobs/...`:用户作业产物(隔离)
|
||||
|
||||
### Ray 拓扑与“head 不跑训练”
|
||||
- Head 启动为管理节点(CPU/GPU=0),避免训练任务落到 head。
|
||||
- Worker 节点具备 GPU(示例:2 个 worker * 每个 4 GPU)。
|
||||
- driver 通过 `entrypoint_resources`(例如 `worker_node: 1`)强制落 worker。
|
||||
|
||||
### 部署脚本与可重复执行
|
||||
提供完整脚本链路,覆盖:
|
||||
- 清理 legacy 环境、起停容器、启动 Ray head
|
||||
- head discovery publisher、worker watchdog 启动与状态检查
|
||||
- 数据/模型/代码准备(幂等、可复用已有下载)
|
||||
- 启动 API server(并支持 RESET_DB)
|
||||
- API 方式连续提交 PPO/GRPO/SFT 并等待完成
|
||||
|
||||
代表性脚本:
|
||||
- `src/mvp/scripts/run_all_v25_api.sh`:v2.5 happy-path 端到端(含重建集群、准备资源、起 API、提交 3 类任务)
|
||||
- `src/mvp/scripts/run_e2e_v25_cases.sh`:在 happy-path 基础上增加鉴权/隔离/输入校验/资源不足/取消等用例
|
||||
|
||||
## 验收与测试(已通过)
|
||||
|
||||
### 单元测试(本机 venv)
|
||||
- `.venv/bin/python -m pytest`
|
||||
- 覆盖率阈值:>= 90%
|
||||
|
||||
### 远端端到端(h1)
|
||||
- 在 `argus@h1:/home2/argus/infra/mvp/src/mvp/scripts` 执行:
|
||||
- `MVP_INTERNAL_TOKEN=mvp-dev-token RESET_DB=1 ./run_e2e_v25_cases.sh`
|
||||
- 结果:happy-path(PPO/GRPO/SFT)完成,且异常/边界用例验证通过(鉴权、跨用户隔离、输入校验、资源不足转 PENDING_RESOURCES、取消任务等)。
|
||||
|
||||
## 已知问题与后续建议
|
||||
|
||||
- `max_running_tasks=1` 会让队列中的任务在前序 RUNNING 时保持 QUEUED,这在“资源不足”边界测试里需要显式清空/取消前序任务,或接受该行为作为设计的一部分。
|
||||
- 当前仍是 SQLite 单点;后续若要 HA/水平扩展,可在 v2.6+ 引入更强的持久化与多副本(例如 Postgres/etcd)。
|
||||
- API server / watchdog 目前以脚本方式守护;后续可进一步统一为 systemd/supervisor(或平台侧守护)并补齐健康检查与告警。
|
||||
|
||||
@ -15,3 +15,4 @@
|
||||
快速开始:
|
||||
- CLI 提交流程:`scripts/run_all_cli.sh`
|
||||
- API 提交流程:`scripts/run_all_api.sh`
|
||||
- v2.5(Stateless worker + user 隔离 jobs)E2E:`scripts/run_all_v25_api.sh`
|
||||
|
||||
@ -2,14 +2,27 @@ from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
|
||||
def new_task_id(workload: str) -> str:
|
||||
_USER_SAFE_RE = re.compile(r"[^a-z0-9_-]+")
|
||||
|
||||
|
||||
def _normalize_user_id(user_id: str) -> str:
|
||||
s = (user_id or "").strip().lower()
|
||||
s = _USER_SAFE_RE.sub("-", s)
|
||||
s = re.sub(r"-{2,}", "-", s).strip("-")
|
||||
return s[:24] if s else ""
|
||||
|
||||
|
||||
def new_task_id(workload: str, *, user_id: str | None = None) -> str:
|
||||
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
suffix = secrets.token_hex(2)
|
||||
u = _normalize_user_id(user_id or "")
|
||||
if u:
|
||||
return f"mvp2-{u}-{workload}-{ts}-{suffix}"
|
||||
return f"mvp2-{workload}-{ts}-{suffix}"
|
||||
|
||||
|
||||
def attempt_submission_id(task_id: str, attempt_no: int) -> str:
|
||||
return f"{task_id}--a{attempt_no:02d}"
|
||||
|
||||
|
||||
115
src/mvp/py/argus/ray/discovery.py
Normal file
115
src/mvp/py/argus/ray/discovery.py
Normal file
@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _parse_utc(ts: str) -> datetime:
|
||||
# Accept ISO8601 with trailing Z
|
||||
s = ts.strip()
|
||||
if s.endswith("Z"):
|
||||
s = s[:-1] + "+00:00"
|
||||
return datetime.fromisoformat(s).astimezone(timezone.utc)
|
||||
|
||||
|
||||
def _iso_z(dt: datetime) -> str:
|
||||
return dt.astimezone(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HeadRecord:
|
||||
cluster_name: str
|
||||
head_ip: str
|
||||
gcs_port: int
|
||||
dashboard_port: int
|
||||
job_server_url: str
|
||||
updated_at: str
|
||||
expires_at: str
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: dict[str, Any]) -> "HeadRecord":
|
||||
return HeadRecord(
|
||||
cluster_name=str(d["cluster_name"]),
|
||||
head_ip=str(d["head_ip"]),
|
||||
gcs_port=int(d.get("gcs_port", 6379)),
|
||||
dashboard_port=int(d.get("dashboard_port", 8265)),
|
||||
job_server_url=str(d.get("job_server_url") or f"http://{d['head_ip']}:{int(d.get('dashboard_port', 8265))}"),
|
||||
updated_at=str(d["updated_at"]),
|
||||
expires_at=str(d["expires_at"]),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"cluster_name": self.cluster_name,
|
||||
"head_ip": self.head_ip,
|
||||
"gcs_port": self.gcs_port,
|
||||
"dashboard_port": self.dashboard_port,
|
||||
"job_server_url": self.job_server_url,
|
||||
"updated_at": self.updated_at,
|
||||
"expires_at": self.expires_at,
|
||||
}
|
||||
|
||||
def head_addr(self) -> str:
|
||||
return f"{self.head_ip}:{self.gcs_port}"
|
||||
|
||||
|
||||
def build_head_record(
|
||||
*,
|
||||
cluster_name: str,
|
||||
head_ip: str,
|
||||
gcs_port: int = 6379,
|
||||
dashboard_port: int = 8265,
|
||||
ttl_s: int = 60,
|
||||
now: datetime | None = None,
|
||||
) -> HeadRecord:
|
||||
now_dt = now or _utc_now()
|
||||
expires = now_dt + timedelta(seconds=int(ttl_s))
|
||||
updated_at = _iso_z(now_dt)
|
||||
expires_at = _iso_z(expires)
|
||||
return HeadRecord(
|
||||
cluster_name=cluster_name,
|
||||
head_ip=head_ip,
|
||||
gcs_port=int(gcs_port),
|
||||
dashboard_port=int(dashboard_port),
|
||||
job_server_url=f"http://{head_ip}:{int(dashboard_port)}",
|
||||
updated_at=updated_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
|
||||
def load_head_record(path: str, *, now: datetime | None = None) -> HeadRecord | None:
|
||||
p = Path(path)
|
||||
if not p.exists():
|
||||
return None
|
||||
try:
|
||||
obj = json.loads(p.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(obj, dict):
|
||||
return None
|
||||
try:
|
||||
rec = HeadRecord.from_dict(obj)
|
||||
expires = _parse_utc(rec.expires_at)
|
||||
except Exception:
|
||||
return None
|
||||
now_dt = now or _utc_now()
|
||||
if expires <= now_dt:
|
||||
return None
|
||||
return rec
|
||||
|
||||
|
||||
def write_head_record_atomic(path: str, rec: HeadRecord) -> None:
|
||||
p = Path(path)
|
||||
os.makedirs(p.parent, exist_ok=True)
|
||||
tmp = p.with_suffix(p.suffix + ".tmp")
|
||||
tmp.write_text(json.dumps(rec.to_dict(), indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
||||
tmp.replace(p)
|
||||
|
||||
65
src/mvp/py/argus/ray/head_publisher.py
Normal file
65
src/mvp/py/argus/ray/head_publisher.py
Normal file
@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from .discovery import build_head_record, write_head_record_atomic
|
||||
|
||||
|
||||
def publish_once(
|
||||
*,
|
||||
cluster_name: str,
|
||||
head_ip_file: str,
|
||||
head_ip: str,
|
||||
gcs_port: int,
|
||||
dashboard_port: int,
|
||||
ttl_s: int,
|
||||
) -> None:
|
||||
rec = build_head_record(
|
||||
cluster_name=cluster_name,
|
||||
head_ip=head_ip,
|
||||
gcs_port=gcs_port,
|
||||
dashboard_port=dashboard_port,
|
||||
ttl_s=ttl_s,
|
||||
)
|
||||
write_head_record_atomic(head_ip_file, rec)
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
ap = argparse.ArgumentParser(description="Publish Ray head address to shared storage (head.json).")
|
||||
ap.add_argument("--cluster-name", required=True)
|
||||
ap.add_argument("--head-ip-file", required=True)
|
||||
ap.add_argument("--head-ip", required=True)
|
||||
ap.add_argument("--gcs-port", type=int, default=6379)
|
||||
ap.add_argument("--dashboard-port", type=int, default=8265)
|
||||
ap.add_argument("--ttl-s", type=int, default=60)
|
||||
ap.add_argument("--refresh-s", type=int, default=10)
|
||||
ap.add_argument("--once", action="store_true", help="Write once then exit (for testing/debug).")
|
||||
args = ap.parse_args(argv)
|
||||
|
||||
if args.once:
|
||||
publish_once(
|
||||
cluster_name=args.cluster_name,
|
||||
head_ip_file=args.head_ip_file,
|
||||
head_ip=args.head_ip,
|
||||
gcs_port=args.gcs_port,
|
||||
dashboard_port=args.dashboard_port,
|
||||
ttl_s=args.ttl_s,
|
||||
)
|
||||
return 0
|
||||
|
||||
refresh_s = max(1, int(args.refresh_s))
|
||||
while True:
|
||||
publish_once(
|
||||
cluster_name=args.cluster_name,
|
||||
head_ip_file=args.head_ip_file,
|
||||
head_ip=args.head_ip,
|
||||
gcs_port=args.gcs_port,
|
||||
dashboard_port=args.dashboard_port,
|
||||
ttl_s=args.ttl_s,
|
||||
)
|
||||
time.sleep(refresh_s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@ -73,9 +73,9 @@ class RayJobTool:
|
||||
|
||||
return {"env_vars": env_vars}
|
||||
|
||||
def submit(self, spec: JobSpec, no_wait: bool) -> str:
|
||||
def submit(self, spec: JobSpec, no_wait: bool, job_dir: str | None = None) -> str:
|
||||
submission_id = spec.submission_id or f"mvp11_{spec.workload}_{_ts()}_{os.getpid()}"
|
||||
job_dir = self._job_dir(submission_id)
|
||||
job_dir = job_dir or self._job_dir(submission_id)
|
||||
|
||||
built = build_training_argv(spec, submission_id=submission_id, job_dir=job_dir)
|
||||
entrypoint_argv = [
|
||||
|
||||
110
src/mvp/py/argus/ray/worker_watchdog.py
Normal file
110
src/mvp/py/argus/ray/worker_watchdog.py
Normal file
@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
from .discovery import HeadRecord, load_head_record
|
||||
|
||||
|
||||
Runner = Callable[[list[str]], int]
|
||||
|
||||
|
||||
def _default_runner(argv: list[str]) -> int:
|
||||
proc = subprocess.run(argv, check=False)
|
||||
return int(proc.returncode)
|
||||
|
||||
|
||||
def _ray_stop_cmd() -> list[str]:
|
||||
return ["ray", "stop", "--force"]
|
||||
|
||||
|
||||
def _ray_start_cmd(*, head_addr: str, node_ip: str | None, resources_json: str) -> list[str]:
|
||||
argv = ["ray", "start", f"--address={head_addr}", f"--resources={resources_json}", "--disable-usage-stats"]
|
||||
if node_ip:
|
||||
argv.append(f"--node-ip-address={node_ip}")
|
||||
return argv
|
||||
|
||||
|
||||
@dataclass
|
||||
class Watchdog:
|
||||
head_ip_file: str
|
||||
node_ip: str | None
|
||||
resources_json: str
|
||||
poll_s: int
|
||||
runner: Runner
|
||||
|
||||
_current_head_addr: str | None = None
|
||||
|
||||
def _restart_ray(self, desired: str) -> None:
|
||||
# Best-effort stop then start.
|
||||
self.runner(_ray_stop_cmd())
|
||||
self.runner(_ray_start_cmd(head_addr=desired, node_ip=self.node_ip, resources_json=self.resources_json))
|
||||
self._current_head_addr = desired
|
||||
|
||||
def tick_once(self) -> HeadRecord | None:
|
||||
rec = load_head_record(self.head_ip_file)
|
||||
if rec is None:
|
||||
return None
|
||||
desired = rec.head_addr()
|
||||
if self._current_head_addr != desired:
|
||||
self._restart_ray(desired)
|
||||
return rec
|
||||
|
||||
def run_forever(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
self.tick_once()
|
||||
except Exception:
|
||||
# keep watchdog alive
|
||||
pass
|
||||
time.sleep(max(1, int(self.poll_s)))
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
ap = argparse.ArgumentParser(description="Stateless Ray worker watchdog (auto-connect + reconnect).")
|
||||
ap.add_argument("--head-ip-file", required=True)
|
||||
ap.add_argument("--node-ip", default=None)
|
||||
ap.add_argument(
|
||||
"--resources-kv",
|
||||
action="append",
|
||||
default=[],
|
||||
help='Repeatable k=v resources, e.g. --resources-kv worker_node=100',
|
||||
)
|
||||
ap.add_argument("--poll-s", type=int, default=5)
|
||||
ap.add_argument("--once", action="store_true", help="Run one tick then exit (for testing/debug).")
|
||||
args = ap.parse_args(argv)
|
||||
|
||||
resources: dict[str, float] = {"worker_node": 100.0}
|
||||
if args.resources_kv:
|
||||
resources = {}
|
||||
for item in args.resources_kv:
|
||||
if "=" not in item:
|
||||
raise SystemExit(f"invalid --resources-kv (expected k=v): {item!r}")
|
||||
k, v = item.split("=", 1)
|
||||
k = k.strip()
|
||||
v = v.strip()
|
||||
if not k:
|
||||
raise SystemExit(f"invalid --resources-kv key: {item!r}")
|
||||
resources[k] = float(v)
|
||||
|
||||
wd = Watchdog(
|
||||
head_ip_file=args.head_ip_file,
|
||||
node_ip=args.node_ip,
|
||||
resources_json=json.dumps(resources),
|
||||
poll_s=int(args.poll_s),
|
||||
runner=_default_runner,
|
||||
)
|
||||
if args.once:
|
||||
wd.tick_once()
|
||||
return 0
|
||||
|
||||
wd.run_forever()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@ -43,19 +43,32 @@ def create_app(config_path: str) -> FastAPI:
|
||||
|
||||
app = FastAPI(title="mvp-v2", version="2.0")
|
||||
|
||||
def _require_token(req: Request) -> None:
|
||||
def _auth(req: Request) -> dict[str, Any]:
|
||||
token_env = v2_cfg.auth.token_env
|
||||
expected = os.environ.get(token_env, "")
|
||||
if not expected:
|
||||
# Misconfigured service; treat as server error.
|
||||
admin_token = os.environ.get(token_env, "")
|
||||
if not admin_token:
|
||||
raise HTTPException(status_code=500, detail=f"missing token env: {token_env}")
|
||||
|
||||
auth = req.headers.get("authorization") or ""
|
||||
if not auth.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="missing bearer token")
|
||||
got = auth.removeprefix("Bearer ").strip()
|
||||
if got != expected:
|
||||
|
||||
if got == admin_token:
|
||||
return {"user_id": "admin", "is_admin": True}
|
||||
|
||||
info = db.resolve_token(got)
|
||||
if not info:
|
||||
raise HTTPException(status_code=401, detail="invalid token")
|
||||
if info.get("state") != "ACTIVE":
|
||||
raise HTTPException(status_code=403, detail="user disabled")
|
||||
return {"user_id": info["user_id"], "is_admin": False}
|
||||
|
||||
def _require_admin(req: Request) -> dict[str, Any]:
|
||||
subject = _auth(req)
|
||||
if not subject.get("is_admin"):
|
||||
raise HTTPException(status_code=403, detail="admin required")
|
||||
return subject
|
||||
|
||||
@app.on_event("startup")
|
||||
def _startup() -> None:
|
||||
@ -66,9 +79,43 @@ def create_app(config_path: str) -> FastAPI:
|
||||
def _shutdown() -> None:
|
||||
stop_flag.set()
|
||||
|
||||
@app.post("/api/v2/users")
|
||||
async def create_user(req: Request) -> dict[str, Any]:
|
||||
_require_admin(req)
|
||||
obj = await req.json()
|
||||
if not isinstance(obj, dict):
|
||||
raise HTTPException(status_code=400, detail="body must be a JSON object")
|
||||
user_id = str(obj.get("user_id") or "").strip()
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=400, detail="missing user_id")
|
||||
display_name = obj.get("display_name")
|
||||
try:
|
||||
row = db.create_user(user_id=user_id, display_name=str(display_name) if display_name is not None else None)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=409, detail=f"user create failed: {e!r}")
|
||||
return {"user_id": row.get("user_id", user_id), "state": row.get("state", "ACTIVE")}
|
||||
|
||||
@app.post("/api/v2/users/{user_id}/tokens")
|
||||
async def issue_token(user_id: str, req: Request) -> dict[str, Any]:
|
||||
_require_admin(req)
|
||||
u = db.get_user(user_id)
|
||||
if not u:
|
||||
raise HTTPException(status_code=404, detail="user not found")
|
||||
token = db.issue_token(user_id=user_id)
|
||||
return {"user_id": user_id, "token": token}
|
||||
|
||||
@app.post("/api/v2/users/{user_id}:disable")
|
||||
async def disable_user(user_id: str, req: Request) -> dict[str, Any]:
|
||||
_require_admin(req)
|
||||
u = db.get_user(user_id)
|
||||
if not u:
|
||||
raise HTTPException(status_code=404, detail="user not found")
|
||||
db.disable_user(user_id=user_id)
|
||||
return {"user_id": user_id, "state": "DISABLED"}
|
||||
|
||||
@app.post("/api/v2/tasks")
|
||||
async def submit_task(req: Request) -> dict[str, Any]:
|
||||
_require_token(req)
|
||||
subject = _auth(req)
|
||||
body = (await req.body()).decode("utf-8")
|
||||
obj = yaml.safe_load(body) or {}
|
||||
if not isinstance(obj, dict):
|
||||
@ -79,9 +126,18 @@ def create_app(config_path: str) -> FastAPI:
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"invalid jobspec: {e!r}")
|
||||
|
||||
task_id = new_task_id(spec.workload)
|
||||
db.create_task(
|
||||
# v2.5 constraint: training inputs must come from /private/common (dev/prod统一)。
|
||||
common_prefix = ray_cfg.shared_root.rstrip("/") + "/common/"
|
||||
for k, v in (("code_path", spec.code_path), ("train_file", spec.train_file), ("val_file", spec.val_file)):
|
||||
if v is None:
|
||||
continue
|
||||
if not str(v).startswith(common_prefix):
|
||||
raise HTTPException(status_code=400, detail=f"{k} must start with {common_prefix}")
|
||||
|
||||
task_id = new_task_id(spec.workload, user_id=str(subject["user_id"]))
|
||||
db.create_task_v25(
|
||||
task_id=task_id,
|
||||
user_id=str(subject["user_id"]),
|
||||
workload=spec.workload,
|
||||
jobspec_yaml=body,
|
||||
nnodes=spec.nnodes,
|
||||
@ -91,10 +147,13 @@ def create_app(config_path: str) -> FastAPI:
|
||||
|
||||
@app.get("/api/v2/tasks/{task_id}")
|
||||
async def get_task(task_id: str, req: Request) -> dict[str, Any]:
|
||||
_require_token(req)
|
||||
subject = _auth(req)
|
||||
row = db.get_task(task_id)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="task not found")
|
||||
if not subject.get("is_admin"):
|
||||
if str(row.get("user_id") or "") != str(subject["user_id"]):
|
||||
raise HTTPException(status_code=404, detail="task not found")
|
||||
attempts = db.list_attempts(task_id)
|
||||
latest_attempt = attempts[-1] if attempts else None
|
||||
desired = {
|
||||
@ -125,18 +184,24 @@ def create_app(config_path: str) -> FastAPI:
|
||||
|
||||
@app.get("/api/v2/tasks/{task_id}/attempts")
|
||||
async def get_attempts(task_id: str, req: Request) -> dict[str, Any]:
|
||||
_require_token(req)
|
||||
subject = _auth(req)
|
||||
row = db.get_task(task_id)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="task not found")
|
||||
if not subject.get("is_admin"):
|
||||
if str(row.get("user_id") or "") != str(subject["user_id"]):
|
||||
raise HTTPException(status_code=404, detail="task not found")
|
||||
return {"task_id": task_id, "attempts": db.list_attempts(task_id)}
|
||||
|
||||
@app.post("/api/v2/tasks/{task_id}:cancel")
|
||||
async def cancel(task_id: str, req: Request) -> dict[str, Any]:
|
||||
_require_token(req)
|
||||
subject = _auth(req)
|
||||
row = db.get_task(task_id)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="task not found")
|
||||
if not subject.get("is_admin"):
|
||||
if str(row.get("user_id") or "") != str(subject["user_id"]):
|
||||
raise HTTPException(status_code=404, detail="task not found")
|
||||
|
||||
state = str(row["state"])
|
||||
if state in ("SUCCEEDED", "FAILED", "CANCELED"):
|
||||
@ -165,10 +230,13 @@ def create_app(config_path: str) -> FastAPI:
|
||||
|
||||
@app.get("/api/v2/tasks/{task_id}/logs")
|
||||
async def logs(task_id: str, req: Request, tail: int = 2000, attempt: str = "latest") -> Response:
|
||||
_require_token(req)
|
||||
subject = _auth(req)
|
||||
row = db.get_task(task_id)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="task not found")
|
||||
if not subject.get("is_admin"):
|
||||
if str(row.get("user_id") or "") != str(subject["user_id"]):
|
||||
raise HTTPException(status_code=404, detail="task not found")
|
||||
attempts = db.list_attempts(task_id)
|
||||
if not attempts:
|
||||
raise HTTPException(status_code=404, detail="no attempts yet")
|
||||
@ -184,7 +252,9 @@ def create_app(config_path: str) -> FastAPI:
|
||||
|
||||
@app.get("/api/v2/queue")
|
||||
async def queue(req: Request) -> dict[str, Any]:
|
||||
_require_token(req)
|
||||
return db.list_queue()
|
||||
subject = _auth(req)
|
||||
if subject.get("is_admin"):
|
||||
return db.list_queue()
|
||||
return db.list_queue(user_id=str(subject["user_id"]))
|
||||
|
||||
return app
|
||||
|
||||
@ -4,6 +4,8 @@ import os
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import Any, Iterator
|
||||
|
||||
|
||||
@ -18,6 +20,10 @@ def _utc_now_iso() -> str:
|
||||
class Db:
|
||||
db_path: str
|
||||
|
||||
def _hash_token(self, token: str) -> str:
|
||||
# Internal tokens only; store hash to avoid plaintext at rest.
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
conn = sqlite3.connect(self.db_path, timeout=30, isolation_level=None)
|
||||
@ -28,6 +34,27 @@ class Db:
|
||||
|
||||
def init(self) -> None:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
user_id TEXT PRIMARY KEY,
|
||||
display_name TEXT,
|
||||
state TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS api_tokens (
|
||||
token_hash TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
last_used_at TEXT,
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
@ -45,6 +72,12 @@ class Db:
|
||||
)
|
||||
"""
|
||||
)
|
||||
# Best-effort: add user_id column for forward compatibility (v2.5+).
|
||||
# v2.0 tables may exist without it; SQLite supports ADD COLUMN.
|
||||
try:
|
||||
conn.execute("ALTER TABLE tasks ADD COLUMN user_id TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS attempts (
|
||||
@ -92,10 +125,10 @@ class Db:
|
||||
with self.tx() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO tasks (task_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at)
|
||||
VALUES (?, ?, 'QUEUED', ?, ?, ?, ?, ?)
|
||||
INSERT INTO tasks (task_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at, user_id)
|
||||
VALUES (?, ?, 'QUEUED', ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(task_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now),
|
||||
(task_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now, None),
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (?, ?, 'TASK_CREATED', ?)",
|
||||
@ -104,6 +137,97 @@ class Db:
|
||||
row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
|
||||
return dict(row) if row else {}
|
||||
|
||||
def create_task_v25(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
user_id: str,
|
||||
workload: str,
|
||||
jobspec_yaml: str,
|
||||
nnodes: int,
|
||||
n_gpus_per_node: int,
|
||||
) -> dict[str, Any]:
|
||||
now = _utc_now_iso()
|
||||
with self.tx() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO tasks (task_id, user_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at)
|
||||
VALUES (?, ?, ?, 'QUEUED', ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(task_id, user_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now),
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (?, ?, 'TASK_CREATED', ?)",
|
||||
(task_id, now, None),
|
||||
)
|
||||
row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
|
||||
return dict(row) if row else {}
|
||||
|
||||
def create_user(self, *, user_id: str, display_name: str | None = None) -> dict[str, Any]:
|
||||
now = _utc_now_iso()
|
||||
with self.tx() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO users (user_id, display_name, state, created_at)
|
||||
VALUES (?, ?, 'ACTIVE', ?)
|
||||
""",
|
||||
(user_id, display_name, now),
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'USER_CREATED', ?)",
|
||||
(now, user_id),
|
||||
)
|
||||
row = conn.execute("SELECT * FROM users WHERE user_id = ?", (user_id,)).fetchone()
|
||||
return dict(row) if row else {}
|
||||
|
||||
def disable_user(self, *, user_id: str) -> None:
|
||||
now = _utc_now_iso()
|
||||
with self.tx() as conn:
|
||||
conn.execute("UPDATE users SET state = 'DISABLED' WHERE user_id = ?", (user_id,))
|
||||
conn.execute(
|
||||
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'USER_DISABLED', ?)",
|
||||
(now, user_id),
|
||||
)
|
||||
|
||||
def get_user(self, user_id: str) -> dict[str, Any] | None:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute("SELECT * FROM users WHERE user_id = ?", (user_id,)).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def issue_token(self, *, user_id: str) -> str:
|
||||
# Returns plaintext token once; stores hash only.
|
||||
now = _utc_now_iso()
|
||||
token = f"mvp_u_{user_id}_{secrets.token_urlsafe(18)}"
|
||||
token_hash = self._hash_token(token)
|
||||
with self.tx() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO api_tokens (token_hash, user_id, created_at) VALUES (?, ?, ?)",
|
||||
(token_hash, user_id, now),
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'TOKEN_ISSUED', ?)",
|
||||
(now, user_id),
|
||||
)
|
||||
return token
|
||||
|
||||
def resolve_token(self, token: str) -> dict[str, Any] | None:
|
||||
token_hash = self._hash_token(token)
|
||||
with self.tx() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT t.user_id, u.state
|
||||
FROM api_tokens t
|
||||
JOIN users u ON u.user_id = t.user_id
|
||||
WHERE t.token_hash = ?
|
||||
""",
|
||||
(token_hash,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
now = _utc_now_iso()
|
||||
conn.execute("UPDATE api_tokens SET last_used_at = ? WHERE token_hash = ?", (now, token_hash))
|
||||
return {"user_id": row["user_id"], "state": row["state"]}
|
||||
|
||||
def get_task(self, task_id: str) -> dict[str, Any] | None:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
|
||||
@ -116,26 +240,33 @@ class Db:
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def list_queue(self) -> dict[str, list[dict[str, Any]]]:
|
||||
def list_queue(self, *, user_id: str | None = None) -> dict[str, list[dict[str, Any]]]:
|
||||
with self._connect() as conn:
|
||||
pending = conn.execute(
|
||||
"""
|
||||
SELECT task_id, workload, state, nnodes, n_gpus_per_node, next_run_at, created_at, updated_at
|
||||
FROM tasks
|
||||
WHERE state IN ('QUEUED','PENDING_RESOURCES')
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 200
|
||||
"""
|
||||
).fetchall()
|
||||
running = conn.execute(
|
||||
"""
|
||||
SELECT task_id, workload, state, nnodes, n_gpus_per_node, latest_attempt_no, created_at, updated_at
|
||||
FROM tasks
|
||||
WHERE state IN ('SUBMITTING','SUBMITTED','RUNNING')
|
||||
ORDER BY updated_at ASC
|
||||
LIMIT 200
|
||||
"""
|
||||
).fetchall()
|
||||
params: list[Any] = []
|
||||
user_filter_sql = ""
|
||||
if user_id is not None:
|
||||
user_filter_sql = " AND user_id = ?"
|
||||
params = [user_id]
|
||||
|
||||
pending_sql = (
|
||||
"SELECT task_id, workload, state, nnodes, n_gpus_per_node, next_run_at, created_at, updated_at "
|
||||
"FROM tasks "
|
||||
"WHERE state IN ('QUEUED','PENDING_RESOURCES')"
|
||||
f"{user_filter_sql} "
|
||||
"ORDER BY created_at ASC "
|
||||
"LIMIT 200"
|
||||
)
|
||||
running_sql = (
|
||||
"SELECT task_id, workload, state, nnodes, n_gpus_per_node, latest_attempt_no, created_at, updated_at "
|
||||
"FROM tasks "
|
||||
"WHERE state IN ('SUBMITTING','SUBMITTED','RUNNING')"
|
||||
f"{user_filter_sql} "
|
||||
"ORDER BY updated_at ASC "
|
||||
"LIMIT 200"
|
||||
)
|
||||
|
||||
pending = conn.execute(pending_sql, tuple(params)).fetchall()
|
||||
running = conn.execute(running_sql, tuple(params)).fetchall()
|
||||
return {"pending": [dict(r) for r in pending], "running": [dict(r) for r in running]}
|
||||
|
||||
def count_running(self) -> int:
|
||||
|
||||
@ -37,6 +37,12 @@ class Scheduler:
|
||||
def __post_init__(self) -> None:
|
||||
self.tool = RayJobTool(self.ray_cfg)
|
||||
|
||||
def _job_dir_for_task(self, *, user_id: str | None, ray_submission_id: str) -> str:
|
||||
root = self.ray_cfg.shared_root.rstrip("/")
|
||||
if user_id:
|
||||
return f"{root}/users/{user_id}/jobs/{ray_submission_id}"
|
||||
return f"{root}/jobs/{ray_submission_id}"
|
||||
|
||||
def _resources_sufficient(self, *, nnodes: int, n_gpus_per_node: int) -> bool:
|
||||
avail = get_cluster_available()
|
||||
required = float(nnodes * n_gpus_per_node)
|
||||
@ -51,10 +57,13 @@ class Scheduler:
|
||||
def _submit_one(self, task_row: dict[str, Any]) -> None:
|
||||
task_id = str(task_row["task_id"])
|
||||
jobspec_yaml = str(task_row["jobspec_yaml"])
|
||||
user_id = task_row.get("user_id")
|
||||
user_id_s = str(user_id) if user_id not in (None, "") else None
|
||||
|
||||
spec = self._parse_jobspec(jobspec_yaml)
|
||||
attempt_no = int(task_row.get("latest_attempt_no", 0)) + 1
|
||||
ray_sid = attempt_submission_id(task_id, attempt_no)
|
||||
job_dir = self._job_dir_for_task(user_id=user_id_s, ray_submission_id=ray_sid)
|
||||
|
||||
# Record attempt first so that we can surface it even if submit crashes.
|
||||
self.db.create_attempt(task_id=task_id, attempt_no=attempt_no, ray_submission_id=ray_sid)
|
||||
@ -66,7 +75,7 @@ class Scheduler:
|
||||
spec2 = JobSpec.from_dict(d)
|
||||
|
||||
try:
|
||||
submitted = self.tool.submit(spec2, no_wait=True)
|
||||
submitted = self.tool.submit(spec2, no_wait=True, job_dir=job_dir)
|
||||
# submitted should equal ray_sid; keep as source of truth.
|
||||
self.db.update_attempt(task_id=task_id, attempt_no=attempt_no, ray_status="SUBMITTED")
|
||||
self.db.set_task_state(task_id=task_id, state="SUBMITTED")
|
||||
|
||||
@ -12,7 +12,7 @@ def _write_config(tmp_path: Path) -> Path:
|
||||
cfg = {
|
||||
"ray": {
|
||||
"address": "http://127.0.0.1:8265",
|
||||
"shared_root": str(tmp_path),
|
||||
"shared_root": "/private",
|
||||
"entrypoint_resources": {"worker_node": 1},
|
||||
"runtime_env": {"env_vars": {}},
|
||||
},
|
||||
@ -54,7 +54,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
|
||||
|
||||
cfg_path = _write_config(tmp_path)
|
||||
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
|
||||
monkeypatch.setattr(app_mod, "new_task_id", lambda workload: "tid1")
|
||||
monkeypatch.setattr(app_mod, "new_task_id", lambda workload, **kwargs: "tid1")
|
||||
|
||||
class _Tool:
|
||||
def __init__(self):
|
||||
@ -82,7 +82,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
|
||||
r = c.post(
|
||||
"/api/v2/tasks",
|
||||
headers=headers,
|
||||
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
|
||||
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["task_id"] == "tid1"
|
||||
@ -111,7 +111,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
|
||||
db.create_task(
|
||||
task_id="tid2",
|
||||
workload="ppo",
|
||||
jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
|
||||
jobspec_yaml="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
@ -163,4 +163,3 @@ def test_submit_rejects_invalid_jobspec(tmp_path: Path, monkeypatch):
|
||||
with TestClient(app) as c:
|
||||
r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="workload: nope\n")
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
50
src/mvp/py/tests/test_discovery.py
Normal file
50
src/mvp/py/tests/test_discovery.py
Normal file
@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from argus.ray.discovery import build_head_record, load_head_record
|
||||
|
||||
|
||||
def test_load_head_record_rejects_missing(tmp_path: Path):
|
||||
assert load_head_record(str(tmp_path / "nope.json")) is None
|
||||
|
||||
|
||||
def test_load_head_record_rejects_expired(tmp_path: Path):
|
||||
p = tmp_path / "head.json"
|
||||
now = datetime.now(timezone.utc).replace(microsecond=0)
|
||||
rec = build_head_record(cluster_name="c", head_ip="1.2.3.4", ttl_s=1, now=now)
|
||||
p.write_text(json.dumps(rec.to_dict()), encoding="utf-8")
|
||||
assert load_head_record(str(p), now=now + timedelta(seconds=2)) is None
|
||||
|
||||
|
||||
def test_load_head_record_accepts_fresh(tmp_path: Path):
|
||||
p = tmp_path / "head.json"
|
||||
now = datetime.now(timezone.utc).replace(microsecond=0)
|
||||
rec = build_head_record(cluster_name="c", head_ip="1.2.3.4", ttl_s=60, now=now)
|
||||
p.write_text(json.dumps(rec.to_dict()), encoding="utf-8")
|
||||
got = load_head_record(str(p), now=now + timedelta(seconds=1))
|
||||
assert got is not None
|
||||
assert got.head_addr() == "1.2.3.4:6379"
|
||||
|
||||
|
||||
def test_head_publisher_main_once_writes_file(tmp_path: Path):
|
||||
from argus.ray import head_publisher as mod
|
||||
|
||||
p = tmp_path / "head.json"
|
||||
rc = mod.main(
|
||||
[
|
||||
"--cluster-name",
|
||||
"c",
|
||||
"--head-ip-file",
|
||||
str(p),
|
||||
"--head-ip",
|
||||
"9.9.9.9",
|
||||
"--ttl-s",
|
||||
"60",
|
||||
"--once",
|
||||
]
|
||||
)
|
||||
assert rc == 0
|
||||
assert "9.9.9.9" in p.read_text(encoding="utf-8")
|
||||
@ -20,9 +20,27 @@ def test_new_task_id_is_deterministic_with_patches(monkeypatch):
|
||||
assert ids.new_task_id("ppo") == "mvp2-ppo-20250101-010203-abcd"
|
||||
|
||||
|
||||
def test_new_task_id_includes_user_when_provided(monkeypatch):
|
||||
import argus.core.ids as ids
|
||||
|
||||
class _FakeDatetime:
|
||||
@staticmethod
|
||||
def now():
|
||||
class _DT:
|
||||
def strftime(self, fmt: str) -> str:
|
||||
assert fmt == "%Y%m%d-%H%M%S"
|
||||
return "20250101-010203"
|
||||
|
||||
return _DT()
|
||||
|
||||
monkeypatch.setattr(ids, "datetime", _FakeDatetime)
|
||||
monkeypatch.setattr(ids.secrets, "token_hex", lambda n: "abcd")
|
||||
|
||||
assert ids.new_task_id("ppo", user_id="Alice_01") == "mvp2-alice_01-ppo-20250101-010203-abcd"
|
||||
|
||||
|
||||
def test_attempt_submission_id_format():
|
||||
from argus.core.ids import attempt_submission_id
|
||||
|
||||
assert attempt_submission_id("t", 1) == "t--a01"
|
||||
assert attempt_submission_id("t", 12) == "t--a12"
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ def _mk_cfg(tmp_path: Path) -> tuple[RayConfig, V2Config]:
|
||||
root = {
|
||||
"ray": {
|
||||
"address": "http://127.0.0.1:8265",
|
||||
"shared_root": str(tmp_path),
|
||||
"shared_root": "/private",
|
||||
"entrypoint_resources": {"worker_node": 1},
|
||||
"runtime_env": {"env_vars": {}},
|
||||
},
|
||||
@ -32,10 +32,11 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
|
||||
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
||||
db = Db(v2_cfg.sqlite.db_path)
|
||||
db.init()
|
||||
db.create_task(
|
||||
db.create_task_v25(
|
||||
task_id="t1",
|
||||
user_id="alice",
|
||||
workload="ppo",
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
@ -50,9 +51,11 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
|
||||
class _Tool:
|
||||
def __init__(self, cfg):
|
||||
self.submitted = []
|
||||
self.job_dirs = []
|
||||
|
||||
def submit(self, spec, no_wait: bool):
|
||||
def submit(self, spec, no_wait: bool, job_dir: str | None = None):
|
||||
self.submitted.append(spec.submission_id)
|
||||
self.job_dirs.append(job_dir)
|
||||
return str(spec.submission_id)
|
||||
|
||||
def status(self, submission_id: str):
|
||||
@ -71,6 +74,7 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
|
||||
attempts = db.list_attempts("t1")
|
||||
assert len(attempts) == 1
|
||||
assert attempts[0]["ray_submission_id"] == "t1--a01"
|
||||
assert s.tool.job_dirs[-1] == "/private/users/alice/jobs/t1--a01"
|
||||
|
||||
|
||||
def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
|
||||
@ -79,10 +83,11 @@ def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
|
||||
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
||||
db = Db(v2_cfg.sqlite.db_path)
|
||||
db.init()
|
||||
db.create_task(
|
||||
db.create_task_v25(
|
||||
task_id="t1",
|
||||
user_id="alice",
|
||||
workload="ppo",
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
@ -108,10 +113,11 @@ def test_sync_failed_insufficient_resources(monkeypatch, tmp_path: Path):
|
||||
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
||||
db = Db(v2_cfg.sqlite.db_path)
|
||||
db.init()
|
||||
db.create_task(
|
||||
db.create_task_v25(
|
||||
task_id="t1",
|
||||
user_id="alice",
|
||||
workload="ppo",
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
@ -147,10 +153,11 @@ def test_sync_status_error_keeps_state(monkeypatch, tmp_path: Path):
|
||||
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
||||
db = Db(v2_cfg.sqlite.db_path)
|
||||
db.init()
|
||||
db.create_task(
|
||||
db.create_task_v25(
|
||||
task_id="t1",
|
||||
user_id="alice",
|
||||
workload="ppo",
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}),
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
|
||||
179
src/mvp/py/tests/test_users.py
Normal file
179
src/mvp/py/tests/test_users.py
Normal file
@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def _write_config(tmp_path: Path) -> Path:
|
||||
cfg = {
|
||||
"ray": {
|
||||
"address": "http://127.0.0.1:8265",
|
||||
"shared_root": "/private",
|
||||
"entrypoint_resources": {"worker_node": 1},
|
||||
"runtime_env": {"env_vars": {}},
|
||||
},
|
||||
"service": {
|
||||
"api": {"host": "127.0.0.1", "port": 0},
|
||||
"auth": {"token_env": "MVP_INTERNAL_TOKEN"},
|
||||
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
|
||||
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
|
||||
},
|
||||
}
|
||||
p = tmp_path / "cfg.yaml"
|
||||
p.write_text(yaml.safe_dump(cfg), encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
def test_db_user_token_lifecycle(tmp_path: Path):
|
||||
from argus.service.db import Db
|
||||
|
||||
db = Db(str(tmp_path / "mvp.sqlite3"))
|
||||
db.init()
|
||||
db.create_user(user_id="alice", display_name="Alice")
|
||||
tok = db.issue_token(user_id="alice")
|
||||
info = db.resolve_token(tok)
|
||||
assert info and info["user_id"] == "alice"
|
||||
|
||||
db.disable_user(user_id="alice")
|
||||
info2 = db.resolve_token(tok)
|
||||
assert info2 and info2["state"] == "DISABLED"
|
||||
|
||||
|
||||
def test_admin_create_user_issue_token_and_disabled_rejected(tmp_path: Path, monkeypatch):
|
||||
from argus.service import app as app_mod
|
||||
|
||||
cfg_path = _write_config(tmp_path)
|
||||
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
|
||||
|
||||
class _Scheduler:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool = object()
|
||||
|
||||
def run_forever(self, stop_flag):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
||||
app = app_mod.create_app(str(cfg_path))
|
||||
|
||||
admin_headers = {"authorization": "Bearer adm1"}
|
||||
with TestClient(app) as c:
|
||||
r1 = c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice", "display_name": "Alice"})
|
||||
assert r1.status_code == 200
|
||||
assert r1.json()["user_id"] == "alice"
|
||||
|
||||
r2 = c.post("/api/v2/users/alice/tokens", headers=admin_headers)
|
||||
assert r2.status_code == 200
|
||||
token = r2.json()["token"]
|
||||
assert token
|
||||
|
||||
# non-admin token can access regular endpoints
|
||||
r3 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"})
|
||||
assert r3.status_code == 200
|
||||
|
||||
r4 = c.post("/api/v2/users/alice:disable", headers=admin_headers)
|
||||
assert r4.status_code == 200
|
||||
|
||||
r5 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"})
|
||||
assert r5.status_code == 403
|
||||
|
||||
|
||||
def test_tasks_are_isolated_by_user(tmp_path: Path, monkeypatch):
|
||||
from argus.service import app as app_mod
|
||||
|
||||
cfg_path = _write_config(tmp_path)
|
||||
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
|
||||
|
||||
# Deterministic task ids
|
||||
counter = {"n": 0}
|
||||
|
||||
def _new_id(workload: str, **kwargs) -> str:
|
||||
counter["n"] += 1
|
||||
return f"{workload}_t{counter['n']}"
|
||||
|
||||
monkeypatch.setattr(app_mod, "new_task_id", _new_id)
|
||||
|
||||
class _Scheduler:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool = object()
|
||||
|
||||
def run_forever(self, stop_flag):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
||||
app = app_mod.create_app(str(cfg_path))
|
||||
|
||||
admin_headers = {"authorization": "Bearer adm1"}
|
||||
with TestClient(app) as c:
|
||||
# Create users and tokens
|
||||
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
|
||||
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "bob"}).status_code == 200
|
||||
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
|
||||
bob_tok = c.post("/api/v2/users/bob/tokens", headers=admin_headers).json()["token"]
|
||||
|
||||
alice_headers = {"authorization": f"Bearer {alice_tok}"}
|
||||
bob_headers = {"authorization": f"Bearer {bob_tok}"}
|
||||
|
||||
# Each user submits one task.
|
||||
r1 = c.post(
|
||||
"/api/v2/tasks",
|
||||
headers=alice_headers,
|
||||
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
|
||||
)
|
||||
assert r1.status_code == 200
|
||||
alice_tid = r1.json()["task_id"]
|
||||
|
||||
r2 = c.post(
|
||||
"/api/v2/tasks",
|
||||
headers=bob_headers,
|
||||
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
|
||||
)
|
||||
assert r2.status_code == 200
|
||||
bob_tid = r2.json()["task_id"]
|
||||
|
||||
# Queue is scoped to user.
|
||||
qa = c.get("/api/v2/queue", headers=alice_headers).json()
|
||||
qb = c.get("/api/v2/queue", headers=bob_headers).json()
|
||||
assert {t["task_id"] for t in qa["pending"]} == {alice_tid}
|
||||
assert {t["task_id"] for t in qb["pending"]} == {bob_tid}
|
||||
|
||||
# Cross-user access returns 404 (no existence leak).
|
||||
assert c.get(f"/api/v2/tasks/{bob_tid}", headers=alice_headers).status_code == 404
|
||||
assert c.post(f"/api/v2/tasks/{bob_tid}:cancel", headers=alice_headers).status_code == 404
|
||||
assert c.get(f"/api/v2/tasks/{bob_tid}/attempts", headers=alice_headers).status_code == 404
|
||||
|
||||
# Admin can see global queue.
|
||||
qadm = c.get("/api/v2/queue", headers=admin_headers).json()
|
||||
assert {t["task_id"] for t in qadm["pending"]} == {alice_tid, bob_tid}
|
||||
|
||||
|
||||
def test_submit_rejects_non_common_inputs(tmp_path: Path, monkeypatch):
|
||||
from argus.service import app as app_mod
|
||||
|
||||
cfg_path = _write_config(tmp_path)
|
||||
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
|
||||
|
||||
class _Scheduler:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool = object()
|
||||
|
||||
def run_forever(self, stop_flag):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
||||
app = app_mod.create_app(str(cfg_path))
|
||||
|
||||
admin_headers = {"authorization": "Bearer adm1"}
|
||||
with TestClient(app) as c:
|
||||
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
|
||||
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
|
||||
alice_headers = {"authorization": f"Bearer {alice_tok}"}
|
||||
|
||||
r = c.post(
|
||||
"/api/v2/tasks",
|
||||
headers=alice_headers,
|
||||
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: /private/common/datasets/t\n",
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert "code_path must start with /private/common/" in r.text
|
||||
58
src/mvp/py/tests/test_worker_watchdog.py
Normal file
58
src/mvp/py/tests/test_worker_watchdog.py
Normal file
@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from argus.ray.discovery import build_head_record, write_head_record_atomic
|
||||
from argus.ray.worker_watchdog import Watchdog
|
||||
|
||||
|
||||
def test_watchdog_restarts_on_first_seen_and_on_change(tmp_path: Path):
|
||||
head_file = tmp_path / "head.json"
|
||||
calls: list[list[str]] = []
|
||||
|
||||
def runner(argv: list[str]) -> int:
|
||||
calls.append(argv)
|
||||
return 0
|
||||
|
||||
wd = Watchdog(
|
||||
head_ip_file=str(head_file),
|
||||
node_ip="10.0.0.2",
|
||||
resources_json=json.dumps({"worker_node": 100}),
|
||||
poll_s=1,
|
||||
runner=runner,
|
||||
)
|
||||
|
||||
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1"))
|
||||
assert wd.tick_once() is not None
|
||||
assert any("ray" in c[0] and c[1] == "start" for c in calls)
|
||||
|
||||
calls.clear()
|
||||
# Same address -> no restart
|
||||
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1"))
|
||||
wd.tick_once()
|
||||
assert calls == []
|
||||
|
||||
# Address change -> restart
|
||||
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="2.2.2.2"))
|
||||
wd.tick_once()
|
||||
assert any(c[1] == "stop" for c in calls)
|
||||
assert any(c[1] == "start" for c in calls)
|
||||
|
||||
|
||||
def test_watchdog_main_once_invokes_runner(monkeypatch, tmp_path: Path):
|
||||
from argus.ray import worker_watchdog as mod
|
||||
|
||||
head_file = tmp_path / "head.json"
|
||||
write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1"))
|
||||
|
||||
calls: list[list[str]] = []
|
||||
|
||||
def runner(argv: list[str]) -> int:
|
||||
calls.append(argv)
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(mod, "_default_runner", runner)
|
||||
rc = mod.main(["--head-ip-file", str(head_file), "--node-ip", "10.0.0.2", "--resources-kv", "worker_node=100", "--once"])
|
||||
assert rc == 0
|
||||
assert any(c[1] == "start" for c in calls)
|
||||
24
src/mvp/scripts/04_cleanup_v2_legacy.sh
Executable file
24
src/mvp/scripts/04_cleanup_v2_legacy.sh
Executable file
@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
echo "[host] cleanup v2 legacy containers/processes (best-effort)"
|
||||
|
||||
# Known historical container names from v1.1/v2 era
|
||||
LEGACY=(
|
||||
mvp11-ray-head
|
||||
mvp11-ray-worker-0
|
||||
mvp11-ray-worker-1
|
||||
mvp2-ray-head
|
||||
mvp2-ray-worker-0
|
||||
mvp2-ray-worker-1
|
||||
)
|
||||
|
||||
for c in "${LEGACY[@]}"; do
|
||||
if docker ps -a --format '{{.Names}}' | grep -qx "${c}"; then
|
||||
echo "[host] removing legacy container: ${c}"
|
||||
docker rm -f "${c}" >/dev/null 2>&1 || true
|
||||
fi
|
||||
done
|
||||
|
||||
echo "[host] legacy v2 cleanup done"
|
||||
|
||||
44
src/mvp/scripts/22_start_head_discovery.sh
Executable file
44
src/mvp/scripts/22_start_head_discovery.sh
Executable file
@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
# shellcheck source=lib.sh
|
||||
source "${SCRIPT_DIR}/lib.sh"
|
||||
|
||||
CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}"
|
||||
HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}"
|
||||
TTL_S="${TTL_S:-60}"
|
||||
REFRESH_S="${REFRESH_S:-10}"
|
||||
|
||||
# PID file must be container-local to avoid conflicts if /private is shared across containers.
|
||||
PID_PATH="${PID_PATH:-/tmp/argus_head_publisher.pid}"
|
||||
LOG_PATH="${LOG_PATH:-${SHARED_ROOT}/common/logs/argus_head_publisher.log}"
|
||||
|
||||
HEAD_IP="$(container_ip "${HEAD_CONTAINER}")"
|
||||
|
||||
echo "[head] start discovery publisher (supervised): ${HEAD_IP_FILE} (head_ip=${HEAD_IP})"
|
||||
|
||||
# stop existing (best-effort)
|
||||
dexec "${HEAD_CONTAINER}" bash -lc "if test -f '${PID_PATH}'; then pid=\$(cat '${PID_PATH}'); if kill -0 \"\${pid}\" >/dev/null 2>&1; then kill \"\${pid}\" || true; fi; rm -f '${PID_PATH}'; fi"
|
||||
|
||||
dexec "${HEAD_CONTAINER}" bash -lc "mkdir -p \"$(dirname "${LOG_PATH}")\" \"$(dirname "${HEAD_IP_FILE}")\""
|
||||
|
||||
# Supervisor loop: restart publisher if it exits.
|
||||
docker exec -d "${HEAD_CONTAINER}" bash -lc "
|
||||
nohup bash -lc '
|
||||
export PYTHONPATH=/workspace/mvp/py
|
||||
while true; do
|
||||
python3 -m argus.ray.head_publisher \
|
||||
--cluster-name \"${CLUSTER_NAME}\" \
|
||||
--head-ip-file \"${HEAD_IP_FILE}\" \
|
||||
--head-ip \"${HEAD_IP}\" \
|
||||
--ttl-s \"${TTL_S}\" \
|
||||
--refresh-s \"${REFRESH_S}\"
|
||||
sleep 2
|
||||
done
|
||||
' >>'${LOG_PATH}' 2>&1 & echo \$! >'${PID_PATH}'
|
||||
"
|
||||
|
||||
echo "[head] publisher pid stored in ${PID_PATH} (container-local)"
|
||||
echo "[head] logs: ${LOG_PATH}"
|
||||
|
||||
52
src/mvp/scripts/23_start_workers_stateless.sh
Executable file
52
src/mvp/scripts/23_start_workers_stateless.sh
Executable file
@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
# shellcheck source=lib.sh
|
||||
source "${SCRIPT_DIR}/lib.sh"
|
||||
|
||||
CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}"
|
||||
HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}"
|
||||
POLL_S="${POLL_S:-5}"
|
||||
WORKER_NODE_RESOURCE="${WORKER_NODE_RESOURCE:-100}"
|
||||
|
||||
start_one() {
|
||||
local worker="$1"
|
||||
local ip
|
||||
ip="$(container_ip "${worker}")"
|
||||
|
||||
local pid_path="/tmp/argus_worker_watchdog.pid"
|
||||
local log_path="${SHARED_ROOT}/common/logs/argus_worker_watchdog.${worker}.log"
|
||||
|
||||
echo "[${worker}] start stateless watchdog (supervised): head_file=${HEAD_IP_FILE} node_ip=${ip}"
|
||||
|
||||
# stop existing watchdog (best-effort)
|
||||
dexec "${worker}" bash -lc "if test -f '${pid_path}'; then pid=\$(cat '${pid_path}'); if kill -0 \"\${pid}\" >/dev/null 2>&1; then kill \"\${pid}\" || true; fi; rm -f '${pid_path}'; fi"
|
||||
|
||||
# stop any legacy ray process to avoid split-brain
|
||||
dexec "${worker}" bash -lc "ray stop --force || true"
|
||||
dexec "${worker}" bash -lc "mkdir -p \"$(dirname "${log_path}")\""
|
||||
|
||||
docker exec -d "${worker}" bash -lc "
|
||||
nohup bash -lc '
|
||||
export PYTHONPATH=/workspace/mvp/py
|
||||
while true; do
|
||||
python3 -m argus.ray.worker_watchdog \
|
||||
--head-ip-file \"${HEAD_IP_FILE}\" \
|
||||
--node-ip \"${ip}\" \
|
||||
--resources-kv \"worker_node=${WORKER_NODE_RESOURCE}\" \
|
||||
--poll-s \"${POLL_S}\"
|
||||
sleep 2
|
||||
done
|
||||
' >>'${log_path}' 2>&1 & echo \$! >'${pid_path}'
|
||||
"
|
||||
|
||||
echo "[${worker}] watchdog pid stored in ${pid_path} (container-local)"
|
||||
echo "[${worker}] logs: ${log_path}"
|
||||
}
|
||||
|
||||
start_one "${WORKER0_CONTAINER}"
|
||||
start_one "${WORKER1_CONTAINER}"
|
||||
|
||||
echo "[head] ray status"
|
||||
dexec "${HEAD_CONTAINER}" bash -lc "ray status || true"
|
||||
17
src/mvp/scripts/24_status_stateless.sh
Executable file
17
src/mvp/scripts/24_status_stateless.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
# shellcheck source=lib.sh
|
||||
source "${SCRIPT_DIR}/lib.sh"
|
||||
|
||||
CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}"
|
||||
HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}"
|
||||
|
||||
echo "[head] head.json (best-effort): ${HEAD_IP_FILE}"
|
||||
dexec "${HEAD_CONTAINER}" bash -lc "if test -f '${HEAD_IP_FILE}'; then cat '${HEAD_IP_FILE}'; else echo 'missing'; fi"
|
||||
|
||||
echo
|
||||
echo "[head] ray status"
|
||||
dexec "${HEAD_CONTAINER}" bash -lc "ray status || true"
|
||||
|
||||
@ -11,10 +11,30 @@ PPO_DATA_DIR="${SHARED_ROOT}/datasets/gsm8k"
|
||||
SFT_DATA_DIR="${SHARED_ROOT}/datasets/gsm8k_sft"
|
||||
|
||||
CODE_SNAPSHOT_DIR="${SHARED_ROOT}/common/code/verl/verl_repo"
|
||||
COMMON_DIR="${SHARED_ROOT}/common"
|
||||
COMMON_DATASETS_DIR="${SHARED_ROOT}/common/datasets"
|
||||
COMMON_HF_DIR="${SHARED_ROOT}/common/hf"
|
||||
|
||||
echo "[head] ensure dataset dirs exist"
|
||||
dexec "${HEAD_CONTAINER}" bash -lc "mkdir -p '${PPO_DATA_DIR}' '${SFT_DATA_DIR}'"
|
||||
|
||||
echo "[head] ensure v2.5 common links (idempotent)"
|
||||
# In existing deployments, /private/common/{datasets,hf} may already exist as directories (not symlinks).
|
||||
# For v2.5 taskspecs, we only require:
|
||||
# /private/common/datasets/gsm8k -> /private/datasets/gsm8k
|
||||
# /private/common/datasets/gsm8k_sft -> /private/datasets/gsm8k_sft
|
||||
# /private/common/hf/{hub,transformers} -> /private/hf/{hub,transformers} (best-effort)
|
||||
dexec "${HEAD_CONTAINER}" bash -lc "
|
||||
set -euo pipefail
|
||||
mkdir -p '${COMMON_DIR}' '${COMMON_DATASETS_DIR}' '${COMMON_HF_DIR}'
|
||||
ln -sfn '${SHARED_ROOT}/datasets/gsm8k' '${COMMON_DATASETS_DIR}/gsm8k'
|
||||
ln -sfn '${SHARED_ROOT}/datasets/gsm8k_sft' '${COMMON_DATASETS_DIR}/gsm8k_sft'
|
||||
mkdir -p '${SHARED_ROOT}/hf/hub' '${SHARED_ROOT}/hf/transformers'
|
||||
ln -sfn '${SHARED_ROOT}/hf/hub' '${COMMON_HF_DIR}/hub'
|
||||
ln -sfn '${SHARED_ROOT}/hf/transformers' '${COMMON_HF_DIR}/transformers'
|
||||
echo 'common_links_ok'
|
||||
"
|
||||
|
||||
echo "[head] prepare PPO dataset (gsm8k RL parquet) -> ${PPO_DATA_DIR}"
|
||||
dexec "${HEAD_CONTAINER}" bash -lc "if [[ -f '${PPO_DATA_DIR}/train.parquet' && -f '${PPO_DATA_DIR}/test.parquet' ]]; then echo 'ppo_dataset_exists: skip'; else python3 /workspace/verl/examples/data_preprocess/gsm8k.py --local_save_dir '${PPO_DATA_DIR}'; fi"
|
||||
|
||||
|
||||
137
src/mvp/scripts/run_all_v25_api.sh
Executable file
137
src/mvp/scripts/run_all_v25_api.sh
Executable file
@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
# shellcheck source=lib.sh
|
||||
source "${SCRIPT_DIR}/lib.sh"
|
||||
|
||||
# E2E v2.5:
|
||||
# - Clean legacy env
|
||||
# - Start containers
|
||||
# - Start ray head + discovery publisher
|
||||
# - Start stateless worker watchdogs (auto-connect)
|
||||
# - Prepare data/model/code (reuses existing downloads)
|
||||
# - Start API
|
||||
# - Create user + issue token
|
||||
# - Submit PPO/GRPO/SFT via API and wait
|
||||
|
||||
API_ADDR="${API_ADDR:-http://127.0.0.1:8080}"
|
||||
ADMIN_TOKEN="${MVP_INTERNAL_TOKEN:-}"
|
||||
USER_ID="${USER_ID:-alice}"
|
||||
RESET_DB="${RESET_DB:-1}"
|
||||
|
||||
if [[ -z "${ADMIN_TOKEN}" ]]; then
|
||||
echo "ERROR: MVP_INTERNAL_TOKEN must be set in host env (admin token)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
api_curl_admin() {
|
||||
curl -sS -H "Authorization: Bearer ${ADMIN_TOKEN}" "$@"
|
||||
}
|
||||
|
||||
api_wait_ready() {
|
||||
local tries="${1:-60}"
|
||||
for i in $(seq 1 "${tries}"); do
|
||||
if curl -sS -m 2 "${API_ADDR}/docs" >/dev/null 2>&1; then
|
||||
echo "[host] api_ready: ${API_ADDR}"
|
||||
return 0
|
||||
fi
|
||||
echo "[host] waiting api... (${i}/${tries})"
|
||||
sleep 2
|
||||
done
|
||||
echo "ERROR: api not ready: ${API_ADDR}" >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
submit_taskspec() {
|
||||
local token="$1"
|
||||
local taskspec_path="$2"
|
||||
echo "[host] submit via API (user=${USER_ID}): ${taskspec_path}" >&2
|
||||
local resp
|
||||
resp="$(curl -sS -H "Authorization: Bearer ${token}" -H "Content-Type: application/yaml" --data-binary @"${taskspec_path}" "${API_ADDR}/api/v2/tasks")"
|
||||
echo "[host] submit_resp: ${resp}" >&2
|
||||
printf '%s' "${resp}" | python3 -c 'import sys,json; print(json.load(sys.stdin)["task_id"])'
|
||||
}
|
||||
|
||||
wait_task() {
|
||||
local token="$1"
|
||||
local task_id="$2"
|
||||
while true; do
|
||||
local body state
|
||||
body="$(curl -sS -H "Authorization: Bearer ${token}" "${API_ADDR}/api/v2/tasks/${task_id}")"
|
||||
state="$(printf '%s' "${body}" | python3 -c 'import sys,json; print(json.load(sys.stdin)["state"])')"
|
||||
echo "[host] task ${task_id}: ${state}"
|
||||
|
||||
if [[ "${state}" == "SUCCEEDED" ]]; then
|
||||
return 0
|
||||
fi
|
||||
if [[ "${state}" == "FAILED" || "${state}" == "CANCELED" ]]; then
|
||||
echo "[host] terminal=${state}; tail logs (best-effort):" >&2
|
||||
curl -sS -H "Authorization: Bearer ${token}" "${API_ADDR}/api/v2/tasks/${task_id}/logs?tail=200" >&2 || true
|
||||
return 1
|
||||
fi
|
||||
sleep 10
|
||||
done
|
||||
}
|
||||
|
||||
echo "[host] ===== run_all_v25_api.sh begin ====="
|
||||
|
||||
"${SCRIPT_DIR}/00_prereq_check.sh"
|
||||
"${SCRIPT_DIR}/03_cleanup_v1_legacy.sh"
|
||||
"${SCRIPT_DIR}/04_cleanup_v2_legacy.sh"
|
||||
|
||||
echo "[host] bring down existing containers (best-effort)"
|
||||
"${SCRIPT_DIR}/02_down.sh" || true
|
||||
|
||||
echo "[host] (re)create containers"
|
||||
"${SCRIPT_DIR}/01_up.sh"
|
||||
|
||||
echo "[host] restart ray head (no compute on head)"
|
||||
"${SCRIPT_DIR}/20_start_head.sh"
|
||||
|
||||
echo "[host] start head discovery publisher"
|
||||
"${SCRIPT_DIR}/22_start_head_discovery.sh"
|
||||
|
||||
echo "[host] start stateless workers (watchdog auto-connect)"
|
||||
"${SCRIPT_DIR}/23_start_workers_stateless.sh"
|
||||
|
||||
echo "[host] prepare data/model/code snapshot (idempotent; reuse cache)"
|
||||
"${SCRIPT_DIR}/30_prepare_data_and_model.sh"
|
||||
|
||||
echo "[host] install api deps in head container"
|
||||
"${SCRIPT_DIR}/12_install_api_deps.sh"
|
||||
|
||||
echo "[host] stop api (best-effort)"
|
||||
"${SCRIPT_DIR}/61_stop_api.sh" || true
|
||||
|
||||
if [[ "${RESET_DB}" == "1" ]]; then
|
||||
echo "[host] reset api sqlite db in container (best-effort)"
|
||||
docker exec -i "${HEAD_CONTAINER}" bash -lc "rm -f /private/common/db/mvp.sqlite3 /private/common/db/mvp.sqlite3-wal /private/common/db/mvp.sqlite3-shm || true"
|
||||
else
|
||||
echo "[host] keep existing api sqlite db (RESET_DB=${RESET_DB})"
|
||||
fi
|
||||
|
||||
echo "[host] start api (admin token via env)"
|
||||
MVP_INTERNAL_TOKEN="${ADMIN_TOKEN}" "${SCRIPT_DIR}/60_start_api.sh"
|
||||
api_wait_ready 60
|
||||
|
||||
echo "[host] ensure user exists + issue token"
|
||||
api_curl_admin -H "Content-Type: application/json" -d "{\"user_id\":\"${USER_ID}\"}" "${API_ADDR}/api/v2/users" >/dev/null 2>&1 || true
|
||||
USER_TOKEN="$(api_curl_admin -X POST "${API_ADDR}/api/v2/users/${USER_ID}/tokens" | python3 -c 'import sys,json; print(json.load(sys.stdin)["token"])')"
|
||||
echo "[host] user_token_issued: user=${USER_ID}"
|
||||
|
||||
PPO_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/ppo.yaml")"
|
||||
GRPO_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/grpo.yaml")"
|
||||
SFT_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/sft.yaml")"
|
||||
|
||||
echo "[host] submitted task ids:"
|
||||
echo " ppo=${PPO_TASK_ID}"
|
||||
echo " grpo=${GRPO_TASK_ID}"
|
||||
echo " sft=${SFT_TASK_ID}"
|
||||
|
||||
echo "[host] wait for tasks (in submission order)"
|
||||
wait_task "${USER_TOKEN}" "${PPO_TASK_ID}"
|
||||
wait_task "${USER_TOKEN}" "${GRPO_TASK_ID}"
|
||||
wait_task "${USER_TOKEN}" "${SFT_TASK_ID}"
|
||||
|
||||
echo "[host] ===== run_all_v25_api.sh done ====="
|
||||
156
src/mvp/scripts/run_e2e_v25_cases.sh
Executable file
156
src/mvp/scripts/run_e2e_v25_cases.sh
Executable file
@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
# shellcheck source=lib.sh
|
||||
source "${SCRIPT_DIR}/lib.sh"
|
||||
|
||||
API_ADDR="${API_ADDR:-http://127.0.0.1:8080}"
|
||||
ADMIN_TOKEN="${MVP_INTERNAL_TOKEN:-}"
|
||||
|
||||
if [[ -z "${ADMIN_TOKEN}" ]]; then
|
||||
echo "ERROR: MVP_INTERNAL_TOKEN must be set (admin token)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
require_cmd jq
|
||||
require_cmd python3
|
||||
|
||||
api_admin() { curl -sS -H "Authorization: Bearer ${ADMIN_TOKEN}" "$@"; }
|
||||
api_user() { local tok="$1"; shift; curl -sS -H "Authorization: Bearer ${tok}" "$@"; }
|
||||
|
||||
wait_state() {
|
||||
local tok="$1"
|
||||
local task_id="$2"
|
||||
local want="$3"
|
||||
local tries="${4:-120}"
|
||||
for i in $(seq 1 "${tries}"); do
|
||||
local st
|
||||
st="$(api_user "${tok}" "${API_ADDR}/api/v2/tasks/${task_id}" | jq -r .state)"
|
||||
echo "[case] ${task_id} state=${st} (want=${want})"
|
||||
if [[ "${st}" == "${want}" ]]; then
|
||||
return 0
|
||||
fi
|
||||
if [[ "${st}" == "FAILED" || "${st}" == "CANCELED" || "${st}" == "SUCCEEDED" ]]; then
|
||||
return 1
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
return 2
|
||||
}
|
||||
|
||||
issue_token() {
|
||||
local user_id="$1"
|
||||
api_admin -H "Content-Type: application/json" -d "{\"user_id\":\"${user_id}\"}" "${API_ADDR}/api/v2/users" >/dev/null 2>&1 || true
|
||||
api_admin -X POST "${API_ADDR}/api/v2/users/${user_id}/tokens" | jq -r .token
|
||||
}
|
||||
|
||||
submit_yaml() {
|
||||
local tok="$1"
|
||||
local yaml_path="$2"
|
||||
api_user "${tok}" -H "Content-Type: application/yaml" --data-binary @"${yaml_path}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id
|
||||
}
|
||||
|
||||
echo "[case] ===== v2.5 e2e cases ====="
|
||||
|
||||
echo "[case] HP-1: run_all_v25_api.sh (happy path)"
|
||||
RESET_DB="${RESET_DB:-1}"
|
||||
MVP_INTERNAL_TOKEN="${ADMIN_TOKEN}" RESET_DB="${RESET_DB}" "${SCRIPT_DIR}/run_all_v25_api.sh"
|
||||
|
||||
echo "[case] E-Auth-1: missing bearer token"
|
||||
code="$(curl -sS -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/queue" || true)"
|
||||
[[ "${code}" == "401" ]] || { echo "expected 401 got ${code}" >&2; exit 1; }
|
||||
|
||||
echo "[case] E-Auth-2: invalid token"
|
||||
code="$(curl -sS -o /dev/null -w '%{http_code}' -H 'Authorization: Bearer nope' "${API_ADDR}/api/v2/queue" || true)"
|
||||
[[ "${code}" == "401" ]] || { echo "expected 401 got ${code}" >&2; exit 1; }
|
||||
|
||||
echo "[case] setup users alice/bob"
|
||||
ALICE_TOKEN="$(issue_token alice)"
|
||||
BOB_TOKEN="$(issue_token bob)"
|
||||
|
||||
echo "[case] E-Isolation-1: cross-user task visibility (404)"
|
||||
TID="$(submit_yaml "${ALICE_TOKEN}" "${ROOT_DIR}/taskspecs/ppo.yaml")"
|
||||
code="$(api_user "${BOB_TOKEN}" -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/tasks/${TID}" || true)"
|
||||
[[ "${code}" == "404" ]] || { echo "expected 404 got ${code}" >&2; exit 1; }
|
||||
|
||||
# Keep the queue clean for subsequent cases (avoid consuming scheduler capacity).
|
||||
api_user "${ALICE_TOKEN}" -X POST "${API_ADDR}/api/v2/tasks/${TID}:cancel" >/dev/null 2>&1 || true
|
||||
wait_state "${ALICE_TOKEN}" "${TID}" "CANCELED" 30 || true
|
||||
|
||||
echo "[case] E-Input-1: reject non-common inputs"
|
||||
bad_yaml="$(mktemp)"
|
||||
cat >"${bad_yaml}" <<'YAML'
|
||||
workload: "ppo"
|
||||
submission_id: ""
|
||||
code_path: "/private/common/code/verl/verl_repo"
|
||||
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
train_file: "/private/datasets/gsm8k/train.parquet"
|
||||
val_file: "/private/common/datasets/gsm8k/test.parquet"
|
||||
nnodes: 2
|
||||
n_gpus_per_node: 4
|
||||
total_epochs: 1
|
||||
total_training_steps: 10
|
||||
save_freq: 10
|
||||
test_freq: -1
|
||||
YAML
|
||||
code="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${bad_yaml}" -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/tasks" || true)"
|
||||
rm -f "${bad_yaml}"
|
||||
[[ "${code}" == "400" ]] || { echo "expected 400 got ${code}" >&2; exit 1; }
|
||||
|
||||
echo "[case] B-Queue-1: pending_resources when demand > cluster capacity"
|
||||
big_yaml="$(mktemp)"
|
||||
cat >"${big_yaml}" <<'YAML'
|
||||
workload: "ppo"
|
||||
submission_id: ""
|
||||
code_path: "/private/common/code/verl/verl_repo"
|
||||
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
train_file: "/private/common/datasets/gsm8k/train.parquet"
|
||||
val_file: "/private/common/datasets/gsm8k/test.parquet"
|
||||
nnodes: 3
|
||||
n_gpus_per_node: 4
|
||||
total_epochs: 1
|
||||
total_training_steps: 10
|
||||
save_freq: 10
|
||||
test_freq: -1
|
||||
YAML
|
||||
BIG_TID="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${big_yaml}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id)"
|
||||
rm -f "${big_yaml}"
|
||||
wait_state "${ALICE_TOKEN}" "${BIG_TID}" "PENDING_RESOURCES" 60
|
||||
|
||||
echo "[case] B-Cancel-1: cancel a running task (best-effort)"
|
||||
long_yaml="$(mktemp)"
|
||||
cat >"${long_yaml}" <<'YAML'
|
||||
workload: "ppo"
|
||||
submission_id: ""
|
||||
code_path: "/private/common/code/verl/verl_repo"
|
||||
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
train_file: "/private/common/datasets/gsm8k/train.parquet"
|
||||
val_file: "/private/common/datasets/gsm8k/test.parquet"
|
||||
nnodes: 2
|
||||
n_gpus_per_node: 4
|
||||
total_epochs: 1
|
||||
total_training_steps: 200
|
||||
save_freq: 10
|
||||
test_freq: -1
|
||||
YAML
|
||||
LONG_TID="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${long_yaml}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id)"
|
||||
rm -f "${long_yaml}"
|
||||
|
||||
for i in $(seq 1 60); do
|
||||
st="$(api_user "${ALICE_TOKEN}" "${API_ADDR}/api/v2/tasks/${LONG_TID}" | jq -r .state)"
|
||||
echo "[case] wait running: ${LONG_TID} state=${st}"
|
||||
if [[ "${st}" == "RUNNING" ]]; then
|
||||
break
|
||||
fi
|
||||
if [[ "${st}" == "FAILED" || "${st}" == "SUCCEEDED" ]]; then
|
||||
break
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
|
||||
api_user "${ALICE_TOKEN}" -X POST "${API_ADDR}/api/v2/tasks/${LONG_TID}:cancel" | jq .
|
||||
st="$(api_admin "${API_ADDR}/api/v2/tasks/${LONG_TID}" | jq -r .state)"
|
||||
[[ "${st}" == "CANCELED" ]] || { echo "expected CANCELED got ${st}" >&2; exit 1; }
|
||||
|
||||
echo "[case] ===== all v2.5 e2e cases passed ====="
|
||||
@ -6,8 +6,8 @@ code_path: "/private/common/code/verl/verl_repo"
|
||||
|
||||
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
|
||||
train_file: "/private/datasets/gsm8k/train.parquet"
|
||||
val_file: "/private/datasets/gsm8k/test.parquet"
|
||||
train_file: "/private/common/datasets/gsm8k/train.parquet"
|
||||
val_file: "/private/common/datasets/gsm8k/test.parquet"
|
||||
|
||||
nnodes: 2
|
||||
n_gpus_per_node: 4
|
||||
@ -17,4 +17,3 @@ total_training_steps: 10
|
||||
|
||||
save_freq: 10
|
||||
test_freq: -1
|
||||
|
||||
|
||||
@ -8,8 +8,8 @@ code_path: "/private/common/code/verl/verl_repo"
|
||||
|
||||
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
|
||||
train_file: "/private/datasets/gsm8k/train.parquet"
|
||||
val_file: "/private/datasets/gsm8k/test.parquet"
|
||||
train_file: "/private/common/datasets/gsm8k/train.parquet"
|
||||
val_file: "/private/common/datasets/gsm8k/test.parquet"
|
||||
|
||||
nnodes: 2
|
||||
n_gpus_per_node: 4
|
||||
@ -19,4 +19,3 @@ total_training_steps: 10
|
||||
|
||||
save_freq: 10
|
||||
test_freq: -1
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ code_path: "/private/common/code/verl/verl_repo"
|
||||
|
||||
model_id: "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
|
||||
train_file: "/private/datasets/gsm8k_sft/train.parquet"
|
||||
train_file: "/private/common/datasets/gsm8k_sft/train.parquet"
|
||||
val_file: null
|
||||
|
||||
nnodes: 2
|
||||
@ -19,4 +19,3 @@ save_freq: 10
|
||||
|
||||
# SFT driver 默认不分配 GPU(ray job entrypoint 不指定 entrypoint_num_gpus),因此 driver 侧不要依赖 CUDA
|
||||
trainer_device: "cpu"
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user