Files
SmartUp/backend/test_group_binding_create_sync.py
T
2026-06-01 09:06:01 +08:00

375 lines
12 KiB
Python

import json
import sys
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
sys.path.insert(0, str(Path(__file__).resolve().parent))
from app.database import Base, get_db
from app.main import app
from app.models.snapshot import UpstreamRateSnapshot
from app.models.upstream import Upstream
from app.models.notification_log import NotificationLog
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.models.webhook_config import WebhookConfig
from app.routers import websites as websites_router
from app.utils.auth import get_current_user
@pytest.fixture()
def db_session():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture()
def client(db_session):
def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = lambda: object()
try:
yield TestClient(app)
finally:
app.dependency_overrides.clear()
def seed_rows(db_session, *, website_enabled=True, auto_sync_enabled=True):
website = Website(
name="Target",
site_type="sub2api",
base_url="http://target.local",
api_prefix="/api",
auth_type="api_key",
auth_config_json="{}",
groups_endpoint="/groups",
group_update_endpoint="/groups/{id}",
enabled=website_enabled,
auto_sync_enabled=auto_sync_enabled,
)
upstream = Upstream(
name="Upstream",
base_url="http://upstream.local",
api_prefix="/api",
auth_type="bearer",
auth_config_json="{}",
)
db_session.add_all([website, upstream])
db_session.commit()
db_session.refresh(website)
db_session.refresh(upstream)
snapshot = UpstreamRateSnapshot(
upstream_id=upstream.id,
snapshot_json=json.dumps({
"groups": {
"source": {
"group_id": "source",
"group_name": "Source",
"rate": "2",
}
}
}),
)
db_session.add(snapshot)
db_session.commit()
return website, upstream
def binding_payload(website_id, upstream_id, *, enabled=True):
return {
"website_id": website_id,
"target_group_id": "target",
"target_group_name": "Target group",
"source_groups": [{
"upstream_id": upstream_id,
"upstream_name": "Upstream",
"group_id": "source",
"group_name": "Source",
}],
"percent": 10,
"algorithm": "max_plus_percent",
"enabled": enabled,
}
def test_create_binding_runs_initial_sync(monkeypatch, client, db_session):
website, upstream = seed_rows(db_session)
calls = []
class FakeClient:
def __init__(self, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
def update_group_rate(self, endpoint, group_id, rate):
calls.append((endpoint, group_id, str(rate)))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id))
assert response.status_code == 201
assert response.json()["target_group_id"] == "target"
assert calls == [("/groups/{id}", "target", "2.2000")]
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "同步成功"
assert log.old_rate == "1"
assert log.new_rate == "2.2"
def test_create_binding_skips_write_when_website_auto_sync_disabled(client, db_session):
website, upstream = seed_rows(db_session, auto_sync_enabled=False)
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id))
assert response.status_code == 201
assert db_session.query(WebsiteGroupBinding).count() == 1
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "网站未启用自动同步,未写回"
assert log.old_rate is None
assert log.new_rate == "2.2"
def test_create_binding_skips_write_when_binding_disabled(client, db_session):
website, upstream = seed_rows(db_session)
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id, enabled=False))
assert response.status_code == 201
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "绑定未启用,未写回"
assert log.new_rate == "2.2"
def test_create_binding_keeps_binding_when_initial_sync_calculation_fails(client, db_session):
website, upstream = seed_rows(db_session)
db_session.query(UpstreamRateSnapshot).delete()
db_session.commit()
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id))
assert response.status_code == 201
assert db_session.query(WebsiteGroupBinding).count() == 1
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "failed"
assert "没有可用的正数上游倍率" in log.message
assert log.new_rate is None
def test_update_binding_runs_sync_after_save(monkeypatch, client, db_session):
website, upstream = seed_rows(db_session)
binding = WebsiteGroupBinding(
website_id=website.id,
target_group_id="target",
target_group_name="Target group",
source_groups_json=json.dumps([{
"upstream_id": upstream.id,
"upstream_name": "Upstream",
"group_id": "source",
"group_name": "Source",
}], ensure_ascii=False),
percent="10",
algorithm="max_plus_percent",
enabled=True,
)
db_session.add(binding)
db_session.commit()
db_session.refresh(binding)
calls = []
class FakeClient:
def __init__(self, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
def update_group_rate(self, endpoint, group_id, rate):
calls.append((endpoint, group_id, str(rate)))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
response = client.put(
f"/api/group-bindings/{binding.id}",
json={
"target_group_name": "Target group",
"percent": 20,
"enabled": True,
},
)
assert response.status_code == 200
assert calls == [("/groups/{id}", "target", "2.4000")]
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "同步成功"
assert log.new_rate == "2.4"
def test_update_binding_skips_write_when_disabled(monkeypatch, client, db_session):
website, upstream = seed_rows(db_session)
binding = WebsiteGroupBinding(
website_id=website.id,
target_group_id="target",
target_group_name="Target group",
source_groups_json=json.dumps([{
"upstream_id": upstream.id,
"upstream_name": "Upstream",
"group_id": "source",
"group_name": "Source",
}], ensure_ascii=False),
percent="10",
algorithm="max_plus_percent",
enabled=False,
)
db_session.add(binding)
db_session.commit()
db_session.refresh(binding)
class FakeClient:
def __init__(self, **kwargs):
raise AssertionError("should not write when binding is disabled")
def __enter__(self):
return self
def __exit__(self, *args):
pass
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
response = client.put(
f"/api/group-bindings/{binding.id}",
json={
"target_group_name": "Target group",
"percent": 20,
},
)
assert response.status_code == 200
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "绑定未启用,未写回"
def test_create_binding_notifies_when_website_rate_changes(monkeypatch, client, db_session):
website, upstream = seed_rows(db_session)
webhook = WebhookConfig(
name="Notify",
type="generic",
url="http://notify.local/webhook",
enabled=True,
events_json=json.dumps(["website_rate_changed"]),
)
db_session.add(webhook)
db_session.commit()
sent_payloads = []
class FakeClient:
def __init__(self, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
def update_group_rate(self, endpoint, group_id, rate):
pass
def fake_send_generic(url, payload, timeout=15.0):
sent_payloads.append((url, payload))
return "ok"
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic)
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id))
assert response.status_code == 201
assert len(sent_payloads) == 1
_, payload = sent_payloads[0]
assert payload["event"] == "website_rate_changed"
assert payload["website"]["id"] == website.id
assert payload["target_group"]["old_rate"] == "1"
assert payload["target_group"]["new_rate"] == "2.2"
log = db_session.query(NotificationLog).one()
assert log.event_type == "website_rate_changed"
assert log.status == "success"
def test_create_binding_does_not_notify_when_website_rate_unchanged(monkeypatch, client, db_session):
website, upstream = seed_rows(db_session)
webhook = WebhookConfig(
name="Notify",
type="generic",
url="http://notify.local/webhook",
enabled=True,
events_json=json.dumps(["website_rate_changed"]),
)
db_session.add(webhook)
db_session.commit()
class FakeClient:
def __init__(self, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "2.2"}]
def update_group_rate(self, endpoint, group_id, rate):
pass
def fake_send_generic(url, payload, timeout=15.0):
raise AssertionError("should not notify when target rate is unchanged")
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic)
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id))
assert response.status_code == 201
assert db_session.query(NotificationLog).count() == 0