diff --git a/backend/test_group_binding_create_sync.py b/backend/test_group_binding_create_sync.py index 7d1994f..9f6297e 100644 --- a/backend/test_group_binding_create_sync.py +++ b/backend/test_group_binding_create_sync.py @@ -3,22 +3,20 @@ import sys from pathlib import Path import pytest -from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool sys.path.insert(0, str(Path(__file__).resolve().parent)) -from app.database import Base, get_db -from app.main import app +from app.database import Base from app.models.snapshot import UpstreamRateSnapshot from app.models.upstream import Upstream from app.models.notification_log import NotificationLog from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog from app.models.webhook_config import WebhookConfig from app.routers import websites as websites_router -from app.utils.auth import get_current_user +from app.schemas.website import BindingCreate, BindingUpdate @pytest.fixture() @@ -45,33 +43,6 @@ def db_session(engine): db.close() -@pytest.fixture() -def client(engine, monkeypatch): - """Provide a TestClient with a fresh session factory for each request.""" - TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - - def override_get_db(): - db = TestingSessionLocal() - try: - yield db - finally: - db.close() - - # Mock lifespan-related database initialization to avoid using the real DB file - monkeypatch.setattr("app.main.init_db", lambda: None) - monkeypatch.setattr("app.main._init_admin", lambda: None) - monkeypatch.setattr("app.main.start_scheduler", lambda: None) - monkeypatch.setattr("app.main.stop_scheduler", lambda: None) - - app.dependency_overrides[get_db] = override_get_db - app.dependency_overrides[get_current_user] = lambda: object() - try: - with TestClient(app) as c: - yield c - finally: - app.dependency_overrides.clear() - - def seed_rows(db_session, *, website_enabled=True, auto_sync_enabled=True): website = Website( name="Target", @@ -113,24 +84,24 @@ def seed_rows(db_session, *, website_enabled=True, auto_sync_enabled=True): return website, upstream -def binding_payload(website_id, upstream_id, *, enabled=True): - return { - "website_id": website_id, - "target_group_id": "target", - "target_group_name": "Target group", - "source_groups": [{ +def make_body(website_id, upstream_id, *, enabled=True): + return BindingCreate( + website_id=website_id, + target_group_id="target", + target_group_name="Target group", + source_groups=[{ "upstream_id": upstream_id, "upstream_name": "Upstream", "group_id": "source", "group_name": "Source", }], - "percent": 10, - "algorithm": "max_plus_percent", - "enabled": enabled, - } + percent=10, + algorithm="max_plus_percent", + enabled=enabled, + ) -def test_create_binding_runs_initial_sync(monkeypatch, client, db_session): +def test_create_binding_runs_initial_sync(monkeypatch, db_session): website, upstream = seed_rows(db_session) calls = [] @@ -151,10 +122,13 @@ def test_create_binding_runs_initial_sync(monkeypatch, client, db_session): monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) - response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id)) + result = websites_router.create_binding( + make_body(website.id, upstream.id), + db_session, + object(), + ) - assert response.status_code == 201 - assert response.json()["target_group_id"] == "target" + assert result.target_group_id == "target" assert calls == [("/groups/{id}", "target", "2.2000")] log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" @@ -163,12 +137,16 @@ def test_create_binding_runs_initial_sync(monkeypatch, client, db_session): assert log.new_rate == "2.2" -def test_create_binding_skips_write_when_website_auto_sync_disabled(client, db_session): +def test_create_binding_skips_write_when_website_auto_sync_disabled(db_session): website, upstream = seed_rows(db_session, auto_sync_enabled=False) - response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id)) + result = websites_router.create_binding( + make_body(website.id, upstream.id), + db_session, + object(), + ) - assert response.status_code == 201 + assert result.target_group_id == "target" assert db_session.query(WebsiteGroupBinding).count() == 1 log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" @@ -177,26 +155,34 @@ def test_create_binding_skips_write_when_website_auto_sync_disabled(client, db_s assert log.new_rate == "2.2" -def test_create_binding_skips_write_when_binding_disabled(client, db_session): +def test_create_binding_skips_write_when_binding_disabled(db_session): website, upstream = seed_rows(db_session) - response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id, enabled=False)) + result = websites_router.create_binding( + make_body(website.id, upstream.id, enabled=False), + db_session, + object(), + ) - assert response.status_code == 201 + assert result.target_group_id == "target" log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "绑定未启用,未写回" assert log.new_rate == "2.2" -def test_create_binding_keeps_binding_when_initial_sync_calculation_fails(client, db_session): +def test_create_binding_keeps_binding_when_initial_sync_calculation_fails(db_session): website, upstream = seed_rows(db_session) db_session.query(UpstreamRateSnapshot).delete() db_session.commit() - response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id)) + result = websites_router.create_binding( + make_body(website.id, upstream.id), + db_session, + object(), + ) - assert response.status_code == 201 + assert result.target_group_id == "target" assert db_session.query(WebsiteGroupBinding).count() == 1 log = db_session.query(WebsiteSyncLog).one() assert log.status == "failed" @@ -204,7 +190,7 @@ def test_create_binding_keeps_binding_when_initial_sync_calculation_fails(client assert log.new_rate is None -def test_update_binding_runs_sync_after_save(monkeypatch, client, db_session): +def test_update_binding_runs_sync_after_save(monkeypatch, db_session): website, upstream = seed_rows(db_session) binding = WebsiteGroupBinding( website_id=website.id, @@ -243,16 +229,18 @@ def test_update_binding_runs_sync_after_save(monkeypatch, client, db_session): monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) - response = client.put( - f"/api/group-bindings/{binding.id}", - json={ - "target_group_name": "Target group", - "percent": 20, - "enabled": True, - }, + result = websites_router.update_binding( + binding.id, + BindingUpdate( + target_group_name="Target group", + percent=20, + enabled=True, + ), + db_session, + object(), ) - assert response.status_code == 200 + assert result.target_group_id == "target" assert calls == [("/groups/{id}", "target", "2.4000")] log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" @@ -260,7 +248,7 @@ def test_update_binding_runs_sync_after_save(monkeypatch, client, db_session): assert log.new_rate == "2.4" -def test_update_binding_skips_write_when_disabled(monkeypatch, client, db_session): +def test_update_binding_skips_write_when_disabled(monkeypatch, db_session): website, upstream = seed_rows(db_session) binding = WebsiteGroupBinding( website_id=website.id, @@ -291,21 +279,20 @@ def test_update_binding_skips_write_when_disabled(monkeypatch, client, db_sessio monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) - response = client.put( - f"/api/group-bindings/{binding.id}", - json={ - "target_group_name": "Target group", - "percent": 20, - }, + result = websites_router.update_binding( + binding.id, + BindingUpdate(target_group_name="Target group", percent=20), + db_session, + object(), ) - assert response.status_code == 200 + assert result.target_group_id == "target" log = db_session.query(WebsiteSyncLog).one() assert log.status == "success" assert log.message == "绑定未启用,未写回" -def test_create_binding_notifies_when_website_rate_changes(monkeypatch, client, db_session): +def test_create_binding_notifies_when_website_rate_changes(monkeypatch, db_session): website, upstream = seed_rows(db_session) webhook = WebhookConfig( name="Notify", @@ -341,9 +328,13 @@ def test_create_binding_notifies_when_website_rate_changes(monkeypatch, client, monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic) - response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id)) + result = websites_router.create_binding( + make_body(website.id, upstream.id), + db_session, + object(), + ) - assert response.status_code == 201 + assert result.target_group_id == "target" assert len(sent_payloads) == 1 _, payload = sent_payloads[0] assert payload["event"] == "website_rate_changed" @@ -355,7 +346,7 @@ def test_create_binding_notifies_when_website_rate_changes(monkeypatch, client, assert log.status == "success" -def test_create_binding_does_not_notify_when_website_rate_unchanged(monkeypatch, client, db_session): +def test_create_binding_does_not_notify_when_website_rate_unchanged(monkeypatch, db_session): website, upstream = seed_rows(db_session) webhook = WebhookConfig( name="Notify", @@ -388,7 +379,11 @@ def test_create_binding_does_not_notify_when_website_rate_unchanged(monkeypatch, monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient) monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic) - response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id)) + result = websites_router.create_binding( + make_body(website.id, upstream.id), + db_session, + object(), + ) - assert response.status_code == 201 + assert result.target_group_id == "target" assert db_session.query(NotificationLog).count() == 0