import json import sys from pathlib import Path import pytest from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool sys.path.insert(0, str(Path(__file__).resolve().parent)) from app.database import Base from app.models.snapshot import UpstreamRateSnapshot from app.models.upstream import Upstream from app.models.notification_log import NotificationLog from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog from app.models.webhook_config import WebhookConfig from app.routers import websites as websites_router from app.schemas.website import BindingCreate, BindingUpdate @pytest.fixture() def engine(): """Create a fresh in-memory database for each test.""" engine = create_engine( "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key Base.metadata.create_all(bind=engine) return engine @pytest.fixture() def db_session(engine): """Provide a fresh session for the test logic.""" TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) db = TestingSessionLocal() try: yield db finally: db.close() def seed_rows(db_session, *, website_enabled=True, auto_sync_enabled=True): website = Website( name="Target", site_type="sub2api", base_url="http://target.local", api_prefix="/api", auth_type="api_key", auth_config_json="{}", groups_endpoint="/groups", group_update_endpoint="/groups/{id}", enabled=website_enabled, auto_sync_enabled=auto_sync_enabled, ) upstream = Upstream( name="Upstream", base_url="http://upstream.local", api_prefix="/api", auth_type="bearer", auth_config_json="{}", ) db_session.add_all([website, upstream]) db_session.commit() db_session.refresh(website) db_session.refresh(upstream) snapshot = UpstreamRateSnapshot( upstream_id=upstream.id, snapshot_json=json.dumps({ "groups": { "source": { "group_id": "source", "group_name": "Source", "rate": "2", } } }), ) db_session.add(snapshot) db_session.commit() return website, upstream def make_body(website_id, upstream_id, *, enabled=True): return BindingCreate( website_id=website_id, target_group_id="target", target_group_name="Target group", source_groups=[{ "upstream_id": upstream_id, "upstream_name": "Upstream", "group_id": "source", "group_name": "Source", }], percent=10, algorithm="max_plus_percent", enabled=enabled, ) def test_create_binding_runs_initial_sync(monkeypatch, db_session): website, upstream = seed_rows(db_session) calls = [] class FakeClient: def __init__(self, **kwargs): pass def __enter__(self): return self def __exit__(self, *args): pass def get_groups(self, endpoint): return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}] def update_group_rate(self, endpoint, group_id, rate): calls.append((endpoint, group_id, str(rate))) monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) result = websites_router.create_binding( make_body(website.id, upstream.id), db_session, object(), ) assert result.target_group_id == "target" assert calls == [("/groups/{id}", "target", "2.2000")] log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "同步成功" assert log.old_rate == "1" assert log.new_rate == "2.2" def test_create_binding_skips_write_when_website_auto_sync_disabled(db_session): website, upstream = seed_rows(db_session, auto_sync_enabled=False) result = websites_router.create_binding( make_body(website.id, upstream.id), db_session, object(), ) assert result.target_group_id == "target" assert db_session.query(WebsiteGroupBinding).count() == 1 log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "网站未启用自动同步,未写回" assert log.old_rate is None assert log.new_rate == "2.2" def test_create_binding_skips_write_when_binding_disabled(db_session): website, upstream = seed_rows(db_session) result = websites_router.create_binding( make_body(website.id, upstream.id, enabled=False), db_session, object(), ) assert result.target_group_id == "target" log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "绑定未启用,未写回" assert log.new_rate == "2.2" def test_create_binding_keeps_binding_when_initial_sync_calculation_fails(db_session): website, upstream = seed_rows(db_session) db_session.query(UpstreamRateSnapshot).delete() db_session.commit() result = websites_router.create_binding( make_body(website.id, upstream.id), db_session, object(), ) assert result.target_group_id == "target" assert db_session.query(WebsiteGroupBinding).count() == 1 log = db_session.query(WebsiteSyncLog).one() assert log.status == "failed" assert "没有可用的正数上游倍率" in log.message assert log.new_rate is None def test_update_binding_runs_sync_after_save(monkeypatch, db_session): website, upstream = seed_rows(db_session) binding = WebsiteGroupBinding( website_id=website.id, target_group_id="target", target_group_name="Target group", source_groups_json=json.dumps([{ "upstream_id": upstream.id, "upstream_name": "Upstream", "group_id": "source", "group_name": "Source", }], ensure_ascii=False), percent="10", algorithm="max_plus_percent", enabled=True, ) db_session.add(binding) db_session.commit() db_session.refresh(binding) calls = [] class FakeClient: def __init__(self, **kwargs): pass def __enter__(self): return self def __exit__(self, *args): pass def get_groups(self, endpoint): return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}] def update_group_rate(self, endpoint, group_id, rate): calls.append((endpoint, group_id, str(rate))) monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) result = websites_router.update_binding( binding.id, BindingUpdate( target_group_name="Target group", percent=20, enabled=True, ), db_session, object(), ) assert result.target_group_id == "target" assert calls == [("/groups/{id}", "target", "2.4000")] log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "同步成功" assert log.new_rate == "2.4" def test_update_binding_skips_write_when_disabled(monkeypatch, db_session): website, upstream = seed_rows(db_session) binding = WebsiteGroupBinding( website_id=website.id, target_group_id="target", target_group_name="Target group", source_groups_json=json.dumps([{ "upstream_id": upstream.id, "upstream_name": "Upstream", "group_id": "source", "group_name": "Source", }], ensure_ascii=False), percent="10", algorithm="max_plus_percent", enabled=False, ) db_session.add(binding) db_session.commit() db_session.refresh(binding) class FakeClient: def __init__(self, **kwargs): raise AssertionError("should not write when binding is disabled") def __enter__(self): return self def __exit__(self, *args): pass monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) result = websites_router.update_binding( binding.id, BindingUpdate(target_group_name="Target group", percent=20), db_session, object(), ) assert result.target_group_id == "target" log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "绑定未启用,未写回" def test_create_binding_notifies_when_website_rate_changes(monkeypatch, db_session): website, upstream = seed_rows(db_session) webhook = WebhookConfig( name="Notify", type="generic", url="http://notify.local/webhook", enabled=True, events_json=json.dumps(["website_rate_changed"]), ) db_session.add(webhook) db_session.commit() sent_payloads = [] class FakeClient: def __init__(self, **kwargs): pass def __enter__(self): return self def __exit__(self, *args): pass def get_groups(self, endpoint): return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}] def update_group_rate(self, endpoint, group_id, rate): pass def fake_send_generic(url, payload, timeout=15.0): sent_payloads.append((url, payload)) return "ok" monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic) result = websites_router.create_binding( make_body(website.id, upstream.id), db_session, object(), ) assert result.target_group_id == "target" assert len(sent_payloads) == 1 _, payload = sent_payloads[0] assert payload["event"] == "website_rate_changed" assert payload["website"]["id"] == website.id assert payload["target_group"]["old_rate"] == "1" assert payload["target_group"]["new_rate"] == "2.2" log = db_session.query(NotificationLog).one() assert log.event_type == "website_rate_changed" assert log.status == "success" def test_create_binding_does_not_notify_when_website_rate_unchanged(monkeypatch, db_session): website, upstream = seed_rows(db_session) webhook = WebhookConfig( name="Notify", type="generic", url="http://notify.local/webhook", enabled=True, events_json=json.dumps(["website_rate_changed"]), ) db_session.add(webhook) db_session.commit() class FakeClient: def __init__(self, **kwargs): pass def __enter__(self): return self def __exit__(self, *args): pass def get_groups(self, endpoint): return [{"id": "target", "name": "Target group", "rate_multiplier": "2.2"}] def update_group_rate(self, endpoint, group_id, rate): pass def fake_send_generic(url, payload, timeout=15.0): raise AssertionError("should not notify when target rate is unchanged") monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic) result = websites_router.create_binding( make_body(website.id, upstream.id), db_session, object(), ) assert result.target_group_id == "target" assert db_session.query(NotificationLog).count() == 0