v2.0 补充单元测试,行覆盖90以上

This commit is contained in:
yuyr 2025-12-26 10:50:33 +08:00
parent 4dacac24f0
commit ce8c2128b4
24 changed files with 2265 additions and 286 deletions

4
.gitignore vendored
View File

@ -2,3 +2,7 @@ verl/
skypilot-ssh-test/
ray_in_docker/
__pycache__/
.venv/
.pytest_cache/
.coverage
htmlcov/

3
pytest.ini Normal file
View File

@ -0,0 +1,3 @@
[pytest]
testpaths = src/mvp/py/tests
addopts = --maxfail=1 --cov=argus --cov=server --cov-report=term-missing --cov-fail-under=90

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

@ -1,96 +1,49 @@
渐进的 AI Infrastructure 演进路线图。从最初的单机脚本执行,到最终的智能化运维平台
对应架构演进图,设计**基于 Native Ray Cluster 与 Verl 框架的 AI Infra Roadmap 设计文档**。
这一版的设计采用了 **Overlay 架构 + GPFS 核心存储 + 无状态Stateless节点池** 的模式,逻辑非常自洽且具备极高的云原生弹性。
---
### **项目代号AI Infra Roadmap (Native Ray + Verl)**
### **项目代号AI Infra Overlay Platform (Stateless Ray + GPFS)**
#### **阶段一:核心内核构建 (Foundation & Core Execution)**
#### **阶段一:内核构建与验证 (Kernel & Verification)**
这一阶段主要解决“能不能跑”的问题,聚焦于核心计算引擎的对接和基础任务调度。
*目标:验证核心计算逻辑,跑通“提交-执行”的最小闭环。*
* **v1.1: 原型验证 (Verl Task Spec & Ray Job)**
* **核心功能**:实现了最基础的任务提交链路。
* **组件**
* **Ray Job Tool (Ray Client)**:作为客户端工具。
* **VerlTaskSpec YAML**:定义任务的标准配置文件。
* **Multi-Verl Code Path**:支持多代码路径。
* **核心功能**:实现基础的任务定义与提交。
* **组件**
* `Ray Job Tool (Ray Client)`:客户端工具。
* `VerlTaskSpec YAML`:定义多代码路径 (Multi-Verl Code Path) 和任务参数。
* **基础设施**Handmade Ray Cluster手工搭建的 Ray 集群)。
* **目标**:验证 Verl 框架与 Ray 的基本交互
* **基础设施**Handmade Ray Cluster手工搭建的集群用于验证核心代码
* **v2.0: 任务管理层 (Task Management)**
* **核心功能**:引入了服务化管理,不再单纯依赖命令行工具
* **新增组件**
* **API Server**:提供统一的接口层。
* **Task Management**:实现了任务队列 (Queue)、映射 (Map) 和重试/重新提交 (Resubmit) 机制。
* **核心功能**:引入服务端,管理任务生命周期
* **新增组件**
* `API Server`:统一接口层。
* `Task Management`:实现任务的队列 (Queue)、映射 (Map) 和重试 (Resubmit) 机制。
* **基础设施**:仍运行在 Handmade Ray Cluster 上。
* **v2.5: 资源与用户管理 (User & Node Management)**
* **核心功能**:从“手工集群”迈向“自动化集群”,并增加了多租户基础。
* **新增组件**
* **User Management**:用户权限与身份管理。
* **Node Management**:核心升级点。支持通过 SSH 管理节点池,实现 Auto-managed Ray Cluster自动管理的 Ray 集群),不再手动维护。
* **演进**:基础设施层由 Handmade 变为 SSH Node (Auto Managed)。
* **基础设施**:仍运行在手工集群上,但控制面开始服务化。
---
### **阶段二:产品化与服务化 (Productization & Serving)**
### **阶段二:架构质变 - 无状态节点池 (The Stateless Shift)**
这一阶段主要解决“好不好用”的问题,发布了第一个正式版本,并扩展了业务场景。
*目标:通过 GPFS 实现控制反转 (IoC),彻底解耦平台层与计算节点层。这是本架构最关键的转折点。*
* **v3.0: 正式发布版 (Frontend & Data Management)** * **里程碑****1st Version to Release!!** (首个对外发布版本)
* **核心功能**:完整的前后端分离,闭环了用户的数据流。
* **新增组件**
* **WebUI**:提供可视化的用户界面。
* **Data Management (SFTPGo)**:集成了 SFTPGo解决用户训练数据、代码的上传与下载问题。
* **价值**:用户可以通过 Web 界面完成从数据上传到任务提交的全流程。
* **v3.5: 定制化与推理服务 (Customized Task & Serving)**
* **核心功能**:支持更复杂的训练需求和模型推理。
* **新增组件**
* **Model Serving**:不仅能训练,还能部署模型服务。
* **Customized VerlTaskSpec YAML**:支持自定义参数 (Param)、奖励函数 (Reward)、Verl 代码等。
* **价值**:从单一的训练平台扩展为“训练+推理”的一体化平台,且支持算法工程师深度定制实验参数。
---
### **阶段三:可观测性体系 (Observability)**
这一阶段主要解决“看得清”的问题,确保系统的稳定性和模型训练的可追踪性。
* **v4.0: 系统级可观测性 (System Observability)**
* **核心功能**:建立完整的基础设施监控。
* **新增组件**
* **Prometheus**:指标采集。
* **Grafana**:监控大盘展示。
* **Alert**:告警系统。
* **ELK**:日志收集与分析 (Elasticsearch, Logstash, Kibana)。
* **基础设施升级**:在 SSH Node 上部署了 **Exporter**,用于采集节点层面的 metrics。
* **v4.5: 实验级可观测性 (ML Observability)**
* **核心功能**:专注于模型训练过程的指标追踪。
* **新增组件**
* **Weight & Bias (WanB)**:集成专业的 ML 实验追踪工具,用于记录 Loss、Accuracy 等训练指标。
* **v2.5: 用户管理 & 无状态 Ray 节点池 (User Mgmt & Stateless Ray Node Pool)** * **核心机制:基于 GPFS 的服务发现 (Service Discovery)**
* **Ray Head (有状态)**:由 `Node Management` 启动(通常通过 SSH 或 K8s StatefulSet。启动后将自身的 IP 地址写入 GPFS 中的 `Head IP File`
* **Ray Worker (无状态)**
* **Stateless**Worker 容器启动时不依赖平台指令。
* **Auto Connect**:启动脚本读取 GPFS 中的 `Head IP File`,获得 Head 地址并自动加入集群。
* **Watchdog**Worker 内部运行看门狗进程,监控 Head IP 变化。如果 Head 变动Worker 自动重启或重连,实现自愈。
* **新增组件**
* `User Management`:多用户隔离。
* `GPFS`:取代了之前的 JuiceFS作为唯一的共享存储和元数据交换媒介。
@ -98,33 +51,83 @@
---
### **阶段四:智能化运维 (Operability & Intelligence)**
### **阶段三:产品化与高级能力 (Productization & Advanced Features)**
这一阶段主要解决“自动化”的问题,引入 AI 来管理 AI 平台。
*目标:发布首个正式版本,并支持大模型训练所需的复杂网络与推理能力。*
* **v5.0: 智能运维闭环 (Statistics, SOP, Agent)**
* **核心功能**:通过数据统计和 Agent 实现平台的自动化治理。
* **新增组件**
* **Statistics**:平台维度的统计分析(资源利用率、任务成功率等)。
* **SOP Tools**:标准作业程序工具化(自动化运维脚本)。
* **Agent**:智能体。可能用于自动故障诊断、资源自动调度优化或交互式助手。
* **v3.0: 正式发布版 (Release v1.0)** * **里程碑****1st Version to Release!!**
* **核心功能**:闭环用户数据流。
* **新增组件**
* `WebUI`:可视化操作界面。
* `Data Management (SFTPGo)`:用户上传数据/代码 -> SFTPGo -> 写入 GPFS -> Ray Worker 可见。
* **基础设施**:全量切换到 `Ray Worker Node` (Stateless) + `GPFS` 的架构。
* **v3.5: 高级定制与训推一体 (Advanced Task & Serving)** * **核心功能**:支持复杂的科研需求。
* **新增组件**
* `Model Serving`:支持模型推理服务。
* `Advanced VerlTaskSpec`:支持自定义 Reward Function、自定义代码、Checkpoint 断点续训 (Resubmit from last checkpoint)。
* **网络增强**
* **IB Network Supporting**:支持 InfiniBand 网络,确保多机训练的高性能互联。
* **愿景**:打造一个具备自我管理、自我修复能力的 AI 基础设施平台。
---
### **架构层级总结**
### **阶段四:全链路可观测性 (Full-Stack Observability)**
*目标:打开黑盒,监控基础设施与业务指标。*
* **v4.0: 系统级可观测性 (System Observability)** * **核心功能**:监控集群“活着”且“健康”。
* **新增组件**
* `Prometheus` + `Grafana` + `ELK`:指标与日志平台。
* `Exporter`:部署在 Ray Worker Node 中的监控探针(采集 GPU/CPU/GPFS IO 指标)。
* **v4.5: 算法级可观测性 (ML Observability)** * **核心功能**:监控模型“练得好不好”。
* **新增组件**
* `Weights & Bias (WanB)`:集成实验追踪工具,记录 Loss 曲线和训练参数。
| 层级 | 关键组件/技术 |
| --- | --- |
| **接入层 (Frontend/API)** | WebUI, API Server, User Management |
| **调度与编排 (Orchestration)** | Task Management, Ray Job Tool (Client), Node Management |
| **计算引擎 (Compute)** | Native Ray Cluster, Verl Framework (TaskSpec YAML) |
| **数据与存储 (Data)** | SFTPGo (Data Management), Model Serving |
| **可观测性 (Observability)** | Prometheus, Grafana, ELK, Weights & Bias |
| **运维与智能 (Ops)** | Exporters, Statistics, SOP Tools, Agent |
---
### **阶段五:智能化运维 (AIOps)**
*目标:迈向自动化与自治。*
* **v5.0: 智能运维闭环 (Operability)** * **核心功能**:降低运维成本,提升稳定性。
* **新增组件**
* `Statistics`:集群资源利用率统计报表。
* `SOP Tools`:标准运维工具(如自动清理 GPFS 垃圾文件、僵尸节点检测)。
* `Agent`:智能运维助手(基于 LLM 的日志分析与故障诊断)。
---
### **新架构核心亮点总结**
1. **极简的节点管理**
* 利用 v2.5 的 **Head IP File + Watchdog** 机制,平台层不再需要维护复杂的 Worker IP 列表和 SSH 连接池。
* **扩缩容极其简单**只需在底层K8s/Docker增加 Worker 副本数,它们就会自动通过 GPFS 找到 Head 并加入战斗。
2. **统一的数据平面 (GPFS)**
* 从 v2.5 开始GPFS 承担了 **数据存储** (Code/Data)、**状态同步** (Head IP) 和 **检查点存储** (Checkpoints) 三大职责,架构非常收敛。
3. **高弹性 (Resilience)**
* Worker 的 **Watchdog** 机制确保了当 Head 重启或网络抖动时,集群具备自我修复能力,无需人工干预。

File diff suppressed because it is too large Load Diff

14
specs/mvp/v2.5/README.md Normal file
View File

@ -0,0 +1,14 @@
# MVP v2.5Design— User Management & Stateless Ray Node Pool
本目录基于 `specs/mvp/mvp_roadmap_v2.md``specs/mvp/image/roadmap_v2.5.png` 的 v2.5 规划,
给出一份**可落地、可验证、可迭代实现**的详细方案设计文档集合。
v2.5 的核心变化:
- 在 v2.0 的任务队列/调度/重试基础上,引入 **User Management**多用户隔离、目录隔离、token
- 引入 **Stateless Ray Node Pool**worker 节点/容器不再需要平台显式下发 head 地址通过共享存储GPFS/NFS完成服务发现与自愈连接watchdog
文档:
- `specs/mvp/v2.5/v2.5_design.md`总体架构、关键机制head IP file / watchdog / 用户隔离 / 任务流)。
- `specs/mvp/v2.5/v2.5_api.md`API 设计(用户、任务、队列、日志)与鉴权约定。
- `specs/mvp/v2.5/v2.5_acceptance.md`:开发/部署/验收流程与可验证标准。

