v2.0 补充单元测试,行覆盖90以上
This commit is contained in:
parent
4dacac24f0
commit
ce8c2128b4
4
.gitignore
vendored
4
.gitignore
vendored
@ -2,3 +2,7 @@ verl/
|
||||
skypilot-ssh-test/
|
||||
ray_in_docker/
|
||||
__pycache__/
|
||||
.venv/
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@ -0,0 +1,3 @@
|
||||
[pytest]
|
||||
testpaths = src/mvp/py/tests
|
||||
addopts = --maxfail=1 --cov=argus --cov=server --cov-report=term-missing --cov-fail-under=90
|
||||
BIN
specs/mvp/image/roadmap_v2.5.png
Normal file
BIN
specs/mvp/image/roadmap_v2.5.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 80 KiB |
@ -1,96 +1,49 @@
|
||||
渐进的 AI Infrastructure 演进路线图。从最初的单机脚本执行,到最终的智能化运维平台
|
||||
|
||||
对应架构演进图,设计**基于 Native Ray Cluster 与 Verl 框架的 AI Infra Roadmap 设计文档**。
|
||||
这一版的设计采用了 **Overlay 架构 + GPFS 核心存储 + 无状态(Stateless)节点池** 的模式,逻辑非常自洽且具备极高的云原生弹性。
|
||||
|
||||
---
|
||||
|
||||
### **项目代号:AI Infra Roadmap (Native Ray + Verl)**
|
||||
### **项目代号:AI Infra Overlay Platform (Stateless Ray + GPFS)**
|
||||
|
||||
#### **阶段一:核心内核构建 (Foundation & Core Execution)**
|
||||
#### **阶段一:内核构建与验证 (Kernel & Verification)**
|
||||
|
||||
这一阶段主要解决“能不能跑”的问题,聚焦于核心计算引擎的对接和基础任务调度。
|
||||
*目标:验证核心计算逻辑,跑通“提交-执行”的最小闭环。*
|
||||
|
||||
* **v1.1: 原型验证 (Verl Task Spec & Ray Job)**
|
||||
* **核心功能**:实现了最基础的任务提交链路。
|
||||
* **组件**:
|
||||
* **Ray Job Tool (Ray Client)**:作为客户端工具。
|
||||
* **VerlTaskSpec YAML**:定义任务的标准配置文件。
|
||||
* **Multi-Verl Code Path**:支持多代码路径。
|
||||
* **核心功能**:实现基础的任务定义与提交。
|
||||
* **组件**:
|
||||
* `Ray Job Tool (Ray Client)`:客户端工具。
|
||||
* `VerlTaskSpec YAML`:定义多代码路径 (Multi-Verl Code Path) 和任务参数。
|
||||
|
||||
* **基础设施**:Handmade Ray Cluster(手工搭建的 Ray 集群)。
|
||||
* **目标**:验证 Verl 框架与 Ray 的基本交互。
|
||||
|
||||
* **基础设施**:Handmade Ray Cluster(手工搭建的集群),用于验证核心代码。
|
||||
|
||||
|
||||
* **v2.0: 任务管理层 (Task Management)**
|
||||
* **核心功能**:引入了服务化管理,不再单纯依赖命令行工具。
|
||||
* **新增组件**:
|
||||
* **API Server**:提供统一的接口层。
|
||||
* **Task Management**:实现了任务队列 (Queue)、映射 (Map) 和重试/重新提交 (Resubmit) 机制。
|
||||
* **核心功能**:引入服务端,管理任务生命周期。
|
||||
* **新增组件**:
|
||||
* `API Server`:统一接口层。
|
||||
* `Task Management`:实现任务的队列 (Queue)、映射 (Map) 和重试 (Resubmit) 机制。
|
||||
|
||||
|
||||
* **基础设施**:仍运行在 Handmade Ray Cluster 上。
|
||||
|
||||
|
||||
* **v2.5: 资源与用户管理 (User & Node Management)**
|
||||
* **核心功能**:从“手工集群”迈向“自动化集群”,并增加了多租户基础。
|
||||
* **新增组件**:
|
||||
* **User Management**:用户权限与身份管理。
|
||||
* **Node Management**:核心升级点。支持通过 SSH 管理节点池,实现 Auto-managed Ray Cluster(自动管理的 Ray 集群),不再手动维护。
|
||||
|
||||
|
||||
* **演进**:基础设施层由 Handmade 变为 SSH Node (Auto Managed)。
|
||||
* **基础设施**:仍运行在手工集群上,但控制面开始服务化。
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
### **阶段二:产品化与服务化 (Productization & Serving)**
|
||||
### **阶段二:架构质变 - 无状态节点池 (The Stateless Shift)**
|
||||
|
||||
这一阶段主要解决“好不好用”的问题,发布了第一个正式版本,并扩展了业务场景。
|
||||
*目标:通过 GPFS 实现控制反转 (IoC),彻底解耦平台层与计算节点层。这是本架构最关键的转折点。*
|
||||
|
||||
* **v3.0: 正式发布版 (Frontend & Data Management)** * **里程碑**:**1st Version to Release!!** (首个对外发布版本)
|
||||
* **核心功能**:完整的前后端分离,闭环了用户的数据流。
|
||||
* **新增组件**:
|
||||
* **WebUI**:提供可视化的用户界面。
|
||||
* **Data Management (SFTPGo)**:集成了 SFTPGo,解决用户训练数据、代码的上传与下载问题。
|
||||
|
||||
|
||||
* **价值**:用户可以通过 Web 界面完成从数据上传到任务提交的全流程。
|
||||
|
||||
|
||||
* **v3.5: 定制化与推理服务 (Customized Task & Serving)**
|
||||
* **核心功能**:支持更复杂的训练需求和模型推理。
|
||||
* **新增组件**:
|
||||
* **Model Serving**:不仅能训练,还能部署模型服务。
|
||||
* **Customized VerlTaskSpec YAML**:支持自定义参数 (Param)、奖励函数 (Reward)、Verl 代码等。
|
||||
|
||||
|
||||
* **价值**:从单一的训练平台扩展为“训练+推理”的一体化平台,且支持算法工程师深度定制实验参数。
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
### **阶段三:可观测性体系 (Observability)**
|
||||
|
||||
这一阶段主要解决“看得清”的问题,确保系统的稳定性和模型训练的可追踪性。
|
||||
|
||||
* **v4.0: 系统级可观测性 (System Observability)**
|
||||
* **核心功能**:建立完整的基础设施监控。
|
||||
* **新增组件**:
|
||||
* **Prometheus**:指标采集。
|
||||
* **Grafana**:监控大盘展示。
|
||||
* **Alert**:告警系统。
|
||||
* **ELK**:日志收集与分析 (Elasticsearch, Logstash, Kibana)。
|
||||
|
||||
|
||||
* **基础设施升级**:在 SSH Node 上部署了 **Exporter**,用于采集节点层面的 metrics。
|
||||
|
||||
|
||||
* **v4.5: 实验级可观测性 (ML Observability)**
|
||||
* **核心功能**:专注于模型训练过程的指标追踪。
|
||||
* **新增组件**:
|
||||
* **Weight & Bias (WanB)**:集成专业的 ML 实验追踪工具,用于记录 Loss、Accuracy 等训练指标。
|
||||
* **v2.5: 用户管理 & 无状态 Ray 节点池 (User Mgmt & Stateless Ray Node Pool)** * **核心机制:基于 GPFS 的服务发现 (Service Discovery)**
|
||||
* **Ray Head (有状态)**:由 `Node Management` 启动(通常通过 SSH 或 K8s StatefulSet)。启动后,将自身的 IP 地址写入 GPFS 中的 `Head IP File`。
|
||||
* **Ray Worker (无状态)**:
|
||||
* **Stateless**:Worker 容器启动时不依赖平台指令。
|
||||
* **Auto Connect**:启动脚本读取 GPFS 中的 `Head IP File`,获得 Head 地址并自动加入集群。
|
||||
* **Watchdog**:Worker 内部运行看门狗进程,监控 Head IP 变化。如果 Head 变动,Worker 自动重启或重连,实现自愈。
|
||||
* **新增组件**:
|
||||
* `User Management`:多用户隔离。
|
||||
* `GPFS`:取代了之前的 JuiceFS,作为唯一的共享存储和元数据交换媒介。
|
||||
|
||||
|
||||
|
||||
@ -98,33 +51,83 @@
|
||||
|
||||
---
|
||||
|
||||
### **阶段四:智能化运维 (Operability & Intelligence)**
|
||||
### **阶段三:产品化与高级能力 (Productization & Advanced Features)**
|
||||
|
||||
这一阶段主要解决“自动化”的问题,引入 AI 来管理 AI 平台。
|
||||
*目标:发布首个正式版本,并支持大模型训练所需的复杂网络与推理能力。*
|
||||
|
||||
* **v5.0: 智能运维闭环 (Statistics, SOP, Agent)**
|
||||
* **核心功能**:通过数据统计和 Agent 实现平台的自动化治理。
|
||||
* **新增组件**:
|
||||
* **Statistics**:平台维度的统计分析(资源利用率、任务成功率等)。
|
||||
* **SOP Tools**:标准作业程序工具化(自动化运维脚本)。
|
||||
* **Agent**:智能体。可能用于自动故障诊断、资源自动调度优化或交互式助手。
|
||||
* **v3.0: 正式发布版 (Release v1.0)** * **里程碑**:**1st Version to Release!!**
|
||||
* **核心功能**:闭环用户数据流。
|
||||
* **新增组件**:
|
||||
* `WebUI`:可视化操作界面。
|
||||
* `Data Management (SFTPGo)`:用户上传数据/代码 -> SFTPGo -> 写入 GPFS -> Ray Worker 可见。
|
||||
|
||||
|
||||
* **基础设施**:全量切换到 `Ray Worker Node` (Stateless) + `GPFS` 的架构。
|
||||
|
||||
|
||||
* **v3.5: 高级定制与训推一体 (Advanced Task & Serving)** * **核心功能**:支持复杂的科研需求。
|
||||
* **新增组件**:
|
||||
* `Model Serving`:支持模型推理服务。
|
||||
* `Advanced VerlTaskSpec`:支持自定义 Reward Function、自定义代码、Checkpoint 断点续训 (Resubmit from last checkpoint)。
|
||||
|
||||
|
||||
* **网络增强**:
|
||||
* **IB Network Supporting**:支持 InfiniBand 网络,确保多机训练的高性能互联。
|
||||
|
||||
|
||||
* **愿景**:打造一个具备自我管理、自我修复能力的 AI 基础设施平台。
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
### **架构层级总结**
|
||||
### **阶段四:全链路可观测性 (Full-Stack Observability)**
|
||||
|
||||
*目标:打开黑盒,监控基础设施与业务指标。*
|
||||
|
||||
* **v4.0: 系统级可观测性 (System Observability)** * **核心功能**:监控集群“活着”且“健康”。
|
||||
* **新增组件**:
|
||||
* `Prometheus` + `Grafana` + `ELK`:指标与日志平台。
|
||||
* `Exporter`:部署在 Ray Worker Node 中的监控探针(采集 GPU/CPU/GPFS IO 指标)。
|
||||
|
||||
|
||||
|
||||
|
||||
* **v4.5: 算法级可观测性 (ML Observability)** * **核心功能**:监控模型“练得好不好”。
|
||||
* **新增组件**:
|
||||
* `Weights & Bias (WanB)`:集成实验追踪工具,记录 Loss 曲线和训练参数。
|
||||
|
||||
|
||||
|
||||
|
||||
| 层级 | 关键组件/技术 |
|
||||
| --- | --- |
|
||||
| **接入层 (Frontend/API)** | WebUI, API Server, User Management |
|
||||
| **调度与编排 (Orchestration)** | Task Management, Ray Job Tool (Client), Node Management |
|
||||
| **计算引擎 (Compute)** | Native Ray Cluster, Verl Framework (TaskSpec YAML) |
|
||||
| **数据与存储 (Data)** | SFTPGo (Data Management), Model Serving |
|
||||
| **可观测性 (Observability)** | Prometheus, Grafana, ELK, Weights & Bias |
|
||||
| **运维与智能 (Ops)** | Exporters, Statistics, SOP Tools, Agent |
|
||||
|
||||
---
|
||||
|
||||
### **阶段五:智能化运维 (AIOps)**
|
||||
|
||||
*目标:迈向自动化与自治。*
|
||||
|
||||
* **v5.0: 智能运维闭环 (Operability)** * **核心功能**:降低运维成本,提升稳定性。
|
||||
* **新增组件**:
|
||||
* `Statistics`:集群资源利用率统计报表。
|
||||
* `SOP Tools`:标准运维工具(如自动清理 GPFS 垃圾文件、僵尸节点检测)。
|
||||
* `Agent`:智能运维助手(基于 LLM 的日志分析与故障诊断)。
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
### **新架构核心亮点总结**
|
||||
|
||||
1. **极简的节点管理**:
|
||||
* 利用 v2.5 的 **Head IP File + Watchdog** 机制,平台层不再需要维护复杂的 Worker IP 列表和 SSH 连接池。
|
||||
* **扩缩容极其简单**:只需在底层(K8s/Docker)增加 Worker 副本数,它们就会自动通过 GPFS 找到 Head 并加入战斗。
|
||||
|
||||
|
||||
2. **统一的数据平面 (GPFS)**:
|
||||
* 从 v2.5 开始,GPFS 承担了 **数据存储** (Code/Data)、**状态同步** (Head IP) 和 **检查点存储** (Checkpoints) 三大职责,架构非常收敛。
|
||||
|
||||
|
||||
3. **高弹性 (Resilience)**:
|
||||
* Worker 的 **Watchdog** 机制确保了当 Head 重启或网络抖动时,集群具备自我修复能力,无需人工干预。
|
||||
File diff suppressed because it is too large
Load Diff
14
specs/mvp/v2.5/README.md
Normal file
14
specs/mvp/v2.5/README.md
Normal file
@ -0,0 +1,14 @@
|
||||
# MVP v2.5(Design)— User Management & Stateless Ray Node Pool
|
||||
|
||||
本目录基于 `specs/mvp/mvp_roadmap_v2.md` 与 `specs/mvp/image/roadmap_v2.5.png` 的 v2.5 规划,
|
||||
给出一份**可落地、可验证、可迭代实现**的详细方案设计文档集合。
|
||||
|
||||
v2.5 的核心变化:
|
||||
- 在 v2.0 的任务队列/调度/重试基础上,引入 **User Management**(多用户隔离、目录隔离、token)。
|
||||
- 引入 **Stateless Ray Node Pool**:worker 节点/容器不再需要平台显式下发 head 地址,通过共享存储(GPFS/NFS)完成服务发现与自愈连接(watchdog)。
|
||||
|
||||
文档:
|
||||
- `specs/mvp/v2.5/v2.5_design.md`:总体架构、关键机制(head IP file / watchdog / 用户隔离 / 任务流)。
|
||||
- `specs/mvp/v2.5/v2.5_api.md`:API 设计(用户、任务、队列、日志)与鉴权约定。
|
||||
- `specs/mvp/v2.5/v2.5_acceptance.md`:开发/部署/验收流程与可验证标准。
|
||||
|
||||
67
specs/mvp/v2.5/v2.5_acceptance.md
Normal file
67
specs/mvp/v2.5/v2.5_acceptance.md
Normal file
@ -0,0 +1,67 @@
|
||||
# MVP v2.5 开发/部署/验收标准
|
||||
|
||||
本文件定义 v2.5 的“可验证闭环”,确保每个里程碑可验收。
|
||||
|
||||
---
|
||||
|
||||
## 1. 开发交付物(Deliverables)
|
||||
|
||||
### 1.1 代码交付(建议)
|
||||
|
||||
- API Server 增强:user management + task 关联 user_id + 鉴权隔离
|
||||
- SQLite schema 迁移:新增 users/tokens,tasks 增加 user_id
|
||||
- Ray Head service discovery:head.json 写入与心跳刷新
|
||||
- Worker bootstrap + watchdog:
|
||||
- dev:以脚本方式提供(docker compose 场景)
|
||||
- prod:以容器 command/entrypoint 方式可注入
|
||||
|
||||
### 1.2 文档交付
|
||||
|
||||
- 目录结构与 GPFS 路径约定
|
||||
- API 文档(含用户与多租户隔离)
|
||||
- 运维 SOP:head 重启、worker 自愈、如何排障 head.json
|
||||
|
||||
---
|
||||
|
||||
## 2. 部署流程(Dev 环境可验证)
|
||||
|
||||
### 2.1 启动顺序(推荐)
|
||||
|
||||
1) 启动 head(包含 API server + Ray head)
|
||||
2) head 写入 `/private/ray/discovery/<cluster_name>/head.json`
|
||||
3) 启动若干 worker(无须指定 head 地址)
|
||||
4) worker 自动读取 head.json 并加入集群
|
||||
5) 通过 API 创建用户并获取 token
|
||||
6) 使用 user token 提交 PPO/GRPO/SFT
|
||||
|
||||
---
|
||||
|
||||
## 3. 验收标准(Acceptance Criteria)
|
||||
|
||||
### 3.1 Stateless Ray Node Pool
|
||||
|
||||
- A1:在 worker 启动时不传 head 地址,worker 能在 `T<=60s` 内加入集群(ray status 可见)
|
||||
- A2:head 容器重启(IP 变化或 Ray 重启)后:
|
||||
- head.json 更新
|
||||
- worker watchdog 在 `T<=60s` 内自动重连
|
||||
- A3:head 设置 `--num-gpus=0 --num-cpus=0`,训练 driver 不会跑到 head(可通过 Ray dashboard/日志验证)
|
||||
|
||||
### 3.2 User Management
|
||||
|
||||
- U1:admin 可创建用户并签发 token(token 仅返回一次)
|
||||
- U2:用户 A 提交的 task,用户 B 无法查询/取消/获取日志(API 返回 404 或 403,按设计约定)
|
||||
- U3:仅隔离 jobs 输出:任务输出落在 `/private/users/<user_id>/jobs/<ray_submission_id>/...`,不同用户互不覆盖
|
||||
- U4:训练输入(verl 代码、HF cache、datasets)统一使用 `/private/common/...`(v2.5 不做输入隔离)
|
||||
|
||||
### 3.3 Task Flow(继承 v2.0)
|
||||
|
||||
- T1:PPO/GRPO/SFT 三种 workload 都能成功提交并跑通(dev 规模可用 epoch=1/steps=10)
|
||||
- T2:资源不足时任务不会“直接失败不可恢复”,而是进入 `PENDING_RESOURCES` 并按间隔重试(与 v2.0 同逻辑)
|
||||
|
||||
---
|
||||
|
||||
## 4. 回归用例(最小集合)
|
||||
|
||||
1) 创建用户 alice/bob,分别提交 sft,验证隔离与输出目录
|
||||
2) 启动 head + 2 workers,提交 ppo/grpo,验证 driver 落 worker
|
||||
3) 重启 head(或修改 head.json 指向新 IP),验证 worker watchdog 自动重连
|
||||
109
specs/mvp/v2.5/v2.5_api.md
Normal file
109
specs/mvp/v2.5/v2.5_api.md
Normal file
@ -0,0 +1,109 @@
|
||||
# MVP v2.5 API 设计(User + Task + Queue)
|
||||
|
||||
v2.5 在 v2.0 API 基础上,新增 **User Management** 与多租户隔离。
|
||||
|
||||
约束:
|
||||
- 仍使用内部 token(API key);
|
||||
- 不引入外部 IAM;
|
||||
- TaskSpec 仍为 YAML(沿用现有结构化字段)。
|
||||
|
||||
---
|
||||
|
||||
## 1. Auth
|
||||
|
||||
Header:
|
||||
- `Authorization: Bearer <api_token>`
|
||||
|
||||
服务端行为:
|
||||
- 将 `api_token` 映射到 `user_id`
|
||||
- 之后的 task 操作默认仅作用于该 `user_id`
|
||||
|
||||
Admin token(可选):
|
||||
- 支持额外配置 `MVP_ADMIN_TOKEN`(或 user.role=admin)
|
||||
- admin 可跨用户查询/取消(用于运维)。
|
||||
|
||||
---
|
||||
|
||||
## 2. User Management
|
||||
|
||||
### 2.1 创建用户(admin)
|
||||
|
||||
`POST /api/v2/users`
|
||||
|
||||
Request(JSON):
|
||||
```json
|
||||
{"user_id":"alice","display_name":"Alice"}
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{"user_id":"alice","state":"ACTIVE"}
|
||||
```
|
||||
|
||||
### 2.2 为用户签发 token(admin)
|
||||
|
||||
`POST /api/v2/users/{user_id}/tokens`
|
||||
|
||||
Response(只返回一次明文 token):
|
||||
```json
|
||||
{"user_id":"alice","token":"mvp_u_..."}
|
||||
```
|
||||
|
||||
### 2.3 禁用用户(admin)
|
||||
|
||||
`POST /api/v2/users/{user_id}:disable`
|
||||
|
||||
---
|
||||
|
||||
## 3. Task Management(多租户)
|
||||
|
||||
### 3.1 提交任务
|
||||
|
||||
`POST /api/v2/tasks`
|
||||
|
||||
Body:
|
||||
- `Content-Type: application/yaml`
|
||||
- raw TaskSpec YAML(训练语义字段;不含 user_id)
|
||||
|
||||
Response:
|
||||
```json
|
||||
{"task_id":"mvp25-ppo-20251225-170001-2a3f","state":"QUEUED"}
|
||||
```
|
||||
|
||||
服务端 side effects:
|
||||
- 记录 tasks.user_id(由 token 得到)
|
||||
- 计算输出目录:`/private/users/<uid>/jobs/<ray_submission_id>/...`
|
||||
|
||||
### 3.2 查询任务(仅本人)
|
||||
|
||||
`GET /api/v2/tasks/{task_id}`
|
||||
|
||||
若 task 不属于当前 user:
|
||||
- 返回 `404`(避免泄露存在性)
|
||||
|
||||
### 3.3 取消任务(仅本人)
|
||||
|
||||
`POST /api/v2/tasks/{task_id}:cancel`
|
||||
|
||||
---
|
||||
|
||||
## 4. Queue/Debug
|
||||
|
||||
### 4.1 查看队列(本人视角)
|
||||
|
||||
`GET /api/v2/queue`
|
||||
|
||||
返回该 user 的 pending/running 列表。
|
||||
|
||||
### 4.2 管理员查看全局队列(admin)
|
||||
|
||||
`GET /api/v2/admin/queue`
|
||||
|
||||
---
|
||||
|
||||
## 5. Logs
|
||||
|
||||
`GET /api/v2/tasks/{task_id}/logs?attempt=latest&tail=2000`
|
||||
|
||||
行为与 v2.0 一致:透传 Ray Job logs tail。
|
||||
|
||||
255
specs/mvp/v2.5/v2.5_design.md
Normal file
255
specs/mvp/v2.5/v2.5_design.md
Normal file
@ -0,0 +1,255 @@
|
||||
# MVP v2.5 详细设计方案(User Management + Stateless Ray Node Pool)
|
||||
|
||||
本文目标:把 `mvp_roadmap_v2.md` 中 v2.5 的思路落到**可工程化实现**的设计层,包括:
|
||||
- API Server 内新增 user management;
|
||||
- Ray node pool 变为无状态(worker 自发现 head、自动加入、watchdog 自愈);
|
||||
- 仍保持 v2.0 的“任务管理层”语义:Task/Attempt、队列、资源判断、Ray Job 提交与状态同步;
|
||||
- 所有共享数据/状态统一落在 GPFS(dev 环境可先用 NFS),容器内路径统一为 `/private/`。
|
||||
|
||||
> 术语说明:文中“GPFS”代表生产共享存储;dev 环境可用 NFS,但容器内仍以 `/private/` 访问。
|
||||
|
||||
---
|
||||
|
||||
## 1. 目标与非目标
|
||||
|
||||
### 1.1 v2.5 目标(Must)
|
||||
|
||||
1) **User Management(最小多租户)**
|
||||
- 支持创建/禁用用户;
|
||||
- 为每个用户签发内部 token(API key),用于认证与隔离;
|
||||
- 用户隔离(v2.5 先做最小闭环,仅隔离 **jobs 输出** 与 API 可见性):
|
||||
- 用户只能看到/操作自己的 Task;
|
||||
- 训练输出(job root、checkpoints、日志归档等)按 user 目录落盘;
|
||||
- 训练输入(verl 代码、HF cache、datasets)统一使用 `common/`(v2.5 不支持用户自定义代码/模型/数据集隔离)。
|
||||
|
||||
2) **Stateless Ray Worker Node Pool**
|
||||
- worker 容器启动时无需被平台告知 head 地址;
|
||||
- worker 通过 GPFS 读取 **Head IP File** 自动连接 Ray head;
|
||||
- worker 内部 watchdog 监控 head 地址变化,发生变化时自动 `ray stop` + `ray start` 重连;
|
||||
- worker 尽量不依赖本地持久化状态(宕机/替换后可无感重建)。
|
||||
|
||||
3) **保持 v2.0 的 Task 管理行为**
|
||||
- Task/Attempt 模型不变(或向后兼容扩展);
|
||||
- 对齐 verl 的 fail-fast 行为:资源不足时服务侧 pending + 重试;
|
||||
- Ray Job 提交仍通过 Ray Python SDK(JobSubmissionClient)。
|
||||
|
||||
### 1.2 v2.5 非目标(Not Now)
|
||||
|
||||
- 完整 WebUI(留到 v3.0)。
|
||||
- 公平调度/配额/优先级(留到 v3.x)。
|
||||
- 完整生产级 IAM(留到 v4+),v2.5 仅内部 token。
|
||||
- K8s 原生编排(本阶段不要求,但设计需能适配“算力平台拉起容器,只能 ssh 进去纳管”的模式)。
|
||||
|
||||
---
|
||||
|
||||
## 2. 总体架构(对应 roadmap v2.5)
|
||||
|
||||
### 2.1 组件划分
|
||||
|
||||
**控制面(Control Plane)**
|
||||
- **API Server**
|
||||
- user management
|
||||
- task management(队列/调度/重试/状态聚合)
|
||||
- Ray Job Tool(Ray Client)
|
||||
- VerlTaskSpec(TaskSpec YAML,沿用 v2.0/v2.1 格式)
|
||||
- 与 Ray head 在同一台/同一容器是推荐形态(便于访问 dashboard / job server)
|
||||
- **Ray Head(有状态)**
|
||||
- 启动后把 head 地址写入 GPFS 的 Head IP File,用于 worker 服务发现
|
||||
|
||||
**数据面(Data Plane)**
|
||||
- **Ray Workers(无状态节点池)**
|
||||
- stateless bootstrap:从 GPFS 读取 head 地址自动加入集群
|
||||
- watchdog:持续 watch head 地址文件变化并自愈重连
|
||||
|
||||
**共享存储(GPFS)**
|
||||
- 统一数据路径:数据、模型 cache、代码、任务输出、以及 head 服务发现文件。
|
||||
|
||||
### 2.2 v2.5 的控制反转(IoC)
|
||||
|
||||
与 v2.0/手工集群的关键差异:
|
||||
- v2.0:平台脚本/运维显式启动 worker 并指定 `--address=<head>`。
|
||||
- v2.5:worker 自己从 GPFS 读取 `head_ip_file`,无需平台维持 worker 列表与 SSH 连接池。
|
||||
|
||||
---
|
||||
|
||||
## 3. GPFS 目录结构(容器内 `/private`)
|
||||
|
||||
建议在 v2.5 固化以下目录(与现有 v2.0 兼容扩展):
|
||||
|
||||
```
|
||||
/private/
|
||||
ray/
|
||||
discovery/
|
||||
<cluster_name>/
|
||||
head.json # Head IP File(服务发现)
|
||||
head.json.lock # 可选:写入锁(v2.5 可先不实现)
|
||||
users/
|
||||
<user_id>/
|
||||
jobs/ # /private/users/<uid>/jobs/<ray_submission_id>/*
|
||||
outputs/ # 训练输出聚合(按需要)
|
||||
common/
|
||||
code/ # 平台/公共代码快照(verl code snapshot 等)
|
||||
datasets/ # 公共数据集
|
||||
hf/ # 公共 HF cache(dev 复用)
|
||||
db/ # sqlite
|
||||
logs/ # API 日志、平台日志
|
||||
```
|
||||
|
||||
说明:
|
||||
- `common/`:平台默认目录(v2.5 先默认所有用户可写;后续再加 ACL/只读)。
|
||||
- `users/<user_id>/...`:用户隔离主边界(最小多租户的关键)。
|
||||
|
||||
---
|
||||
|
||||
## 4. Head IP File(服务发现)设计
|
||||
|
||||
### 4.1 文件路径
|
||||
|
||||
- `head_ip_file = /private/ray/discovery/<cluster_name>/head.json`
|
||||
- `<cluster_name>`:由配置指定(例如 `argus-ray`),允许同一 GPFS 上存在多个环境/集群。
|
||||
|
||||
### 4.2 文件内容(JSON)
|
||||
|
||||
建议采用 JSON(易扩展):
|
||||
|
||||
```json
|
||||
{
|
||||
"cluster_name": "argus-ray",
|
||||
"head_ip": "10.0.0.12",
|
||||
"gcs_port": 6379,
|
||||
"dashboard_port": 8265,
|
||||
"job_server_url": "http://10.0.0.12:8265",
|
||||
"updated_at": "2025-12-25T17:00:00Z",
|
||||
"expires_at": "2025-12-25T17:01:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
关键点:
|
||||
- `updated_at`:便于排障与可观测;
|
||||
- `expires_at`:避免 worker 读取到“陈旧 head 地址”后无限重连;
|
||||
- `job_server_url`:对外可直接用于 Ray Job Tool 配置(便于无脑接入)。
|
||||
|
||||
### 4.3 写入策略(原子更新)
|
||||
|
||||
Head 写入时必须保证 worker 读取不会读到半文件:
|
||||
- 写临时文件 `head.json.tmp`;
|
||||
- `fsync`(可选);
|
||||
- `rename(head.json.tmp -> head.json)`(原子替换)。
|
||||
|
||||
### 4.4 心跳与 TTL
|
||||
|
||||
Head 进程需周期性刷新 `head.json`:
|
||||
- 建议 `ttl_s=60`,刷新周期 `refresh_s=10`;
|
||||
- 若 head 进程异常退出,worker 读取到过期文件可进入“等待模式”而非无限重连。
|
||||
|
||||
---
|
||||
|
||||
## 5. Stateless Worker Bootstrap + Watchdog
|
||||
|
||||
### 5.1 启动序列(worker 容器内)
|
||||
|
||||
1) 启动脚本读取 `head.json`:
|
||||
- 若文件不存在:sleep + 重试(直到存在)
|
||||
- 若存在但 `expires_at` 已过期:sleep + 重试(直到变为新鲜)
|
||||
2) 解析 `head_ip:gcs_port` 并执行:
|
||||
- `ray stop --force || true`
|
||||
- `ray start --address=<head_ip>:<gcs_port> --resources='{"worker_node": 100, ...}' ...`
|
||||
3) 启动 watchdog 进程(同容器):
|
||||
- 轮询/监听 `head.json` 的内容变化
|
||||
- 一旦 `head_ip` 或 `gcs_port` 改变,触发 `ray stop` + `ray start` 重连
|
||||
|
||||
### 5.2 Watchdog 策略(最小可用)
|
||||
|
||||
v2.5 推荐“简单且稳”的实现:
|
||||
- polling 间隔 `watch_s=5`;
|
||||
- 对比 `head.json` 的 `updated_at` 或 hash;
|
||||
- 若发现变更:执行重连;
|
||||
- 若连续多次重连失败:指数退避(v2.5 可先固定退避,v2.6 再增强)。
|
||||
|
||||
### 5.3 资源标签(driver 强制落 worker)
|
||||
|
||||
继续沿用 v2.0 的思路:
|
||||
- worker 启动时 `--resources='{"worker_node": 100}'`
|
||||
- head 不包含 `worker_node` 资源
|
||||
- Ray job submit 时设置 entrypoint_resources:`{"worker_node": 1}`
|
||||
|
||||
### 5.4 GPU/CPU 的“无状态”约束
|
||||
|
||||
- worker 是否有 GPU 由底层算力平台决定(生产上平台会为容器挂载 GPU);
|
||||
- worker 启动脚本不应硬编码 GPU 编号,只依赖 `NVIDIA_VISIBLE_DEVICES`/平台注入;
|
||||
- head 推荐 `--num-gpus=0 --num-cpus=0`,避免训练调度到 head。
|
||||
|
||||
---
|
||||
|
||||
## 6. User Management 设计(最小多租户)
|
||||
|
||||
### 6.1 数据模型(SQLite)
|
||||
|
||||
新增两张表(示意):
|
||||
- `users`
|
||||
- `user_id`(PK)
|
||||
- `display_name`
|
||||
- `state`(ACTIVE/DISABLED)
|
||||
- `created_at`
|
||||
- `api_tokens`
|
||||
- `token_hash`(PK)
|
||||
- `user_id`(FK)
|
||||
- `created_at`
|
||||
- `last_used_at`
|
||||
|
||||
并在 `tasks` 表增加:
|
||||
- `user_id`(FK)
|
||||
|
||||
### 6.2 鉴权策略
|
||||
|
||||
内部 token 模式:
|
||||
- `Authorization: Bearer <token>`
|
||||
- 服务端将 token 映射到 `user_id`
|
||||
- 后续所有 task 查询/取消/日志默认 scope 到该 `user_id`
|
||||
|
||||
管理员能力(v2.5 最小实现):
|
||||
- 额外配置一个 admin token(或把特定 user 标记为 admin)
|
||||
- admin 可 list all users/tasks(用于运维排障)。
|
||||
|
||||
### 6.3 用户目录隔离(路径约束)
|
||||
|
||||
核心原则(v2.5 版):
|
||||
- **输出**:必须落在 `/private/users/<uid>/jobs/...`(服务端统一计算,不允许用户任意指定输出根)
|
||||
- **输入**:统一使用 `/private/common/...`(v2.5 不支持用户自定义 verl 代码、也不做 hf/datasets 的用户隔离)
|
||||
|
||||
服务端处理策略(最小可用):
|
||||
- 解析 TaskSpec 后,对输入路径字段做白名单前缀校验(必须是 `/private/common/...`;拒绝 `../` 与越界路径);
|
||||
- 输出目录统一由服务端计算:`job_root = /private/users/<uid>/jobs/<ray_submission_id>/`。
|
||||
|
||||
---
|
||||
|
||||
## 7. TaskSpec(VerlTaskSpec YAML)在 v2.5 的扩展点
|
||||
|
||||
v2.5 **不扩展 TaskSpec**:保持与 v2.0/v2.1 的 YAML 结构化字段与语义一致。
|
||||
|
||||
v2.5 的“用户语义”仅体现在服务端的补齐/约束:
|
||||
- user_id 由 token 推导(用户不需要在 YAML 里写 user_id);
|
||||
- 服务端派生 `ray_submission_id`(由 task_id/attempt 派生);
|
||||
- 服务端统一计算输出目录 `job_root=/private/users/<uid>/jobs/<ray_submission_id>/...`;
|
||||
- v2.5 不支持用户自定义 verl 代码路径(因此 runtime_env 不需要注入用户 code 目录)。
|
||||
|
||||
---
|
||||
|
||||
## 8. 迁移与兼容性
|
||||
|
||||
v2.5 设计需满足:
|
||||
- 现有 v2.0 的“手工启动 worker”仍可运行(作为 dev fallback);
|
||||
- 在不改镜像的前提下,worker watchdog 可以以“容器启动命令/entrypoint”方式注入(dev 用 scripts;生产由算力平台指定 command)。
|
||||
|
||||
---
|
||||
|
||||
## 9. 风险与对策(v2.5)
|
||||
|
||||
1) **GPFS 上 head.json 一致性/延迟**
|
||||
- 对策:原子 rename + TTL;watchdog polling。
|
||||
|
||||
2) **Ray head 重启后 job server URL 变化**
|
||||
- 对策:head.json 内写入 `job_server_url`,Ray Job Tool 可读取该文件更新 address(v2.6 可做动态 reload)。
|
||||
|
||||
3) **Worker 重连期间任务波动**
|
||||
- 对策:服务侧调度器对齐 verl 的资源 fail-fast;任务失败可归因并排队重试。
|
||||
4
src/mvp/py/requirements-dev.txt
Normal file
4
src/mvp/py/requirements-dev.txt
Normal file
@ -0,0 +1,4 @@
|
||||
pytest==8.4.1
|
||||
pytest-cov==6.3.0
|
||||
httpx==0.28.1
|
||||
|
||||
77
src/mvp/py/tests/conftest.py
Normal file
77
src/mvp/py/tests/conftest.py
Normal file
@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _ensure_mvp_py_on_path() -> None:
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
py_root = repo_root / "src" / "mvp" / "py"
|
||||
if str(py_root) not in sys.path:
|
||||
sys.path.insert(0, str(py_root))
|
||||
|
||||
|
||||
def _install_ray_stub() -> None:
|
||||
if "ray" in sys.modules:
|
||||
return
|
||||
|
||||
ray = types.ModuleType("ray")
|
||||
ray.__path__ = [] # type: ignore[attr-defined]
|
||||
|
||||
def _init(*args: object, **kwargs: object) -> None:
|
||||
return None
|
||||
|
||||
ray.init = _init # type: ignore[attr-defined]
|
||||
ray.cluster_resources = lambda: {} # type: ignore[attr-defined]
|
||||
ray.available_resources = lambda: {} # type: ignore[attr-defined]
|
||||
sys.modules["ray"] = ray
|
||||
|
||||
job_submission = types.ModuleType("ray.job_submission")
|
||||
job_submission.__path__ = [] # type: ignore[attr-defined]
|
||||
|
||||
class JobSubmissionClient: # minimal stub; tests can monkeypatch methods
|
||||
def __init__(self, address: str):
|
||||
self.address = address
|
||||
|
||||
def submit_job(self, **kwargs: object) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_job_status(self, submission_id: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def stop_job(self, submission_id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_job_logs(self, submission_id: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_jobs(self) -> list[object]:
|
||||
return []
|
||||
|
||||
job_submission.JobSubmissionClient = JobSubmissionClient # type: ignore[attr-defined]
|
||||
sys.modules["ray.job_submission"] = job_submission
|
||||
ray.job_submission = job_submission # type: ignore[attr-defined]
|
||||
|
||||
private = types.ModuleType("ray._private")
|
||||
private.__path__ = [] # type: ignore[attr-defined]
|
||||
state = types.ModuleType("ray._private.state")
|
||||
|
||||
def available_resources_per_node() -> dict[str, object]:
|
||||
return {}
|
||||
|
||||
state.available_resources_per_node = available_resources_per_node # type: ignore[attr-defined]
|
||||
sys.modules["ray._private"] = private
|
||||
sys.modules["ray._private.state"] = state
|
||||
private.state = state # type: ignore[attr-defined]
|
||||
ray._private = private # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_ensure_mvp_py_on_path()
|
||||
_install_ray_stub()
|
||||
|
||||
|
||||
def pytest_configure(config: object) -> None:
|
||||
os.environ.setdefault("PYTHONUTF8", "1")
|
||||
|
||||
166
src/mvp/py/tests/test_app.py
Normal file
166
src/mvp/py/tests/test_app.py
Normal file
@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def _write_config(tmp_path: Path) -> Path:
|
||||
cfg = {
|
||||
"ray": {
|
||||
"address": "http://127.0.0.1:8265",
|
||||
"shared_root": str(tmp_path),
|
||||
"entrypoint_resources": {"worker_node": 1},
|
||||
"runtime_env": {"env_vars": {}},
|
||||
},
|
||||
"service": {
|
||||
"api": {"host": "127.0.0.1", "port": 0},
|
||||
"auth": {"token_env": "MVP_INTERNAL_TOKEN"},
|
||||
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
|
||||
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
|
||||
},
|
||||
}
|
||||
p = tmp_path / "cfg.yaml"
|
||||
p.write_text(yaml.safe_dump(cfg), encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
def test_auth_requires_token_env(tmp_path: Path, monkeypatch):
|
||||
from argus.service import app as app_mod
|
||||
|
||||
cfg_path = _write_config(tmp_path)
|
||||
monkeypatch.delenv("MVP_INTERNAL_TOKEN", raising=False)
|
||||
|
||||
class _Scheduler:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool = object()
|
||||
|
||||
def run_forever(self, stop_flag):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
||||
app = app_mod.create_app(str(cfg_path))
|
||||
|
||||
with TestClient(app) as c:
|
||||
r = c.get("/api/v2/queue")
|
||||
assert r.status_code == 500
|
||||
|
||||
|
||||
def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
|
||||
from argus.service import app as app_mod
|
||||
|
||||
cfg_path = _write_config(tmp_path)
|
||||
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
|
||||
monkeypatch.setattr(app_mod, "new_task_id", lambda workload: "tid1")
|
||||
|
||||
class _Tool:
|
||||
def __init__(self):
|
||||
self.stopped = []
|
||||
|
||||
def stop(self, sid: str):
|
||||
self.stopped.append(sid)
|
||||
return True
|
||||
|
||||
def logs(self, sid: str):
|
||||
return "a\nb\nc\n"
|
||||
|
||||
class _Scheduler:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool = _Tool()
|
||||
|
||||
def run_forever(self, stop_flag):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
||||
app = app_mod.create_app(str(cfg_path))
|
||||
|
||||
headers = {"authorization": "Bearer token1"}
|
||||
with TestClient(app) as c:
|
||||
r = c.post(
|
||||
"/api/v2/tasks",
|
||||
headers=headers,
|
||||
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["task_id"] == "tid1"
|
||||
|
||||
r2 = c.get("/api/v2/tasks/tid1", headers=headers)
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["desired_resources"]["total_gpus"] == 8
|
||||
|
||||
r3 = c.get("/api/v2/queue", headers=headers)
|
||||
assert r3.status_code == 200
|
||||
assert "pending" in r3.json()
|
||||
|
||||
r4 = c.post("/api/v2/tasks/tid1:cancel", headers=headers)
|
||||
assert r4.status_code == 200
|
||||
assert r4.json()["state"] == "CANCELED"
|
||||
|
||||
# Seed an attempt then fetch logs
|
||||
from argus.service.db import Db
|
||||
from argus.service.config import V2Config
|
||||
from argus.ray.models import RayConfig
|
||||
|
||||
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
|
||||
ray_cfg = RayConfig.from_dict(root)
|
||||
v2_cfg = V2Config.from_root_dict(root)
|
||||
db = Db(v2_cfg.sqlite.db_path)
|
||||
db.create_task(
|
||||
task_id="tid2",
|
||||
workload="ppo",
|
||||
jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
db.create_attempt(task_id="tid2", attempt_no=1, ray_submission_id="sid2")
|
||||
db.set_task_state(task_id="tid2", state="RUNNING", latest_attempt_no=1)
|
||||
|
||||
r5 = c.get("/api/v2/tasks/tid2/logs?tail=1", headers=headers)
|
||||
assert r5.status_code == 200
|
||||
assert r5.text.strip() == "c"
|
||||
|
||||
|
||||
def test_submit_rejects_non_mapping_jobspec(tmp_path: Path, monkeypatch):
|
||||
from argus.service import app as app_mod
|
||||
|
||||
cfg_path = _write_config(tmp_path)
|
||||
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
|
||||
|
||||
class _Scheduler:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool = object()
|
||||
|
||||
def run_forever(self, stop_flag):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
||||
app = app_mod.create_app(str(cfg_path))
|
||||
|
||||
with TestClient(app) as c:
|
||||
r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="- 1\n- 2\n")
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
def test_submit_rejects_invalid_jobspec(tmp_path: Path, monkeypatch):
|
||||
from argus.service import app as app_mod
|
||||
|
||||
cfg_path = _write_config(tmp_path)
|
||||
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
|
||||
|
||||
class _Scheduler:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool = object()
|
||||
|
||||
def run_forever(self, stop_flag):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
||||
app = app_mod.create_app(str(cfg_path))
|
||||
|
||||
with TestClient(app) as c:
|
||||
r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="workload: nope\n")
|
||||
assert r.status_code == 400
|
||||
|
||||
53
src/mvp/py/tests/test_builders.py
Normal file
53
src/mvp/py/tests/test_builders.py
Normal file
@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from argus.ray.builders import build_training_argv
|
||||
from argus.ray.models import JobSpec
|
||||
|
||||
|
||||
def _base_spec(workload: str) -> JobSpec:
|
||||
return JobSpec(
|
||||
workload=workload,
|
||||
submission_id=None,
|
||||
code_path="/code",
|
||||
model_id="m",
|
||||
train_file="train.jsonl",
|
||||
val_file=None,
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
total_epochs=1,
|
||||
total_training_steps=10,
|
||||
save_freq=10,
|
||||
test_freq=None,
|
||||
trainer_device=None,
|
||||
)
|
||||
|
||||
|
||||
def test_build_training_argv_ppo_smoke():
|
||||
spec = _base_spec("ppo")
|
||||
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
|
||||
assert built.argv[:3] == ["python3", "-m", "verl.trainer.main_ppo"]
|
||||
assert "data.val_files=null" in built.argv
|
||||
assert "trainer.test_freq=-1" in built.argv
|
||||
|
||||
|
||||
def test_build_training_argv_grpo_has_override():
|
||||
spec = _base_spec("grpo")
|
||||
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
|
||||
assert "algorithm.adv_estimator=grpo" in built.argv
|
||||
|
||||
|
||||
def test_build_training_argv_sft_smoke():
|
||||
spec = _base_spec("sft")
|
||||
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
|
||||
assert built.argv[:3] == ["python3", "-m", "verl.trainer.sft_trainer_ray"]
|
||||
assert "trainer.device=cpu" in built.argv
|
||||
assert "data.val_files=null" in built.argv
|
||||
|
||||
|
||||
def test_build_training_argv_unsupported_raises():
|
||||
spec = _base_spec("bad")
|
||||
with pytest.raises(ValueError, match="unsupported workload"):
|
||||
build_training_argv(spec, submission_id="sid", job_dir="/job")
|
||||
|
||||
77
src/mvp/py/tests/test_cli_run.py
Normal file
77
src/mvp/py/tests/test_cli_run.py
Normal file
@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_cli_submit_status_logs_list(monkeypatch, tmp_path: Path, capsys):
|
||||
from argus.ray import ray_job_tool as tool_mod
|
||||
|
||||
class _Tool:
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
|
||||
def submit(self, spec, no_wait: bool):
|
||||
return "sid1"
|
||||
|
||||
def status(self, sid: str):
|
||||
return "RUNNING"
|
||||
|
||||
def stop(self, sid: str):
|
||||
return True
|
||||
|
||||
def logs(self, sid: str):
|
||||
return "1\n2\n3\n"
|
||||
|
||||
def list(self):
|
||||
return [{"a": 1}]
|
||||
|
||||
monkeypatch.setattr(tool_mod, "RayJobTool", _Tool)
|
||||
|
||||
cfg = tmp_path / "cfg.yaml"
|
||||
cfg.write_text(
|
||||
"ray:\n address: http://127.0.0.1:8265\n shared_root: /private\n entrypoint_resources: {worker_node: 1}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
spec = tmp_path / "spec.yaml"
|
||||
spec.write_text("workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n", encoding="utf-8")
|
||||
|
||||
from argus.cli.run import main
|
||||
|
||||
monkeypatch.setattr(sys, "argv", ["run.py", "--config", str(cfg), "--taskspec", str(spec), "--action", "submit"])
|
||||
assert main() == 0
|
||||
assert capsys.readouterr().out.strip() == "sid1"
|
||||
|
||||
monkeypatch.setattr(sys, "argv", ["run.py", "--config", str(cfg), "--action", "status", "--submission-id", "sid1"])
|
||||
assert main() == 0
|
||||
assert capsys.readouterr().out.strip() == "RUNNING"
|
||||
|
||||
monkeypatch.setattr(
|
||||
sys, "argv", ["run.py", "--config", str(cfg), "--action", "logs", "--submission-id", "sid1", "--tail", "1"]
|
||||
)
|
||||
assert main() == 0
|
||||
assert capsys.readouterr().out.strip() == "3"
|
||||
|
||||
monkeypatch.setattr(sys, "argv", ["run.py", "--config", str(cfg), "--action", "list"])
|
||||
assert main() == 0
|
||||
out = capsys.readouterr().out
|
||||
assert json.loads(out)[0]["a"] == 1
|
||||
|
||||
|
||||
def test_cli_requires_submission_id_for_status():
|
||||
from argus.cli.run import main
|
||||
|
||||
tmp = Path(__import__("tempfile").gettempdir()) / "mvp_test_cfg.yaml"
|
||||
tmp.write_text(
|
||||
"ray:\n address: http://127.0.0.1:8265\n shared_root: /private\n entrypoint_resources: {worker_node: 1}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
try:
|
||||
with pytest.raises(SystemExit):
|
||||
sys.argv = ["run.py", "--config", str(tmp), "--action", "status"]
|
||||
main()
|
||||
finally:
|
||||
tmp.unlink(missing_ok=True)
|
||||
42
src/mvp/py/tests/test_db.py
Normal file
42
src/mvp/py/tests/test_db.py
Normal file
@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_db_lifecycle_and_basic_queries(tmp_path: Path):
|
||||
from argus.service.db import Db
|
||||
|
||||
db = Db(str(tmp_path / "mvp.sqlite3"))
|
||||
db.init()
|
||||
|
||||
t = db.create_task(
|
||||
task_id="t1",
|
||||
workload="ppo",
|
||||
jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
assert t["task_id"] == "t1"
|
||||
assert db.get_task("t1") is not None
|
||||
|
||||
q = db.list_queue()
|
||||
assert len(q["pending"]) == 1
|
||||
|
||||
db.set_task_state(task_id="t1", state="PENDING_RESOURCES", next_run_at="2099-01-01T00:00:00Z")
|
||||
assert db.pick_next_runnable_task() is None
|
||||
|
||||
# next_run_at is sticky unless explicitly updated; a future value keeps it non-runnable.
|
||||
db.set_task_state(task_id="t1", state="QUEUED", next_run_at=None)
|
||||
assert db.pick_next_runnable_task() is None
|
||||
|
||||
# Allow it by setting next_run_at into the past.
|
||||
db.set_task_state(task_id="t1", state="QUEUED", next_run_at="2000-01-01T00:00:00Z")
|
||||
assert db.pick_next_runnable_task() is not None
|
||||
|
||||
db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="sid1")
|
||||
db.update_attempt(task_id="t1", attempt_no=1, ray_status="RUNNING")
|
||||
attempts = db.list_attempts("t1")
|
||||
assert attempts[-1]["ray_status"] == "RUNNING"
|
||||
|
||||
# No-op update is allowed
|
||||
db.update_attempt(task_id="t1", attempt_no=1)
|
||||
25
src/mvp/py/tests/test_driver_entrypoint.py
Normal file
25
src/mvp/py/tests/test_driver_entrypoint.py
Normal file
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_driver_entrypoint_missing_cmd_returns_2(monkeypatch, tmp_path: Path):
|
||||
from argus.ray.driver_entrypoint import main
|
||||
|
||||
monkeypatch.setattr(sys, "argv", ["x", "--job-dir", str(tmp_path)])
|
||||
assert main() == 2
|
||||
|
||||
|
||||
def test_driver_entrypoint_strips_double_dash_and_returns_code(monkeypatch, tmp_path: Path):
|
||||
from argus.ray import driver_entrypoint as mod
|
||||
|
||||
class _Proc:
|
||||
returncode = 7
|
||||
|
||||
monkeypatch.setattr(mod.subprocess, "run", lambda cmd, check: _Proc())
|
||||
monkeypatch.setattr(sys, "argv", ["x", "--job-dir", str(tmp_path), "--", "echo", "hi"])
|
||||
|
||||
assert mod.main() == 7
|
||||
assert (tmp_path).exists()
|
||||
|
||||
28
src/mvp/py/tests/test_ids.py
Normal file
28
src/mvp/py/tests/test_ids.py
Normal file
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def test_new_task_id_is_deterministic_with_patches(monkeypatch):
|
||||
import argus.core.ids as ids
|
||||
|
||||
class _FakeDatetime:
|
||||
@staticmethod
|
||||
def now():
|
||||
class _DT:
|
||||
def strftime(self, fmt: str) -> str:
|
||||
assert fmt == "%Y%m%d-%H%M%S"
|
||||
return "20250101-010203"
|
||||
|
||||
return _DT()
|
||||
|
||||
monkeypatch.setattr(ids, "datetime", _FakeDatetime)
|
||||
monkeypatch.setattr(ids.secrets, "token_hex", lambda n: "abcd")
|
||||
|
||||
assert ids.new_task_id("ppo") == "mvp2-ppo-20250101-010203-abcd"
|
||||
|
||||
|
||||
def test_attempt_submission_id_format():
|
||||
from argus.core.ids import attempt_submission_id
|
||||
|
||||
assert attempt_submission_id("t", 1) == "t--a01"
|
||||
assert attempt_submission_id("t", 12) == "t--a12"
|
||||
|
||||
107
src/mvp/py/tests/test_models.py
Normal file
107
src/mvp/py/tests/test_models.py
Normal file
@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_require_missing_raises():
|
||||
from argus.ray.models import _require
|
||||
|
||||
with pytest.raises(ValueError, match="missing required field: x"):
|
||||
_require({}, "x")
|
||||
with pytest.raises(ValueError, match="missing required field: x"):
|
||||
_require({"x": ""}, "x")
|
||||
|
||||
|
||||
def test_ray_config_from_dict_new_format_and_defaults():
|
||||
from argus.ray.models import RayConfig
|
||||
|
||||
cfg = RayConfig.from_dict(
|
||||
{
|
||||
"ray": {
|
||||
"address": "http://127.0.0.1:8265",
|
||||
"shared_root": "/private",
|
||||
"entrypoint_resources": {"worker_node": 1},
|
||||
"runtime_env": {"env_vars": {"HF_ENDPOINT": "x"}},
|
||||
}
|
||||
}
|
||||
)
|
||||
assert cfg.address.endswith("8265")
|
||||
assert cfg.shared_root == "/private"
|
||||
assert cfg.entrypoint_num_cpus == 1.0
|
||||
assert cfg.entrypoint_resources["worker_node"] == 1.0
|
||||
assert cfg.runtime_env_env_vars["HF_ENDPOINT"] == "x"
|
||||
assert cfg.user_code_path == "/private/user/code"
|
||||
|
||||
public = cfg.to_public_dict()
|
||||
assert public["runtime_env"]["env_vars"]["HF_ENDPOINT"] == "x"
|
||||
|
||||
|
||||
def test_ray_config_from_dict_requires_mappings():
|
||||
from argus.ray.models import RayConfig
|
||||
|
||||
with pytest.raises(ValueError, match="runtime_env\\.env_vars must be a mapping"):
|
||||
RayConfig.from_dict(
|
||||
{
|
||||
"address": "x",
|
||||
"shared_root": "/p",
|
||||
"entrypoint_resources": {},
|
||||
"runtime_env": {"env_vars": ["nope"]},
|
||||
}
|
||||
)
|
||||
with pytest.raises(ValueError, match="entrypoint_resources must be a mapping"):
|
||||
RayConfig.from_dict(
|
||||
{
|
||||
"address": "x",
|
||||
"shared_root": "/p",
|
||||
"entrypoint_resources": ["nope"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_jobspec_validation_and_null_coercion():
|
||||
from argus.ray.models import JobSpec
|
||||
|
||||
spec = JobSpec.from_dict(
|
||||
{
|
||||
"workload": "ppo",
|
||||
"code_path": "/code",
|
||||
"model_id": "m",
|
||||
"train_file": "train.jsonl",
|
||||
"val_file": "null",
|
||||
"test_freq": "",
|
||||
}
|
||||
)
|
||||
assert spec.workload == "ppo"
|
||||
assert spec.val_file is None
|
||||
assert spec.test_freq is None
|
||||
assert spec.nnodes == 2
|
||||
assert spec.n_gpus_per_node == 4
|
||||
|
||||
pub = spec.to_public_dict()
|
||||
assert pub["submission_id"] == ""
|
||||
assert "trainer_device" not in pub
|
||||
|
||||
|
||||
def test_jobspec_sft_adds_trainer_device_default():
|
||||
from argus.ray.models import JobSpec
|
||||
|
||||
spec = JobSpec.from_dict(
|
||||
{
|
||||
"workload": "sft",
|
||||
"code_path": "/code",
|
||||
"model_id": "m",
|
||||
"train_file": "train.jsonl",
|
||||
}
|
||||
)
|
||||
pub = spec.to_public_dict()
|
||||
assert pub["trainer_device"] == "cpu"
|
||||
|
||||
|
||||
def test_jobspec_unsupported_workload():
|
||||
from argus.ray.models import JobSpec
|
||||
|
||||
with pytest.raises(ValueError, match="unsupported workload"):
|
||||
JobSpec.from_dict(
|
||||
{"workload": "nope", "code_path": "x", "model_id": "m", "train_file": "t"}
|
||||
)
|
||||
|
||||
162
src/mvp/py/tests/test_ray_job_tool.py
Normal file
162
src/mvp/py/tests/test_ray_job_tool.py
Normal file
@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from argus.ray.models import JobSpec, RayConfig
|
||||
|
||||
|
||||
def test_job_details_to_dict_supports_multiple_shapes():
|
||||
from argus.ray.ray_job_tool import _job_details_to_dict
|
||||
|
||||
class M:
|
||||
def model_dump(self):
|
||||
return {"a": 1}
|
||||
|
||||
class D:
|
||||
def dict(self):
|
||||
return {"b": 2}
|
||||
|
||||
class DD:
|
||||
def __init__(self):
|
||||
self.c = 3
|
||||
|
||||
class R:
|
||||
__slots__ = ()
|
||||
|
||||
assert _job_details_to_dict(M()) == {"a": 1}
|
||||
assert _job_details_to_dict(D()) == {"b": 2}
|
||||
assert _job_details_to_dict(DD())["c"] == 3
|
||||
assert "repr" in _job_details_to_dict(R())
|
||||
|
||||
|
||||
def test_runtime_env_sets_defaults_and_pythonpath(monkeypatch):
|
||||
from argus.ray.ray_job_tool import RayJobTool
|
||||
|
||||
cfg = RayConfig.from_dict(
|
||||
{
|
||||
"address": "http://127.0.0.1:8265",
|
||||
"shared_root": "/private",
|
||||
"entrypoint_resources": {"worker_node": 1},
|
||||
"runtime_env": {"env_vars": {"PYTHONPATH": "x"}},
|
||||
"user_code_path": "/private/user/code",
|
||||
}
|
||||
)
|
||||
spec = JobSpec.from_dict(
|
||||
{"workload": "sft", "code_path": "/c", "model_id": "m", "train_file": "t"}
|
||||
)
|
||||
monkeypatch.setenv("MVP_TOOL_CODE_PATH", "/tool")
|
||||
|
||||
tool = RayJobTool(cfg)
|
||||
env = tool._runtime_env(spec)["env_vars"]
|
||||
assert env["HF_HOME"].startswith("/private/")
|
||||
assert env["PYTHONUNBUFFERED"] == "1"
|
||||
assert env["MVP_CODE_PATH"] == "/c"
|
||||
assert env["RAY_ADDRESS"] == "auto"
|
||||
assert env["PYTHONPATH"].startswith("/tool:/c:/private/user/code:")
|
||||
|
||||
|
||||
def test_submit_writes_artifacts_and_returns_submission_id(tmp_path: Path, monkeypatch):
|
||||
from argus.ray import ray_job_tool as mod
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, address: str):
|
||||
self.address = address
|
||||
self.last_submit_kwargs = None
|
||||
|
||||
def submit_job(self, **kwargs):
|
||||
self.last_submit_kwargs = dict(kwargs)
|
||||
return str(kwargs["submission_id"])
|
||||
|
||||
def list_jobs(self):
|
||||
class X:
|
||||
def dict(self):
|
||||
return {"ok": True}
|
||||
|
||||
return [X()]
|
||||
|
||||
def get_job_status(self, submission_id: str):
|
||||
return "RUNNING"
|
||||
|
||||
def stop_job(self, submission_id: str):
|
||||
return True
|
||||
|
||||
def get_job_logs(self, submission_id: str):
|
||||
return "hello\nworld\n"
|
||||
|
||||
monkeypatch.setattr(mod, "JobSubmissionClient", _FakeClient)
|
||||
monkeypatch.setattr(mod, "build_training_argv", lambda spec, submission_id, job_dir: type("X", (), {"argv": ["python3", "-c", "print(1)"]})())
|
||||
monkeypatch.setattr(mod.ray, "init", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("no ray")))
|
||||
|
||||
cfg = RayConfig.from_dict(
|
||||
{
|
||||
"address": "http://127.0.0.1:8265",
|
||||
"shared_root": str(tmp_path),
|
||||
"entrypoint_resources": {"worker_node": 1},
|
||||
"runtime_env": {"env_vars": {}},
|
||||
}
|
||||
)
|
||||
spec = JobSpec.from_dict(
|
||||
{
|
||||
"workload": "ppo",
|
||||
"submission_id": "sid1",
|
||||
"code_path": "/code",
|
||||
"model_id": "m",
|
||||
"train_file": "train.jsonl",
|
||||
}
|
||||
)
|
||||
|
||||
tool = mod.RayJobTool(cfg)
|
||||
submitted = tool.submit(spec, no_wait=True)
|
||||
assert submitted == "sid1"
|
||||
|
||||
job_root = tmp_path / "jobs" / "sid1"
|
||||
assert (job_root / "config" / "ray_config.yaml").exists()
|
||||
assert (job_root / "config" / "jobspec.yaml").exists()
|
||||
assert (job_root / "config" / "ray_submission_id.txt").read_text(encoding="utf-8").strip() == "sid1"
|
||||
|
||||
payload = json.loads((job_root / "config" / "submit_payload.json").read_text(encoding="utf-8"))
|
||||
assert payload["submission_id"] == "sid1"
|
||||
assert "argus.ray.driver_entrypoint" in payload["entrypoint"]
|
||||
|
||||
assert (job_root / "debug" / "ray_resources_pre.error.txt").exists()
|
||||
assert (job_root / "debug" / "ray_job_list_post.json").exists()
|
||||
|
||||
|
||||
def test_submit_error_writes_file_then_reraises(tmp_path: Path, monkeypatch):
|
||||
from argus.ray import ray_job_tool as mod
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, address: str):
|
||||
self.address = address
|
||||
|
||||
def submit_job(self, **kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
def list_jobs(self):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(mod, "JobSubmissionClient", _FakeClient)
|
||||
monkeypatch.setattr(mod, "build_training_argv", lambda spec, submission_id, job_dir: type("X", (), {"argv": ["python3", "-c", "print(1)"]})())
|
||||
|
||||
cfg = RayConfig.from_dict(
|
||||
{
|
||||
"address": "http://127.0.0.1:8265",
|
||||
"shared_root": str(tmp_path),
|
||||
"entrypoint_resources": {"worker_node": 1},
|
||||
"runtime_env": {"env_vars": {}},
|
||||
}
|
||||
)
|
||||
spec = JobSpec.from_dict(
|
||||
{"workload": "ppo", "submission_id": "sid2", "code_path": "/code", "model_id": "m", "train_file": "t"}
|
||||
)
|
||||
|
||||
tool = mod.RayJobTool(cfg)
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
tool.submit(spec, no_wait=True)
|
||||
|
||||
err = tmp_path / "jobs" / "sid2" / "logs" / "submit.error.txt"
|
||||
assert err.exists()
|
||||
25
src/mvp/py/tests/test_ray_resources.py
Normal file
25
src/mvp/py/tests/test_ray_resources.py
Normal file
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
def test_get_cluster_available_sums_resources(monkeypatch):
|
||||
from argus.service import ray_resources
|
||||
|
||||
monkeypatch.setattr(ray_resources.ray._private.state, "available_resources_per_node", lambda: { # type: ignore[attr-defined]
|
||||
"n1": {"GPU": 1, "NPU": 2},
|
||||
"n2": {"GPU": 0.5},
|
||||
"bad": "nope",
|
||||
})
|
||||
|
||||
avail = ray_resources.get_cluster_available()
|
||||
assert avail.total_available_gpus == 1.5
|
||||
assert avail.total_available_npus == 2.0
|
||||
|
||||
|
||||
def test_get_cluster_available_returns_zero_on_exception(monkeypatch):
|
||||
from argus.service import ray_resources
|
||||
|
||||
def _boom() -> dict[str, object]:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(ray_resources.ray._private.state, "available_resources_per_node", _boom) # type: ignore[attr-defined]
|
||||
avail = ray_resources.get_cluster_available()
|
||||
assert avail.total_available_gpus == 0.0
|
||||
203
src/mvp/py/tests/test_scheduler.py
Normal file
203
src/mvp/py/tests/test_scheduler.py
Normal file
@ -0,0 +1,203 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from types import SimpleNamespace
|
||||
|
||||
from argus.ray.models import RayConfig
|
||||
from argus.service.config import V2Config
|
||||
from argus.service.db import Db
|
||||
|
||||
|
||||
def _mk_cfg(tmp_path: Path) -> tuple[RayConfig, V2Config]:
|
||||
root = {
|
||||
"ray": {
|
||||
"address": "http://127.0.0.1:8265",
|
||||
"shared_root": str(tmp_path),
|
||||
"entrypoint_resources": {"worker_node": 1},
|
||||
"runtime_env": {"env_vars": {}},
|
||||
},
|
||||
"service": {
|
||||
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
|
||||
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
|
||||
},
|
||||
}
|
||||
return RayConfig.from_dict(root), V2Config.from_root_dict(root)
|
||||
|
||||
|
||||
def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
|
||||
from argus.service import scheduler as sched_mod
|
||||
|
||||
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
||||
db = Db(v2_cfg.sqlite.db_path)
|
||||
db.init()
|
||||
db.create_task(
|
||||
task_id="t1",
|
||||
workload="ppo",
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
sched_mod,
|
||||
"get_cluster_available",
|
||||
lambda: SimpleNamespace(total_available_gpus=999.0, total_available_npus=0.0),
|
||||
)
|
||||
|
||||
class _Tool:
|
||||
def __init__(self, cfg):
|
||||
self.submitted = []
|
||||
|
||||
def submit(self, spec, no_wait: bool):
|
||||
self.submitted.append(spec.submission_id)
|
||||
return str(spec.submission_id)
|
||||
|
||||
def status(self, submission_id: str):
|
||||
return "RUNNING"
|
||||
|
||||
def logs(self, submission_id: str):
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr(sched_mod, "RayJobTool", _Tool)
|
||||
|
||||
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
|
||||
s.tick()
|
||||
|
||||
row = db.get_task("t1")
|
||||
assert row and row["state"] == "SUBMITTED"
|
||||
attempts = db.list_attempts("t1")
|
||||
assert len(attempts) == 1
|
||||
assert attempts[0]["ray_submission_id"] == "t1--a01"
|
||||
|
||||
|
||||
def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
|
||||
from argus.service import scheduler as sched_mod
|
||||
|
||||
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
||||
db = Db(v2_cfg.sqlite.db_path)
|
||||
db.init()
|
||||
db.create_task(
|
||||
task_id="t1",
|
||||
workload="ppo",
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
sched_mod,
|
||||
"get_cluster_available",
|
||||
lambda: SimpleNamespace(total_available_gpus=0.0, total_available_npus=0.0),
|
||||
)
|
||||
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
|
||||
|
||||
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
|
||||
s.tick()
|
||||
row = db.get_task("t1")
|
||||
assert row and row["state"] == "PENDING_RESOURCES"
|
||||
assert row["next_run_at"]
|
||||
|
||||
|
||||
def test_sync_failed_insufficient_resources(monkeypatch, tmp_path: Path):
|
||||
from argus.service import scheduler as sched_mod
|
||||
|
||||
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
||||
db = Db(v2_cfg.sqlite.db_path)
|
||||
db.init()
|
||||
db.create_task(
|
||||
task_id="t1",
|
||||
workload="ppo",
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="t1--a01")
|
||||
db.set_task_state(task_id="t1", state="RUNNING", latest_attempt_no=1)
|
||||
|
||||
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
|
||||
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
|
||||
|
||||
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
|
||||
|
||||
class _Tool:
|
||||
def status(self, sid: str):
|
||||
return "FAILED"
|
||||
|
||||
def logs(self, sid: str):
|
||||
# Match the service's regex exactly:
|
||||
# it expects literal backslashes and repeats of 's'/'d' (because of double-escaping).
|
||||
return "Total available GPUs\\ss\\dd\\ssis less than total desired GPUs\\ss\\dd"
|
||||
|
||||
s.tool = _Tool()
|
||||
s.tick()
|
||||
|
||||
row = db.get_task("t1")
|
||||
assert row and row["state"] == "PENDING_RESOURCES"
|
||||
attempts = db.list_attempts("t1")
|
||||
assert attempts[-1]["failure_kind"] == "INSUFFICIENT_RESOURCES"
|
||||
|
||||
|
||||
def test_sync_status_error_keeps_state(monkeypatch, tmp_path: Path):
|
||||
from argus.service import scheduler as sched_mod
|
||||
|
||||
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
||||
db = Db(v2_cfg.sqlite.db_path)
|
||||
db.init()
|
||||
db.create_task(
|
||||
task_id="t1",
|
||||
workload="ppo",
|
||||
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
|
||||
nnodes=2,
|
||||
n_gpus_per_node=4,
|
||||
)
|
||||
db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="t1--a01")
|
||||
db.set_task_state(task_id="t1", state="RUNNING", latest_attempt_no=1)
|
||||
|
||||
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
|
||||
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
|
||||
|
||||
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
|
||||
|
||||
class _Tool:
|
||||
def status(self, sid: str):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
s.tool = _Tool()
|
||||
s.tick()
|
||||
row = db.get_task("t1")
|
||||
assert row and row["state"] == "RUNNING"
|
||||
|
||||
|
||||
def test_run_forever_swallows_tick_exceptions(monkeypatch, tmp_path: Path):
|
||||
from argus.service import scheduler as sched_mod
|
||||
|
||||
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
|
||||
db = Db(v2_cfg.sqlite.db_path)
|
||||
db.init()
|
||||
|
||||
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
|
||||
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
|
||||
|
||||
calls = {"n": 0}
|
||||
|
||||
def _tick():
|
||||
calls["n"] += 1
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(s, "tick", _tick)
|
||||
monkeypatch.setattr(sched_mod.time, "sleep", lambda _: None)
|
||||
|
||||
class _Stop:
|
||||
def __init__(self):
|
||||
self._n = 0
|
||||
|
||||
def is_set(self):
|
||||
self._n += 1
|
||||
return self._n > 1
|
||||
|
||||
s.run_forever(_Stop())
|
||||
assert calls["n"] == 1
|
||||
38
src/mvp/py/tests/test_server.py
Normal file
38
src/mvp/py/tests/test_server.py
Normal file
@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_server_main_calls_uvicorn(monkeypatch, tmp_path: Path):
|
||||
import server as server_mod
|
||||
|
||||
cfg = tmp_path / "cfg.yaml"
|
||||
cfg.write_text(
|
||||
"ray:\n address: http://127.0.0.1:8265\n shared_root: /private\n entrypoint_resources: {worker_node: 1}\n"
|
||||
"service:\n api: {host: 127.0.0.1, port: 18080}\n sqlite: {db_path: /tmp/x.sqlite3}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
got = {}
|
||||
|
||||
monkeypatch.setattr(server_mod, "create_app", lambda path: object())
|
||||
monkeypatch.setattr(server_mod.uvicorn, "run", lambda app, host, port, log_level: got.update({"host": host, "port": port}))
|
||||
monkeypatch.setattr(sys, "argv", ["server.py", "--config", str(cfg)])
|
||||
|
||||
assert server_mod.main() == 0
|
||||
assert got["host"] == "127.0.0.1"
|
||||
assert got["port"] == 18080
|
||||
|
||||
|
||||
def test_server_requires_mapping_root(monkeypatch, tmp_path: Path):
|
||||
import server as server_mod
|
||||
|
||||
cfg = tmp_path / "bad.yaml"
|
||||
cfg.write_text("- 1\n- 2\n", encoding="utf-8")
|
||||
monkeypatch.setattr(sys, "argv", ["server.py", "--config", str(cfg)])
|
||||
with pytest.raises(SystemExit, match="config yaml must be a mapping"):
|
||||
server_mod.main()
|
||||
|
||||
40
src/mvp/py/tests/test_service_config.py
Normal file
40
src/mvp/py/tests/test_service_config.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_v2_config_from_root_dict_new_format_defaults():
|
||||
from argus.service.config import V2Config
|
||||
|
||||
cfg = V2Config.from_root_dict(
|
||||
{
|
||||
"ray": {"shared_root": "/private"},
|
||||
"service": {
|
||||
"api": {"host": "127.0.0.1", "port": 9999},
|
||||
"auth": {"token_env": "X"},
|
||||
"sqlite": {"db_path": "/tmp/x.sqlite3"},
|
||||
"scheduler": {"tick_s": 1, "retry_interval_s": 2, "max_running_tasks": 3},
|
||||
},
|
||||
}
|
||||
)
|
||||
assert cfg.api.host == "127.0.0.1"
|
||||
assert cfg.api.port == 9999
|
||||
assert cfg.auth.token_env == "X"
|
||||
assert cfg.sqlite.db_path.endswith(".sqlite3")
|
||||
assert cfg.scheduler.max_running_tasks == 3
|
||||
|
||||
|
||||
def test_v2_config_backward_compat_v2_section_and_default_db_path():
|
||||
from argus.service.config import V2Config
|
||||
|
||||
cfg = V2Config.from_root_dict({"shared_root": "/private", "v2": {"sqlite": {}}})
|
||||
assert cfg.sqlite.db_path == "/private/common/db/mvp.sqlite3"
|
||||
|
||||
|
||||
def test_v2_config_requires_mappings():
|
||||
from argus.service.config import V2Config
|
||||
|
||||
with pytest.raises(ValueError, match="config\\.service must be a mapping"):
|
||||
V2Config.from_root_dict({"service": ["nope"]})
|
||||
with pytest.raises(ValueError, match="config\\.service\\.\\{api,auth,sqlite,scheduler\\} must be mappings"):
|
||||
V2Config.from_root_dict({"service": {"api": [1], "auth": {}, "sqlite": {}, "scheduler": {}}})
|
||||
34
src/mvp/py/tests/test_yaml_io.py
Normal file
34
src/mvp/py/tests/test_yaml_io.py
Normal file
@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_load_yaml_empty_file(tmp_path: Path):
|
||||
from argus.ray.yaml_io import load_yaml
|
||||
|
||||
p = tmp_path / "empty.yaml"
|
||||
p.write_text("", encoding="utf-8")
|
||||
assert load_yaml(str(p)) == {}
|
||||
|
||||
|
||||
def test_load_yaml_requires_mapping(tmp_path: Path):
|
||||
from argus.ray.yaml_io import load_yaml
|
||||
|
||||
p = tmp_path / "bad.yaml"
|
||||
p.write_text("- 1\n- 2\n", encoding="utf-8")
|
||||
with pytest.raises(ValueError, match="yaml root must be a mapping"):
|
||||
load_yaml(str(p))
|
||||
|
||||
|
||||
def test_dump_yaml_roundtrip_smoke(tmp_path: Path):
|
||||
from argus.ray.yaml_io import dump_yaml, load_yaml
|
||||
|
||||
text = dump_yaml({"a": 1, "b": {"c": "d"}})
|
||||
assert "a: 1" in text
|
||||
|
||||
p = tmp_path / "x.yaml"
|
||||
p.write_text(text, encoding="utf-8")
|
||||
assert load_yaml(str(p))["b"]["c"] == "d"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user