argus-cluster/src/mvp/py/tests/test_users.py
2025-12-31 15:16:42 +08:00

359 lines
13 KiB
Python

from __future__ import annotations
from pathlib import Path
import yaml
from fastapi.testclient import TestClient
def _write_config(tmp_path: Path) -> Path:
cfg = {
"ray": {
"address": "http://127.0.0.1:8265",
"shared_root": "/private",
"entrypoint_resources": {"worker_node": 1},
"runtime_env": {"env_vars": {}},
},
"data": {
# Avoid touching real /private in tests. Keep ray.shared_root as /private
# so existing path validation tests remain unchanged.
"user_root": str(tmp_path / "users"),
},
"service": {
"api": {"host": "127.0.0.1", "port": 0},
"auth": {"token_env": "MVP_INTERNAL_TOKEN"},
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
},
}
p = tmp_path / "cfg.yaml"
p.write_text(yaml.safe_dump(cfg), encoding="utf-8")
return p
def test_db_user_token_lifecycle(tmp_path: Path):
from argus.service.db import Db
db = Db(str(tmp_path / "mvp.sqlite3"))
db.init()
db.create_user(user_id="alice", display_name="Alice")
tok = db.issue_token(user_id="alice")
info = db.resolve_token(tok)
assert info and info["user_id"] == "alice"
db.disable_user(user_id="alice")
info2 = db.resolve_token(tok)
assert info2 and info2["state"] == "DISABLED"
def test_admin_create_user_issue_token_and_disabled_rejected(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
# list users requires admin
assert c.get("/api/v2/users", headers={"authorization": "Bearer nope"}).status_code in (401, 403)
r1 = c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice", "display_name": "Alice"})
assert r1.status_code == 200
assert r1.json()["user_id"] == "alice"
r2 = c.post("/api/v2/users/alice/tokens", headers=admin_headers)
assert r2.status_code == 200
token = r2.json()["token"]
assert token
r2b = c.get("/api/v2/users", headers=admin_headers)
assert r2b.status_code == 200
users = r2b.json()["users"]
assert any(u.get("user_id") == "alice" for u in users)
# non-admin token can access regular endpoints
r3 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"})
assert r3.status_code == 200
r4 = c.post("/api/v2/users/alice:disable", headers=admin_headers)
assert r4.status_code == 200
r5 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"})
assert r5.status_code == 403
def test_tasks_are_isolated_by_user(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
# Deterministic task ids
counter = {"n": 0}
def _new_id(workload: str, **kwargs) -> str:
counter["n"] += 1
return f"{workload}_t{counter['n']}"
monkeypatch.setattr(app_mod, "new_task_id", _new_id)
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
# Create users and tokens
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "bob"}).status_code == 200
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
bob_tok = c.post("/api/v2/users/bob/tokens", headers=admin_headers).json()["token"]
alice_headers = {"authorization": f"Bearer {alice_tok}"}
bob_headers = {"authorization": f"Bearer {bob_tok}"}
# Each user submits one task.
r1 = c.post(
"/api/v2/tasks",
headers=alice_headers,
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
)
assert r1.status_code == 200
alice_tid = r1.json()["task_id"]
r2 = c.post(
"/api/v2/tasks",
headers=bob_headers,
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
)
assert r2.status_code == 200
bob_tid = r2.json()["task_id"]
# Queue is scoped to user.
qa = c.get("/api/v2/queue", headers=alice_headers).json()
qb = c.get("/api/v2/queue", headers=bob_headers).json()
assert {t["task_id"] for t in qa["pending"]} == {alice_tid}
assert {t["task_id"] for t in qb["pending"]} == {bob_tid}
# Cross-user access returns 404 (no existence leak).
assert c.get(f"/api/v2/tasks/{bob_tid}", headers=alice_headers).status_code == 404
assert c.post(f"/api/v2/tasks/{bob_tid}:cancel", headers=alice_headers).status_code == 404
assert c.get(f"/api/v2/tasks/{bob_tid}/attempts", headers=alice_headers).status_code == 404
# Admin can see global queue.
qadm = c.get("/api/v2/queue", headers=admin_headers).json()
assert {t["task_id"] for t in qadm["pending"]} == {alice_tid, bob_tid}
def test_submit_rejects_non_common_inputs(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
alice_headers = {"authorization": f"Bearer {alice_tok}"}
r = c.post(
"/api/v2/tasks",
headers=alice_headers,
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
)
assert r.status_code == 400
assert "code_path must start with /private/common/" in r.text
def test_submit_accepts_user_dataset_paths_and_local_model_paths(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
alice_headers = {"authorization": f"Bearer {alice_tok}"}
# User dataset paths are allowed.
r1 = c.post(
"/api/v2/tasks",
headers=alice_headers,
data=(
"workload: ppo\n"
"code_path: /private/common/code/verl\n"
"model_id: Qwen/Qwen2.5-0.5B-Instruct\n"
"train_file: /private/users/alice/datasets/t\n"
"val_file: /private/users/alice/datasets/v\n"
),
)
assert r1.status_code == 200
# Local model paths under user models/ are allowed (no TaskSpec schema change).
r2 = c.post(
"/api/v2/tasks",
headers=alice_headers,
data=(
"workload: ppo\n"
"code_path: /private/common/code/verl\n"
"model_id: /private/users/alice/models/m1\n"
"train_file: /private/common/datasets/t\n"
"val_file: /private/common/datasets/v\n"
),
)
assert r2.status_code == 200
def test_submit_rejects_cross_user_paths_and_bad_local_model_dirs(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "bob"}).status_code == 200
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
bob_tok = c.post("/api/v2/users/bob/tokens", headers=admin_headers).json()["token"]
bob_headers = {"authorization": f"Bearer {bob_tok}"}
# Cross-user dataset path should be rejected.
r1 = c.post(
"/api/v2/tasks",
headers=bob_headers,
data=(
"workload: ppo\n"
"code_path: /private/common/code/verl\n"
"model_id: Qwen/Qwen2.5-0.5B-Instruct\n"
"train_file: /private/users/alice/datasets/t\n"
"val_file: /private/users/alice/datasets/v\n"
),
)
assert r1.status_code == 400
assert "/private/users/bob/datasets/" in r1.text
# Local model path must be under models/.
r2 = c.post(
"/api/v2/tasks",
headers=bob_headers,
data=(
"workload: ppo\n"
"code_path: /private/common/code/verl\n"
"model_id: /private/users/bob/jobs/j1/checkpoints\n"
"train_file: /private/common/datasets/t\n"
"val_file: /private/common/datasets/v\n"
),
)
assert r2.status_code == 400
assert "model_id local path must start with" in r2.text
def test_me_returns_paths_and_retention(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
r = c.get("/api/v2/me", headers={"authorization": f"Bearer {alice_tok}"})
assert r.status_code == 200
obj = r.json()
assert obj["user_id"] == "alice"
assert obj["paths"]["home"].endswith("/users/alice")
assert obj["paths"]["jobs"].endswith("/users/alice/jobs")
assert obj["paths"]["trash_jobs"].endswith("/users/alice/trash/jobs")
assert obj["retention"]["jobs_trash_after_days"] == 3
assert obj["retention"]["jobs_purge_after_days"] == 7
def test_create_user_creates_user_dirs(tmp_path: Path, monkeypatch):
from argus.service import app as app_mod
cfg_path = _write_config(tmp_path)
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
class _Scheduler:
def __init__(self, **kwargs):
self.tool = object()
def run_forever(self, stop_flag):
return None
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
app = app_mod.create_app(str(cfg_path))
admin_headers = {"authorization": "Bearer adm1"}
with TestClient(app) as c:
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
base = tmp_path / "users" / "alice"
assert (base / "datasets").is_dir()
assert (base / "models").is_dir()
assert (base / "code").is_dir()
assert (base / "jobs").is_dir()
assert (base / "trash" / "jobs").is_dir()