Files
SmartUp/backend/test_batch_sync.py
T
2026-06-01 09:06:01 +08:00

187 lines
7.5 KiB
Python

import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.database import Base
# Import all models to ensure they are registered with Base.metadata
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.models.upstream import Upstream
from app.models.snapshot import UpstreamRateSnapshot
import json
from datetime import datetime, timezone
@pytest.fixture()
def db_session():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
def test_sync_website_group_bindings_success(db_session, monkeypatch):
from app.routers.websites import sync_website_group_bindings
# Setup website
w = Website(name="W1", base_url="http://w1", enabled=True, auto_sync_enabled=True, auth_config_json="{}", timeout_seconds=30)
db_session.add(w)
db_session.commit()
db_session.refresh(w)
# Setup upstream and snapshot
u = Upstream(name="U1", base_url="http://u1")
db_session.add(u)
db_session.commit()
db_session.refresh(u)
db_session.add(UpstreamRateSnapshot(
upstream_id=u.id,
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}, "G2": {"rate": 2.0}}}),
captured_at=datetime.now(timezone.utc)
))
db_session.commit()
# Setup 2 bindings
b1 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG1", target_group_name="TG1",
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G1"}]),
algorithm="max_plus_percent", percent=10, enabled=True)
b2 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG2", target_group_name="TG2",
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G2"}]),
algorithm="max_plus_percent", percent=10, enabled=True)
db_session.add_all([b1, b2])
db_session.commit()
# Mock Website client
class MockClient:
def __init__(self, **kwargs): pass
def __enter__(self): return self
def __exit__(self, *args): pass
def get_groups(self, endpoint): return [{"id": "TG1", "rate_multiplier": 1.0}, {"id": "TG2", "rate_multiplier": 2.0}]
def update_group_rate(self, endpoint, gid, rate): pass
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
# Call batch sync
resp = sync_website_group_bindings(w.id, db_session)
assert resp.total == 2
assert resp.success == 2
assert "成功 2" in resp.message
assert len(resp.logs) == 2
# Verify logs in DB
logs = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).all()
assert len(logs) == 2
def test_sync_website_group_bindings_partial_failure(db_session, monkeypatch):
from app.routers.websites import sync_website_group_bindings
w = Website(name="W1", base_url="http://w1", enabled=True, auto_sync_enabled=True, auth_config_json="{}", timeout_seconds=30)
db_session.add(w)
db_session.commit()
db_session.refresh(w)
u = Upstream(name="U1", base_url="http://u1")
db_session.add(u)
db_session.commit()
db_session.refresh(u)
db_session.add(UpstreamRateSnapshot(
upstream_id=u.id,
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}, "G2": {"rate": 2.0}}}),
captured_at=datetime.now(timezone.utc)
))
db_session.commit()
b1 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG1",
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G1"}]),
algorithm="max_plus_percent", percent=10, enabled=True)
b2 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG2",
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G2"}]),
algorithm="max_plus_percent", percent=10, enabled=True)
db_session.add_all([b1, b2])
db_session.commit()
class MockClient:
def __init__(self, **kwargs): pass
def __enter__(self): return self
def __exit__(self, *args): pass
def get_groups(self, endpoint): return [] # Will cause sync to fail or skip
def update_group_rate(self, endpoint, gid, rate):
if gid == "TG2": raise Exception("Simulated error")
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
# The sync_binding implementation returns a log with status="failed" on exception during write back
resp = sync_website_group_bindings(w.id, db_session)
# Depending on how sync_binding handles missing target group, it might be "success" (with message) or "failed"
# Actually if get_groups is empty, sync_binding might fail to find old_rate but it continues to update_group_rate
# If update_group_rate fails, it returns a failed log.
assert resp.total == 2
# TG1: update_group_rate succeeds -> success
# TG2: update_group_rate fails -> failed
assert resp.success == 1
assert resp.failed == 1
def test_sync_website_group_bindings_no_bindings(db_session):
from app.routers.websites import sync_website_group_bindings
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
db_session.add(w)
db_session.commit()
db_session.refresh(w)
resp = sync_website_group_bindings(w.id, db_session)
assert resp.total == 0
assert resp.message == "暂无绑定可同步"
def test_sync_website_group_bindings_not_found(db_session):
from app.routers.websites import sync_website_group_bindings
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc:
sync_website_group_bindings(999, db_session)
assert exc.value.status_code == 404
def test_sync_website_group_bindings_exception_persists_log(db_session, monkeypatch):
from app.routers.websites import sync_website_group_bindings
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
db_session.add(w)
db_session.commit()
b1 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG1",
source_groups_json="[]",
algorithm="max_plus_percent", percent=10, enabled=True)
db_session.add(b1)
db_session.commit()
# Mock sync_binding to raise an exception
def mock_sync_binding(db, binding, write):
raise Exception("Fatal crash")
monkeypatch.setattr("app.routers.websites.sync_binding", mock_sync_binding)
resp = sync_website_group_bindings(w.id, db_session)
assert resp.total == 1
assert resp.failed == 1
assert resp.logs[0].status == "failed"
assert "Fatal crash" in resp.logs[0].message
# Verify it was PERSISTED in the database
log_in_db = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first()
assert log_in_db is not None
assert log_in_db.status == "failed"
assert "Fatal crash" in log_in_db.message