306 lines
10 KiB
Python
306 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
import yaml
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
def _write_config(tmp_path: Path) -> Path:
|
|
cfg = {
|
|
"ray": {
|
|
"address": "http://127.0.0.1:8265",
|
|
"shared_root": "/private",
|
|
"entrypoint_resources": {"worker_node": 1},
|
|
"runtime_env": {"env_vars": {}},
|
|
},
|
|
"data": {
|
|
# Avoid touching real /private in tests. Keep ray.shared_root as /private
|
|
# so existing path validation tests remain unchanged.
|
|
"user_root": str(tmp_path / "users"),
|
|
},
|
|
"service": {
|
|
"api": {"host": "127.0.0.1", "port": 0},
|
|
"auth": {"token_env": "MVP_INTERNAL_TOKEN"},
|
|
"sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")},
|
|
"scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1},
|
|
},
|
|
}
|
|
p = tmp_path / "cfg.yaml"
|
|
p.write_text(yaml.safe_dump(cfg), encoding="utf-8")
|
|
return p
|
|
|
|
|
|
def test_auth_requires_token_env(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.delenv("MVP_INTERNAL_TOKEN", raising=False)
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
with TestClient(app) as c:
|
|
r = c.get("/api/v2/queue")
|
|
assert r.status_code == 500
|
|
|
|
|
|
def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
|
|
monkeypatch.setattr(app_mod, "new_task_id", lambda workload, **kwargs: "tid1")
|
|
|
|
class _Tool:
|
|
def __init__(self):
|
|
self.stopped = []
|
|
|
|
def stop(self, sid: str):
|
|
self.stopped.append(sid)
|
|
return True
|
|
|
|
def logs(self, sid: str):
|
|
return "a\nb\nc\n"
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = _Tool()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
headers = {"authorization": "Bearer token1"}
|
|
with TestClient(app) as c:
|
|
r = c.post(
|
|
"/api/v2/tasks",
|
|
headers=headers,
|
|
data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
|
|
)
|
|
assert r.status_code == 200
|
|
assert r.json()["task_id"] == "tid1"
|
|
|
|
r2 = c.get("/api/v2/tasks/tid1", headers=headers)
|
|
assert r2.status_code == 200
|
|
assert r2.json()["desired_resources"]["total_gpus"] == 8
|
|
r2s = c.get("/api/v2/tasks/tid1/spec", headers=headers)
|
|
assert r2s.status_code == 200
|
|
assert "workload: ppo" in r2s.text
|
|
assert "code_path:" in r2s.text
|
|
# Spec endpoint returns a resolved TaskSpec, i.e. includes default values.
|
|
assert "nnodes: 2" in r2s.text
|
|
assert "n_gpus_per_node: 4" in r2s.text
|
|
|
|
# Seed an attempt and ensure submission_id is reflected in the resolved spec.
|
|
from argus.service.db import Db
|
|
from argus.service.config import V2Config
|
|
from argus.ray.models import RayConfig
|
|
|
|
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
|
|
ray_cfg = RayConfig.from_dict(root)
|
|
v2_cfg = V2Config.from_root_dict(root)
|
|
db = Db(v2_cfg.sqlite.db_path)
|
|
db.create_attempt(task_id="tid1", attempt_no=1, ray_submission_id="sid1")
|
|
r2s2 = c.get("/api/v2/tasks/tid1/spec", headers=headers)
|
|
assert r2s2.status_code == 200
|
|
assert "submission_id: sid1" in r2s2.text
|
|
|
|
r3 = c.get("/api/v2/queue", headers=headers)
|
|
assert r3.status_code == 200
|
|
assert "pending" in r3.json()
|
|
|
|
r3b = c.get("/api/v2/tasks?limit=10", headers=headers)
|
|
assert r3b.status_code == 200
|
|
assert any(t.get("task_id") == "tid1" for t in r3b.json().get("tasks", []))
|
|
|
|
r3c = c.get("/api/v2/tasks?limit=10&offset=0&states=QUEUED", headers=headers)
|
|
assert r3c.status_code == 200
|
|
assert all(t.get("state") == "QUEUED" for t in r3c.json().get("tasks", []))
|
|
|
|
r3d = c.get("/api/v2/tasks?states=NOPE", headers=headers)
|
|
assert r3d.status_code == 400
|
|
|
|
r4 = c.post("/api/v2/tasks/tid1:cancel", headers=headers)
|
|
assert r4.status_code == 200
|
|
assert r4.json()["state"] == "CANCELED"
|
|
|
|
# Seed an attempt then fetch logs
|
|
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
|
|
ray_cfg = RayConfig.from_dict(root)
|
|
v2_cfg = V2Config.from_root_dict(root)
|
|
db = Db(v2_cfg.sqlite.db_path)
|
|
db.create_task(
|
|
task_id="tid2",
|
|
workload="ppo",
|
|
jobspec_yaml="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n",
|
|
nnodes=2,
|
|
n_gpus_per_node=4,
|
|
)
|
|
db.create_attempt(task_id="tid2", attempt_no=1, ray_submission_id="sid2")
|
|
db.set_task_state(task_id="tid2", state="RUNNING", latest_attempt_no=1)
|
|
|
|
r6 = c.get("/api/v2/tasks?limit=1&offset=0&states=RUNNING", headers=headers)
|
|
assert r6.status_code == 200
|
|
assert any(t.get("task_id") == "tid2" for t in r6.json().get("tasks", []))
|
|
|
|
r7 = c.get("/api/v2/tasks?limit=1&offset=1&states=RUNNING", headers=headers)
|
|
assert r7.status_code == 200
|
|
assert "has_more" in r7.json()
|
|
|
|
r5 = c.get("/api/v2/tasks/tid2/logs?tail=1", headers=headers)
|
|
assert r5.status_code == 200
|
|
assert r5.text.strip() == "c"
|
|
|
|
|
|
def test_submit_rejects_non_mapping_jobspec(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
with TestClient(app) as c:
|
|
r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="- 1\n- 2\n")
|
|
assert r.status_code == 400
|
|
|
|
|
|
def test_submit_rejects_invalid_jobspec(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
with TestClient(app) as c:
|
|
r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="workload: nope\n")
|
|
assert r.status_code == 400
|
|
|
|
|
|
def test_me_sftp_reset_password_disabled_returns_400(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg_path = _write_config(tmp_path)
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
# seed user + token
|
|
from argus.service.config import V2Config
|
|
from argus.service.db import Db
|
|
|
|
root = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
|
|
v2_cfg = V2Config.from_root_dict(root)
|
|
db = Db(v2_cfg.sqlite.db_path)
|
|
db.init()
|
|
db.create_user(user_id="u1", display_name=None)
|
|
token = db.issue_token(user_id="u1")
|
|
|
|
with TestClient(app) as c:
|
|
r = c.post("/api/v2/me/sftp:reset_password", headers={"authorization": f"Bearer {token}"})
|
|
assert r.status_code == 400
|
|
|
|
|
|
def test_me_sftp_reset_password_enabled_returns_password(tmp_path: Path, monkeypatch):
|
|
from argus.service import app as app_mod
|
|
|
|
cfg = yaml.safe_load(_write_config(tmp_path).read_text(encoding="utf-8"))
|
|
cfg["data"]["sftpgo"] = {
|
|
"enabled": True,
|
|
"host": "127.0.0.1",
|
|
"sftp_port": 2022,
|
|
"admin_api_base": "http://127.0.0.1:8081",
|
|
"admin_user": "admin",
|
|
"admin_password_env": "SFTPGO_ADMIN_PASSWORD",
|
|
}
|
|
cfg_path = tmp_path / "cfg_sftp.yaml"
|
|
cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8")
|
|
|
|
monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1")
|
|
monkeypatch.setenv("SFTPGO_ADMIN_PASSWORD", "pw1")
|
|
|
|
class _FakeSFTPGo:
|
|
def __init__(self, **kwargs):
|
|
self.reset = []
|
|
self.enabled = []
|
|
|
|
def reset_password(self, username: str, new_password: str, home_dir: str):
|
|
assert username
|
|
assert new_password
|
|
assert home_dir
|
|
self.reset.append((username, home_dir))
|
|
|
|
def enable_user(self, username: str, home_dir: str):
|
|
self.enabled.append((username, home_dir))
|
|
|
|
fake_client = _FakeSFTPGo()
|
|
|
|
class _FakeSFTPGoFactory:
|
|
def __call__(self, **kwargs):
|
|
return fake_client
|
|
|
|
class _Scheduler:
|
|
def __init__(self, **kwargs):
|
|
self.tool = object()
|
|
|
|
def run_forever(self, stop_flag):
|
|
return None
|
|
|
|
monkeypatch.setattr(app_mod, "Scheduler", _Scheduler)
|
|
monkeypatch.setattr(app_mod, "SFTPGoAdminClient", _FakeSFTPGoFactory())
|
|
app = app_mod.create_app(str(cfg_path))
|
|
|
|
# seed user in DB
|
|
from argus.service.db import Db
|
|
from argus.service.config import V2Config
|
|
|
|
v2_cfg = V2Config.from_root_dict(cfg)
|
|
db = Db(v2_cfg.sqlite.db_path)
|
|
db.init()
|
|
db.create_user(user_id="u1", display_name=None)
|
|
token = db.issue_token(user_id="u1")
|
|
|
|
with TestClient(app) as c:
|
|
r = c.post("/api/v2/me/sftp:reset_password", headers={"authorization": f"Bearer {token}"})
|
|
assert r.status_code == 200
|
|
j = r.json()
|
|
assert j["user_id"] == "u1"
|
|
assert isinstance(j["password"], str) and len(j["password"]) >= 8
|