from __future__ import annotations import os from pathlib import Path import pytest import yaml from fastapi.testclient import TestClient def _write_config(tmp_path: Path) -> Path: cfg = { "ray": { "address": "http://127.0.0.1:8265", "shared_root": "/private", "entrypoint_resources": {"worker_node": 1}, "runtime_env": {"env_vars": {}}, }, "data": { # Avoid touching real /private in tests. Keep ray.shared_root as /private # so existing path validation tests remain unchanged. "user_root": str(tmp_path / "users"), }, "service": { "api": {"host": "127.0.0.1", "port": 0}, "auth": {"token_env": "MVP_INTERNAL_TOKEN"}, "sqlite": {"db_path": str(tmp_path / "mvp.sqlite3")}, "scheduler": {"tick_s": 1, "retry_interval_s": 1, "max_running_tasks": 1}, }, } p = tmp_path / "cfg.yaml" p.write_text(yaml.safe_dump(cfg), encoding="utf-8") return p def test_auth_requires_token_env(tmp_path: Path, monkeypatch): from argus.service import app as app_mod cfg_path = _write_config(tmp_path) monkeypatch.delenv("MVP_INTERNAL_TOKEN", raising=False) class _Scheduler: def __init__(self, **kwargs): self.tool = object() def run_forever(self, stop_flag): return None monkeypatch.setattr(app_mod, "Scheduler", _Scheduler) app = app_mod.create_app(str(cfg_path)) with TestClient(app) as c: r = c.get("/api/v2/queue") assert r.status_code == 500 def test_task_submit_get_cancel_logs_queue(tmp_path: Path, monkeypatch): from argus.service import app as app_mod cfg_path = _write_config(tmp_path) monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1") monkeypatch.setattr(app_mod, "new_task_id", lambda workload, **kwargs: "tid1") class _Tool: def __init__(self): self.stopped = [] def stop(self, sid: str): self.stopped.append(sid) return True def logs(self, sid: str): return "a\nb\nc\n" class _Scheduler: def __init__(self, **kwargs): self.tool = _Tool() def run_forever(self, stop_flag): return None monkeypatch.setattr(app_mod, "Scheduler", _Scheduler) app = app_mod.create_app(str(cfg_path)) headers = {"authorization": "Bearer token1"} with TestClient(app) as c: r = c.post( "/api/v2/tasks", headers=headers, data="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n", ) assert r.status_code == 200 assert r.json()["task_id"] == "tid1" r2 = c.get("/api/v2/tasks/tid1", headers=headers) assert r2.status_code == 200 assert r2.json()["desired_resources"]["total_gpus"] == 8 r2s = c.get("/api/v2/tasks/tid1/spec", headers=headers) assert r2s.status_code == 200 assert "workload: ppo" in r2s.text assert "code_path:" in r2s.text # Spec endpoint returns a resolved TaskSpec, i.e. includes default values. assert "nnodes: 2" in r2s.text assert "n_gpus_per_node: 4" in r2s.text # Seed an attempt and ensure submission_id is reflected in the resolved spec. from argus.service.db import Db from argus.service.config import V2Config from argus.ray.models import RayConfig root = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) ray_cfg = RayConfig.from_dict(root) v2_cfg = V2Config.from_root_dict(root) db = Db(v2_cfg.sqlite.db_path) db.create_attempt(task_id="tid1", attempt_no=1, ray_submission_id="sid1") r2s2 = c.get("/api/v2/tasks/tid1/spec", headers=headers) assert r2s2.status_code == 200 assert "submission_id: sid1" in r2s2.text r3 = c.get("/api/v2/queue", headers=headers) assert r3.status_code == 200 assert "pending" in r3.json() r3b = c.get("/api/v2/tasks?limit=10", headers=headers) assert r3b.status_code == 200 assert any(t.get("task_id") == "tid1" for t in r3b.json().get("tasks", [])) r3c = c.get("/api/v2/tasks?limit=10&offset=0&states=QUEUED", headers=headers) assert r3c.status_code == 200 assert all(t.get("state") == "QUEUED" for t in r3c.json().get("tasks", [])) r3d = c.get("/api/v2/tasks?states=NOPE", headers=headers) assert r3d.status_code == 400 r4 = c.post("/api/v2/tasks/tid1:cancel", headers=headers) assert r4.status_code == 200 assert r4.json()["state"] == "CANCELED" # Seed an attempt then fetch logs root = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) ray_cfg = RayConfig.from_dict(root) v2_cfg = V2Config.from_root_dict(root) db = Db(v2_cfg.sqlite.db_path) db.create_task( task_id="tid2", workload="ppo", jobspec_yaml="workload: ppo\ncode_path: /private/common/code/verl\nmodel_id: m\ntrain_file: /private/common/datasets/t\nval_file: /private/common/datasets/v\n", nnodes=2, n_gpus_per_node=4, ) db.create_attempt(task_id="tid2", attempt_no=1, ray_submission_id="sid2") db.set_task_state(task_id="tid2", state="RUNNING", latest_attempt_no=1) r6 = c.get("/api/v2/tasks?limit=1&offset=0&states=RUNNING", headers=headers) assert r6.status_code == 200 assert any(t.get("task_id") == "tid2" for t in r6.json().get("tasks", [])) r7 = c.get("/api/v2/tasks?limit=1&offset=1&states=RUNNING", headers=headers) assert r7.status_code == 200 assert "has_more" in r7.json() r5 = c.get("/api/v2/tasks/tid2/logs?tail=1", headers=headers) assert r5.status_code == 200 assert r5.text.strip() == "c" def test_submit_rejects_non_mapping_jobspec(tmp_path: Path, monkeypatch): from argus.service import app as app_mod cfg_path = _write_config(tmp_path) monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1") class _Scheduler: def __init__(self, **kwargs): self.tool = object() def run_forever(self, stop_flag): return None monkeypatch.setattr(app_mod, "Scheduler", _Scheduler) app = app_mod.create_app(str(cfg_path)) with TestClient(app) as c: r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="- 1\n- 2\n") assert r.status_code == 400 def test_submit_rejects_invalid_jobspec(tmp_path: Path, monkeypatch): from argus.service import app as app_mod cfg_path = _write_config(tmp_path) monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1") class _Scheduler: def __init__(self, **kwargs): self.tool = object() def run_forever(self, stop_flag): return None monkeypatch.setattr(app_mod, "Scheduler", _Scheduler) app = app_mod.create_app(str(cfg_path)) with TestClient(app) as c: r = c.post("/api/v2/tasks", headers={"authorization": "Bearer token1"}, data="workload: nope\n") assert r.status_code == 400 def test_me_sftp_reset_password_disabled_returns_400(tmp_path: Path, monkeypatch): from argus.service import app as app_mod cfg_path = _write_config(tmp_path) monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1") class _Scheduler: def __init__(self, **kwargs): self.tool = object() def run_forever(self, stop_flag): return None monkeypatch.setattr(app_mod, "Scheduler", _Scheduler) app = app_mod.create_app(str(cfg_path)) # seed user + token from argus.service.config import V2Config from argus.service.db import Db root = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) v2_cfg = V2Config.from_root_dict(root) db = Db(v2_cfg.sqlite.db_path) db.init() db.create_user(user_id="u1", display_name=None) token = db.issue_token(user_id="u1") with TestClient(app) as c: r = c.post("/api/v2/me/sftp:reset_password", headers={"authorization": f"Bearer {token}"}) assert r.status_code == 400 def test_me_sftp_reset_password_enabled_returns_password(tmp_path: Path, monkeypatch): from argus.service import app as app_mod cfg = yaml.safe_load(_write_config(tmp_path).read_text(encoding="utf-8")) cfg["data"]["sftpgo"] = { "enabled": True, "host": "127.0.0.1", "sftp_port": 2022, "admin_api_base": "http://127.0.0.1:8081", "admin_user": "admin", "admin_password_env": "SFTPGO_ADMIN_PASSWORD", } cfg_path = tmp_path / "cfg_sftp.yaml" cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8") monkeypatch.setenv("MVP_INTERNAL_TOKEN", "token1") monkeypatch.setenv("SFTPGO_ADMIN_PASSWORD", "pw1") class _FakeSFTPGo: def __init__(self, **kwargs): self.reset = [] self.enabled = [] def reset_password(self, username: str, new_password: str, home_dir: str): assert username assert new_password assert home_dir self.reset.append((username, home_dir)) def enable_user(self, username: str, home_dir: str): self.enabled.append((username, home_dir)) fake_client = _FakeSFTPGo() class _FakeSFTPGoFactory: def __call__(self, **kwargs): return fake_client class _Scheduler: def __init__(self, **kwargs): self.tool = object() def run_forever(self, stop_flag): return None monkeypatch.setattr(app_mod, "Scheduler", _Scheduler) monkeypatch.setattr(app_mod, "SFTPGoAdminClient", _FakeSFTPGoFactory()) app = app_mod.create_app(str(cfg_path)) # seed user in DB from argus.service.db import Db from argus.service.config import V2Config v2_cfg = V2Config.from_root_dict(cfg) db = Db(v2_cfg.sqlite.db_path) db.init() db.create_user(user_id="u1", display_name=None) token = db.issue_token(user_id="u1") with TestClient(app) as c: r = c.post("/api/v2/me/sftp:reset_password", headers={"authorization": f"Bearer {token}"}) assert r.status_code == 200 j = r.json() assert j["user_id"] == "u1" assert isinstance(j["password"], str) and len(j["password"]) >= 8