argus-cluster/src/mvp/py/tests/test_app.py
2025-12-31 15:16:42 +08:00

306 lines
10 KiB
Python

from __future__ import annotations
import os
from pathlib import Path
import pytest
import yaml
from fastapi.testclient import TestClient
def _write_config(tmp_path: Path) -> Path:
cfg = {
"ray": {
"address": "http://127.0.0.1:8265",
"shared_root": "/private",
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
},
"data": {
# Avoid touching real /private in tests. Keep ray.shared_root as /private
# so existing path validation tests remain unchanged.
"user_root": str(tmp_path / "users"),
},
"service": {
"api": {"host": "127.0.0.1", "port": 0},
"auth": {"token_env": "MVP_INTERNAL_TOKEN"},
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
},
}
p = tmp_path / "cfg.yaml"
p.write_text(yaml.safe_dump(cfg), encoding="utf-8")
return p
def test_auth_requires_token_env(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.delenv("MVP_INTERNAL_TOKEN", raising=False)
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
with TestClient(app) as c:
r = c.get("/api/v2/queue")
assert r.status_code == 500
def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
monkeypatch.setattr(app_mod, "new_task_id", lambda workload, **kwargs: "tid1")
class _Tool:
def __init__(self):
self.stopped = []
def stop(self, sid: str):
self.stopped.append(sid)
return True
def logs(self, sid: str):
return "a\nb\nc\n"
class _Scheduler:
def __init__(self, **kwargs):
self.tool = _Tool()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
headers = {"authorization": "Bearer token1"}
with TestClient(app) as c:
r = c.post(
"/api/v2/tasks",
headers=headers,
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
)
assert r.status_code == 200
assert r.json()["task_id"] == "tid1"
r2 = c.get("/api/v2/tasks/tid1", headers=headers)
assert r2.status_code == 200
assert r2.json()["desired_resources"]["total_gpus"] == 8
r2s = c.get("/api/v2/tasks/tid1/spec", headers=headers)
assert r2s.status_code == 200
assert "workload: ppo" in r2s.text
assert "code_path:" in r2s.text
# Spec endpoint returns a resolved TaskSpec, i.e. includes default values.
assert "nnodes: 2" in r2s.text
assert "n_gpus_per_node: 4" in r2s.text
# Seed an attempt and ensure submission_id is reflected in the resolved spec.
from argus.service.db import Db
from argus.service.config import V2Config
from argus.ray.models import RayConfig
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
ray_cfg = RayConfig.from_dict(root)
v2_cfg = V2Config.from_root_dict(root)
db = Db(v2_cfg.sqlite.db_path)
db.create_attempt(task_id="tid1", attempt_no=1, ray_submission_id="sid1")
r2s2 = c.get("/api/v2/tasks/tid1/spec", headers=headers)
assert r2s2.status_code == 200
assert "submission_id: sid1" in r2s2.text
r3 = c.get("/api/v2/queue", headers=headers)
assert r3.status_code == 200
assert "pending" in r3.json()
r3b = c.get("/api/v2/tasks?limit=10", headers=headers)
assert r3b.status_code == 200
assert any(t.get("task_id") == "tid1" for t in r3b.json().get("tasks", []))
r3c = c.get("/api/v2/tasks?limit=10&offset=0&states=QUEUED", headers=headers)
assert r3c.status_code == 200
assert all(t.get("state") == "QUEUED" for t in r3c.json().get("tasks", []))
r3d = c.get("/api/v2/tasks?states=NOPE", headers=headers)
assert r3d.status_code == 400
r4 = c.post("/api/v2/tasks/tid1:cancel", headers=headers)
assert r4.status_code == 200
assert r4.json()["state"] == "CANCELED"
# Seed an attempt then fetch logs
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
ray_cfg = RayConfig.from_dict(root)
v2_cfg = V2Config.from_root_dict(root)
db = Db(v2_cfg.sqlite.db_path)
db.create_task(
task_id="tid2",
workload="ppo",
jobspec_yaml="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
nnodes=2,
n_gpus_per_node=4,
)
db.create_attempt(task_id="tid2", attempt_no=1, ray_submission_id="sid2")
db.set_task_state(task_id="tid2", state="RUNNING", latest_attempt_no=1)
r6 = c.get("/api/v2/tasks?limit=1&offset=0&states=RUNNING", headers=headers)
assert r6.status_code == 200
assert any(t.get("task_id") == "tid2" for t in r6.json().get("tasks", []))
r7 = c.get("/api/v2/tasks?limit=1&offset=1&states=RUNNING", headers=headers)
assert r7.status_code == 200
assert "has_more" in r7.json()
r5 = c.get("/api/v2/tasks/tid2/logs?tail=1", headers=headers)
assert r5.status_code == 200
assert r5.text.strip() == "c"
def test_submit_rejects_non_mapping_jobspec(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
with TestClient(app) as c:
r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="- 1\n- 2\n")
assert r.status_code == 400
def test_submit_rejects_invalid_jobspec(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
with TestClient(app) as c:
r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="workload: nope\n")
assert r.status_code == 400
def test_me_sftp_reset_password_disabled_returns_400(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
# seed user + token
from argus.service.config import V2Config
from argus.service.db import Db
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
v2_cfg = V2Config.from_root_dict(root)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_user(user_id="u1", display_name=None)
token = db.issue_token(user_id="u1")
with TestClient(app) as c:
r = c.post("/api/v2/me/sftp:reset_password", headers={"authorization": f"Bearer {token}"})
assert r.status_code == 400
def test_me_sftp_reset_password_enabled_returns_password(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg = yaml.safe_load(_write_config(tmp_path).read_text(encoding="utf-8"))
cfg["data"]["sftpgo"] = {
"enabled": True,
"host": "127.0.0.1",
"sftp_port": 2022,
"admin_api_base": "http://127.0.0.1:8081",
"admin_user": "admin",
"admin_password_env": "SFTPGO_ADMIN_PASSWORD",
}
cfg_path = tmp_path / "cfg_sftp.yaml"
cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8")
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
monkeypatch.setenv("SFTPGO_ADMIN_PASSWORD", "pw1")
class _FakeSFTPGo:
def __init__(self, **kwargs):
self.reset = []
self.enabled = []
def reset_password(self, username: str, new_password: str, home_dir: str):
assert username
assert new_password
assert home_dir
self.reset.append((username, home_dir))
def enable_user(self, username: str, home_dir: str):
self.enabled.append((username, home_dir))
fake_client = _FakeSFTPGo()
class _FakeSFTPGoFactory:
def __call__(self, **kwargs):
return fake_client
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
monkeypatch.setattr(app_mod, "SFTPGoAdminClient", _FakeSFTPGoFactory())
app = app_mod.create_app(str(cfg_path))
# seed user in DB
from argus.service.db import Db
from argus.service.config import V2Config
v2_cfg = V2Config.from_root_dict(cfg)
db = Db(v2_cfg.sqlite.db_path)
db.init()
db.create_user(user_id="u1", display_name=None)
token = db.issue_token(user_id="u1")
with TestClient(app) as c:
r = c.post("/api/v2/me/sftp:reset_password", headers={"authorization": f"Bearer {token}"})
assert r.status_code == 200
j = r.json()
assert j["user_id"] == "u1"
assert isinstance(j["password"], str) and len(j["password"]) >= 8