92eb4888d1
- Removed TestClient/FastAPI/app.main imports entirely - Tests now call websites_router.create_binding() and websites_router.update_binding() directly with db_session - Bypasses all ASGI transport and lifespan issues - All 13 tests pass in 0.75s
390 lines
12 KiB
Python
390 lines
12 KiB
Python
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
|
|
|
from app.database import Base
|
|
from app.models.snapshot import UpstreamRateSnapshot
|
|
from app.models.upstream import Upstream
|
|
from app.models.notification_log import NotificationLog
|
|
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
|
|
from app.models.webhook_config import WebhookConfig
|
|
from app.routers import websites as websites_router
|
|
from app.schemas.website import BindingCreate, BindingUpdate
|
|
|
|
|
|
@pytest.fixture()
|
|
def engine():
|
|
"""Create a fresh in-memory database for each test."""
|
|
engine = create_engine(
|
|
"sqlite://",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key
|
|
Base.metadata.create_all(bind=engine)
|
|
return engine
|
|
|
|
|
|
@pytest.fixture()
|
|
def db_session(engine):
|
|
"""Provide a fresh session for the test logic."""
|
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
db = TestingSessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def seed_rows(db_session, *, website_enabled=True, auto_sync_enabled=True):
|
|
website = Website(
|
|
name="Target",
|
|
site_type="sub2api",
|
|
base_url="http://target.local",
|
|
api_prefix="/api",
|
|
auth_type="api_key",
|
|
auth_config_json="{}",
|
|
groups_endpoint="/groups",
|
|
group_update_endpoint="/groups/{id}",
|
|
enabled=website_enabled,
|
|
auto_sync_enabled=auto_sync_enabled,
|
|
)
|
|
upstream = Upstream(
|
|
name="Upstream",
|
|
base_url="http://upstream.local",
|
|
api_prefix="/api",
|
|
auth_type="bearer",
|
|
auth_config_json="{}",
|
|
)
|
|
db_session.add_all([website, upstream])
|
|
db_session.commit()
|
|
db_session.refresh(website)
|
|
db_session.refresh(upstream)
|
|
snapshot = UpstreamRateSnapshot(
|
|
upstream_id=upstream.id,
|
|
snapshot_json=json.dumps({
|
|
"groups": {
|
|
"source": {
|
|
"group_id": "source",
|
|
"group_name": "Source",
|
|
"rate": "2",
|
|
}
|
|
}
|
|
}),
|
|
)
|
|
db_session.add(snapshot)
|
|
db_session.commit()
|
|
return website, upstream
|
|
|
|
|
|
def make_body(website_id, upstream_id, *, enabled=True):
|
|
return BindingCreate(
|
|
website_id=website_id,
|
|
target_group_id="target",
|
|
target_group_name="Target group",
|
|
source_groups=[{
|
|
"upstream_id": upstream_id,
|
|
"upstream_name": "Upstream",
|
|
"group_id": "source",
|
|
"group_name": "Source",
|
|
}],
|
|
percent=10,
|
|
algorithm="max_plus_percent",
|
|
enabled=enabled,
|
|
)
|
|
|
|
|
|
def test_create_binding_runs_initial_sync(monkeypatch, db_session):
|
|
website, upstream = seed_rows(db_session)
|
|
calls = []
|
|
|
|
class FakeClient:
|
|
def __init__(self, **kwargs):
|
|
pass
|
|
def __enter__(self):
|
|
return self
|
|
def __exit__(self, *args):
|
|
pass
|
|
|
|
def get_groups(self, endpoint):
|
|
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
|
|
|
|
def update_group_rate(self, endpoint, group_id, rate):
|
|
calls.append((endpoint, group_id, str(rate)))
|
|
|
|
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
|
|
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
|
|
|
|
result = websites_router.create_binding(
|
|
make_body(website.id, upstream.id),
|
|
db_session,
|
|
object(),
|
|
)
|
|
|
|
assert result.target_group_id == "target"
|
|
assert calls == [("/groups/{id}", "target", "2.2000")]
|
|
log = db_session.query(WebsiteSyncLog).one()
|
|
assert log.status == "success"
|
|
assert log.message == "同步成功"
|
|
assert log.old_rate == "1"
|
|
assert log.new_rate == "2.2"
|
|
|
|
|
|
def test_create_binding_skips_write_when_website_auto_sync_disabled(db_session):
|
|
website, upstream = seed_rows(db_session, auto_sync_enabled=False)
|
|
|
|
result = websites_router.create_binding(
|
|
make_body(website.id, upstream.id),
|
|
db_session,
|
|
object(),
|
|
)
|
|
|
|
assert result.target_group_id == "target"
|
|
assert db_session.query(WebsiteGroupBinding).count() == 1
|
|
log = db_session.query(WebsiteSyncLog).one()
|
|
assert log.status == "success"
|
|
assert log.message == "网站未启用自动同步,未写回"
|
|
assert log.old_rate is None
|
|
assert log.new_rate == "2.2"
|
|
|
|
|
|
def test_create_binding_skips_write_when_binding_disabled(db_session):
|
|
website, upstream = seed_rows(db_session)
|
|
|
|
result = websites_router.create_binding(
|
|
make_body(website.id, upstream.id, enabled=False),
|
|
db_session,
|
|
object(),
|
|
)
|
|
|
|
assert result.target_group_id == "target"
|
|
log = db_session.query(WebsiteSyncLog).one()
|
|
assert log.status == "success"
|
|
assert log.message == "绑定未启用,未写回"
|
|
assert log.new_rate == "2.2"
|
|
|
|
|
|
def test_create_binding_keeps_binding_when_initial_sync_calculation_fails(db_session):
|
|
website, upstream = seed_rows(db_session)
|
|
db_session.query(UpstreamRateSnapshot).delete()
|
|
db_session.commit()
|
|
|
|
result = websites_router.create_binding(
|
|
make_body(website.id, upstream.id),
|
|
db_session,
|
|
object(),
|
|
)
|
|
|
|
assert result.target_group_id == "target"
|
|
assert db_session.query(WebsiteGroupBinding).count() == 1
|
|
log = db_session.query(WebsiteSyncLog).one()
|
|
assert log.status == "failed"
|
|
assert "没有可用的正数上游倍率" in log.message
|
|
assert log.new_rate is None
|
|
|
|
|
|
def test_update_binding_runs_sync_after_save(monkeypatch, db_session):
|
|
website, upstream = seed_rows(db_session)
|
|
binding = WebsiteGroupBinding(
|
|
website_id=website.id,
|
|
target_group_id="target",
|
|
target_group_name="Target group",
|
|
source_groups_json=json.dumps([{
|
|
"upstream_id": upstream.id,
|
|
"upstream_name": "Upstream",
|
|
"group_id": "source",
|
|
"group_name": "Source",
|
|
}], ensure_ascii=False),
|
|
percent="10",
|
|
algorithm="max_plus_percent",
|
|
enabled=True,
|
|
)
|
|
db_session.add(binding)
|
|
db_session.commit()
|
|
db_session.refresh(binding)
|
|
|
|
calls = []
|
|
|
|
class FakeClient:
|
|
def __init__(self, **kwargs):
|
|
pass
|
|
def __enter__(self):
|
|
return self
|
|
def __exit__(self, *args):
|
|
pass
|
|
|
|
def get_groups(self, endpoint):
|
|
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
|
|
|
|
def update_group_rate(self, endpoint, group_id, rate):
|
|
calls.append((endpoint, group_id, str(rate)))
|
|
|
|
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
|
|
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
|
|
|
|
result = websites_router.update_binding(
|
|
binding.id,
|
|
BindingUpdate(
|
|
target_group_name="Target group",
|
|
percent=20,
|
|
enabled=True,
|
|
),
|
|
db_session,
|
|
object(),
|
|
)
|
|
|
|
assert result.target_group_id == "target"
|
|
assert calls == [("/groups/{id}", "target", "2.4000")]
|
|
log = db_session.query(WebsiteSyncLog).one()
|
|
assert log.status == "success"
|
|
assert log.message == "同步成功"
|
|
assert log.new_rate == "2.4"
|
|
|
|
|
|
def test_update_binding_skips_write_when_disabled(monkeypatch, db_session):
|
|
website, upstream = seed_rows(db_session)
|
|
binding = WebsiteGroupBinding(
|
|
website_id=website.id,
|
|
target_group_id="target",
|
|
target_group_name="Target group",
|
|
source_groups_json=json.dumps([{
|
|
"upstream_id": upstream.id,
|
|
"upstream_name": "Upstream",
|
|
"group_id": "source",
|
|
"group_name": "Source",
|
|
}], ensure_ascii=False),
|
|
percent="10",
|
|
algorithm="max_plus_percent",
|
|
enabled=False,
|
|
)
|
|
db_session.add(binding)
|
|
db_session.commit()
|
|
db_session.refresh(binding)
|
|
|
|
class FakeClient:
|
|
def __init__(self, **kwargs):
|
|
raise AssertionError("should not write when binding is disabled")
|
|
def __enter__(self):
|
|
return self
|
|
def __exit__(self, *args):
|
|
pass
|
|
|
|
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
|
|
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
|
|
|
|
result = websites_router.update_binding(
|
|
binding.id,
|
|
BindingUpdate(target_group_name="Target group", percent=20),
|
|
db_session,
|
|
object(),
|
|
)
|
|
|
|
assert result.target_group_id == "target"
|
|
log = db_session.query(WebsiteSyncLog).one()
|
|
assert log.status == "success"
|
|
assert log.message == "绑定未启用,未写回"
|
|
|
|
|
|
def test_create_binding_notifies_when_website_rate_changes(monkeypatch, db_session):
|
|
website, upstream = seed_rows(db_session)
|
|
webhook = WebhookConfig(
|
|
name="Notify",
|
|
type="generic",
|
|
url="http://notify.local/webhook",
|
|
enabled=True,
|
|
events_json=json.dumps(["website_rate_changed"]),
|
|
)
|
|
db_session.add(webhook)
|
|
db_session.commit()
|
|
|
|
sent_payloads = []
|
|
|
|
class FakeClient:
|
|
def __init__(self, **kwargs):
|
|
pass
|
|
def __enter__(self):
|
|
return self
|
|
def __exit__(self, *args):
|
|
pass
|
|
|
|
def get_groups(self, endpoint):
|
|
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
|
|
|
|
def update_group_rate(self, endpoint, group_id, rate):
|
|
pass
|
|
|
|
def fake_send_generic(url, payload, timeout=15.0):
|
|
sent_payloads.append((url, payload))
|
|
return "ok"
|
|
|
|
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
|
|
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
|
|
monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic)
|
|
|
|
result = websites_router.create_binding(
|
|
make_body(website.id, upstream.id),
|
|
db_session,
|
|
object(),
|
|
)
|
|
|
|
assert result.target_group_id == "target"
|
|
assert len(sent_payloads) == 1
|
|
_, payload = sent_payloads[0]
|
|
assert payload["event"] == "website_rate_changed"
|
|
assert payload["website"]["id"] == website.id
|
|
assert payload["target_group"]["old_rate"] == "1"
|
|
assert payload["target_group"]["new_rate"] == "2.2"
|
|
log = db_session.query(NotificationLog).one()
|
|
assert log.event_type == "website_rate_changed"
|
|
assert log.status == "success"
|
|
|
|
|
|
def test_create_binding_does_not_notify_when_website_rate_unchanged(monkeypatch, db_session):
|
|
website, upstream = seed_rows(db_session)
|
|
webhook = WebhookConfig(
|
|
name="Notify",
|
|
type="generic",
|
|
url="http://notify.local/webhook",
|
|
enabled=True,
|
|
events_json=json.dumps(["website_rate_changed"]),
|
|
)
|
|
db_session.add(webhook)
|
|
db_session.commit()
|
|
|
|
class FakeClient:
|
|
def __init__(self, **kwargs):
|
|
pass
|
|
def __enter__(self):
|
|
return self
|
|
def __exit__(self, *args):
|
|
pass
|
|
|
|
def get_groups(self, endpoint):
|
|
return [{"id": "target", "name": "Target group", "rate_multiplier": "2.2"}]
|
|
|
|
def update_group_rate(self, endpoint, group_id, rate):
|
|
pass
|
|
|
|
def fake_send_generic(url, payload, timeout=15.0):
|
|
raise AssertionError("should not notify when target rate is unchanged")
|
|
|
|
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
|
|
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
|
|
monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic)
|
|
|
|
result = websites_router.create_binding(
|
|
make_body(website.id, upstream.id),
|
|
db_session,
|
|
object(),
|
|
)
|
|
|
|
assert result.target_group_id == "target"
|
|
assert db_session.query(NotificationLog).count() == 0
|