359 lines
13 KiB
Python
359 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
def _write_config(tmp_path: Path) -> Path:
|
|
cfg = {
|
|
"ray": {
|
|
"address": "http://127.0.0.1:8265",
|
|
"shared_root": "/private",
|
|
"entrypoint_resources": {"worker_node": 1},
|
|
"runtime_env": {"env_vars": {}},
|
|
},
|
|
"data": {
|
|
# Avoid touching real /private in tests. Keep ray.shared_root as /private
|
|
# so existing path validation tests remain unchanged.
|
|
"user_root": str(tmp_path / "users"),
|
|
},
|
|
"service": {
|
|
"api": {"host": "127.0.0.1", "port": 0},
|
|
"auth": {"token_env": "MVP_INTERNAL_TOKEN"},
|
|
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
|
|
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
|
|
},
|
|
}
|
|
p = tmp_path / "cfg.yaml"
|
|
p.write_text(yaml.safe_dump(cfg), encoding="utf-8")
|
|
return p
|
|
|
|
|
|
def test_db_user_token_lifecycle(tmp_path: Path):
|
|
from argus.service.db import Db
|
|
|
|
db = Db(str(tmp_path / "mvp.sqlite3"))
|
|
db.init()
|
|
db.create_user(user_id="alice", display_name="Alice")
|
|
tok = db.issue_token(user_id="alice")
|
|
info = db.resolve_token(tok)
|
|
assert info and info["user_id"] == "alice"
|
|
|
|
db.disable_user(user_id="alice")
|
|
info2 = db.resolve_token(tok)
|
|
assert info2 and info2["state"] == "DISABLED"
|
|
|
|
|
|
def test_admin_create_user_issue_token_and_disabled_rejected(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
admin_headers = {"authorization": "Bearer adm1"}
|
|
with TestClient(app) as c:
|
|
# list users requires admin
|
|
assert c.get("/api/v2/users", headers={"authorization": "Bearer nope"}).status_code in (401, 403)
|
|
|
|
r1 = c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice", "display_name": "Alice"})
|
|
assert r1.status_code == 200
|
|
assert r1.json()["user_id"] == "alice"
|
|
|
|
r2 = c.post("/api/v2/users/alice/tokens", headers=admin_headers)
|
|
assert r2.status_code == 200
|
|
token = r2.json()["token"]
|
|
assert token
|
|
|
|
r2b = c.get("/api/v2/users", headers=admin_headers)
|
|
assert r2b.status_code == 200
|
|
users = r2b.json()["users"]
|
|
assert any(u.get("user_id") == "alice" for u in users)
|
|
|
|
# non-admin token can access regular endpoints
|
|
r3 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"})
|
|
assert r3.status_code == 200
|
|
|
|
r4 = c.post("/api/v2/users/alice:disable", headers=admin_headers)
|
|
assert r4.status_code == 200
|
|
|
|
r5 = c.get("/api/v2/queue", headers={"authorization": f"Bearer {token}"})
|
|
assert r5.status_code == 403
|
|
|
|
|
|
def test_tasks_are_isolated_by_user(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
|
|
|
|
# Deterministic task ids
|
|
counter = {"n": 0}
|
|
|
|
def _new_id(workload: str, **kwargs) -> str:
|
|
counter["n"] += 1
|
|
return f"{workload}_t{counter['n']}"
|
|
|
|
monkeypatch.setattr(app_mod, "new_task_id", _new_id)
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
admin_headers = {"authorization": "Bearer adm1"}
|
|
with TestClient(app) as c:
|
|
# Create users and tokens
|
|
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
|
|
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "bob"}).status_code == 200
|
|
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
|
|
bob_tok = c.post("/api/v2/users/bob/tokens", headers=admin_headers).json()["token"]
|
|
|
|
alice_headers = {"authorization": f"Bearer {alice_tok}"}
|
|
bob_headers = {"authorization": f"Bearer {bob_tok}"}
|
|
|
|
# Each user submits one task.
|
|
r1 = c.post(
|
|
"/api/v2/tasks",
|
|
headers=alice_headers,
|
|
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
|
|
)
|
|
assert r1.status_code == 200
|
|
alice_tid = r1.json()["task_id"]
|
|
|
|
r2 = c.post(
|
|
"/api/v2/tasks",
|
|
headers=bob_headers,
|
|
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
|
|
)
|
|
assert r2.status_code == 200
|
|
bob_tid = r2.json()["task_id"]
|
|
|
|
# Queue is scoped to user.
|
|
qa = c.get("/api/v2/queue", headers=alice_headers).json()
|
|
qb = c.get("/api/v2/queue", headers=bob_headers).json()
|
|
assert {t["task_id"] for t in qa["pending"]} == {alice_tid}
|
|
assert {t["task_id"] for t in qb["pending"]} == {bob_tid}
|
|
|
|
# Cross-user access returns 404 (no existence leak).
|
|
assert c.get(f"/api/v2/tasks/{bob_tid}", headers=alice_headers).status_code == 404
|
|
assert c.post(f"/api/v2/tasks/{bob_tid}:cancel", headers=alice_headers).status_code == 404
|
|
assert c.get(f"/api/v2/tasks/{bob_tid}/attempts", headers=alice_headers).status_code == 404
|
|
|
|
# Admin can see global queue.
|
|
qadm = c.get("/api/v2/queue", headers=admin_headers).json()
|
|
assert {t["task_id"] for t in qadm["pending"]} == {alice_tid, bob_tid}
|
|
|
|
|
|
def test_submit_rejects_non_common_inputs(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
admin_headers = {"authorization": "Bearer adm1"}
|
|
with TestClient(app) as c:
|
|
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
|
|
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
|
|
alice_headers = {"authorization": f"Bearer {alice_tok}"}
|
|
|
|
r = c.post(
|
|
"/api/v2/tasks",
|
|
headers=alice_headers,
|
|
data="workload: ppo\ncode_path: /c\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
|
|
)
|
|
assert r.status_code == 400
|
|
assert "code_path must start with /private/common/" in r.text
|
|
|
|
|
|
def test_submit_accepts_user_dataset_paths_and_local_model_paths(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
admin_headers = {"authorization": "Bearer adm1"}
|
|
with TestClient(app) as c:
|
|
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
|
|
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
|
|
alice_headers = {"authorization": f"Bearer {alice_tok}"}
|
|
|
|
# User dataset paths are allowed.
|
|
r1 = c.post(
|
|
"/api/v2/tasks",
|
|
headers=alice_headers,
|
|
data=(
|
|
"workload: ppo\n"
|
|
"code_path: /private/common/code/verl\n"
|
|
"model_id: Qwen/Qwen2.5-0.5B-Instruct\n"
|
|
"train_file: /private/users/alice/datasets/t\n"
|
|
"val_file: /private/users/alice/datasets/v\n"
|
|
),
|
|
)
|
|
assert r1.status_code == 200
|
|
|
|
# Local model paths under user models/ are allowed (no TaskSpec schema change).
|
|
r2 = c.post(
|
|
"/api/v2/tasks",
|
|
headers=alice_headers,
|
|
data=(
|
|
"workload: ppo\n"
|
|
"code_path: /private/common/code/verl\n"
|
|
"model_id: /private/users/alice/models/m1\n"
|
|
"train_file: /private/common/datasets/t\n"
|
|
"val_file: /private/common/datasets/v\n"
|
|
),
|
|
)
|
|
assert r2.status_code == 200
|
|
|
|
|
|
def test_submit_rejects_cross_user_paths_and_bad_local_model_dirs(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
admin_headers = {"authorization": "Bearer adm1"}
|
|
with TestClient(app) as c:
|
|
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
|
|
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "bob"}).status_code == 200
|
|
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
|
|
bob_tok = c.post("/api/v2/users/bob/tokens", headers=admin_headers).json()["token"]
|
|
bob_headers = {"authorization": f"Bearer {bob_tok}"}
|
|
|
|
# Cross-user dataset path should be rejected.
|
|
r1 = c.post(
|
|
"/api/v2/tasks",
|
|
headers=bob_headers,
|
|
data=(
|
|
"workload: ppo\n"
|
|
"code_path: /private/common/code/verl\n"
|
|
"model_id: Qwen/Qwen2.5-0.5B-Instruct\n"
|
|
"train_file: /private/users/alice/datasets/t\n"
|
|
"val_file: /private/users/alice/datasets/v\n"
|
|
),
|
|
)
|
|
assert r1.status_code == 400
|
|
assert "/private/users/bob/datasets/" in r1.text
|
|
|
|
# Local model path must be under models/.
|
|
r2 = c.post(
|
|
"/api/v2/tasks",
|
|
headers=bob_headers,
|
|
data=(
|
|
"workload: ppo\n"
|
|
"code_path: /private/common/code/verl\n"
|
|
"model_id: /private/users/bob/jobs/j1/checkpoints\n"
|
|
"train_file: /private/common/datasets/t\n"
|
|
"val_file: /private/common/datasets/v\n"
|
|
),
|
|
)
|
|
assert r2.status_code == 400
|
|
assert "model_id local path must start with" in r2.text
|
|
|
|
|
|
def test_me_returns_paths_and_retention(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
admin_headers = {"authorization": "Bearer adm1"}
|
|
with TestClient(app) as c:
|
|
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
|
|
alice_tok = c.post("/api/v2/users/alice/tokens", headers=admin_headers).json()["token"]
|
|
|
|
r = c.get("/api/v2/me", headers={"authorization": f"Bearer {alice_tok}"})
|
|
assert r.status_code == 200
|
|
obj = r.json()
|
|
assert obj["user_id"] == "alice"
|
|
assert obj["paths"]["home"].endswith("/users/alice")
|
|
assert obj["paths"]["jobs"].endswith("/users/alice/jobs")
|
|
assert obj["paths"]["trash_jobs"].endswith("/users/alice/trash/jobs")
|
|
assert obj["retention"]["jobs_trash_after_days"] == 3
|
|
assert obj["retention"]["jobs_purge_after_days"] == 7
|
|
|
|
|
|
def test_create_user_creates_user_dirs(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "adm1")
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
admin_headers = {"authorization": "Bearer adm1"}
|
|
with TestClient(app) as c:
|
|
assert c.post("/api/v2/users", headers=admin_headers, json={"user_id": "alice"}).status_code == 200
|
|
|
|
base = tmp_path / "users" / "alice"
|
|
assert (base / "datasets").is_dir()
|
|
assert (base / "models").is_dir()
|
|
assert (base / "code").is_dir()
|
|
assert (base / "jobs").is_dir()
|
|
assert (base / "trash" / "jobs").is_dir()
|