View File

@ -0,0 +1,67 @@
# MVP v2.5 开发/部署/验收标准
本文件定义 v2.5 的“可验证闭环”,确保每个里程碑可验收。
---
## 1. 开发交付物Deliverables
### 1.1 代码交付(建议)
- API Server 增强user management + task 关联 user_id + 鉴权隔离
- SQLite schema 迁移:新增 users/tokenstasks 增加 user_id
- Ray Head service discoveryhead.json 写入与心跳刷新
- Worker bootstrap + watchdog
- dev以脚本方式提供docker compose 场景)
- prod以容器 command/entrypoint 方式可注入
### 1.2 文档交付
- 目录结构与 GPFS 路径约定
- API 文档(含用户与多租户隔离)
- 运维 SOPhead 重启、worker 自愈、如何排障 head.json
---
## 2. 部署流程Dev 环境可验证)
### 2.1 启动顺序(推荐)
1) 启动 head包含 API server + Ray head
2) head 写入 `/private/ray/discovery/<cluster_name>/head.json`
3) 启动若干 worker无须指定 head 地址)
4) worker 自动读取 head.json 并加入集群
5) 通过 API 创建用户并获取 token
6) 使用 user token 提交 PPO/GRPO/SFT
---
## 3. 验收标准Acceptance Criteria
### 3.1 Stateless Ray Node Pool
- A1在 worker 启动时不传 head 地址worker 能在 `T<=60s` 内加入集群ray status 可见)
- A2head 容器重启IP 变化或 Ray 重启)后:
- head.json 更新
- worker watchdog 在 `T<=60s` 内自动重连
- A3head 设置 `--num-gpus=0 --num-cpus=0`,训练 driver 不会跑到 head可通过 Ray dashboard/日志验证)
### 3.2 User Management
- U1admin 可创建用户并签发 tokentoken 仅返回一次)
- U2用户 A 提交的 task用户 B 无法查询/取消/获取日志API 返回 404 或 403按设计约定
- U3仅隔离 jobs 输出:任务输出落在 `/private/users/<user_id>/jobs/<ray_submission_id>/...`,不同用户互不覆盖
- U4训练输入verl 代码、HF cache、datasets统一使用 `/private/common/...`v2.5 不做输入隔离)
### 3.3 Task Flow继承 v2.0
- T1PPO/GRPO/SFT 三种 workload 都能成功提交并跑通dev 规模可用 epoch=1/steps=10
- T2资源不足时任务不会“直接失败不可恢复”而是进入 `PENDING_RESOURCES` 并按间隔重试(与 v2.0 同逻辑)
---
## 4. 回归用例(最小集合)
1) 创建用户 alice/bob分别提交 sft验证隔离与输出目录
2) 启动 head + 2 workers提交 ppo/grpo验证 driver 落 worker
3) 重启 head或修改 head.json 指向新 IP验证 worker watchdog 自动重连

