187 lines
7.5 KiB
Python
187 lines
7.5 KiB
Python
import pytest
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.database import Base
|
|
# Import all models to ensure they are registered with Base.metadata
|
|
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key
|
|
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
|
|
from app.models.upstream import Upstream
|
|
from app.models.snapshot import UpstreamRateSnapshot
|
|
import json
|
|
from datetime import datetime, timezone
|
|
|
|
@pytest.fixture()
|
|
def db_session():
|
|
engine = create_engine(
|
|
"sqlite://",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
Base.metadata.create_all(bind=engine)
|
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
db = TestingSessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
Base.metadata.drop_all(bind=engine)
|
|
|
|
def test_sync_website_group_bindings_success(db_session, monkeypatch):
|
|
from app.routers.websites import sync_website_group_bindings
|
|
|
|
# Setup website
|
|
w = Website(name="W1", base_url="http://w1", enabled=True, auto_sync_enabled=True, auth_config_json="{}", timeout_seconds=30)
|
|
db_session.add(w)
|
|
db_session.commit()
|
|
db_session.refresh(w)
|
|
|
|
# Setup upstream and snapshot
|
|
u = Upstream(name="U1", base_url="http://u1")
|
|
db_session.add(u)
|
|
db_session.commit()
|
|
db_session.refresh(u)
|
|
|
|
db_session.add(UpstreamRateSnapshot(
|
|
upstream_id=u.id,
|
|
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}, "G2": {"rate": 2.0}}}),
|
|
captured_at=datetime.now(timezone.utc)
|
|
))
|
|
db_session.commit()
|
|
|
|
# Setup 2 bindings
|
|
b1 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG1", target_group_name="TG1",
|
|
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G1"}]),
|
|
algorithm="max_plus_percent", percent=10, enabled=True)
|
|
b2 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG2", target_group_name="TG2",
|
|
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G2"}]),
|
|
algorithm="max_plus_percent", percent=10, enabled=True)
|
|
db_session.add_all([b1, b2])
|
|
db_session.commit()
|
|
|
|
# Mock Website client
|
|
class MockClient:
|
|
def __init__(self, **kwargs): pass
|
|
def __enter__(self): return self
|
|
def __exit__(self, *args): pass
|
|
def get_groups(self, endpoint): return [{"id": "TG1", "rate_multiplier": 1.0}, {"id": "TG2", "rate_multiplier": 2.0}]
|
|
def update_group_rate(self, endpoint, gid, rate): pass
|
|
|
|
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
|
|
|
|
# Call batch sync
|
|
resp = sync_website_group_bindings(w.id, db_session)
|
|
|
|
assert resp.total == 2
|
|
assert resp.success == 2
|
|
assert "成功 2" in resp.message
|
|
assert len(resp.logs) == 2
|
|
|
|
# Verify logs in DB
|
|
logs = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).all()
|
|
assert len(logs) == 2
|
|
|
|
def test_sync_website_group_bindings_partial_failure(db_session, monkeypatch):
|
|
from app.routers.websites import sync_website_group_bindings
|
|
|
|
w = Website(name="W1", base_url="http://w1", enabled=True, auto_sync_enabled=True, auth_config_json="{}", timeout_seconds=30)
|
|
db_session.add(w)
|
|
db_session.commit()
|
|
db_session.refresh(w)
|
|
|
|
u = Upstream(name="U1", base_url="http://u1")
|
|
db_session.add(u)
|
|
db_session.commit()
|
|
db_session.refresh(u)
|
|
|
|
db_session.add(UpstreamRateSnapshot(
|
|
upstream_id=u.id,
|
|
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}, "G2": {"rate": 2.0}}}),
|
|
captured_at=datetime.now(timezone.utc)
|
|
))
|
|
db_session.commit()
|
|
|
|
b1 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG1",
|
|
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G1"}]),
|
|
algorithm="max_plus_percent", percent=10, enabled=True)
|
|
b2 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG2",
|
|
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G2"}]),
|
|
algorithm="max_plus_percent", percent=10, enabled=True)
|
|
db_session.add_all([b1, b2])
|
|
db_session.commit()
|
|
|
|
class MockClient:
|
|
def __init__(self, **kwargs): pass
|
|
def __enter__(self): return self
|
|
def __exit__(self, *args): pass
|
|
def get_groups(self, endpoint): return [] # Will cause sync to fail or skip
|
|
def update_group_rate(self, endpoint, gid, rate):
|
|
if gid == "TG2": raise Exception("Simulated error")
|
|
|
|
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
|
|
|
|
# The sync_binding implementation returns a log with status="failed" on exception during write back
|
|
|
|
resp = sync_website_group_bindings(w.id, db_session)
|
|
|
|
# Depending on how sync_binding handles missing target group, it might be "success" (with message) or "failed"
|
|
# Actually if get_groups is empty, sync_binding might fail to find old_rate but it continues to update_group_rate
|
|
# If update_group_rate fails, it returns a failed log.
|
|
|
|
assert resp.total == 2
|
|
# TG1: update_group_rate succeeds -> success
|
|
# TG2: update_group_rate fails -> failed
|
|
assert resp.success == 1
|
|
assert resp.failed == 1
|
|
|
|
def test_sync_website_group_bindings_no_bindings(db_session):
|
|
from app.routers.websites import sync_website_group_bindings
|
|
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
|
|
db_session.add(w)
|
|
db_session.commit()
|
|
db_session.refresh(w)
|
|
|
|
resp = sync_website_group_bindings(w.id, db_session)
|
|
assert resp.total == 0
|
|
assert resp.message == "暂无绑定可同步"
|
|
|
|
def test_sync_website_group_bindings_not_found(db_session):
|
|
from app.routers.websites import sync_website_group_bindings
|
|
from fastapi import HTTPException
|
|
with pytest.raises(HTTPException) as exc:
|
|
sync_website_group_bindings(999, db_session)
|
|
assert exc.value.status_code == 404
|
|
|
|
def test_sync_website_group_bindings_exception_persists_log(db_session, monkeypatch):
|
|
from app.routers.websites import sync_website_group_bindings
|
|
|
|
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
|
|
db_session.add(w)
|
|
db_session.commit()
|
|
|
|
b1 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG1",
|
|
source_groups_json="[]",
|
|
algorithm="max_plus_percent", percent=10, enabled=True)
|
|
db_session.add(b1)
|
|
db_session.commit()
|
|
|
|
# Mock sync_binding to raise an exception
|
|
def mock_sync_binding(db, binding, write):
|
|
raise Exception("Fatal crash")
|
|
|
|
monkeypatch.setattr("app.routers.websites.sync_binding", mock_sync_binding)
|
|
|
|
resp = sync_website_group_bindings(w.id, db_session)
|
|
|
|
assert resp.total == 1
|
|
assert resp.failed == 1
|
|
assert resp.logs[0].status == "failed"
|
|
assert "Fatal crash" in resp.logs[0].message
|
|
|
|
# Verify it was PERSISTED in the database
|
|
log_in_db = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first()
|
|
assert log_in_db is not None
|
|
assert log_in_db.status == "failed"
|
|
assert "Fatal crash" in log_in_db.message
|