diff --git a/specs/mvp/v2.5/README.md b/specs/mvp/v2.5/README.md index 0c398e2..0703a10 100644 --- a/specs/mvp/v2.5/README.md +++ b/specs/mvp/v2.5/README.md @@ -11,4 +11,4 @@ v2.5 的核心变化: - `specs/mvp/v2.5/v2.5_design.md`:总体架构、关键机制(head IP file / watchdog / 用户隔离 / 任务流)。 - `specs/mvp/v2.5/v2.5_api.md`:API 设计(用户、任务、队列、日志)与鉴权约定。 - `specs/mvp/v2.5/v2.5_acceptance.md`:开发/部署/验收流程与可验证标准。 - +- `specs/mvp/v2.5/v2.5_summary.md`:v2.5 已实现内容总结(本次迭代做了什么、验收结果、已知限制)。 diff --git a/specs/mvp/v2.5/notices.md b/specs/mvp/v2.5/notices.md new file mode 100644 index 0000000..6b4a2cf --- /dev/null +++ b/specs/mvp/v2.5/notices.md @@ -0,0 +1,3 @@ +# 记录问题 +1. task 、 submission id 里加上 user name +2. 补全端到端测试用例,各种正常和异常用例,边界情况测试 \ No newline at end of file diff --git a/specs/mvp/v2.5/v2.5_dev_plan.md b/specs/mvp/v2.5/v2.5_dev_plan.md new file mode 100644 index 0000000..ce2bad6 --- /dev/null +++ b/specs/mvp/v2.5/v2.5_dev_plan.md @@ -0,0 +1,229 @@ +# MVP v2.5 开发计划(TDD 驱动) + +本文是 v2.5 的**工程化开发计划**,强调“先写测试,再写实现”(TDD),并将每个里程碑拆成**可独立验收**的小闭环。 + +输入依据: +- 路线图:`specs/mvp/mvp_roadmap_v2.md` +- v2.5 设计:`specs/mvp/v2.5/v2.5_design.md` +- v2.5 API 草案:`specs/mvp/v2.5/v2.5_api.md` +- v2.5 验收:`specs/mvp/v2.5/v2.5_acceptance.md` + +v2.5 约束(已确认): +- **不扩展 TaskSpec**:沿用 v2.0/v2.1 的 YAML 结构化字段与语义。 +- **不支持自定义 reward function / 不支持用户自定义 verl 代码**。 +- 训练输入(verl 代码、HF cache、datasets)统一使用 `/private/common/...`。 +- 多用户隔离 v2.5 **先只隔离 jobs 输出目录**:`/private/users//jobs//...`。 + +--- + +## 0. TDD 规范(所有功能都遵循) + +### 0.1 测试分层 + +1) **单元测试(fast)** +- 纯 Python 逻辑:DB、鉴权、ID、路径派生、head.json 解析/TTL、watchdog 决策逻辑。 +- 目标:不依赖真实 Ray、不依赖 docker、不依赖网络。 + +2) **组件测试(中等)** +- FastAPI 路由:使用 `fastapi.testclient.TestClient`(现有 v2.0 已采用)。 +- 目标:验证 auth/权限隔离、API 行为、状态机。 + +3) **端到端(慢/手工或脚本)** +- 在 `argus@h1` 上通过 scripts/compose 跑一次“head publish → worker auto-connect → API submit”闭环。 +- 目标:验证无状态 worker + watchdog 的真实行为。 + +### 0.2 测试约定 + +- 测试目录:`src/mvp/py/tests/` +- 新增功能必须先补齐测试用例,并让其在未实现时失败(红)。 +- 实现最小改动让测试变绿(绿)。 +- 重构/去重复(重构)。 + +> 注:现有测试通过 `src/mvp/py/tests/conftest.py` 注入 ray stub,确保单测不依赖真实 ray 包;v2.5 新增模块也应复用此模式。 + +--- + +## 1. 里程碑拆分(v2.5 = 4 个可验证闭环) + +### M1:User 表/Token 表 + 基础鉴权(不影响现有内部 token 兼容) + +**目标** +- 引入 user/token 的持久化与鉴权映射(token → user_id)。 +- 兼容现有 `Authorization: Bearer ` 的“单租户模式”,避免一次性破坏 v2.0 用法: + - v2.5 可以先支持两种 token 模式: + - legacy:环境变量 `MVP_INTERNAL_TOKEN`(全局单租户); + - user token:DB 内签发 token(多用户)。 +- admin 能创建用户、签发 token、禁用用户。 + +**TDD 用例(先写测试)** + +单测: +- `test_user_db_create_disable()` + - 创建用户 ACTIVE;禁用后状态变为 DISABLED;重复创建返回冲突或幂等(按最终约定)。 +- `test_token_hashing()` + - 签发 token 时 DB 中只保存 hash,不保存明文。 + +API 测试(TestClient): +- `test_admin_create_user_and_issue_token()` + - admin token 可创建用户并签发 token(明文 token 只返回一次)。 +- `test_disabled_user_token_rejected()` + - 用户被禁用后,使用旧 token 调用 API 返回 401/403。 + +**实现落点(建议模块)** +- `argus.service.auth`:token 校验与 user_id 解析(兼容 legacy 模式) +- `argus.service.db`:新增 `users`、`api_tokens` 表与 CRUD +- `argus.service.app`:新增 user 管理 endpoints(admin scope) +- `configs/dev.yaml`:补充 admin token/env 相关配置(保持 YAML 风格) + +**验收点** +- `v2.5_acceptance.md`:U1 可通过自动化 API 测试覆盖。 + +--- + +### M2:Task 绑定 user_id + API 可见性隔离(仍不改 TaskSpec) + +**目标** +- 提交 task 时由 token 推导 `user_id`,写入 `tasks.user_id`。 +- task 查询/取消/日志默认只允许 owner;他人访问返回 404(避免泄露存在性)。 +- queue 默认只返回当前用户队列;admin 可查询全局队列(可选)。 + +**TDD 用例(先写测试)** + +单测: +- `test_tasks_table_has_user_id()`:创建任务必须落 `user_id`,且 `list_queue(user_id=...)` 只返回该用户任务。 + +API 测试: +- `test_task_visibility_isolated()` + - user A 创建 task;user B 查询 `/api/v2/tasks/{id}` 返回 404; + - user B cancel/logs 也返回 404。 +- `test_queue_isolated()` + - A/B 各自创建 task;`GET /api/v2/queue` 只看到自己的。 + +**实现落点** +- `argus.service.app`:为 task endpoints 增加 user scope +- `argus.service.db`:tasks 表增加 user_id 字段、索引、按 user 过滤的查询方法 +- `argus.service.scheduler`:pick_next_runnable_task 等仍按“全局 FIFO”或“按 user FIFO” + - v2.5 先保持“全局 FIFO”最简单(但 API queue 视角是按 user 过滤)。 + +**验收点** +- `v2.5_acceptance.md`:U2 可通过 API 测试覆盖。 + +--- + +### M3:Jobs 输出目录按 user 隔离(只改输出,不改输入) + +**目标** +- Ray Job 的 job_root 目录由服务端统一计算到: + - `/private/users//jobs//...` +- TaskSpec 内与输入相关的路径字段必须是 `/private/common/...`(v2.5 输入统一 common)。 +- 任何用户无法通过 TaskSpec 指定输出写到非 user jobs 目录(避免越权写)。 + +**TDD 用例(先写测试)** + +单测: +- `test_job_root_derivation_per_user()` + - 给定 user_id 与 ray_submission_id,派生 job_root 固定且正确。 +- `test_reject_non_common_inputs()` + - TaskSpec 中 train_file / val_file / code_path / hf 路径等若不以 `/private/common/` 开头则拒绝(HTTP 400)。 + +API 测试: +- `test_job_dir_written_under_user_jobs()` + - 提交 task 后,在 DB 或 submit payload 中能看到 job_root 在 user jobs 下(可通过 mock RayJobTool.submit 捕获 spec)。 + +**实现落点(建议最小侵入)** +- 在 service 层派生 `job_root` 并注入到 RayJobTool/builders(而不是让用户从 TaskSpec 指定)。 +- RayJobTool `_job_dir()` 改为接收“job_root 生成器”或直接接收 `job_root` 参数(由服务层提供)。 + - 目标:保持 RayJobTool 的职责清晰:提交 Ray job;路径策略由 service 决定。 + +**验收点** +- `v2.5_acceptance.md`:U3/U4 可通过 API/单测覆盖。 + +--- + +### M4:Stateless Ray Node Pool(head.json + worker watchdog)+ 端到端脚本验证 + +**目标** +- head 启动后持续写入 `/private/ray/discovery//head.json`(包含 TTL)。 +- worker 容器内运行 watchdog(或启动脚本 + watchdog),无需平台显式传 head 地址: + - 读取 head.json(存在且未过期)→ `ray start --address=:` + - head.json 变化 → `ray stop` + `ray start` 重连 +- 在 dev 环境(docker compose)提供一键脚本复现(e2e)。 + +**TDD 用例(先写测试)** + +单测(不跑真实 ray): +- `test_head_json_read_validate_ttl()` + - 文件不存在/过期 → 返回“不可用” + - 未过期 → 返回 head 地址 +- `test_watchdog_decision_on_change()` + - head_ip 变化 → 触发重连动作 + - only updated_at 变化(地址不变)→ 不重连(或按策略重连,需确定) + +组件/脚本级测试(可选): +- 如果 watchdog 用 Python 实现,可对“执行命令”层做 stub(不真正跑 `ray start`),只验证会调用什么命令。 + +端到端脚本(手工/慢): +- 提供脚本 `scripts/run_all_v25_stateless.sh`(命名示例): + 1) 起 head(Ray head + API) + 2) 启动 head publisher(写 head.json) + 3) 起 2 个 worker(每个 4 GPU),worker 只跑 watchdog,不传 head 地址 + 4) `ray status` 显示 1 head + 2 worker 且 GPU=8 + 5) 通过 API 创建用户/签发 token,提交 PPO/GRPO/SFT + 6) 重启 head(或更新 head.json 指向新地址)验证 worker 自动重连 + +**实现落点(建议实现策略)** + +为了可测试性(TDD),推荐把“读 head.json/判定 TTL/生成 ray start 命令”做成 Python 模块: +- `argus.ray.discovery`:read/write head.json(原子写、TTL) +- `argus.ray.worker_watchdog`:watch loop(polling + change detection),执行命令可注入(便于单测 stub) + +脚本层保持薄: +- `scripts/` 负责 docker exec / compose 编排与进程守护; +- watchdog 进程由容器内 python 模块运行(更可测、更易移植到生产平台的 entrypoint/command)。 + +**验收点** +- `v2.5_acceptance.md`:A1/A2/A3 主要通过 e2e 脚本 + dashboard/日志验证。 + +--- + +## 2. 回归策略(确保 v2.0 不被破坏) + +在 v2.5 过程中保留并持续回归以下用例(至少单测覆盖): +- 旧的内部 token 模式仍可访问 `GET /api/v2/queue` 与提交 task(若决定保留兼容)。 +- scheduler 的“资源不足 → PENDING_RESOURCES → 延迟重试”行为不变(现有 `test_scheduler.py` 覆盖)。 +- `ray entrypoint_resources` 强制 driver 落 worker(继续使用 `worker_node` 自定义资源)。 + +--- + +## 3. 交付清单(代码/脚本/文档) + +### 3.1 代码 +- user/tokens:DB schema + auth + API endpoints +- tasks:绑定 user_id + 权限隔离 +- job_root:按 user jobs 输出目录派生(输入仍 common) +- discovery/watchdog:head.json + worker 自愈 + +### 3.2 scripts(dev e2e) +- head:启动 Ray head + head publisher +- workers:以无状态方式启动(不传 head addr)+ watchdog +- `run_all`:一键跑通(含 API submit + 查询 + cancel + 观察队列) + +### 3.3 文档 +- 更新 `specs/mvp/v2.5/*`(设计/API/验收/开发计划) +- 补充 `src/mvp/README.md` 的 v2.5 使用方式(如需要) + +--- + +## 4. 关键待确认点(开始实现前必须定稿) + +1) **legacy token 是否继续兼容** +- 方案 A:保留 `MVP_INTERNAL_TOKEN`(单租户)+ 新增 user token(多租户) +- 方案 B:v2.5 直接切换到 user token(破坏兼容,但更清晰) + +2) **调度公平性** +- v2.5 先全局 FIFO(简单);后续 v3 再引入 per-user 公平调度/配额。 + +3) **head.json 的生产写入者** +- 方案 A:与 API 同进程线程(最少组件) +- 方案 B:独立进程(更独立、易运维) + diff --git a/specs/mvp/v2.5/v2.5_e2e_test_cases.md b/specs/mvp/v2.5/v2.5_e2e_test_cases.md new file mode 100644 index 0000000..2e8472e --- /dev/null +++ b/specs/mvp/v2.5/v2.5_e2e_test_cases.md @@ -0,0 +1,132 @@ +# MVP v2.5 端到端测试用例(正常/异常/边界) + +本用例集目标:覆盖 v2.5 的关键能力与边界条件(User + jobs 隔离 + stateless node pool + API 队列调度)。 + +约束(v2.5 已确认): +- TaskSpec 不扩展;不支持 reward_fn;不支持用户自定义 verl 代码。 +- 输入统一 `/private/common/...`;用户隔离先只隔离 `/private/users//jobs/...` 输出。 + +--- + +## 0. 环境前置 + +远端目录示例: +- `argus@h1:/home2/argus/infra/mvp/src/mvp/` + +共享目录(宿主机): +- `/home2/argus/infra/mvp/shared/` + +容器内路径约定: +- `/private` 为共享存储挂载点 + +需要: +- GPU 0-7 可用 +- 3 容器:head(无 GPU)+ 2 worker(各 4 GPU) + +--- + +## 1. 正常用例(Happy Path) + +### HP-1:v2.5 全链路(PPO → GRPO → SFT,串行) + +步骤: +1) `cd /home2/argus/infra/mvp/src/mvp/scripts` +2) `MVP_INTERNAL_TOKEN= RESET_DB=1 ./run_all_v25_api.sh` + +期望: +- Ray dashboard 显示 3 nodes(head+2 workers),GPU 总数 8。 +- 3 个 task 最终为 `SUCCEEDED`。 +- 输出目录存在且按用户隔离: + - `/private/users//jobs//{config,logs,checkpoints,debug}` + +### HP-2:Driver 不在 head 跑 + +验证点(任选一种): +- Ray job 的 driver node IP 不等于 head 容器 IP; +- 或日志/调度信息显示 entrypoint_resources 生效(driver 在 worker)。 + +--- + +## 2. 异常用例(Error Cases) + +### E-Auth-1:缺 token + +请求: +- `GET /api/v2/queue` 不带 `Authorization` 头 + +期望: +- 返回 401(missing bearer token) + +### E-Auth-2:无效 token + +请求: +- `Authorization: Bearer ` + +期望: +- 返回 401(invalid token) + +### E-Auth-3:用户禁用后拒绝访问 + +步骤: +1) admin 创建用户 `bob` 并签发 token +2) admin 禁用 `bob` +3) 用 bob token 请求 `/api/v2/queue` + +期望: +- 返回 403(user disabled) + +### E-Isolation-1:跨用户访问 task 资源(不泄露存在性) + +步骤: +1) alice 提交 task 得到 `task_id` +2) bob 查询 `/api/v2/tasks/{task_id}` + +期望: +- 返回 404(task not found) + +### E-Input-1:输入路径不在 /private/common(v2.5 约束) + +请求: +- 提交 taskspec 但 `train_file` 或 `code_path` 不以 `/private/common/` 开头 + +期望: +- 返回 400,并给出具体字段错误(例如 `train_file must start with /private/common/`)。 + +--- + +## 3. 边界用例(Boundary) + +### B-Queue-1:资源不足时不提交 Ray(PENDING_RESOURCES) + +步骤: +1) 构造任务需求 `nnodes=3` 且 `n_gpus_per_node=4`(total 12 GPU) +2) 提交后轮询状态 + +期望: +- task 进入 `PENDING_RESOURCES`(服务侧 pending,不向 Ray submit) +- 具备 `next_run_at` + +### B-Cancel-1:任务取消(QUEUED/RUNNING) + +步骤: +1) 提交一个较长 steps 的任务(确保有机会 RUNNING) +2) 调用 `POST /api/v2/tasks/{task_id}:cancel` + +期望: +- task state 为 `CANCELED` +- attempt 中 `ray_status` 最终为 `STOPPED`(或 Ray 侧停止) + +--- + +## 4. 可执行回归脚本 + +见: +- `src/mvp/scripts/run_e2e_v25_cases.sh` + +脚本覆盖: +- HP-1 +- E-Auth-1/E-Auth-2/E-Input-1 +- E-Isolation-1 +- B-Queue-1 +- B-Cancel-1(best-effort) + diff --git a/specs/mvp/v2.5/v2.5_summary.md b/specs/mvp/v2.5/v2.5_summary.md new file mode 100644 index 0000000..deb52fa --- /dev/null +++ b/specs/mvp/v2.5/v2.5_summary.md @@ -0,0 +1,92 @@ +# MVP v2.5 迭代总结(已落地) + +本文档总结 v2.5 在 v2.0/v2.1/v2.2…基础上完成的能力、实现点、验收方式与已知限制,便于回顾与后续版本迭代对齐。 + +## 目标与边界 + +v2.5 的核心目标: +- 引入 **User Management(用户管理)**:基于 token 的鉴权与任务级隔离(“只隔离 jobs”)。 +- 引入 **Stateless Ray Node Pool(无状态 Ray worker 池)**:worker 不依赖平台下发 head 地址,自动发现并连接/自愈。 +- 保持 **TaskSpec(v1.1 同款 YAML 格式)不扩展**:本迭代不支持 reward function、自定义 verl 代码等。 + +明确不做(v2.5 约束): +- 不支持 TaskSpec 扩展(例如 `reward_fn_path` 等)。 +- 不支持用户自定义 verl/hf/dataset 的隔离或自定义路径:**统一使用 `/private/common/...`** 的公共资源。 +- 用户隔离仅覆盖 **任务与产物目录**(jobs),不覆盖 HF cache、datasets 等公共缓存。 + +## 关键能力(对外表现) + +### 1) 多用户鉴权与任务隔离 +- API 仍使用内部 `Authorization: Bearer ` 方式: + - 管理员 token 来自环境变量 `MVP_INTERNAL_TOKEN`(admin)。 + - 业务用户 token 由管理员通过 API 下发并持久化到 SQLite。 +- 用户隔离策略: + - 非管理员用户只能查询/取消/拉取日志 **自己的 task**;跨用户访问返回 404(不泄露存在性)。 + - 训练产物落盘隔离:Ray job 目录统一写入 `/private/users//jobs//...`。 + +### 2) task_id / submission_id 带用户名 +- 新任务 ID 规则:`mvp2----` +- Ray submission id(attempt)规则:`--aNN`,因此自然包含用户名。 +- 作用:Dashboard/日志/落盘目录可读性更强,便于按用户追踪和审计。 + +### 3) “无状态 worker 池”与 head 地址发现 +- Head 在共享存储写入 **head 地址文件**(例如 `head.json`),worker 通过 watchdog: + - 轮询发现 head 地址 + - 自动 `ray start --address ...` 加入集群 + - 掉线后自动重连(watchdog 自愈) +- 达成效果:在生产环境中,即使 worker 容器由算力平台创建(只提供 SSH 纳管),也能通过共享存储实现连接与自愈。 + +### 4) 任务调度:队列 + Ray Job 提交 + 状态回传 +- API 提交任务后进入 SQLite 队列,由后台 scheduler 逐个提交到 Ray(默认 `max_running_tasks=1`)。 +- Scheduler 持续轮询 Ray job 状态并回写任务状态(RUNNING/SUCCEEDED/FAILED/CANCELED)。 +- 资源不足的“可重试失败”处理: + - 针对 VERL 的 fail-fast(`Total available GPUs ... is less than total desired GPUs ...`)或集群资源不足, + 任务进入 `PENDING_RESOURCES` 并设置 `next_run_at`,按 `retry_interval_s` 周期重试。 + +## 关键实现点(工程化落地) + +### 存储与目录约定(容器内视角) +- 共享根路径统一为 `/private`(对齐生产挂载)。 +- v2.5 强约束:TaskSpec 的以下字段必须以 `/private/common/` 开头: + - `code_path` / `train_file` / `val_file` +- 公共目录(示例): + - `/private/common/hf`:HF 缓存 + - `/private/common/datasets`:训练数据(必要时通过 symlink 指向已有缓存目录复用下载) + - `/private/common/db/mvp.sqlite3`:队列与用户信息(SQLite) + - `/private/common/logs`:API / watchdog 日志 + - `/private/users//jobs/...`:用户作业产物(隔离) + +### Ray 拓扑与“head 不跑训练” +- Head 启动为管理节点(CPU/GPU=0),避免训练任务落到 head。 +- Worker 节点具备 GPU(示例:2 个 worker * 每个 4 GPU)。 +- driver 通过 `entrypoint_resources`(例如 `worker_node: 1`)强制落 worker。 + +### 部署脚本与可重复执行 +提供完整脚本链路,覆盖: +- 清理 legacy 环境、起停容器、启动 Ray head +- head discovery publisher、worker watchdog 启动与状态检查 +- 数据/模型/代码准备(幂等、可复用已有下载) +- 启动 API server(并支持 RESET_DB) +- API 方式连续提交 PPO/GRPO/SFT 并等待完成 + +代表性脚本: +- `src/mvp/scripts/run_all_v25_api.sh`:v2.5 happy-path 端到端(含重建集群、准备资源、起 API、提交 3 类任务) +- `src/mvp/scripts/run_e2e_v25_cases.sh`:在 happy-path 基础上增加鉴权/隔离/输入校验/资源不足/取消等用例 + +## 验收与测试(已通过) + +### 单元测试(本机 venv) +- `.venv/bin/python -m pytest` +- 覆盖率阈值:>= 90% + +### 远端端到端(h1) +- 在 `argus@h1:/home2/argus/infra/mvp/src/mvp/scripts` 执行: + - `MVP_INTERNAL_TOKEN=mvp-dev-token RESET_DB=1 ./run_e2e_v25_cases.sh` +- 结果:happy-path(PPO/GRPO/SFT)完成,且异常/边界用例验证通过(鉴权、跨用户隔离、输入校验、资源不足转 PENDING_RESOURCES、取消任务等)。 + +## 已知问题与后续建议 + +- `max_running_tasks=1` 会让队列中的任务在前序 RUNNING 时保持 QUEUED,这在“资源不足”边界测试里需要显式清空/取消前序任务,或接受该行为作为设计的一部分。 +- 当前仍是 SQLite 单点;后续若要 HA/水平扩展,可在 v2.6+ 引入更强的持久化与多副本(例如 Postgres/etcd)。 +- API server / watchdog 目前以脚本方式守护;后续可进一步统一为 systemd/supervisor(或平台侧守护)并补齐健康检查与告警。 + diff --git a/src/mvp/README.md b/src/mvp/README.md index d7db5f7..9758140 100644 --- a/src/mvp/README.md +++ b/src/mvp/README.md @@ -15,3 +15,4 @@ 快速开始: - CLI 提交流程:`scripts/run_all_cli.sh` - API 提交流程:`scripts/run_all_api.sh` +- v2.5(Stateless worker + user 隔离 jobs)E2E:`scripts/run_all_v25_api.sh` diff --git a/src/mvp/py/argus/core/ids.py b/src/mvp/py/argus/core/ids.py index 0f1f0cc..bd46978 100644 --- a/src/mvp/py/argus/core/ids.py +++ b/src/mvp/py/argus/core/ids.py @@ -2,14 +2,27 @@ from __future__ import annotations import secrets from datetime import datetime +import re -def new_task_id(workload: str) -> str: +_USER_SAFE_RE = re.compile(r"[^a-z0-9_-]+") + + +def _normalize_user_id(user_id: str) -> str: + s = (user_id or "").strip().lower() + s = _USER_SAFE_RE.sub("-", s) + s = re.sub(r"-{2,}", "-", s).strip("-") + return s[:24] if s else "" + + +def new_task_id(workload: str, *, user_id: str | None = None) -> str: ts = datetime.now().strftime("%Y%m%d-%H%M%S") suffix = secrets.token_hex(2) + u = _normalize_user_id(user_id or "") + if u: + return f"mvp2-{u}-{workload}-{ts}-{suffix}" return f"mvp2-{workload}-{ts}-{suffix}" def attempt_submission_id(task_id: str, attempt_no: int) -> str: return f"{task_id}--a{attempt_no:02d}" - diff --git a/src/mvp/py/argus/ray/discovery.py b/src/mvp/py/argus/ray/discovery.py new file mode 100644 index 0000000..efaf4b0 --- /dev/null +++ b/src/mvp/py/argus/ray/discovery.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any + + +def _utc_now() -> datetime: + return datetime.now(timezone.utc) + + +def _parse_utc(ts: str) -> datetime: + # Accept ISO8601 with trailing Z + s = ts.strip() + if s.endswith("Z"): + s = s[:-1] + "+00:00" + return datetime.fromisoformat(s).astimezone(timezone.utc) + + +def _iso_z(dt: datetime) -> str: + return dt.astimezone(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + +@dataclass(frozen=True) +class HeadRecord: + cluster_name: str + head_ip: str + gcs_port: int + dashboard_port: int + job_server_url: str + updated_at: str + expires_at: str + + @staticmethod + def from_dict(d: dict[str, Any]) -> "HeadRecord": + return HeadRecord( + cluster_name=str(d["cluster_name"]), + head_ip=str(d["head_ip"]), + gcs_port=int(d.get("gcs_port", 6379)), + dashboard_port=int(d.get("dashboard_port", 8265)), + job_server_url=str(d.get("job_server_url") or f"http://{d['head_ip']}:{int(d.get('dashboard_port', 8265))}"), + updated_at=str(d["updated_at"]), + expires_at=str(d["expires_at"]), + ) + + def to_dict(self) -> dict[str, Any]: + return { + "cluster_name": self.cluster_name, + "head_ip": self.head_ip, + "gcs_port": self.gcs_port, + "dashboard_port": self.dashboard_port, + "job_server_url": self.job_server_url, + "updated_at": self.updated_at, + "expires_at": self.expires_at, + } + + def head_addr(self) -> str: + return f"{self.head_ip}:{self.gcs_port}" + + +def build_head_record( + *, + cluster_name: str, + head_ip: str, + gcs_port: int = 6379, + dashboard_port: int = 8265, + ttl_s: int = 60, + now: datetime | None = None, +) -> HeadRecord: + now_dt = now or _utc_now() + expires = now_dt + timedelta(seconds=int(ttl_s)) + updated_at = _iso_z(now_dt) + expires_at = _iso_z(expires) + return HeadRecord( + cluster_name=cluster_name, + head_ip=head_ip, + gcs_port=int(gcs_port), + dashboard_port=int(dashboard_port), + job_server_url=f"http://{head_ip}:{int(dashboard_port)}", + updated_at=updated_at, + expires_at=expires_at, + ) + + +def load_head_record(path: str, *, now: datetime | None = None) -> HeadRecord | None: + p = Path(path) + if not p.exists(): + return None + try: + obj = json.loads(p.read_text(encoding="utf-8")) + except Exception: + return None + if not isinstance(obj, dict): + return None + try: + rec = HeadRecord.from_dict(obj) + expires = _parse_utc(rec.expires_at) + except Exception: + return None + now_dt = now or _utc_now() + if expires <= now_dt: + return None + return rec + + +def write_head_record_atomic(path: str, rec: HeadRecord) -> None: + p = Path(path) + os.makedirs(p.parent, exist_ok=True) + tmp = p.with_suffix(p.suffix + ".tmp") + tmp.write_text(json.dumps(rec.to_dict(), indent=2, sort_keys=True) + "\n", encoding="utf-8") + tmp.replace(p) + diff --git a/src/mvp/py/argus/ray/head_publisher.py b/src/mvp/py/argus/ray/head_publisher.py new file mode 100644 index 0000000..e8fc68f --- /dev/null +++ b/src/mvp/py/argus/ray/head_publisher.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import argparse +import time + +from .discovery import build_head_record, write_head_record_atomic + + +def publish_once( + *, + cluster_name: str, + head_ip_file: str, + head_ip: str, + gcs_port: int, + dashboard_port: int, + ttl_s: int, +) -> None: + rec = build_head_record( + cluster_name=cluster_name, + head_ip=head_ip, + gcs_port=gcs_port, + dashboard_port=dashboard_port, + ttl_s=ttl_s, + ) + write_head_record_atomic(head_ip_file, rec) + + +def main(argv: list[str] | None = None) -> int: + ap = argparse.ArgumentParser(description="Publish Ray head address to shared storage (head.json).") + ap.add_argument("--cluster-name", required=True) + ap.add_argument("--head-ip-file", required=True) + ap.add_argument("--head-ip", required=True) + ap.add_argument("--gcs-port", type=int, default=6379) + ap.add_argument("--dashboard-port", type=int, default=8265) + ap.add_argument("--ttl-s", type=int, default=60) + ap.add_argument("--refresh-s", type=int, default=10) + ap.add_argument("--once", action="store_true", help="Write once then exit (for testing/debug).") + args = ap.parse_args(argv) + + if args.once: + publish_once( + cluster_name=args.cluster_name, + head_ip_file=args.head_ip_file, + head_ip=args.head_ip, + gcs_port=args.gcs_port, + dashboard_port=args.dashboard_port, + ttl_s=args.ttl_s, + ) + return 0 + + refresh_s = max(1, int(args.refresh_s)) + while True: + publish_once( + cluster_name=args.cluster_name, + head_ip_file=args.head_ip_file, + head_ip=args.head_ip, + gcs_port=args.gcs_port, + dashboard_port=args.dashboard_port, + ttl_s=args.ttl_s, + ) + time.sleep(refresh_s) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/mvp/py/argus/ray/ray_job_tool.py b/src/mvp/py/argus/ray/ray_job_tool.py index 8d7c538..bff2f3e 100644 --- a/src/mvp/py/argus/ray/ray_job_tool.py +++ b/src/mvp/py/argus/ray/ray_job_tool.py @@ -73,9 +73,9 @@ class RayJobTool: return {"env_vars": env_vars} - def submit(self, spec: JobSpec, no_wait: bool) -> str: + def submit(self, spec: JobSpec, no_wait: bool, job_dir: str | None = None) -> str: submission_id = spec.submission_id or f"mvp11_{spec.workload}_{_ts()}_{os.getpid()}" - job_dir = self._job_dir(submission_id) + job_dir = job_dir or self._job_dir(submission_id) built = build_training_argv(spec, submission_id=submission_id, job_dir=job_dir) entrypoint_argv = [ diff --git a/src/mvp/py/argus/ray/worker_watchdog.py b/src/mvp/py/argus/ray/worker_watchdog.py new file mode 100644 index 0000000..b030804 --- /dev/null +++ b/src/mvp/py/argus/ray/worker_watchdog.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import argparse +import json +import subprocess +import time +from dataclasses import dataclass +from typing import Callable + +from .discovery import HeadRecord, load_head_record + + +Runner = Callable[[list[str]], int] + + +def _default_runner(argv: list[str]) -> int: + proc = subprocess.run(argv, check=False) + return int(proc.returncode) + + +def _ray_stop_cmd() -> list[str]: + return ["ray", "stop", "--force"] + + +def _ray_start_cmd(*, head_addr: str, node_ip: str | None, resources_json: str) -> list[str]: + argv = ["ray", "start", f"--address={head_addr}", f"--resources={resources_json}", "--disable-usage-stats"] + if node_ip: + argv.append(f"--node-ip-address={node_ip}") + return argv + + +@dataclass +class Watchdog: + head_ip_file: str + node_ip: str | None + resources_json: str + poll_s: int + runner: Runner + + _current_head_addr: str | None = None + + def _restart_ray(self, desired: str) -> None: + # Best-effort stop then start. + self.runner(_ray_stop_cmd()) + self.runner(_ray_start_cmd(head_addr=desired, node_ip=self.node_ip, resources_json=self.resources_json)) + self._current_head_addr = desired + + def tick_once(self) -> HeadRecord | None: + rec = load_head_record(self.head_ip_file) + if rec is None: + return None + desired = rec.head_addr() + if self._current_head_addr != desired: + self._restart_ray(desired) + return rec + + def run_forever(self) -> None: + while True: + try: + self.tick_once() + except Exception: + # keep watchdog alive + pass + time.sleep(max(1, int(self.poll_s))) + + +def main(argv: list[str] | None = None) -> int: + ap = argparse.ArgumentParser(description="Stateless Ray worker watchdog (auto-connect + reconnect).") + ap.add_argument("--head-ip-file", required=True) + ap.add_argument("--node-ip", default=None) + ap.add_argument( + "--resources-kv", + action="append", + default=[], + help='Repeatable k=v resources, e.g. --resources-kv worker_node=100', + ) + ap.add_argument("--poll-s", type=int, default=5) + ap.add_argument("--once", action="store_true", help="Run one tick then exit (for testing/debug).") + args = ap.parse_args(argv) + + resources: dict[str, float] = {"worker_node": 100.0} + if args.resources_kv: + resources = {} + for item in args.resources_kv: + if "=" not in item: + raise SystemExit(f"invalid --resources-kv (expected k=v): {item!r}") + k, v = item.split("=", 1) + k = k.strip() + v = v.strip() + if not k: + raise SystemExit(f"invalid --resources-kv key: {item!r}") + resources[k] = float(v) + + wd = Watchdog( + head_ip_file=args.head_ip_file, + node_ip=args.node_ip, + resources_json=json.dumps(resources), + poll_s=int(args.poll_s), + runner=_default_runner, + ) + if args.once: + wd.tick_once() + return 0 + + wd.run_forever() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/mvp/py/argus/service/app.py b/src/mvp/py/argus/service/app.py index 6df707a..0d7c2db 100644 --- a/src/mvp/py/argus/service/app.py +++ b/src/mvp/py/argus/service/app.py @@ -43,19 +43,32 @@ def create_app(config_path: str) -> FastAPI: app = FastAPI(title="mvp-v2", version="2.0") - def _require_token(req: Request) -> None: + def _auth(req: Request) -> dict[str, Any]: token_env = v2_cfg.auth.token_env - expected = os.environ.get(token_env, "") - if not expected: - # Misconfigured service; treat as server error. + admin_token = os.environ.get(token_env, "") + if not admin_token: raise HTTPException(status_code=500, detail=f"missing token env: {token_env}") auth = req.headers.get("authorization") or "" if not auth.startswith("Bearer "): raise HTTPException(status_code=401, detail="missing bearer token") got = auth.removeprefix("Bearer ").strip() - if got != expected: + + if got == admin_token: + return {"user_id": "admin", "is_admin": True} + + info = db.resolve_token(got) + if not info: raise HTTPException(status_code=401, detail="invalid token") + if info.get("state") != "ACTIVE": + raise HTTPException(status_code=403, detail="user disabled") + return {"user_id": info["user_id"], "is_admin": False} + + def _require_admin(req: Request) -> dict[str, Any]: + subject = _auth(req) + if not subject.get("is_admin"): + raise HTTPException(status_code=403, detail="admin required") + return subject @app.on_event("startup") def _startup() -> None: @@ -66,9 +79,43 @@ def create_app(config_path: str) -> FastAPI: def _shutdown() -> None: stop_flag.set() + @app.post("/api/v2/users") + async def create_user(req: Request) -> dict[str, Any]: + _require_admin(req) + obj = await req.json() + if not isinstance(obj, dict): + raise HTTPException(status_code=400, detail="body must be a JSON object") + user_id = str(obj.get("user_id") or "").strip() + if not user_id: + raise HTTPException(status_code=400, detail="missing user_id") + display_name = obj.get("display_name") + try: + row = db.create_user(user_id=user_id, display_name=str(display_name) if display_name is not None else None) + except Exception as e: + raise HTTPException(status_code=409, detail=f"user create failed: {e!r}") + return {"user_id": row.get("user_id", user_id), "state": row.get("state", "ACTIVE")} + + @app.post("/api/v2/users/{user_id}/tokens") + async def issue_token(user_id: str, req: Request) -> dict[str, Any]: + _require_admin(req) + u = db.get_user(user_id) + if not u: + raise HTTPException(status_code=404, detail="user not found") + token = db.issue_token(user_id=user_id) + return {"user_id": user_id, "token": token} + + @app.post("/api/v2/users/{user_id}:disable") + async def disable_user(user_id: str, req: Request) -> dict[str, Any]: + _require_admin(req) + u = db.get_user(user_id) + if not u: + raise HTTPException(status_code=404, detail="user not found") + db.disable_user(user_id=user_id) + return {"user_id": user_id, "state": "DISABLED"} + @app.post("/api/v2/tasks") async def submit_task(req: Request) -> dict[str, Any]: - _require_token(req) + subject = _auth(req) body = (await req.body()).decode("utf-8") obj = yaml.safe_load(body) or {} if not isinstance(obj, dict): @@ -79,9 +126,18 @@ def create_app(config_path: str) -> FastAPI: except Exception as e: raise HTTPException(status_code=400, detail=f"invalid jobspec: {e!r}") - task_id = new_task_id(spec.workload) - db.create_task( + # v2.5 constraint: training inputs must come from /private/common (dev/prod统一)。 + common_prefix = ray_cfg.shared_root.rstrip("/") + "/common/" + for k, v in (("code_path", spec.code_path), ("train_file", spec.train_file), ("val_file", spec.val_file)): + if v is None: + continue + if not str(v).startswith(common_prefix): + raise HTTPException(status_code=400, detail=f"{k} must start with {common_prefix}") + + task_id = new_task_id(spec.workload, user_id=str(subject["user_id"])) + db.create_task_v25( task_id=task_id, + user_id=str(subject["user_id"]), workload=spec.workload, jobspec_yaml=body, nnodes=spec.nnodes, @@ -91,10 +147,13 @@ def create_app(config_path: str) -> FastAPI: @app.get("/api/v2/tasks/{task_id}") async def get_task(task_id: str, req: Request) -> dict[str, Any]: - _require_token(req) + subject = _auth(req) row = db.get_task(task_id) if not row: raise HTTPException(status_code=404, detail="task not found") + if not subject.get("is_admin"): + if str(row.get("user_id") or "") != str(subject["user_id"]): + raise HTTPException(status_code=404, detail="task not found") attempts = db.list_attempts(task_id) latest_attempt = attempts[-1] if attempts else None desired = { @@ -125,18 +184,24 @@ def create_app(config_path: str) -> FastAPI: @app.get("/api/v2/tasks/{task_id}/attempts") async def get_attempts(task_id: str, req: Request) -> dict[str, Any]: - _require_token(req) + subject = _auth(req) row = db.get_task(task_id) if not row: raise HTTPException(status_code=404, detail="task not found") + if not subject.get("is_admin"): + if str(row.get("user_id") or "") != str(subject["user_id"]): + raise HTTPException(status_code=404, detail="task not found") return {"task_id": task_id, "attempts": db.list_attempts(task_id)} @app.post("/api/v2/tasks/{task_id}:cancel") async def cancel(task_id: str, req: Request) -> dict[str, Any]: - _require_token(req) + subject = _auth(req) row = db.get_task(task_id) if not row: raise HTTPException(status_code=404, detail="task not found") + if not subject.get("is_admin"): + if str(row.get("user_id") or "") != str(subject["user_id"]): + raise HTTPException(status_code=404, detail="task not found") state = str(row["state"]) if state in ("SUCCEEDED", "FAILED", "CANCELED"): @@ -165,10 +230,13 @@ def create_app(config_path: str) -> FastAPI: @app.get("/api/v2/tasks/{task_id}/logs") async def logs(task_id: str, req: Request, tail: int = 2000, attempt: str = "latest") -> Response: - _require_token(req) + subject = _auth(req) row = db.get_task(task_id) if not row: raise HTTPException(status_code=404, detail="task not found") + if not subject.get("is_admin"): + if str(row.get("user_id") or "") != str(subject["user_id"]): + raise HTTPException(status_code=404, detail="task not found") attempts = db.list_attempts(task_id) if not attempts: raise HTTPException(status_code=404, detail="no attempts yet") @@ -184,7 +252,9 @@ def create_app(config_path: str) -> FastAPI: @app.get("/api/v2/queue") async def queue(req: Request) -> dict[str, Any]: - _require_token(req) - return db.list_queue() + subject = _auth(req) + if subject.get("is_admin"): + return db.list_queue() + return db.list_queue(user_id=str(subject["user_id"])) return app diff --git a/src/mvp/py/argus/service/db.py b/src/mvp/py/argus/service/db.py index 4bcb2e8..0aa36d9 100644 --- a/src/mvp/py/argus/service/db.py +++ b/src/mvp/py/argus/service/db.py @@ -4,6 +4,8 @@ import os import sqlite3 from contextlib import contextmanager from dataclasses import dataclass +import hashlib +import secrets from typing import Any, Iterator @@ -18,6 +20,10 @@ def _utc_now_iso() -> str: class Db: db_path: str + def _hash_token(self, token: str) -> str: + # Internal tokens only; store hash to avoid plaintext at rest. + return hashlib.sha256(token.encode("utf-8")).hexdigest() + def _connect(self) -> sqlite3.Connection: os.makedirs(os.path.dirname(self.db_path), exist_ok=True) conn = sqlite3.connect(self.db_path, timeout=30, isolation_level=None) @@ -28,6 +34,27 @@ class Db: def init(self) -> None: with self._connect() as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS users ( + user_id TEXT PRIMARY KEY, + display_name TEXT, + state TEXT NOT NULL, + created_at TEXT NOT NULL + ) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS api_tokens ( + token_hash TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + created_at TEXT NOT NULL, + last_used_at TEXT, + FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE + ) + """ + ) conn.execute( """ CREATE TABLE IF NOT EXISTS tasks ( @@ -45,6 +72,12 @@ class Db: ) """ ) + # Best-effort: add user_id column for forward compatibility (v2.5+). + # v2.0 tables may exist without it; SQLite supports ADD COLUMN. + try: + conn.execute("ALTER TABLE tasks ADD COLUMN user_id TEXT") + except sqlite3.OperationalError: + pass conn.execute( """ CREATE TABLE IF NOT EXISTS attempts ( @@ -92,10 +125,10 @@ class Db: with self.tx() as conn: conn.execute( """ - INSERT INTO tasks (task_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at) - VALUES (?, ?, 'QUEUED', ?, ?, ?, ?, ?) + INSERT INTO tasks (task_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at, user_id) + VALUES (?, ?, 'QUEUED', ?, ?, ?, ?, ?, ?) """, - (task_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now), + (task_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now, None), ) conn.execute( "INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (?, ?, 'TASK_CREATED', ?)", @@ -104,6 +137,97 @@ class Db: row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone() return dict(row) if row else {} + def create_task_v25( + self, + *, + task_id: str, + user_id: str, + workload: str, + jobspec_yaml: str, + nnodes: int, + n_gpus_per_node: int, + ) -> dict[str, Any]: + now = _utc_now_iso() + with self.tx() as conn: + conn.execute( + """ + INSERT INTO tasks (task_id, user_id, workload, state, jobspec_yaml, nnodes, n_gpus_per_node, created_at, updated_at) + VALUES (?, ?, ?, 'QUEUED', ?, ?, ?, ?, ?) + """, + (task_id, user_id, workload, jobspec_yaml, nnodes, n_gpus_per_node, now, now), + ) + conn.execute( + "INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (?, ?, 'TASK_CREATED', ?)", + (task_id, now, None), + ) + row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone() + return dict(row) if row else {} + + def create_user(self, *, user_id: str, display_name: str | None = None) -> dict[str, Any]: + now = _utc_now_iso() + with self.tx() as conn: + conn.execute( + """ + INSERT INTO users (user_id, display_name, state, created_at) + VALUES (?, ?, 'ACTIVE', ?) + """, + (user_id, display_name, now), + ) + conn.execute( + "INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'USER_CREATED', ?)", + (now, user_id), + ) + row = conn.execute("SELECT * FROM users WHERE user_id = ?", (user_id,)).fetchone() + return dict(row) if row else {} + + def disable_user(self, *, user_id: str) -> None: + now = _utc_now_iso() + with self.tx() as conn: + conn.execute("UPDATE users SET state = 'DISABLED' WHERE user_id = ?", (user_id,)) + conn.execute( + "INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'USER_DISABLED', ?)", + (now, user_id), + ) + + def get_user(self, user_id: str) -> dict[str, Any] | None: + with self._connect() as conn: + row = conn.execute("SELECT * FROM users WHERE user_id = ?", (user_id,)).fetchone() + return dict(row) if row else None + + def issue_token(self, *, user_id: str) -> str: + # Returns plaintext token once; stores hash only. + now = _utc_now_iso() + token = f"mvp_u_{user_id}_{secrets.token_urlsafe(18)}" + token_hash = self._hash_token(token) + with self.tx() as conn: + conn.execute( + "INSERT INTO api_tokens (token_hash, user_id, created_at) VALUES (?, ?, ?)", + (token_hash, user_id, now), + ) + conn.execute( + "INSERT INTO events (task_id, ts, event_type, payload_json) VALUES (NULL, ?, 'TOKEN_ISSUED', ?)", + (now, user_id), + ) + return token + + def resolve_token(self, token: str) -> dict[str, Any] | None: + token_hash = self._hash_token(token) + with self.tx() as conn: + row = conn.execute( + """ + SELECT t.user_id, u.state + FROM api_tokens t + JOIN users u ON u.user_id = t.user_id + WHERE t.token_hash = ? + """, + (token_hash,), + ).fetchone() + if not row: + return None + now = _utc_now_iso() + conn.execute("UPDATE api_tokens SET last_used_at = ? WHERE token_hash = ?", (now, token_hash)) + return {"user_id": row["user_id"], "state": row["state"]} + def get_task(self, task_id: str) -> dict[str, Any] | None: with self._connect() as conn: row = conn.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone() @@ -116,26 +240,33 @@ class Db: ).fetchall() return [dict(r) for r in rows] - def list_queue(self) -> dict[str, list[dict[str, Any]]]: + def list_queue(self, *, user_id: str | None = None) -> dict[str, list[dict[str, Any]]]: with self._connect() as conn: - pending = conn.execute( - """ - SELECT task_id, workload, state, nnodes, n_gpus_per_node, next_run_at, created_at, updated_at - FROM tasks - WHERE state IN ('QUEUED','PENDING_RESOURCES') - ORDER BY created_at ASC - LIMIT 200 - """ - ).fetchall() - running = conn.execute( - """ - SELECT task_id, workload, state, nnodes, n_gpus_per_node, latest_attempt_no, created_at, updated_at - FROM tasks - WHERE state IN ('SUBMITTING','SUBMITTED','RUNNING') - ORDER BY updated_at ASC - LIMIT 200 - """ - ).fetchall() + params: list[Any] = [] + user_filter_sql = "" + if user_id is not None: + user_filter_sql = " AND user_id = ?" + params = [user_id] + + pending_sql = ( + "SELECT task_id, workload, state, nnodes, n_gpus_per_node, next_run_at, created_at, updated_at " + "FROM tasks " + "WHERE state IN ('QUEUED','PENDING_RESOURCES')" + f"{user_filter_sql} " + "ORDER BY created_at ASC " + "LIMIT 200" + ) + running_sql = ( + "SELECT task_id, workload, state, nnodes, n_gpus_per_node, latest_attempt_no, created_at, updated_at " + "FROM tasks " + "WHERE state IN ('SUBMITTING','SUBMITTED','RUNNING')" + f"{user_filter_sql} " + "ORDER BY updated_at ASC " + "LIMIT 200" + ) + + pending = conn.execute(pending_sql, tuple(params)).fetchall() + running = conn.execute(running_sql, tuple(params)).fetchall() return {"pending": [dict(r) for r in pending], "running": [dict(r) for r in running]} def count_running(self) -> int: diff --git a/src/mvp/py/argus/service/scheduler.py b/src/mvp/py/argus/service/scheduler.py index 452b77c..0e57253 100644 --- a/src/mvp/py/argus/service/scheduler.py +++ b/src/mvp/py/argus/service/scheduler.py @@ -37,6 +37,12 @@ class Scheduler: def __post_init__(self) -> None: self.tool = RayJobTool(self.ray_cfg) + def _job_dir_for_task(self, *, user_id: str | None, ray_submission_id: str) -> str: + root = self.ray_cfg.shared_root.rstrip("/") + if user_id: + return f"{root}/users/{user_id}/jobs/{ray_submission_id}" + return f"{root}/jobs/{ray_submission_id}" + def _resources_sufficient(self, *, nnodes: int, n_gpus_per_node: int) -> bool: avail = get_cluster_available() required = float(nnodes * n_gpus_per_node) @@ -51,10 +57,13 @@ class Scheduler: def _submit_one(self, task_row: dict[str, Any]) -> None: task_id = str(task_row["task_id"]) jobspec_yaml = str(task_row["jobspec_yaml"]) + user_id = task_row.get("user_id") + user_id_s = str(user_id) if user_id not in (None, "") else None spec = self._parse_jobspec(jobspec_yaml) attempt_no = int(task_row.get("latest_attempt_no", 0)) + 1 ray_sid = attempt_submission_id(task_id, attempt_no) + job_dir = self._job_dir_for_task(user_id=user_id_s, ray_submission_id=ray_sid) # Record attempt first so that we can surface it even if submit crashes. self.db.create_attempt(task_id=task_id, attempt_no=attempt_no, ray_submission_id=ray_sid) @@ -66,7 +75,7 @@ class Scheduler: spec2 = JobSpec.from_dict(d) try: - submitted = self.tool.submit(spec2, no_wait=True) + submitted = self.tool.submit(spec2, no_wait=True, job_dir=job_dir) # submitted should equal ray_sid; keep as source of truth. self.db.update_attempt(task_id=task_id, attempt_no=attempt_no, ray_status="SUBMITTED") self.db.set_task_state(task_id=task_id, state="SUBMITTED") diff --git a/src/mvp/py/tests/test_app.py b/src/mvp/py/tests/test_app.py index 51f8453..99e2243 100644 --- a/src/mvp/py/tests/test_app.py +++ b/src/mvp/py/tests/test_app.py @@ -12,7 +12,7 @@ def _write_config(tmp_path: Path) -> Path: cfg = { "ray": { "address": "http://127.0.0.1:8265", - "shared_root": str(tmp_path), + "shared_root": "/private", "entrypoint_resources": {"worker_node": 1}, "runtime_env": {"env_vars": {}}, }, @@ -54,7 +54,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch): cfg_path = _write_config(tmp_path) monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1") - monkeypatch.setattr(app_mod, "new_task_id", lambda workload: "tid1") + monkeypatch.setattr(app_mod, "new_task_id", lambda workload, **kwargs: "tid1") class _Tool: def __init__(self): @@ -82,7 +82,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch): r = c.post( "/api/v2/tasks", headers=headers, - data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n", + data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n", ) assert r.status_code == 200 assert r.json()["task_id"] == "tid1" @@ -111,7 +111,7 @@ def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch): db.create_task( task_id="tid2", workload="ppo", - jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n", + jobspec_yaml="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n", nnodes=2, n_gpus_per_node=4, ) @@ -163,4 +163,3 @@ def test_submit_rejects_invalid_jobspec(tmp_path: Path, monkeypatch): with TestClient(app) as c: r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="workload: nope\n") assert r.status_code == 400 - diff --git a/src/mvp/py/tests/test_discovery.py b/src/mvp/py/tests/test_discovery.py new file mode 100644 index 0000000..de132c8 --- /dev/null +++ b/src/mvp/py/tests/test_discovery.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import json +from datetime import datetime, timedelta, timezone +from pathlib import Path + +from argus.ray.discovery import build_head_record, load_head_record + + +def test_load_head_record_rejects_missing(tmp_path: Path): + assert load_head_record(str(tmp_path / "nope.json")) is None + + +def test_load_head_record_rejects_expired(tmp_path: Path): + p = tmp_path / "head.json" + now = datetime.now(timezone.utc).replace(microsecond=0) + rec = build_head_record(cluster_name="c", head_ip="1.2.3.4", ttl_s=1, now=now) + p.write_text(json.dumps(rec.to_dict()), encoding="utf-8") + assert load_head_record(str(p), now=now + timedelta(seconds=2)) is None + + +def test_load_head_record_accepts_fresh(tmp_path: Path): + p = tmp_path / "head.json" + now = datetime.now(timezone.utc).replace(microsecond=0) + rec = build_head_record(cluster_name="c", head_ip="1.2.3.4", ttl_s=60, now=now) + p.write_text(json.dumps(rec.to_dict()), encoding="utf-8") + got = load_head_record(str(p), now=now + timedelta(seconds=1)) + assert got is not None + assert got.head_addr() == "1.2.3.4:6379" + + +def test_head_publisher_main_once_writes_file(tmp_path: Path): + from argus.ray import head_publisher as mod + + p = tmp_path / "head.json" + rc = mod.main( + [ + "--cluster-name", + "c", + "--head-ip-file", + str(p), + "--head-ip", + "9.9.9.9", + "--ttl-s", + "60", + "--once", + ] + ) + assert rc == 0 + assert "9.9.9.9" in p.read_text(encoding="utf-8") diff --git a/src/mvp/py/tests/test_ids.py b/src/mvp/py/tests/test_ids.py index 39953a5..ad2a8fe 100644 --- a/src/mvp/py/tests/test_ids.py +++ b/src/mvp/py/tests/test_ids.py @@ -20,9 +20,27 @@ def test_new_task_id_is_deterministic_with_patches(monkeypatch): assert ids.new_task_id("ppo") == "mvp2-ppo-20250101-010203-abcd" +def test_new_task_id_includes_user_when_provided(monkeypatch): + import argus.core.ids as ids + + class _FakeDatetime: + @staticmethod + def now(): + class _DT: + def strftime(self, fmt: str) -> str: + assert fmt == "%Y%m%d-%H%M%S" + return "20250101-010203" + + return _DT() + + monkeypatch.setattr(ids, "datetime", _FakeDatetime) + monkeypatch.setattr(ids.secrets, "token_hex", lambda n: "abcd") + + assert ids.new_task_id("ppo", user_id="Alice_01") == "mvp2-alice_01-ppo-20250101-010203-abcd" + + def test_attempt_submission_id_format(): from argus.core.ids import attempt_submission_id assert attempt_submission_id("t", 1) == "t--a01" assert attempt_submission_id("t", 12) == "t--a12" - diff --git a/src/mvp/py/tests/test_scheduler.py b/src/mvp/py/tests/test_scheduler.py index 8fbcd69..15dc262 100644 --- a/src/mvp/py/tests/test_scheduler.py +++ b/src/mvp/py/tests/test_scheduler.py @@ -14,7 +14,7 @@ def _mk_cfg(tmp_path: Path) -> tuple[RayConfig, V2Config]: root = { "ray": { "address": "http://127.0.0.1:8265", - "shared_root": str(tmp_path), + "shared_root": "/private", "entrypoint_resources": {"worker_node": 1}, "runtime_env": {"env_vars": {}}, }, @@ -32,10 +32,11 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path): ray_cfg, v2_cfg = _mk_cfg(tmp_path) db = Db(v2_cfg.sqlite.db_path) db.init() - db.create_task( + db.create_task_v25( task_id="t1", + user_id="alice", workload="ppo", - jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}), + jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}), nnodes=2, n_gpus_per_node=4, ) @@ -50,9 +51,11 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path): class _Tool: def __init__(self, cfg): self.submitted = [] + self.job_dirs = [] - def submit(self, spec, no_wait: bool): + def submit(self, spec, no_wait: bool, job_dir: str | None = None): self.submitted.append(spec.submission_id) + self.job_dirs.append(job_dir) return str(spec.submission_id) def status(self, submission_id: str): @@ -71,6 +74,7 @@ def test_tick_submits_one_task(monkeypatch, tmp_path: Path): attempts = db.list_attempts("t1") assert len(attempts) == 1 assert attempts[0]["ray_submission_id"] == "t1--a01" + assert s.tool.job_dirs[-1] == "/private/users/alice/jobs/t1--a01" def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path): @@ -79,10 +83,11 @@ def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path): ray_cfg, v2_cfg = _mk_cfg(tmp_path) db = Db(v2_cfg.sqlite.db_path) db.init() - db.create_task( + db.create_task_v25( task_id="t1", + user_id="alice", workload="ppo", - jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}), + jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}), nnodes=2, n_gpus_per_node=4, ) @@ -108,10 +113,11 @@ def test_sync_failed_insufficient_resources(monkeypatch, tmp_path: Path): ray_cfg, v2_cfg = _mk_cfg(tmp_path) db = Db(v2_cfg.sqlite.db_path) db.init() - db.create_task( + db.create_task_v25( task_id="t1", + user_id="alice", workload="ppo", - jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}), + jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}), nnodes=2, n_gpus_per_node=4, ) @@ -147,10 +153,11 @@ def test_sync_status_error_keeps_state(monkeypatch, tmp_path: Path): ray_cfg, v2_cfg = _mk_cfg(tmp_path) db = Db(v2_cfg.sqlite.db_path) db.init() - db.create_task( + db.create_task_v25( task_id="t1", + user_id="alice", workload="ppo", - jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}), + jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/private/common/code/verl", "model_id": "m", "train_file": "/private/common/datasets/t"}), nnodes=2, n_gpus_per_node=4, ) diff --git a/src/mvp/py/tests/test_users.py b/src/mvp/py/tests/test_users.py new file mode 100644 index 0000000..70ca876 --- /dev/null +++ b/src/mvp/py/tests/test_users.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from pathlib import Path + +import yaml +from fastapi.testclient import TestClient + + +def _write_config(tmp_path: Path) -> Path: + cfg = { + "ray": { + "address": "http://127.0.0.1:8265", + "shared_root": "/private", + "entrypoint_resources": {"worker_node": 1}, + "runtime_env": {"env_vars": {}}, + }, + "service": { + "api": {"host": "127.0.0.1", "port": 0}, + "auth": {"token_env": "MVP_INTERNAL_TOKEN"}, + "sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")}, + "scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1}, + }, + } + p = tmp_path / "cfg.yaml" + p.write_text(yaml.safe_dump(cfg), encoding="utf-8") + return p + + +def test_db_user_token_lifecycle(tmp_path: Path): + from argus.service.db import Db + + db = Db(str(tmp_path / "mvp.sqlite3")) + db.init() + db.create_user(user_id="alice", display_name="Alice") + tok = db.issue_token(user_id="alice") + info = db.resolve_token(tok) + assert info and info["user_id"] == "alice" + + db.disable_user(user_id="alice") + info2 = db.resolve_token(tok) + assert info2 and info2["state"] == "DISABLED" + + +def test_admin_create_user_issue_token_and_disabled_rejected(tmp_path: Path, monkeypatch): + from argus.service import app as app_mod + + cfg_path = _write_config(tmp_path) + monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1") + + class _Scheduler: + def __init__(self, **kwargs): + self.tool = object() + + def run_forever(self, stop_flag): + return None + + monkeypatch.setattr(app_mod, "Scheduler", _Scheduler) + app = app_mod.create_app(str(cfg_path)) + + admin_headers = {"authorization": "Bearer adm1"} + with TestClient(app) as c: + r1 = c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice", "display_name": "Alice"}) + assert r1.status_code == 200 + assert r1.json()["user_id"] == "alice" + + r2 = c.post("/api/v2/users/alice/tokens", headers=admin_headers) + assert r2.status_code == 200 + token = r2.json()["token"] + assert token + + # non-admin token can access regular endpoints + r3 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"}) + assert r3.status_code == 200 + + r4 = c.post("/api/v2/users/alice:disable", headers=admin_headers) + assert r4.status_code == 200 + + r5 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"}) + assert r5.status_code == 403 + + +def test_tasks_are_isolated_by_user(tmp_path: Path, monkeypatch): + from argus.service import app as app_mod + + cfg_path = _write_config(tmp_path) + monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1") + + # Deterministic task ids + counter = {"n": 0} + + def _new_id(workload: str, **kwargs) -> str: + counter["n"] += 1 + return f"{workload}_t{counter['n']}" + + monkeypatch.setattr(app_mod, "new_task_id", _new_id) + + class _Scheduler: + def __init__(self, **kwargs): + self.tool = object() + + def run_forever(self, stop_flag): + return None + + monkeypatch.setattr(app_mod, "Scheduler", _Scheduler) + app = app_mod.create_app(str(cfg_path)) + + admin_headers = {"authorization": "Bearer adm1"} + with TestClient(app) as c: + # Create users and tokens + assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200 + assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "bob"}).status_code == 200 + alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"] + bob_tok = c.post("/api/v2/users/bob/tokens", headers=admin_headers).json()["token"] + + alice_headers = {"authorization": f"Bearer {alice_tok}"} + bob_headers = {"authorization": f"Bearer {bob_tok}"} + + # Each user submits one task. + r1 = c.post( + "/api/v2/tasks", + headers=alice_headers, + data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n", + ) + assert r1.status_code == 200 + alice_tid = r1.json()["task_id"] + + r2 = c.post( + "/api/v2/tasks", + headers=bob_headers, + data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\n", + ) + assert r2.status_code == 200 + bob_tid = r2.json()["task_id"] + + # Queue is scoped to user. + qa = c.get("/api/v2/queue", headers=alice_headers).json() + qb = c.get("/api/v2/queue", headers=bob_headers).json() + assert {t["task_id"] for t in qa["pending"]} == {alice_tid} + assert {t["task_id"] for t in qb["pending"]} == {bob_tid} + + # Cross-user access returns 404 (no existence leak). + assert c.get(f"/api/v2/tasks/{bob_tid}", headers=alice_headers).status_code == 404 + assert c.post(f"/api/v2/tasks/{bob_tid}:cancel", headers=alice_headers).status_code == 404 + assert c.get(f"/api/v2/tasks/{bob_tid}/attempts", headers=alice_headers).status_code == 404 + + # Admin can see global queue. + qadm = c.get("/api/v2/queue", headers=admin_headers).json() + assert {t["task_id"] for t in qadm["pending"]} == {alice_tid, bob_tid} + + +def test_submit_rejects_non_common_inputs(tmp_path: Path, monkeypatch): + from argus.service import app as app_mod + + cfg_path = _write_config(tmp_path) + monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1") + + class _Scheduler: + def __init__(self, **kwargs): + self.tool = object() + + def run_forever(self, stop_flag): + return None + + monkeypatch.setattr(app_mod, "Scheduler", _Scheduler) + app = app_mod.create_app(str(cfg_path)) + + admin_headers = {"authorization": "Bearer adm1"} + with TestClient(app) as c: + assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200 + alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"] + alice_headers = {"authorization": f"Bearer {alice_tok}"} + + r = c.post( + "/api/v2/tasks", + headers=alice_headers, + data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: /private/common/datasets/t\n", + ) + assert r.status_code == 400 + assert "code_path must start with /private/common/" in r.text diff --git a/src/mvp/py/tests/test_worker_watchdog.py b/src/mvp/py/tests/test_worker_watchdog.py new file mode 100644 index 0000000..88cb8ed --- /dev/null +++ b/src/mvp/py/tests/test_worker_watchdog.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import json +from pathlib import Path + +from argus.ray.discovery import build_head_record, write_head_record_atomic +from argus.ray.worker_watchdog import Watchdog + + +def test_watchdog_restarts_on_first_seen_and_on_change(tmp_path: Path): + head_file = tmp_path / "head.json" + calls: list[list[str]] = [] + + def runner(argv: list[str]) -> int: + calls.append(argv) + return 0 + + wd = Watchdog( + head_ip_file=str(head_file), + node_ip="10.0.0.2", + resources_json=json.dumps({"worker_node": 100}), + poll_s=1, + runner=runner, + ) + + write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1")) + assert wd.tick_once() is not None + assert any("ray" in c[0] and c[1] == "start" for c in calls) + + calls.clear() + # Same address -> no restart + write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1")) + wd.tick_once() + assert calls == [] + + # Address change -> restart + write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="2.2.2.2")) + wd.tick_once() + assert any(c[1] == "stop" for c in calls) + assert any(c[1] == "start" for c in calls) + + +def test_watchdog_main_once_invokes_runner(monkeypatch, tmp_path: Path): + from argus.ray import worker_watchdog as mod + + head_file = tmp_path / "head.json" + write_head_record_atomic(str(head_file), build_head_record(cluster_name="c", head_ip="1.1.1.1")) + + calls: list[list[str]] = [] + + def runner(argv: list[str]) -> int: + calls.append(argv) + return 0 + + monkeypatch.setattr(mod, "_default_runner", runner) + rc = mod.main(["--head-ip-file", str(head_file), "--node-ip", "10.0.0.2", "--resources-kv", "worker_node=100", "--once"]) + assert rc == 0 + assert any(c[1] == "start" for c in calls) diff --git a/src/mvp/scripts/04_cleanup_v2_legacy.sh b/src/mvp/scripts/04_cleanup_v2_legacy.sh new file mode 100755 index 0000000..35972fe --- /dev/null +++ b/src/mvp/scripts/04_cleanup_v2_legacy.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +set -euo pipefail + +echo "[host] cleanup v2 legacy containers/processes (best-effort)" + +# Known historical container names from v1.1/v2 era +LEGACY=( + mvp11-ray-head + mvp11-ray-worker-0 + mvp11-ray-worker-1 + mvp2-ray-head + mvp2-ray-worker-0 + mvp2-ray-worker-1 +) + +for c in "${LEGACY[@]}"; do + if docker ps -a --format '{{.Names}}' | grep -qx "${c}"; then + echo "[host] removing legacy container: ${c}" + docker rm -f "${c}" >/dev/null 2>&1 || true + fi +done + +echo "[host] legacy v2 cleanup done" + diff --git a/src/mvp/scripts/22_start_head_discovery.sh b/src/mvp/scripts/22_start_head_discovery.sh new file mode 100755 index 0000000..a72c55d --- /dev/null +++ b/src/mvp/scripts/22_start_head_discovery.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# shellcheck source=lib.sh +source "${SCRIPT_DIR}/lib.sh" + +CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}" +HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}" +TTL_S="${TTL_S:-60}" +REFRESH_S="${REFRESH_S:-10}" + +# PID file must be container-local to avoid conflicts if /private is shared across containers. +PID_PATH="${PID_PATH:-/tmp/argus_head_publisher.pid}" +LOG_PATH="${LOG_PATH:-${SHARED_ROOT}/common/logs/argus_head_publisher.log}" + +HEAD_IP="$(container_ip "${HEAD_CONTAINER}")" + +echo "[head] start discovery publisher (supervised): ${HEAD_IP_FILE} (head_ip=${HEAD_IP})" + +# stop existing (best-effort) +dexec "${HEAD_CONTAINER}" bash -lc "if test -f '${PID_PATH}'; then pid=\$(cat '${PID_PATH}'); if kill -0 \"\${pid}\" >/dev/null 2>&1; then kill \"\${pid}\" || true; fi; rm -f '${PID_PATH}'; fi" + +dexec "${HEAD_CONTAINER}" bash -lc "mkdir -p \"$(dirname "${LOG_PATH}")\" \"$(dirname "${HEAD_IP_FILE}")\"" + +# Supervisor loop: restart publisher if it exits. +docker exec -d "${HEAD_CONTAINER}" bash -lc " + nohup bash -lc ' + export PYTHONPATH=/workspace/mvp/py + while true; do + python3 -m argus.ray.head_publisher \ + --cluster-name \"${CLUSTER_NAME}\" \ + --head-ip-file \"${HEAD_IP_FILE}\" \ + --head-ip \"${HEAD_IP}\" \ + --ttl-s \"${TTL_S}\" \ + --refresh-s \"${REFRESH_S}\" + sleep 2 + done + ' >>'${LOG_PATH}' 2>&1 & echo \$! >'${PID_PATH}' +" + +echo "[head] publisher pid stored in ${PID_PATH} (container-local)" +echo "[head] logs: ${LOG_PATH}" + diff --git a/src/mvp/scripts/23_start_workers_stateless.sh b/src/mvp/scripts/23_start_workers_stateless.sh new file mode 100755 index 0000000..7f1a80c --- /dev/null +++ b/src/mvp/scripts/23_start_workers_stateless.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# shellcheck source=lib.sh +source "${SCRIPT_DIR}/lib.sh" + +CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}" +HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}" +POLL_S="${POLL_S:-5}" +WORKER_NODE_RESOURCE="${WORKER_NODE_RESOURCE:-100}" + +start_one() { + local worker="$1" + local ip + ip="$(container_ip "${worker}")" + + local pid_path="/tmp/argus_worker_watchdog.pid" + local log_path="${SHARED_ROOT}/common/logs/argus_worker_watchdog.${worker}.log" + + echo "[${worker}] start stateless watchdog (supervised): head_file=${HEAD_IP_FILE} node_ip=${ip}" + + # stop existing watchdog (best-effort) + dexec "${worker}" bash -lc "if test -f '${pid_path}'; then pid=\$(cat '${pid_path}'); if kill -0 \"\${pid}\" >/dev/null 2>&1; then kill \"\${pid}\" || true; fi; rm -f '${pid_path}'; fi" + + # stop any legacy ray process to avoid split-brain + dexec "${worker}" bash -lc "ray stop --force || true" + dexec "${worker}" bash -lc "mkdir -p \"$(dirname "${log_path}")\"" + + docker exec -d "${worker}" bash -lc " + nohup bash -lc ' + export PYTHONPATH=/workspace/mvp/py + while true; do + python3 -m argus.ray.worker_watchdog \ + --head-ip-file \"${HEAD_IP_FILE}\" \ + --node-ip \"${ip}\" \ + --resources-kv \"worker_node=${WORKER_NODE_RESOURCE}\" \ + --poll-s \"${POLL_S}\" + sleep 2 + done + ' >>'${log_path}' 2>&1 & echo \$! >'${pid_path}' + " + + echo "[${worker}] watchdog pid stored in ${pid_path} (container-local)" + echo "[${worker}] logs: ${log_path}" +} + +start_one "${WORKER0_CONTAINER}" +start_one "${WORKER1_CONTAINER}" + +echo "[head] ray status" +dexec "${HEAD_CONTAINER}" bash -lc "ray status || true" diff --git a/src/mvp/scripts/24_status_stateless.sh b/src/mvp/scripts/24_status_stateless.sh new file mode 100755 index 0000000..2b5d30e --- /dev/null +++ b/src/mvp/scripts/24_status_stateless.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# shellcheck source=lib.sh +source "${SCRIPT_DIR}/lib.sh" + +CLUSTER_NAME="${CLUSTER_NAME:-argus-ray}" +HEAD_IP_FILE="${HEAD_IP_FILE:-${SHARED_ROOT}/ray/discovery/${CLUSTER_NAME}/head.json}" + +echo "[head] head.json (best-effort): ${HEAD_IP_FILE}" +dexec "${HEAD_CONTAINER}" bash -lc "if test -f '${HEAD_IP_FILE}'; then cat '${HEAD_IP_FILE}'; else echo 'missing'; fi" + +echo +echo "[head] ray status" +dexec "${HEAD_CONTAINER}" bash -lc "ray status || true" + diff --git a/src/mvp/scripts/30_prepare_data_and_model.sh b/src/mvp/scripts/30_prepare_data_and_model.sh index 157e743..4c4a79f 100755 --- a/src/mvp/scripts/30_prepare_data_and_model.sh +++ b/src/mvp/scripts/30_prepare_data_and_model.sh @@ -11,10 +11,30 @@ PPO_DATA_DIR="${SHARED_ROOT}/datasets/gsm8k" SFT_DATA_DIR="${SHARED_ROOT}/datasets/gsm8k_sft" CODE_SNAPSHOT_DIR="${SHARED_ROOT}/common/code/verl/verl_repo" +COMMON_DIR="${SHARED_ROOT}/common" +COMMON_DATASETS_DIR="${SHARED_ROOT}/common/datasets" +COMMON_HF_DIR="${SHARED_ROOT}/common/hf" echo "[head] ensure dataset dirs exist" dexec "${HEAD_CONTAINER}" bash -lc "mkdir -p '${PPO_DATA_DIR}' '${SFT_DATA_DIR}'" +echo "[head] ensure v2.5 common links (idempotent)" +# In existing deployments, /private/common/{datasets,hf} may already exist as directories (not symlinks). +# For v2.5 taskspecs, we only require: +# /private/common/datasets/gsm8k -> /private/datasets/gsm8k +# /private/common/datasets/gsm8k_sft -> /private/datasets/gsm8k_sft +# /private/common/hf/{hub,transformers} -> /private/hf/{hub,transformers} (best-effort) +dexec "${HEAD_CONTAINER}" bash -lc " + set -euo pipefail + mkdir -p '${COMMON_DIR}' '${COMMON_DATASETS_DIR}' '${COMMON_HF_DIR}' + ln -sfn '${SHARED_ROOT}/datasets/gsm8k' '${COMMON_DATASETS_DIR}/gsm8k' + ln -sfn '${SHARED_ROOT}/datasets/gsm8k_sft' '${COMMON_DATASETS_DIR}/gsm8k_sft' + mkdir -p '${SHARED_ROOT}/hf/hub' '${SHARED_ROOT}/hf/transformers' + ln -sfn '${SHARED_ROOT}/hf/hub' '${COMMON_HF_DIR}/hub' + ln -sfn '${SHARED_ROOT}/hf/transformers' '${COMMON_HF_DIR}/transformers' + echo 'common_links_ok' +" + echo "[head] prepare PPO dataset (gsm8k RL parquet) -> ${PPO_DATA_DIR}" dexec "${HEAD_CONTAINER}" bash -lc "if [[ -f '${PPO_DATA_DIR}/train.parquet' && -f '${PPO_DATA_DIR}/test.parquet' ]]; then echo 'ppo_dataset_exists: skip'; else python3 /workspace/verl/examples/data_preprocess/gsm8k.py --local_save_dir '${PPO_DATA_DIR}'; fi" diff --git a/src/mvp/scripts/run_all_v25_api.sh b/src/mvp/scripts/run_all_v25_api.sh new file mode 100755 index 0000000..1d926a8 --- /dev/null +++ b/src/mvp/scripts/run_all_v25_api.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# shellcheck source=lib.sh +source "${SCRIPT_DIR}/lib.sh" + +# E2E v2.5: +# - Clean legacy env +# - Start containers +# - Start ray head + discovery publisher +# - Start stateless worker watchdogs (auto-connect) +# - Prepare data/model/code (reuses existing downloads) +# - Start API +# - Create user + issue token +# - Submit PPO/GRPO/SFT via API and wait + +API_ADDR="${API_ADDR:-http://127.0.0.1:8080}" +ADMIN_TOKEN="${MVP_INTERNAL_TOKEN:-}" +USER_ID="${USER_ID:-alice}" +RESET_DB="${RESET_DB:-1}" + +if [[ -z "${ADMIN_TOKEN}" ]]; then + echo "ERROR: MVP_INTERNAL_TOKEN must be set in host env (admin token)" >&2 + exit 1 +fi + +api_curl_admin() { + curl -sS -H "Authorization: Bearer ${ADMIN_TOKEN}" "$@" +} + +api_wait_ready() { + local tries="${1:-60}" + for i in $(seq 1 "${tries}"); do + if curl -sS -m 2 "${API_ADDR}/docs" >/dev/null 2>&1; then + echo "[host] api_ready: ${API_ADDR}" + return 0 + fi + echo "[host] waiting api... (${i}/${tries})" + sleep 2 + done + echo "ERROR: api not ready: ${API_ADDR}" >&2 + return 1 +} + +submit_taskspec() { + local token="$1" + local taskspec_path="$2" + echo "[host] submit via API (user=${USER_ID}): ${taskspec_path}" >&2 + local resp + resp="$(curl -sS -H "Authorization: Bearer ${token}" -H "Content-Type: application/yaml" --data-binary @"${taskspec_path}" "${API_ADDR}/api/v2/tasks")" + echo "[host] submit_resp: ${resp}" >&2 + printf '%s' "${resp}" | python3 -c 'import sys,json; print(json.load(sys.stdin)["task_id"])' +} + +wait_task() { + local token="$1" + local task_id="$2" + while true; do + local body state + body="$(curl -sS -H "Authorization: Bearer ${token}" "${API_ADDR}/api/v2/tasks/${task_id}")" + state="$(printf '%s' "${body}" | python3 -c 'import sys,json; print(json.load(sys.stdin)["state"])')" + echo "[host] task ${task_id}: ${state}" + + if [[ "${state}" == "SUCCEEDED" ]]; then + return 0 + fi + if [[ "${state}" == "FAILED" || "${state}" == "CANCELED" ]]; then + echo "[host] terminal=${state}; tail logs (best-effort):" >&2 + curl -sS -H "Authorization: Bearer ${token}" "${API_ADDR}/api/v2/tasks/${task_id}/logs?tail=200" >&2 || true + return 1 + fi + sleep 10 + done +} + +echo "[host] ===== run_all_v25_api.sh begin =====" + +"${SCRIPT_DIR}/00_prereq_check.sh" +"${SCRIPT_DIR}/03_cleanup_v1_legacy.sh" +"${SCRIPT_DIR}/04_cleanup_v2_legacy.sh" + +echo "[host] bring down existing containers (best-effort)" +"${SCRIPT_DIR}/02_down.sh" || true + +echo "[host] (re)create containers" +"${SCRIPT_DIR}/01_up.sh" + +echo "[host] restart ray head (no compute on head)" +"${SCRIPT_DIR}/20_start_head.sh" + +echo "[host] start head discovery publisher" +"${SCRIPT_DIR}/22_start_head_discovery.sh" + +echo "[host] start stateless workers (watchdog auto-connect)" +"${SCRIPT_DIR}/23_start_workers_stateless.sh" + +echo "[host] prepare data/model/code snapshot (idempotent; reuse cache)" +"${SCRIPT_DIR}/30_prepare_data_and_model.sh" + +echo "[host] install api deps in head container" +"${SCRIPT_DIR}/12_install_api_deps.sh" + +echo "[host] stop api (best-effort)" +"${SCRIPT_DIR}/61_stop_api.sh" || true + +if [[ "${RESET_DB}" == "1" ]]; then + echo "[host] reset api sqlite db in container (best-effort)" + docker exec -i "${HEAD_CONTAINER}" bash -lc "rm -f /private/common/db/mvp.sqlite3 /private/common/db/mvp.sqlite3-wal /private/common/db/mvp.sqlite3-shm || true" +else + echo "[host] keep existing api sqlite db (RESET_DB=${RESET_DB})" +fi + +echo "[host] start api (admin token via env)" +MVP_INTERNAL_TOKEN="${ADMIN_TOKEN}" "${SCRIPT_DIR}/60_start_api.sh" +api_wait_ready 60 + +echo "[host] ensure user exists + issue token" +api_curl_admin -H "Content-Type: application/json" -d "{\"user_id\":\"${USER_ID}\"}" "${API_ADDR}/api/v2/users" >/dev/null 2>&1 || true +USER_TOKEN="$(api_curl_admin -X POST "${API_ADDR}/api/v2/users/${USER_ID}/tokens" | python3 -c 'import sys,json; print(json.load(sys.stdin)["token"])')" +echo "[host] user_token_issued: user=${USER_ID}" + +PPO_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/ppo.yaml")" +GRPO_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/grpo.yaml")" +SFT_TASK_ID="$(submit_taskspec "${USER_TOKEN}" "${ROOT_DIR}/taskspecs/sft.yaml")" + +echo "[host] submitted task ids:" +echo " ppo=${PPO_TASK_ID}" +echo " grpo=${GRPO_TASK_ID}" +echo " sft=${SFT_TASK_ID}" + +echo "[host] wait for tasks (in submission order)" +wait_task "${USER_TOKEN}" "${PPO_TASK_ID}" +wait_task "${USER_TOKEN}" "${GRPO_TASK_ID}" +wait_task "${USER_TOKEN}" "${SFT_TASK_ID}" + +echo "[host] ===== run_all_v25_api.sh done =====" diff --git a/src/mvp/scripts/run_e2e_v25_cases.sh b/src/mvp/scripts/run_e2e_v25_cases.sh new file mode 100755 index 0000000..76c9952 --- /dev/null +++ b/src/mvp/scripts/run_e2e_v25_cases.sh @@ -0,0 +1,156 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# shellcheck source=lib.sh +source "${SCRIPT_DIR}/lib.sh" + +API_ADDR="${API_ADDR:-http://127.0.0.1:8080}" +ADMIN_TOKEN="${MVP_INTERNAL_TOKEN:-}" + +if [[ -z "${ADMIN_TOKEN}" ]]; then + echo "ERROR: MVP_INTERNAL_TOKEN must be set (admin token)" >&2 + exit 1 +fi + +require_cmd jq +require_cmd python3 + +api_admin() { curl -sS -H "Authorization: Bearer ${ADMIN_TOKEN}" "$@"; } +api_user() { local tok="$1"; shift; curl -sS -H "Authorization: Bearer ${tok}" "$@"; } + +wait_state() { + local tok="$1" + local task_id="$2" + local want="$3" + local tries="${4:-120}" + for i in $(seq 1 "${tries}"); do + local st + st="$(api_user "${tok}" "${API_ADDR}/api/v2/tasks/${task_id}" | jq -r .state)" + echo "[case] ${task_id} state=${st} (want=${want})" + if [[ "${st}" == "${want}" ]]; then + return 0 + fi + if [[ "${st}" == "FAILED" || "${st}" == "CANCELED" || "${st}" == "SUCCEEDED" ]]; then + return 1 + fi + sleep 5 + done + return 2 +} + +issue_token() { + local user_id="$1" + api_admin -H "Content-Type: application/json" -d "{\"user_id\":\"${user_id}\"}" "${API_ADDR}/api/v2/users" >/dev/null 2>&1 || true + api_admin -X POST "${API_ADDR}/api/v2/users/${user_id}/tokens" | jq -r .token +} + +submit_yaml() { + local tok="$1" + local yaml_path="$2" + api_user "${tok}" -H "Content-Type: application/yaml" --data-binary @"${yaml_path}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id +} + +echo "[case] ===== v2.5 e2e cases =====" + +echo "[case] HP-1: run_all_v25_api.sh (happy path)" +RESET_DB="${RESET_DB:-1}" +MVP_INTERNAL_TOKEN="${ADMIN_TOKEN}" RESET_DB="${RESET_DB}" "${SCRIPT_DIR}/run_all_v25_api.sh" + +echo "[case] E-Auth-1: missing bearer token" +code="$(curl -sS -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/queue" || true)" +[[ "${code}" == "401" ]] || { echo "expected 401 got ${code}" >&2; exit 1; } + +echo "[case] E-Auth-2: invalid token" +code="$(curl -sS -o /dev/null -w '%{http_code}' -H 'Authorization: Bearer nope' "${API_ADDR}/api/v2/queue" || true)" +[[ "${code}" == "401" ]] || { echo "expected 401 got ${code}" >&2; exit 1; } + +echo "[case] setup users alice/bob" +ALICE_TOKEN="$(issue_token alice)" +BOB_TOKEN="$(issue_token bob)" + +echo "[case] E-Isolation-1: cross-user task visibility (404)" +TID="$(submit_yaml "${ALICE_TOKEN}" "${ROOT_DIR}/taskspecs/ppo.yaml")" +code="$(api_user "${BOB_TOKEN}" -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/tasks/${TID}" || true)" +[[ "${code}" == "404" ]] || { echo "expected 404 got ${code}" >&2; exit 1; } + +# Keep the queue clean for subsequent cases (avoid consuming scheduler capacity). +api_user "${ALICE_TOKEN}" -X POST "${API_ADDR}/api/v2/tasks/${TID}:cancel" >/dev/null 2>&1 || true +wait_state "${ALICE_TOKEN}" "${TID}" "CANCELED" 30 || true + +echo "[case] E-Input-1: reject non-common inputs" +bad_yaml="$(mktemp)" +cat >"${bad_yaml}" <<'YAML' +workload: "ppo" +submission_id: "" +code_path: "/private/common/code/verl/verl_repo" +model_id: "Qwen/Qwen2.5-0.5B-Instruct" +train_file: "/private/datasets/gsm8k/train.parquet" +val_file: "/private/common/datasets/gsm8k/test.parquet" +nnodes: 2 +n_gpus_per_node: 4 +total_epochs: 1 +total_training_steps: 10 +save_freq: 10 +test_freq: -1 +YAML +code="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${bad_yaml}" -o /dev/null -w '%{http_code}' "${API_ADDR}/api/v2/tasks" || true)" +rm -f "${bad_yaml}" +[[ "${code}" == "400" ]] || { echo "expected 400 got ${code}" >&2; exit 1; } + +echo "[case] B-Queue-1: pending_resources when demand > cluster capacity" +big_yaml="$(mktemp)" +cat >"${big_yaml}" <<'YAML' +workload: "ppo" +submission_id: "" +code_path: "/private/common/code/verl/verl_repo" +model_id: "Qwen/Qwen2.5-0.5B-Instruct" +train_file: "/private/common/datasets/gsm8k/train.parquet" +val_file: "/private/common/datasets/gsm8k/test.parquet" +nnodes: 3 +n_gpus_per_node: 4 +total_epochs: 1 +total_training_steps: 10 +save_freq: 10 +test_freq: -1 +YAML +BIG_TID="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${big_yaml}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id)" +rm -f "${big_yaml}" +wait_state "${ALICE_TOKEN}" "${BIG_TID}" "PENDING_RESOURCES" 60 + +echo "[case] B-Cancel-1: cancel a running task (best-effort)" +long_yaml="$(mktemp)" +cat >"${long_yaml}" <<'YAML' +workload: "ppo" +submission_id: "" +code_path: "/private/common/code/verl/verl_repo" +model_id: "Qwen/Qwen2.5-0.5B-Instruct" +train_file: "/private/common/datasets/gsm8k/train.parquet" +val_file: "/private/common/datasets/gsm8k/test.parquet" +nnodes: 2 +n_gpus_per_node: 4 +total_epochs: 1 +total_training_steps: 200 +save_freq: 10 +test_freq: -1 +YAML +LONG_TID="$(api_user "${ALICE_TOKEN}" -H 'Content-Type: application/yaml' --data-binary @"${long_yaml}" "${API_ADDR}/api/v2/tasks" | jq -r .task_id)" +rm -f "${long_yaml}" + +for i in $(seq 1 60); do + st="$(api_user "${ALICE_TOKEN}" "${API_ADDR}/api/v2/tasks/${LONG_TID}" | jq -r .state)" + echo "[case] wait running: ${LONG_TID} state=${st}" + if [[ "${st}" == "RUNNING" ]]; then + break + fi + if [[ "${st}" == "FAILED" || "${st}" == "SUCCEEDED" ]]; then + break + fi + sleep 5 +done + +api_user "${ALICE_TOKEN}" -X POST "${API_ADDR}/api/v2/tasks/${LONG_TID}:cancel" | jq . +st="$(api_admin "${API_ADDR}/api/v2/tasks/${LONG_TID}" | jq -r .state)" +[[ "${st}" == "CANCELED" ]] || { echo "expected CANCELED got ${st}" >&2; exit 1; } + +echo "[case] ===== all v2.5 e2e cases passed =====" diff --git a/src/mvp/taskspecs/grpo.yaml b/src/mvp/taskspecs/grpo.yaml index 83eba66..583d93c 100644 --- a/src/mvp/taskspecs/grpo.yaml +++ b/src/mvp/taskspecs/grpo.yaml @@ -6,8 +6,8 @@ code_path: "/private/common/code/verl/verl_repo" model_id: "Qwen/Qwen2.5-0.5B-Instruct" -train_file: "/private/datasets/gsm8k/train.parquet" -val_file: "/private/datasets/gsm8k/test.parquet" +train_file: "/private/common/datasets/gsm8k/train.parquet" +val_file: "/private/common/datasets/gsm8k/test.parquet" nnodes: 2 n_gpus_per_node: 4 @@ -17,4 +17,3 @@ total_training_steps: 10 save_freq: 10 test_freq: -1 - diff --git a/src/mvp/taskspecs/ppo.yaml b/src/mvp/taskspecs/ppo.yaml index 05bc7f9..f748a0d 100644 --- a/src/mvp/taskspecs/ppo.yaml +++ b/src/mvp/taskspecs/ppo.yaml @@ -8,8 +8,8 @@ code_path: "/private/common/code/verl/verl_repo" model_id: "Qwen/Qwen2.5-0.5B-Instruct" -train_file: "/private/datasets/gsm8k/train.parquet" -val_file: "/private/datasets/gsm8k/test.parquet" +train_file: "/private/common/datasets/gsm8k/train.parquet" +val_file: "/private/common/datasets/gsm8k/test.parquet" nnodes: 2 n_gpus_per_node: 4 @@ -19,4 +19,3 @@ total_training_steps: 10 save_freq: 10 test_freq: -1 - diff --git a/src/mvp/taskspecs/sft.yaml b/src/mvp/taskspecs/sft.yaml index 67637b6..5491707 100644 --- a/src/mvp/taskspecs/sft.yaml +++ b/src/mvp/taskspecs/sft.yaml @@ -6,7 +6,7 @@ code_path: "/private/common/code/verl/verl_repo" model_id: "Qwen/Qwen2.5-0.5B-Instruct" -train_file: "/private/datasets/gsm8k_sft/train.parquet" +train_file: "/private/common/datasets/gsm8k_sft/train.parquet" val_file: null nnodes: 2 @@ -19,4 +19,3 @@ save_freq: 10 # SFT driver 默认不分配 GPU(ray job entrypoint 不指定 entrypoint_num_gpus),因此 driver 侧不要依赖 CUDA trainer_device: "cpu" -