109
specs/mvp/v2.5/v2.5_api.md Normal file
View File

@ -0,0 +1,109 @@
# MVP v2.5 API 设计User + Task + Queue
v2.5 在 v2.0 API 基础上,新增 **User Management** 与多租户隔离。
约束:
- 仍使用内部 tokenAPI key
- 不引入外部 IAM
- TaskSpec 仍为 YAML沿用现有结构化字段
---
## 1. Auth
Header
- `Authorization: Bearer <api_token>`
服务端行为:
- 将 `api_token` 映射到 `user_id`
- 之后的 task 操作默认仅作用于该 `user_id`
Admin token可选
- 支持额外配置 `MVP_ADMIN_TOKEN`(或 user.role=admin
- admin 可跨用户查询/取消(用于运维)。
---
## 2. User Management
### 2.1 创建用户admin
`POST /api/v2/users`
RequestJSON
```json
{"user_id":"alice","display_name":"Alice"}
```
Response
```json
{"user_id":"alice","state":"ACTIVE"}
```
### 2.2 为用户签发 tokenadmin
`POST /api/v2/users/{user_id}/tokens`
Response只返回一次明文 token
```json
{"user_id":"alice","token":"mvp_u_..."}
```
### 2.3 禁用用户admin
`POST /api/v2/users/{user_id}:disable`
---
## 3. Task Management多租户
### 3.1 提交任务
`POST /api/v2/tasks`
Body
- `Content-Type: application/yaml`
- raw TaskSpec YAML训练语义字段不含 user_id
Response
```json
{"task_id":"mvp25-ppo-20251225-170001-2a3f","state":"QUEUED"}
```
服务端 side effects
- 记录 tasks.user_id由 token 得到)
- 计算输出目录:`/private/users/<uid>/jobs/<ray_submission_id>/...`
### 3.2 查询任务(仅本人)
`GET /api/v2/tasks/{task_id}`
若 task 不属于当前 user
- 返回 `404`(避免泄露存在性)
### 3.3 取消任务(仅本人)
`POST /api/v2/tasks/{task_id}:cancel`
---
## 4. Queue/Debug
### 4.1 查看队列(本人视角)
`GET /api/v2/queue`
返回该 user 的 pending/running 列表。
### 4.2 管理员查看全局队列admin
`GET /api/v2/admin/queue`
---
## 5. Logs
`GET /api/v2/tasks/{task_id}/logs?attempt=latest&tail=2000`
行为与 v2.0 一致:透传 Ray Job logs tail。

View File

