import json import sys from pathlib import Path import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool sys.path.insert(0, str(Path(__file__).resolve().parent)) from app.database import Base, get_db from app.main import app from app.models.snapshot import UpstreamRateSnapshot from app.models.upstream import Upstream from app.models.notification_log import NotificationLog from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog from app.models.webhook_config import WebhookConfig from app.routers import websites as websites_router from app.utils.auth import get_current_user @pytest.fixture() def engine(): """Create a fresh in-memory database for each test.""" engine = create_engine( "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key Base.metadata.create_all(bind=engine) return engine @pytest.fixture() def db_session(engine): """Provide a fresh session for the test logic.""" TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) db = TestingSessionLocal() try: yield db finally: db.close() @pytest.fixture() def client(engine, monkeypatch): """Provide a TestClient with a fresh session factory for each request.""" TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) def override_get_db(): db = TestingSessionLocal() try: yield db finally: db.close() # Mock lifespan-related database initialization to avoid using the real DB file monkeypatch.setattr("app.main.init_db", lambda: None) monkeypatch.setattr("app.main._init_admin", lambda: None) monkeypatch.setattr("app.main.start_scheduler", lambda: None) monkeypatch.setattr("app.main.stop_scheduler", lambda: None) app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_current_user] = lambda: object() try: with TestClient(app) as c: yield c finally: app.dependency_overrides.clear() def seed_rows(db_session, *, website_enabled=True, auto_sync_enabled=True): website = Website( name="Target", site_type="sub2api", base_url="http://target.local", api_prefix="/api", auth_type="api_key", auth_config_json="{}", groups_endpoint="/groups", group_update_endpoint="/groups/{id}", enabled=website_enabled, auto_sync_enabled=auto_sync_enabled, ) upstream = Upstream( name="Upstream", base_url="http://upstream.local", api_prefix="/api", auth_type="bearer", auth_config_json="{}", ) db_session.add_all([website, upstream]) db_session.commit() db_session.refresh(website) db_session.refresh(upstream) snapshot = UpstreamRateSnapshot( upstream_id=upstream.id, snapshot_json=json.dumps({ "groups": { "source": { "group_id": "source", "group_name": "Source", "rate": "2", } } }), ) db_session.add(snapshot) db_session.commit() return website, upstream def binding_payload(website_id, upstream_id, *, enabled=True): return { "website_id": website_id, "target_group_id": "target", "target_group_name": "Target group", "source_groups": [{ "upstream_id": upstream_id, "upstream_name": "Upstream", "group_id": "source", "group_name": "Source", }], "percent": 10, "algorithm": "max_plus_percent", "enabled": enabled, } def test_create_binding_runs_initial_sync(monkeypatch, client, db_session): website, upstream = seed_rows(db_session) calls = [] class FakeClient: def __init__(self, **kwargs): pass def __enter__(self): return self def __exit__(self, *args): pass def get_groups(self, endpoint): return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}] def update_group_rate(self, endpoint, group_id, rate): calls.append((endpoint, group_id, str(rate))) monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id)) assert response.status_code == 201 assert response.json()["target_group_id"] == "target" assert calls == [("/groups/{id}", "target", "2.2000")] log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "同步成功" assert log.old_rate == "1" assert log.new_rate == "2.2" def test_create_binding_skips_write_when_website_auto_sync_disabled(client, db_session): website, upstream = seed_rows(db_session, auto_sync_enabled=False) response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id)) assert response.status_code == 201 assert db_session.query(WebsiteGroupBinding).count() == 1 log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "网站未启用自动同步,未写回" assert log.old_rate is None assert log.new_rate == "2.2" def test_create_binding_skips_write_when_binding_disabled(client, db_session): website, upstream = seed_rows(db_session) response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id, enabled=False)) assert response.status_code == 201 log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "绑定未启用,未写回" assert log.new_rate == "2.2" def test_create_binding_keeps_binding_when_initial_sync_calculation_fails(client, db_session): website, upstream = seed_rows(db_session) db_session.query(UpstreamRateSnapshot).delete() db_session.commit() response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id)) assert response.status_code == 201 assert db_session.query(WebsiteGroupBinding).count() == 1 log = db_session.query(WebsiteSyncLog).one() assert log.status == "failed" assert "没有可用的正数上游倍率" in log.message assert log.new_rate is None def test_update_binding_runs_sync_after_save(monkeypatch, client, db_session): website, upstream = seed_rows(db_session) binding = WebsiteGroupBinding( website_id=website.id, target_group_id="target", target_group_name="Target group", source_groups_json=json.dumps([{ "upstream_id": upstream.id, "upstream_name": "Upstream", "group_id": "source", "group_name": "Source", }], ensure_ascii=False), percent="10", algorithm="max_plus_percent", enabled=True, ) db_session.add(binding) db_session.commit() db_session.refresh(binding) calls = [] class FakeClient: def __init__(self, **kwargs): pass def __enter__(self): return self def __exit__(self, *args): pass def get_groups(self, endpoint): return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}] def update_group_rate(self, endpoint, group_id, rate): calls.append((endpoint, group_id, str(rate))) monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) response = client.put( f"/api/group-bindings/{binding.id}", json={ "target_group_name": "Target group", "percent": 20, "enabled": True, }, ) assert response.status_code == 200 assert calls == [("/groups/{id}", "target", "2.4000")] log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "同步成功" assert log.new_rate == "2.4" def test_update_binding_skips_write_when_disabled(monkeypatch, client, db_session): website, upstream = seed_rows(db_session) binding = WebsiteGroupBinding( website_id=website.id, target_group_id="target", target_group_name="Target group", source_groups_json=json.dumps([{ "upstream_id": upstream.id, "upstream_name": "Upstream", "group_id": "source", "group_name": "Source", }], ensure_ascii=False), percent="10", algorithm="max_plus_percent", enabled=False, ) db_session.add(binding) db_session.commit() db_session.refresh(binding) class FakeClient: def __init__(self, **kwargs): raise AssertionError("should not write when binding is disabled") def __enter__(self): return self def __exit__(self, *args): pass monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) response = client.put( f"/api/group-bindings/{binding.id}", json={ "target_group_name": "Target group", "percent": 20, }, ) assert response.status_code == 200 log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "绑定未启用,未写回" def test_create_binding_notifies_when_website_rate_changes(monkeypatch, client, db_session): website, upstream = seed_rows(db_session) webhook = WebhookConfig( name="Notify", type="generic", url="http://notify.local/webhook", enabled=True, events_json=json.dumps(["website_rate_changed"]), ) db_session.add(webhook) db_session.commit() sent_payloads = [] class FakeClient: def __init__(self, **kwargs): pass def __enter__(self): return self def __exit__(self, *args): pass def get_groups(self, endpoint): return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}] def update_group_rate(self, endpoint, group_id, rate): pass def fake_send_generic(url, payload, timeout=15.0): sent_payloads.append((url, payload)) return "ok" monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic) response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id)) assert response.status_code == 201 assert len(sent_payloads) == 1 _, payload = sent_payloads[0] assert payload["event"] == "website_rate_changed" assert payload["website"]["id"] == website.id assert payload["target_group"]["old_rate"] == "1" assert payload["target_group"]["new_rate"] == "2.2" log = db_session.query(NotificationLog).one() assert log.event_type == "website_rate_changed" assert log.status == "success" def test_create_binding_does_not_notify_when_website_rate_unchanged(monkeypatch, client, db_session): website, upstream = seed_rows(db_session) webhook = WebhookConfig( name="Notify", type="generic", url="http://notify.local/webhook", enabled=True, events_json=json.dumps(["website_rate_changed"]), ) db_session.add(webhook) db_session.commit() class FakeClient: def __init__(self, **kwargs): pass def __enter__(self): return self def __exit__(self, *args): pass def get_groups(self, endpoint): return [{"id": "target", "name": "Target group", "rate_multiplier": "2.2"}] def update_group_rate(self, endpoint, group_id, rate): pass def fake_send_generic(url, payload, timeout=15.0): raise AssertionError("should not notify when target rate is unchanged") monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic) response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id)) assert response.status_code == 201 assert db_session.query(NotificationLog).count() == 0