Files
SmartUp/backend/test_group_binding_create_sync.py
liumangmang 92eb4888d1 test: replace TestClient with direct router calls to fix hang in binding tests
- Removed TestClient/FastAPI/app.main imports entirely
- Tests now call websites_router.create_binding() and
  websites_router.update_binding() directly with db_session
- Bypasses all ASGI transport and lifespan issues
- All 13 tests pass in 0.75s
2026-06-01 10:26:23 +08:00

390 lines
12 KiB
Python

import json
import sys
from pathlib import Path
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
sys.path.insert(0, str(Path(__file__).resolve().parent))
from app.database import Base
from app.models.snapshot import UpstreamRateSnapshot
from app.models.upstream import Upstream
from app.models.notification_log import NotificationLog
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.models.webhook_config import WebhookConfig
from app.routers import websites as websites_router
from app.schemas.website import BindingCreate, BindingUpdate
@pytest.fixture()
def engine():
"""Create a fresh in-memory database for each test."""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key
Base.metadata.create_all(bind=engine)
return engine
@pytest.fixture()
def db_session(engine):
"""Provide a fresh session for the test logic."""
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
def seed_rows(db_session, *, website_enabled=True, auto_sync_enabled=True):
website = Website(
name="Target",
site_type="sub2api",
base_url="http://target.local",
api_prefix="/api",
auth_type="api_key",
auth_config_json="{}",
groups_endpoint="/groups",
group_update_endpoint="/groups/{id}",
enabled=website_enabled,
auto_sync_enabled=auto_sync_enabled,
)
upstream = Upstream(
name="Upstream",
base_url="http://upstream.local",
api_prefix="/api",
auth_type="bearer",
auth_config_json="{}",
)
db_session.add_all([website, upstream])
db_session.commit()
db_session.refresh(website)
db_session.refresh(upstream)
snapshot = UpstreamRateSnapshot(
upstream_id=upstream.id,
snapshot_json=json.dumps({
"groups": {
"source": {
"group_id": "source",
"group_name": "Source",
"rate": "2",
}
}
}),
)
db_session.add(snapshot)
db_session.commit()
return website, upstream
def make_body(website_id, upstream_id, *, enabled=True):
return BindingCreate(
website_id=website_id,
target_group_id="target",
target_group_name="Target group",
source_groups=[{
"upstream_id": upstream_id,
"upstream_name": "Upstream",
"group_id": "source",
"group_name": "Source",
}],
percent=10,
algorithm="max_plus_percent",
enabled=enabled,
)
def test_create_binding_runs_initial_sync(monkeypatch, db_session):
website, upstream = seed_rows(db_session)
calls = []
class FakeClient:
def __init__(self, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
def update_group_rate(self, endpoint, group_id, rate):
calls.append((endpoint, group_id, str(rate)))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
result = websites_router.create_binding(
make_body(website.id, upstream.id),
db_session,
object(),
)
assert result.target_group_id == "target"
assert calls == [("/groups/{id}", "target", "2.2000")]
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "同步成功"
assert log.old_rate == "1"
assert log.new_rate == "2.2"
def test_create_binding_skips_write_when_website_auto_sync_disabled(db_session):
website, upstream = seed_rows(db_session, auto_sync_enabled=False)
result = websites_router.create_binding(
make_body(website.id, upstream.id),
db_session,
object(),
)
assert result.target_group_id == "target"
assert db_session.query(WebsiteGroupBinding).count() == 1
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "网站未启用自动同步,未写回"
assert log.old_rate is None
assert log.new_rate == "2.2"
def test_create_binding_skips_write_when_binding_disabled(db_session):
website, upstream = seed_rows(db_session)
result = websites_router.create_binding(
make_body(website.id, upstream.id, enabled=False),
db_session,
object(),
)
assert result.target_group_id == "target"
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "绑定未启用,未写回"
assert log.new_rate == "2.2"
def test_create_binding_keeps_binding_when_initial_sync_calculation_fails(db_session):
website, upstream = seed_rows(db_session)
db_session.query(UpstreamRateSnapshot).delete()
db_session.commit()
result = websites_router.create_binding(
make_body(website.id, upstream.id),
db_session,
object(),
)
assert result.target_group_id == "target"
assert db_session.query(WebsiteGroupBinding).count() == 1
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "failed"
assert "没有可用的正数上游倍率" in log.message
assert log.new_rate is None
def test_update_binding_runs_sync_after_save(monkeypatch, db_session):
website, upstream = seed_rows(db_session)
binding = WebsiteGroupBinding(
website_id=website.id,
target_group_id="target",
target_group_name="Target group",
source_groups_json=json.dumps([{
"upstream_id": upstream.id,
"upstream_name": "Upstream",
"group_id": "source",
"group_name": "Source",
}], ensure_ascii=False),
percent="10",
algorithm="max_plus_percent",
enabled=True,
)
db_session.add(binding)
db_session.commit()
db_session.refresh(binding)
calls = []
class FakeClient:
def __init__(self, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
def update_group_rate(self, endpoint, group_id, rate):
calls.append((endpoint, group_id, str(rate)))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
result = websites_router.update_binding(
binding.id,
BindingUpdate(
target_group_name="Target group",
percent=20,
enabled=True,
),
db_session,
object(),
)
assert result.target_group_id == "target"
assert calls == [("/groups/{id}", "target", "2.4000")]
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "同步成功"
assert log.new_rate == "2.4"
def test_update_binding_skips_write_when_disabled(monkeypatch, db_session):
website, upstream = seed_rows(db_session)
binding = WebsiteGroupBinding(
website_id=website.id,
target_group_id="target",
target_group_name="Target group",
source_groups_json=json.dumps([{
"upstream_id": upstream.id,
"upstream_name": "Upstream",
"group_id": "source",
"group_name": "Source",
}], ensure_ascii=False),
percent="10",
algorithm="max_plus_percent",
enabled=False,
)
db_session.add(binding)
db_session.commit()
db_session.refresh(binding)
class FakeClient:
def __init__(self, **kwargs):
raise AssertionError("should not write when binding is disabled")
def __enter__(self):
return self
def __exit__(self, *args):
pass
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
result = websites_router.update_binding(
binding.id,
BindingUpdate(target_group_name="Target group", percent=20),
db_session,
object(),
)
assert result.target_group_id == "target"
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "绑定未启用,未写回"
def test_create_binding_notifies_when_website_rate_changes(monkeypatch, db_session):
website, upstream = seed_rows(db_session)
webhook = WebhookConfig(
name="Notify",
type="generic",
url="http://notify.local/webhook",
enabled=True,
events_json=json.dumps(["website_rate_changed"]),
)
db_session.add(webhook)
db_session.commit()
sent_payloads = []
class FakeClient:
def __init__(self, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
def update_group_rate(self, endpoint, group_id, rate):
pass
def fake_send_generic(url, payload, timeout=15.0):
sent_payloads.append((url, payload))
return "ok"
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic)
result = websites_router.create_binding(
make_body(website.id, upstream.id),
db_session,
object(),
)
assert result.target_group_id == "target"
assert len(sent_payloads) == 1
_, payload = sent_payloads[0]
assert payload["event"] == "website_rate_changed"
assert payload["website"]["id"] == website.id
assert payload["target_group"]["old_rate"] == "1"
assert payload["target_group"]["new_rate"] == "2.2"
log = db_session.query(NotificationLog).one()
assert log.event_type == "website_rate_changed"
assert log.status == "success"
def test_create_binding_does_not_notify_when_website_rate_unchanged(monkeypatch, db_session):
website, upstream = seed_rows(db_session)
webhook = WebhookConfig(
name="Notify",
type="generic",
url="http://notify.local/webhook",
enabled=True,
events_json=json.dumps(["website_rate_changed"]),
)
db_session.add(webhook)
db_session.commit()
class FakeClient:
def __init__(self, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "2.2"}]
def update_group_rate(self, endpoint, group_id, rate):
pass
def fake_send_generic(url, payload, timeout=15.0):
raise AssertionError("should not notify when target rate is unchanged")
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic)
result = websites_router.create_binding(
make_body(website.id, upstream.id),
db_session,
object(),
)
assert result.target_group_id == "target"
assert db_session.query(NotificationLog).count() == 0