@ -0,0 +1,255 @@
# MVP v2.5 详细设计方案User Management + Stateless Ray Node Pool
本文目标:把 `mvp_roadmap_v2.md` 中 v2.5 的思路落到**可工程化实现**的设计层,包括:
- API Server 内新增 user management
- Ray node pool 变为无状态worker 自发现 head、自动加入、watchdog 自愈);
- 仍保持 v2.0 的“任务管理层”语义Task/Attempt、队列、资源判断、Ray Job 提交与状态同步;
- 所有共享数据/状态统一落在 GPFSdev 环境可先用 NFS容器内路径统一为 `/private/`
> 术语说明文中“GPFS”代表生产共享存储dev 环境可用 NFS但容器内仍以 `/private/` 访问。
---
## 1. 目标与非目标
### 1.1 v2.5 目标Must
1) **User Management最小多租户**
- 支持创建/禁用用户;
- 为每个用户签发内部 tokenAPI key用于认证与隔离
- 用户隔离v2.5 先做最小闭环,仅隔离 **jobs 输出** 与 API 可见性):
- 用户只能看到/操作自己的 Task
- 训练输出job root、checkpoints、日志归档等按 user 目录落盘;
- 训练输入verl 代码、HF cache、datasets统一使用 `common/`v2.5 不支持用户自定义代码/模型/数据集隔离)。
2) **Stateless Ray Worker Node Pool**
- worker 容器启动时无需被平台告知 head 地址;
- worker 通过 GPFS 读取 **Head IP File** 自动连接 Ray head
- worker 内部 watchdog 监控 head 地址变化,发生变化时自动 `ray stop` + `ray start` 重连;
- worker 尽量不依赖本地持久化状态(宕机/替换后可无感重建)。
3) **保持 v2.0 的 Task 管理行为**
- Task/Attempt 模型不变(或向后兼容扩展);
- 对齐 verl 的 fail-fast 行为:资源不足时服务侧 pending + 重试;
- Ray Job 提交仍通过 Ray Python SDKJobSubmissionClient
### 1.2 v2.5 非目标Not Now
- 完整 WebUI留到 v3.0)。
- 公平调度/配额/优先级(留到 v3.x
- 完整生产级 IAM留到 v4+v2.5 仅内部 token。
- K8s 原生编排(本阶段不要求,但设计需能适配“算力平台拉起容器,只能 ssh 进去纳管”的模式)。
---
## 2. 总体架构(对应 roadmap v2.5
### 2.1 组件划分
**控制面Control Plane**
- **API Server**
- user management
- task management队列/调度/重试/状态聚合)
- Ray Job ToolRay Client
- VerlTaskSpecTaskSpec YAML沿用 v2.0/v2.1 格式)
- 与 Ray head 在同一台/同一容器是推荐形态(便于访问 dashboard / job server
- **Ray Head有状态**
- 启动后把 head 地址写入 GPFS 的 Head IP File用于 worker 服务发现
**数据面Data Plane**
- **Ray Workers无状态节点池**
- stateless bootstrap从 GPFS 读取 head 地址自动加入集群
- watchdog持续 watch head 地址文件变化并自愈重连
**共享存储GPFS**
- 统一数据路径:数据、模型 cache、代码、任务输出、以及 head 服务发现文件。
### 2.2 v2.5 的控制反转IoC
与 v2.0/手工集群的关键差异:
- v2.0:平台脚本/运维显式启动 worker 并指定 `--address=<head>`
- v2.5worker 自己从 GPFS 读取 `head_ip_file`,无需平台维持 worker 列表与 SSH 连接池。
---
## 3. GPFS 目录结构(容器内 `/private`
建议在 v2.5 固化以下目录(与现有 v2.0 兼容扩展):
```
/private/
ray/
discovery/
<cluster_name>/
head.json # Head IP File服务发现
head.json.lock # 可选写入锁v2.5 可先不实现)
users/
<user_id>/
jobs/ # /private/users/<uid>/jobs/<ray_submission_id>/*
outputs/ # 训练输出聚合(按需要)
common/
code/ # 平台/公共代码快照verl code snapshot 等)
datasets/ # 公共数据集
hf/ # 公共 HF cachedev 复用)
db/ # sqlite
logs/ # API 日志、平台日志
```
说明:
- `common/`平台默认目录v2.5 先默认所有用户可写;后续再加 ACL/只读)。
- `users/<user_id>/...`:用户隔离主边界(最小多租户的关键)。
---
## 4. Head IP File服务发现设计
### 4.1 文件路径
- `head_ip_file = /private/ray/discovery/<cluster_name>/head.json`
- `<cluster_name>`:由配置指定(例如 `argus-ray`),允许同一 GPFS 上存在多个环境/集群。
### 4.2 文件内容JSON
建议采用 JSON易扩展
```json
{
"cluster_name": "argus-ray",
"head_ip": "10.0.0.12",
"gcs_port": 6379,
"dashboard_port": 8265,
"job_server_url": "http://10.0.0.12:8265",
"updated_at": "2025-12-25T17:00:00Z",
"expires_at": "2025-12-25T17:01:00Z"
}
```
关键点:
- `updated_at`:便于排障与可观测;
- `expires_at`:避免 worker 读取到“陈旧 head 地址”后无限重连;
- `job_server_url`:对外可直接用于 Ray Job Tool 配置(便于无脑接入)。
### 4.3 写入策略(原子更新)
Head 写入时必须保证 worker 读取不会读到半文件:
- 写临时文件 `head.json.tmp`
- `fsync`(可选);
- `rename(head.json.tmp -> head.json)`(原子替换)。
### 4.4 心跳与 TTL
Head 进程需周期性刷新 `head.json`
- 建议 `ttl_s=60`,刷新周期 `refresh_s=10`
- 若 head 进程异常退出worker 读取到过期文件可进入“等待模式”而非无限重连。
---
## 5. Stateless Worker Bootstrap + Watchdog
### 5.1 启动序列worker 容器内)
1) 启动脚本读取 `head.json`
- 若文件不存在sleep + 重试(直到存在)
- 若存在但 `expires_at` 已过期sleep + 重试(直到变为新鲜)
2) 解析 `head_ip:gcs_port` 并执行:
- `ray stop --force || true`
- `ray start --address=<head_ip>:<gcs_port> --resources='{"worker_node": 100, ...}' ...`
3) 启动 watchdog 进程(同容器):
- 轮询/监听 `head.json` 的内容变化
- 一旦 `head_ip``gcs_port` 改变,触发 `ray stop` + `ray start` 重连
### 5.2 Watchdog 策略(最小可用)
v2.5 推荐“简单且稳”的实现:
- polling 间隔 `watch_s=5`
- 对比 `head.json``updated_at` 或 hash
- 若发现变更:执行重连;
- 若连续多次重连失败指数退避v2.5 可先固定退避v2.6 再增强)。
### 5.3 资源标签driver 强制落 worker
继续沿用 v2.0 的思路:
- worker 启动时 `--resources='{"worker_node": 100}'`
- head 不包含 `worker_node` 资源
- Ray job submit 时设置 entrypoint_resources`{"worker_node": 1}`
### 5.4 GPU/CPU 的“无状态”约束
- worker 是否有 GPU 由底层算力平台决定(生产上平台会为容器挂载 GPU
- worker 启动脚本不应硬编码 GPU 编号,只依赖 `NVIDIA_VISIBLE_DEVICES`/平台注入;
- head 推荐 `--num-gpus=0 --num-cpus=0`,避免训练调度到 head。
---
## 6. User Management 设计(最小多租户)
### 6.1 数据模型SQLite
新增两张表(示意):
- `users`
- `user_id`PK
- `display_name`
- `state`ACTIVE/DISABLED
- `created_at`
- `api_tokens`
- `token_hash`PK
- `user_id`FK
- `created_at`
- `last_used_at`
并在 `tasks` 表增加:
- `user_id`FK
### 6.2 鉴权策略
内部 token 模式:
- `Authorization: Bearer <token>`
- 服务端将 token 映射到 `user_id`
- 后续所有 task 查询/取消/日志默认 scope 到该 `user_id`
管理员能力v2.5 最小实现):
- 额外配置一个 admin token或把特定 user 标记为 admin
- admin 可 list all users/tasks用于运维排障
### 6.3 用户目录隔离(路径约束)
核心原则v2.5 版):
- **输出**:必须落在 `/private/users/<uid>/jobs/...`(服务端统一计算,不允许用户任意指定输出根)
- **输入**:统一使用 `/private/common/...`v2.5 不支持用户自定义 verl 代码、也不做 hf/datasets 的用户隔离)
服务端处理策略(最小可用):
- 解析 TaskSpec 后,对输入路径字段做白名单前缀校验(必须是 `/private/common/...`;拒绝 `../` 与越界路径);
- 输出目录统一由服务端计算:`job_root = /private/users/<uid>/jobs/<ray_submission_id>/`
---
## 7. TaskSpecVerlTaskSpec YAML在 v2.5 的扩展点
v2.5 **不扩展 TaskSpec**:保持与 v2.0/v2.1 的 YAML 结构化字段与语义一致。
v2.5 的“用户语义”仅体现在服务端的补齐/约束:
- user_id 由 token 推导(用户不需要在 YAML 里写 user_id
- 服务端派生 `ray_submission_id`(由 task_id/attempt 派生);
- 服务端统一计算输出目录 `job_root=/private/users/<uid>/jobs/<ray_submission_id>/...`
- v2.5 不支持用户自定义 verl 代码路径(因此 runtime_env 不需要注入用户 code 目录)。
---
## 8. 迁移与兼容性
v2.5 设计需满足:
- 现有 v2.0 的“手工启动 worker”仍可运行作为 dev fallback
- 在不改镜像的前提下worker watchdog 可以以“容器启动命令/entrypoint”方式注入dev 用 scripts生产由算力平台指定 command
---
## 9. 风险与对策v2.5
1) **GPFS 上 head.json 一致性/延迟**
- 对策:原子 rename + TTLwatchdog polling。
2) **Ray head 重启后 job server URL 变化**
- 对策head.json 内写入 `job_server_url`Ray Job Tool 可读取该文件更新 addressv2.6 可做动态 reload
3) **Worker 重连期间任务波动**
- 对策:服务侧调度器对齐 verl 的资源 fail-fast任务失败可归因并排队重试。

View File

@ -0,0 +1,4 @@
pytest==8.4.1
pytest-cov==6.3.0
httpx==0.28.1

View File

@ -0,0 +1,77 @@
from __future__ import annotations
import os
import sys
import types
from pathlib import Path
def _ensure_mvp_py_on_path() -> None:
repo_root = Path(__file__).resolve().parents[4]
py_root = repo_root / "src" / "mvp" / "py"
if str(py_root) not in sys.path:
sys.path.insert(0, str(py_root))
def _install_ray_stub() -> None:
if "ray" in sys.modules:
return
ray = types.ModuleType("ray")
ray.__path__ = [] # type: ignore[attr-defined]
def _init(*args: object, **kwargs: object) -> None:
return None
ray.init = _init # type: ignore[attr-defined]
ray.cluster_resources = lambda: {} # type: ignore[attr-defined]
ray.available_resources = lambda: {} # type: ignore[attr-defined]
sys.modules["ray"] = ray
job_submission = types.ModuleType("ray.job_submission")
job_submission.__path__ = [] # type: ignore[attr-defined]
class JobSubmissionClient: # minimal stub; tests can monkeypatch methods
def __init__(self, address: str):
self.address = address
def submit_job(self, **kwargs: object) -> str:
raise NotImplementedError
def get_job_status(self, submission_id: str) -> str:
raise NotImplementedError
def stop_job(self, submission_id: str) -> bool:
raise NotImplementedError
def get_job_logs(self, submission_id: str) -> str:
raise NotImplementedError
def list_jobs(self) -> list[object]:
return []
job_submission.JobSubmissionClient = JobSubmissionClient # type: ignore[attr-defined]
sys.modules["ray.job_submission"] = job_submission
ray.job_submission = job_submission # type: ignore[attr-defined]
private = types.ModuleType("ray._private")
private.__path__ = [] # type: ignore[attr-defined]
state = types.ModuleType("ray._private.state")
def available_resources_per_node() -> dict[str, object]:
return {}
state.available_resources_per_node = available_resources_per_node # type: ignore[attr-defined]
sys.modules["ray._private"] = private
sys.modules["ray._private.state"] = state
private.state = state # type: ignore[attr-defined]
ray._private = private # type: ignore[attr-defined]
_ensure_mvp_py_on_path()
_install_ray_stub()
def pytest_configure(config: object) -> None:
os.environ.setdefault("PYTHONUTF8", "1")

View File

@ -0,0 +1,166 @@
from __future__ import annotations
import os
from pathlib import Path
import pytest
import yaml
from fastapi.testclient import TestClient
def _write_config(tmp_path: Path) -> Path:
cfg = {
"ray": {
"address": "http://127.0.0.1:8265",
"shared_root": str(tmp_path),
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
},
"service": {
"api": {"host": "127.0.0.1", "port": 0},
"auth": {"token_env": "MVP_INTERNAL_TOKEN"},
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
},
}
p = tmp_path / "cfg.yaml"
p.write_text(yaml.safe_dump(cfg), encoding="utf-8")
return p
def test_auth_requires_token_env(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.delenv("MVP_INTERNAL_TOKEN", raising=False)
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
with TestClient(app) as c:
r = c.get("/api/v2/queue")
assert r.status_code == 500
def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
monkeypatch.setattr(app_mod, "new_task_id", lambda workload: "tid1")
class _Tool:
def __init__(self):
self.stopped = []
def stop(self, sid: str):
self.stopped.append(sid)
return True
def logs(self, sid: str):
return "a\nb\nc\n"
class _Scheduler:
def __init__(self, **kwargs):
self.tool = _Tool()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
headers = {"authorization": "Bearer token1"}
with TestClient(app) as c:
r = c.post(
"/api/v2/tasks",
headers=headers,
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
)
assert r.status_code == 200
assert r.json()["task_id"] == "tid1"
r2 = c.get("/api/v2/tasks/tid1", headers=headers)
assert r2.status_code == 200
assert r2.json()["desired_resources"]["total_gpus"] == 8
r3 = c.get("/api/v2/queue", headers=headers)
assert r3.status_code == 200
assert "pending" in r3.json()
r4 = c.post("/api/v2/tasks/tid1:cancel", headers=headers)
assert r4.status_code == 200
assert r4.json()["state"] == "CANCELED"
# Seed an attempt then fetch logs
from argus.service.db import Db
from argus.service.config import V2Config
from argus.ray.models import RayConfig
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
ray_cfg = RayConfig.from_dict(root)
v2_cfg = V2Config.from_root_dict(root)
db = Db(v2_cfg.sqlite.db_path)
db.create_task(
task_id="tid2",
workload="ppo",
jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
nnodes=2,
n_gpus_per_node=4,
)
db.create_attempt(task_id="tid2", attempt_no=1, ray_submission_id="sid2")
db.set_task_state(task_id="tid2", state="RUNNING", latest_attempt_no=1)
r5 = c.get("/api/v2/tasks/tid2/logs?tail=1", headers=headers)
assert r5.status_code == 200
assert r5.text.strip() == "c"
def test_submit_rejects_non_mapping_jobspec(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
with TestClient(app) as c:
r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="- 1\n- 2\n")
assert r.status_code == 400
def test_submit_rejects_invalid_jobspec(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
with TestClient(app) as c:
r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="workload: nope\n")
assert r.status_code == 400

View File

@ -0,0 +1,53 @@
from __future__ import annotations
import pytest
from argus.ray.builders import build_training_argv
from argus.ray.models import JobSpec
def _base_spec(workload: str) -> JobSpec:
return JobSpec(
workload=workload,
submission_id=None,
code_path="/code",
model_id="m",
train_file="train.jsonl",
val_file=None,
nnodes=2,
n_gpus_per_node=4,
total_epochs=1,
total_training_steps=10,
save_freq=10,
test_freq=None,
trainer_device=None,
)
def test_build_training_argv_ppo_smoke():
spec = _base_spec("ppo")
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
assert built.argv[:3] == ["python3", "-m", "verl.trainer.main_ppo"]
assert "data.val_files=null" in built.argv
assert "trainer.test_freq=-1" in built.argv
def test_build_training_argv_grpo_has_override():
spec = _base_spec("grpo")
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
assert "algorithm.adv_estimator=grpo" in built.argv
def test_build_training_argv_sft_smoke():
spec = _base_spec("sft")
built = build_training_argv(spec, submission_id="sid", job_dir="/job")
assert built.argv[:3] == ["python3", "-m", "verl.trainer.sft_trainer_ray"]
assert "trainer.device=cpu" in built.argv
assert "data.val_files=null" in built.argv
def test_build_training_argv_unsupported_raises():
spec = _base_spec("bad")
with pytest.raises(ValueError, match="unsupported workload"):
build_training_argv(spec, submission_id="sid", job_dir="/job")

View File

@ -0,0 +1,77 @@
from __future__ import annotations
import json
import sys
from pathlib import Path
import pytest
def test_cli_submit_status_logs_list(monkeypatch, tmp_path: Path, capsys):
from argus.ray import ray_job_tool as tool_mod
class _Tool:
def __init__(self, cfg):
self.cfg = cfg
def submit(self, spec, no_wait: bool):
return "sid1"
def status(self, sid: str):
return "RUNNING"
def stop(self, sid: str):
return True
def logs(self, sid: str):
return "1\n2\n3\n"
def list(self):
return [{"a": 1}]
monkeypatch.setattr(tool_mod, "RayJobTool", _Tool)
cfg = tmp_path / "cfg.yaml"
cfg.write_text(
"ray:\n address: http://127.0.0.1:8265\n shared_root: /private\n entrypoint_resources: {worker_node: 1}\n",
encoding="utf-8",
)
spec = tmp_path / "spec.yaml"
spec.write_text("workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n", encoding="utf-8")
from argus.cli.run import main
monkeypatch.setattr(sys, "argv", ["run.py", "--config", str(cfg), "--taskspec", str(spec), "--action", "submit"])
assert main() == 0
assert capsys.readouterr().out.strip() == "sid1"
monkeypatch.setattr(sys, "argv", ["run.py", "--config", str(cfg), "--action", "status", "--submission-id", "sid1"])
assert main() == 0
assert capsys.readouterr().out.strip() == "RUNNING"
monkeypatch.setattr(
sys, "argv", ["run.py", "--config", str(cfg), "--action", "logs", "--submission-id", "sid1", "--tail", "1"]
)
assert main() == 0
assert capsys.readouterr().out.strip() == "3"
monkeypatch.setattr(sys, "argv", ["run.py", "--config", str(cfg), "--action", "list"])
assert main() == 0
out = capsys.readouterr().out
assert json.loads(out)[0]["a"] == 1
def test_cli_requires_submission_id_for_status():
from argus.cli.run import main
tmp = Path(__import__("tempfile").gettempdir()) / "mvp_test_cfg.yaml"
tmp.write_text(
"ray:\n address: http://127.0.0.1:8265\n shared_root: /private\n entrypoint_resources: {worker_node: 1}\n",
encoding="utf-8",
)
try:
with pytest.raises(SystemExit):
sys.argv = ["run.py", "--config", str(tmp), "--action", "status"]
main()
finally:
tmp.unlink(missing_ok=True)

View File

@ -0,0 +1,42 @@
from __future__ import annotations
from pathlib import Path
def test_db_lifecycle_and_basic_queries(tmp_path: Path):
from argus.service.db import Db
db = Db(str(tmp_path / "mvp.sqlite3"))
db.init()
t = db.create_task(
task_id="t1",
workload="ppo",
jobspec_yaml="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: t\n",
nnodes=2,
n_gpus_per_node=4,
)
assert t["task_id"] == "t1"
assert db.get_task("t1") is not None
q = db.list_queue()
assert len(q["pending"]) == 1
db.set_task_state(task_id="t1", state="PENDING_RESOURCES", next_run_at="2099-01-01T00:00:00Z")
assert db.pick_next_runnable_task() is None
# next_run_at is sticky unless explicitly updated; a future value keeps it non-runnable.
db.set_task_state(task_id="t1", state="QUEUED", next_run_at=None)
assert db.pick_next_runnable_task() is None
# Allow it by setting next_run_at into the past.
db.set_task_state(task_id="t1", state="QUEUED", next_run_at="2000-01-01T00:00:00Z")
assert db.pick_next_runnable_task() is not None
db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="sid1")
db.update_attempt(task_id="t1", attempt_no=1, ray_status="RUNNING")
attempts = db.list_attempts("t1")
assert attempts[-1]["ray_status"] == "RUNNING"
# No-op update is allowed
db.update_attempt(task_id="t1", attempt_no=1)

View File

@ -0,0 +1,25 @@
from __future__ import annotations
import sys
from pathlib import Path
def test_driver_entrypoint_missing_cmd_returns_2(monkeypatch, tmp_path: Path):
from argus.ray.driver_entrypoint import main
monkeypatch.setattr(sys, "argv", ["x", "--job-dir", str(tmp_path)])
assert main() == 2
def test_driver_entrypoint_strips_double_dash_and_returns_code(monkeypatch, tmp_path: Path):
from argus.ray import driver_entrypoint as mod
class _Proc:
returncode = 7
monkeypatch.setattr(mod.subprocess, "run", lambda cmd, check: _Proc())
monkeypatch.setattr(sys, "argv", ["x", "--job-dir", str(tmp_path), "--", "echo", "hi"])
assert mod.main() == 7
assert (tmp_path).exists()

View File

@ -0,0 +1,28 @@
from __future__ import annotations
def test_new_task_id_is_deterministic_with_patches(monkeypatch):
import argus.core.ids as ids
class _FakeDatetime:
@staticmethod
def now():
class _DT:
def strftime(self, fmt: str) -> str:
assert fmt == "%Y%m%d-%H%M%S"
return "20250101-010203"
return _DT()
monkeypatch.setattr(ids, "datetime", _FakeDatetime)
monkeypatch.setattr(ids.secrets, "token_hex", lambda n: "abcd")
assert ids.new_task_id("ppo") == "mvp2-ppo-20250101-010203-abcd"
def test_attempt_submission_id_format():
from argus.core.ids import attempt_submission_id
assert attempt_submission_id("t", 1) == "t--a01"
assert attempt_submission_id("t", 12) == "t--a12"

View File

@ -0,0 +1,107 @@
from __future__ import annotations
import pytest
def test_require_missing_raises():
from argus.ray.models import _require
with pytest.raises(ValueError, match="missing required field: x"):
_require({}, "x")
with pytest.raises(ValueError, match="missing required field: x"):
_require({"x": ""}, "x")
def test_ray_config_from_dict_new_format_and_defaults():
from argus.ray.models import RayConfig
cfg = RayConfig.from_dict(
{
"ray": {
"address": "http://127.0.0.1:8265",
"shared_root": "/private",
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {"HF_ENDPOINT": "x"}},
}
}
)
assert cfg.address.endswith("8265")
assert cfg.shared_root == "/private"
assert cfg.entrypoint_num_cpus == 1.0
assert cfg.entrypoint_resources["worker_node"] == 1.0
assert cfg.runtime_env_env_vars["HF_ENDPOINT"] == "x"
assert cfg.user_code_path == "/private/user/code"
public = cfg.to_public_dict()
assert public["runtime_env"]["env_vars"]["HF_ENDPOINT"] == "x"
def test_ray_config_from_dict_requires_mappings():
from argus.ray.models import RayConfig
with pytest.raises(ValueError, match="runtime_env\\.env_vars must be a mapping"):
RayConfig.from_dict(
{
"address": "x",
"shared_root": "/p",
"entrypoint_resources": {},
"runtime_env": {"env_vars": ["nope"]},
}
)
with pytest.raises(ValueError, match="entrypoint_resources must be a mapping"):
RayConfig.from_dict(
{
"address": "x",
"shared_root": "/p",
"entrypoint_resources": ["nope"],
}
)
def test_jobspec_validation_and_null_coercion():
from argus.ray.models import JobSpec
spec = JobSpec.from_dict(
{
"workload": "ppo",
"code_path": "/code",
"model_id": "m",
"train_file": "train.jsonl",
"val_file": "null",
"test_freq": "",
}
)
assert spec.workload == "ppo"
assert spec.val_file is None
assert spec.test_freq is None
assert spec.nnodes == 2
assert spec.n_gpus_per_node == 4
pub = spec.to_public_dict()
assert pub["submission_id"] == ""
assert "trainer_device" not in pub
def test_jobspec_sft_adds_trainer_device_default():
from argus.ray.models import JobSpec
spec = JobSpec.from_dict(
{
"workload": "sft",
"code_path": "/code",
"model_id": "m",
"train_file": "train.jsonl",
}
)
pub = spec.to_public_dict()
assert pub["trainer_device"] == "cpu"
def test_jobspec_unsupported_workload():
from argus.ray.models import JobSpec
with pytest.raises(ValueError, match="unsupported workload"):
JobSpec.from_dict(
{"workload": "nope", "code_path": "x", "model_id": "m", "train_file": "t"}
)

View File

@ -0,0 +1,162 @@
from __future__ import annotations
import json
import os
from pathlib import Path
import pytest
from argus.ray.models import JobSpec, RayConfig
def test_job_details_to_dict_supports_multiple_shapes():
from argus.ray.ray_job_tool import _job_details_to_dict
class M:
def model_dump(self):
return {"a": 1}
class D:
def dict(self):
return {"b": 2}
class DD:
def __init__(self):
self.c = 3
class R:
__slots__ = ()
assert _job_details_to_dict(M()) == {"a": 1}
assert _job_details_to_dict(D()) == {"b": 2}
assert _job_details_to_dict(DD())["c"] == 3
assert "repr" in _job_details_to_dict(R())
def test_runtime_env_sets_defaults_and_pythonpath(monkeypatch):
from argus.ray.ray_job_tool import RayJobTool
cfg = RayConfig.from_dict(
{
"address": "http://127.0.0.1:8265",
"shared_root": "/private",
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {"PYTHONPATH": "x"}},
"user_code_path": "/private/user/code",
}
)
spec = JobSpec.from_dict(
{"workload": "sft", "code_path": "/c", "model_id": "m", "train_file": "t"}
)
monkeypatch.setenv("MVP_TOOL_CODE_PATH", "/tool")
tool = RayJobTool(cfg)
env = tool._runtime_env(spec)["env_vars"]
assert env["HF_HOME"].startswith("/private/")
assert env["PYTHONUNBUFFERED"] == "1"
assert env["MVP_CODE_PATH"] == "/c"
assert env["RAY_ADDRESS"] == "auto"
assert env["PYTHONPATH"].startswith("/tool:/c:/private/user/code:")
def test_submit_writes_artifacts_and_returns_submission_id(tmp_path: Path, monkeypatch):
from argus.ray import ray_job_tool as mod
class _FakeClient:
def __init__(self, address: str):
self.address = address
self.last_submit_kwargs = None
def submit_job(self, **kwargs):
self.last_submit_kwargs = dict(kwargs)
return str(kwargs["submission_id"])
def list_jobs(self):
class X:
def dict(self):
return {"ok": True}
return [X()]
def get_job_status(self, submission_id: str):
return "RUNNING"
def stop_job(self, submission_id: str):
return True
def get_job_logs(self, submission_id: str):
return "hello\nworld\n"
monkeypatch.setattr(mod, "JobSubmissionClient", _FakeClient)
monkeypatch.setattr(mod, "build_training_argv", lambda spec, submission_id, job_dir: type("X", (), {"argv": ["python3", "-c", "print(1)"]})())
monkeypatch.setattr(mod.ray, "init", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("no ray")))
cfg = RayConfig.from_dict(
{
"address": "http://127.0.0.1:8265",
"shared_root": str(tmp_path),
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
}
)
spec = JobSpec.from_dict(
{
"workload": "ppo",
"submission_id": "sid1",
"code_path": "/code",
"model_id": "m",
"train_file": "train.jsonl",
}
)
tool = mod.RayJobTool(cfg)
submitted = tool.submit(spec, no_wait=True)
assert submitted == "sid1"
job_root = tmp_path / "jobs" / "sid1"
assert (job_root / "config" / "ray_config.yaml").exists()
assert (job_root / "config" / "jobspec.yaml").exists()
assert (job_root / "config" / "ray_submission_id.txt").read_text(encoding="utf-8").strip() == "sid1"
payload = json.loads((job_root / "config" / "submit_payload.json").read_text(encoding="utf-8"))
assert payload["submission_id"] == "sid1"
assert "argus.ray.driver_entrypoint" in payload["entrypoint"]
assert (job_root / "debug" / "ray_resources_pre.error.txt").exists()
assert (job_root / "debug" / "ray_job_list_post.json").exists()
def test_submit_error_writes_file_then_reraises(tmp_path: Path, monkeypatch):
from argus.ray import ray_job_tool as mod
class _FakeClient:
def __init__(self, address: str):
self.address = address
def submit_job(self, **kwargs):
raise RuntimeError("boom")
def list_jobs(self):
return []
monkeypatch.setattr(mod, "JobSubmissionClient", _FakeClient)
monkeypatch.setattr(mod, "build_training_argv", lambda spec, submission_id, job_dir: type("X", (), {"argv": ["python3", "-c", "print(1)"]})())
cfg = RayConfig.from_dict(
{
"address": "http://127.0.0.1:8265",
"shared_root": str(tmp_path),
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
}
)
spec = JobSpec.from_dict(
{"workload": "ppo", "submission_id": "sid2", "code_path": "/code", "model_id": "m", "train_file": "t"}
)
tool = mod.RayJobTool(cfg)
with pytest.raises(RuntimeError, match="boom"):
tool.submit(spec, no_wait=True)
err = tmp_path / "jobs" / "sid2" / "logs" / "submit.error.txt"
assert err.exists()

View File

@ -0,0 +1,25 @@
from __future__ import annotations
def test_get_cluster_available_sums_resources(monkeypatch):
from argus.service import ray_resources
monkeypatch.setattr(ray_resources.ray._private.state, "available_resources_per_node", lambda: { # type: ignore[attr-defined]
"n1": {"GPU": 1, "NPU": 2},
"n2": {"GPU": 0.5},
"bad": "nope",
})
avail = ray_resources.get_cluster_available()
assert avail.total_available_gpus == 1.5
assert avail.total_available_npus == 2.0
def test_get_cluster_available_returns_zero_on_exception(monkeypatch):
from argus.service import ray_resources
def _boom() -> dict[str, object]:
raise RuntimeError("boom")
monkeypatch.setattr(ray_resources.ray._private.state, "available_resources_per_node", _boom) # type: ignore[attr-defined]
avail = ray_resources.get_cluster_available()
assert avail.total_available_gpus == 0.0

View File

@ -0,0 +1,203 @@
from __future__ import annotations
from pathlib import Path
import yaml
from types import SimpleNamespace
from argus.ray.models import RayConfig
from argus.service.config import V2Config
from argus.service.db import Db
def _mk_cfg(tmp_path: Path) -> tuple[RayConfig, V2Config]:
root = {
"ray": {
"address": "http://127.0.0.1:8265",
"shared_root": str(tmp_path),
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
},
"service": {
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
},
}
return RayConfig.from_dict(root), V2Config.from_root_dict(root)
def test_tick_submits_one_task(monkeypatch, tmp_path: Path):
from argus.service import scheduler as sched_mod
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
task_id="t1",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
nnodes=2,
n_gpus_per_node=4,
)
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
monkeypatch.setattr(
sched_mod,
"get_cluster_available",
lambda: SimpleNamespace(total_available_gpus=999.0, total_available_npus=0.0),
)
class _Tool:
def __init__(self, cfg):
self.submitted = []
def submit(self, spec, no_wait: bool):
self.submitted.append(spec.submission_id)
return str(spec.submission_id)
def status(self, submission_id: str):
return "RUNNING"
def logs(self, submission_id: str):
return ""
monkeypatch.setattr(sched_mod, "RayJobTool", _Tool)
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
s.tick()
row = db.get_task("t1")
assert row and row["state"] == "SUBMITTED"
attempts = db.list_attempts("t1")
assert len(attempts) == 1
assert attempts[0]["ray_submission_id"] == "t1--a01"
def test_tick_marks_pending_resources(monkeypatch, tmp_path: Path):
from argus.service import scheduler as sched_mod
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
task_id="t1",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
nnodes=2,
n_gpus_per_node=4,
)
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
monkeypatch.setattr(
sched_mod,
"get_cluster_available",
lambda: SimpleNamespace(total_available_gpus=0.0, total_available_npus=0.0),
)
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
s.tick()
row = db.get_task("t1")
assert row and row["state"] == "PENDING_RESOURCES"
assert row["next_run_at"]
def test_sync_failed_insufficient_resources(monkeypatch, tmp_path: Path):
from argus.service import scheduler as sched_mod
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
task_id="t1",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
nnodes=2,
n_gpus_per_node=4,
)
db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="t1--a01")
db.set_task_state(task_id="t1", state="RUNNING", latest_attempt_no=1)
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
class _Tool:
def status(self, sid: str):
return "FAILED"
def logs(self, sid: str):
# Match the service's regex exactly:
# it expects literal backslashes and repeats of 's'/'d' (because of double-escaping).
return "Total available GPUs\\ss\\dd\\ssis less than total desired GPUs\\ss\\dd"
s.tool = _Tool()
s.tick()
row = db.get_task("t1")
assert row and row["state"] == "PENDING_RESOURCES"
attempts = db.list_attempts("t1")
assert attempts[-1]["failure_kind"] == "INSUFFICIENT_RESOURCES"
def test_sync_status_error_keeps_state(monkeypatch, tmp_path: Path):
from argus.service import scheduler as sched_mod
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_task(
task_id="t1",
workload="ppo",
jobspec_yaml=yaml.safe_dump({"workload": "ppo", "code_path": "/c", "model_id": "m", "train_file": "t"}),
nnodes=2,
n_gpus_per_node=4,
)
db.create_attempt(task_id="t1", attempt_no=1, ray_submission_id="t1--a01")
db.set_task_state(task_id="t1", state="RUNNING", latest_attempt_no=1)
monkeypatch.setattr(sched_mod, "ensure_ray_connected", lambda: None)
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
class _Tool:
def status(self, sid: str):
raise RuntimeError("boom")
s.tool = _Tool()
s.tick()
row = db.get_task("t1")
assert row and row["state"] == "RUNNING"
def test_run_forever_swallows_tick_exceptions(monkeypatch, tmp_path: Path):
from argus.service import scheduler as sched_mod
ray_cfg, v2_cfg = _mk_cfg(tmp_path)
db = Db(v2_cfg.sqlite.db_path)
db.init()
monkeypatch.setattr(sched_mod, "RayJobTool", lambda cfg: None)
s = sched_mod.Scheduler(db=db, ray_cfg=ray_cfg, v2_cfg=v2_cfg)
calls = {"n": 0}
def _tick():
calls["n"] += 1
raise RuntimeError("boom")
monkeypatch.setattr(s, "tick", _tick)
monkeypatch.setattr(sched_mod.time, "sleep", lambda _: None)
class _Stop:
def __init__(self):
self._n = 0
def is_set(self):
self._n += 1
return self._n > 1
s.run_forever(_Stop())
assert calls["n"] == 1

View File

@ -0,0 +1,38 @@
from __future__ import annotations
import sys
from pathlib import Path
import pytest
def test_server_main_calls_uvicorn(monkeypatch, tmp_path: Path):
import server as server_mod
cfg = tmp_path / "cfg.yaml"
cfg.write_text(
"ray:\n address: http://127.0.0.1:8265\n shared_root: /private\n entrypoint_resources: {worker_node: 1}\n"
"service:\n api: {host: 127.0.0.1, port: 18080}\n sqlite: {db_path: /tmp/x.sqlite3}\n",
encoding="utf-8",
)
got = {}
monkeypatch.setattr(server_mod, "create_app", lambda path: object())
monkeypatch.setattr(server_mod.uvicorn, "run", lambda app, host, port, log_level: got.update({"host": host, "port": port}))
monkeypatch.setattr(sys, "argv", ["server.py", "--config", str(cfg)])
assert server_mod.main() == 0
assert got["host"] == "127.0.0.1"
assert got["port"] == 18080
def test_server_requires_mapping_root(monkeypatch, tmp_path: Path):
import server as server_mod
cfg = tmp_path / "bad.yaml"
cfg.write_text("- 1\n- 2\n", encoding="utf-8")
monkeypatch.setattr(sys, "argv", ["server.py", "--config", str(cfg)])
with pytest.raises(SystemExit, match="config yaml must be a mapping"):
server_mod.main()

View File

@ -0,0 +1,40 @@
from __future__ import annotations
import pytest
def test_v2_config_from_root_dict_new_format_defaults():
from argus.service.config import V2Config
cfg = V2Config.from_root_dict(
{
"ray": {"shared_root": "/private"},
"service": {
"api": {"host": "127.0.0.1", "port": 9999},
"auth": {"token_env": "X"},
"sqlite": {"db_path": "/tmp/x.sqlite3"},
"scheduler": {"tick_s": 1, "retry_interval_s": 2, "max_running_tasks": 3},
},
}
)
assert cfg.api.host == "127.0.0.1"
assert cfg.api.port == 9999
assert cfg.auth.token_env == "X"
assert cfg.sqlite.db_path.endswith(".sqlite3")
assert cfg.scheduler.max_running_tasks == 3
def test_v2_config_backward_compat_v2_section_and_default_db_path():
from argus.service.config import V2Config
cfg = V2Config.from_root_dict({"shared_root": "/private", "v2": {"sqlite": {}}})
assert cfg.sqlite.db_path == "/private/common/db/mvp.sqlite3"
def test_v2_config_requires_mappings():
from argus.service.config import V2Config
with pytest.raises(ValueError, match="config\\.service must be a mapping"):
V2Config.from_root_dict({"service": ["nope"]})
with pytest.raises(ValueError, match="config\\.service\\.\\{api,auth,sqlite,scheduler\\} must be mappings"):
V2Config.from_root_dict({"service": {"api": [1], "auth": {}, "sqlite": {}, "scheduler": {}}})

View File

@ -0,0 +1,34 @@
from __future__ import annotations
from pathlib import Path
import pytest
def test_load_yaml_empty_file(tmp_path: Path):
from argus.ray.yaml_io import load_yaml
p = tmp_path / "empty.yaml"
p.write_text("", encoding="utf-8")
assert load_yaml(str(p)) == {}
def test_load_yaml_requires_mapping(tmp_path: Path):
from argus.ray.yaml_io import load_yaml
p = tmp_path / "bad.yaml"
p.write_text("- 1\n- 2\n", encoding="utf-8")
with pytest.raises(ValueError, match="yaml root must be a mapping"):
load_yaml(str(p))
def test_dump_yaml_roundtrip_smoke(tmp_path: Path):
from argus.ray.yaml_io import dump_yaml, load_yaml
text = dump_yaml({"a": 1, "b": {"c": "d"}})
assert "a: 1" in text
p = tmp_path / "x.yaml"
p.write_text(text, encoding="utf-8")
assert load_yaml(str(p))["b"]["c"] == "d"