222 lines
8.3 KiB
Python
222 lines
8.3 KiB
Python
import json
|
|
from datetime import datetime, timezone
|
|
import pytest
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.database import Base
|
|
from app.models.upstream import Upstream
|
|
from app.models.website import Website, WebsiteSyncLog
|
|
from app.models.upstream_key import UpstreamGeneratedKey
|
|
from app.models.snapshot import UpstreamRateSnapshot
|
|
from app.services.website_sync import (
|
|
build_rate_priority_map,
|
|
sync_account_priorities_for_upstream
|
|
)
|
|
from app.services.website_client import Sub2ApiWebsiteClient
|
|
|
|
@pytest.fixture()
|
|
def db_session():
|
|
engine = create_engine(
|
|
"sqlite://",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
Base.metadata.create_all(bind=engine)
|
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
db = TestingSessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
Base.metadata.drop_all(bind=engine)
|
|
|
|
def test_priority_sync_cross_upstream_group(db_session):
|
|
# Setup 2 upstreams
|
|
u1 = Upstream(name="U1", base_url="http://u1")
|
|
u2 = Upstream(name="U2", base_url="http://u2")
|
|
db_session.add_all([u1, u2])
|
|
db_session.commit()
|
|
db_session.refresh(u1)
|
|
db_session.refresh(u2)
|
|
|
|
# Setup snapshots for both with same group ID "VIP" but different rates
|
|
s1 = UpstreamRateSnapshot(
|
|
upstream_id=u1.id,
|
|
snapshot_json=json.dumps({"groups": {"VIP": {"group_name": "VIP", "rate": 1.0}}}),
|
|
captured_at=datetime.now(timezone.utc)
|
|
)
|
|
s2 = UpstreamRateSnapshot(
|
|
upstream_id=u2.id,
|
|
snapshot_json=json.dumps({"groups": {"VIP": {"group_name": "VIP", "rate": 2.0}}}),
|
|
captured_at=datetime.now(timezone.utc)
|
|
)
|
|
db_session.add_all([s1, s2])
|
|
db_session.commit()
|
|
|
|
priority_map = build_rate_priority_map(db_session, {u1.id, u2.id})
|
|
|
|
assert priority_map[f"{u1.id}:VIP"] == 1
|
|
assert priority_map[f"{u2.id}:VIP"] == 2
|
|
assert len(priority_map) == 2
|
|
|
|
def test_priority_sync_full_website_update(db_session, monkeypatch):
|
|
# Setup website and upstreams
|
|
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
|
|
u1 = Upstream(name="U1", base_url="http://u1")
|
|
u2 = Upstream(name="U2", base_url="http://u2")
|
|
db_session.add_all([w, u1, u2])
|
|
db_session.commit()
|
|
db_session.refresh(w)
|
|
db_session.refresh(u1)
|
|
db_session.refresh(u2)
|
|
|
|
# Setup snapshots
|
|
db_session.add(UpstreamRateSnapshot(
|
|
upstream_id=u1.id,
|
|
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}}}),
|
|
captured_at=datetime.now(timezone.utc)
|
|
))
|
|
db_session.add(UpstreamRateSnapshot(
|
|
upstream_id=u2.id,
|
|
snapshot_json=json.dumps({"groups": {"G2": {"rate": 2.0}}}),
|
|
captured_at=datetime.now(timezone.utc)
|
|
))
|
|
db_session.commit()
|
|
|
|
# Setup keys imported to website
|
|
k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", key_name="K1", key_value="V1",
|
|
imported_website_id=w.id, imported_account_id="A1")
|
|
k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", key_name="K2", key_value="V2",
|
|
imported_website_id=w.id, imported_account_id="A2")
|
|
db_session.add_all([k1, k2])
|
|
db_session.commit()
|
|
|
|
# Mock Sub2ApiWebsiteClient
|
|
update_calls = []
|
|
class MockClient:
|
|
def __init__(self, **kwargs): pass
|
|
def __enter__(self): return self
|
|
def __exit__(self, *args): pass
|
|
def update_account(self, account_id, data):
|
|
update_calls.append((account_id, data))
|
|
|
|
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
|
|
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient)
|
|
|
|
# Trigger sync for U1
|
|
sync_account_priorities_for_upstream(db_session, u1.id)
|
|
|
|
# Verify BOTH A1 and A2 were updated because they belong to the same website
|
|
assert len(update_calls) == 2
|
|
account_ids = {c[0] for c in update_calls}
|
|
assert account_ids == {"A1", "A2"}
|
|
|
|
# Priority check: G1(1.0) -> 1, G2(2.0) -> 2
|
|
for aid, data in update_calls:
|
|
if aid == "A1": assert data["priority"] == 1
|
|
if aid == "A2": assert data["priority"] == 2
|
|
|
|
def test_priority_sync_log_structure(db_session, monkeypatch):
|
|
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
|
|
u1 = Upstream(name="U1", base_url="http://u1")
|
|
db_session.add_all([w, u1])
|
|
db_session.commit()
|
|
db_session.refresh(w)
|
|
db_session.refresh(u1)
|
|
|
|
db_session.add(UpstreamRateSnapshot(
|
|
upstream_id=u1.id,
|
|
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}}}),
|
|
captured_at=datetime.now(timezone.utc)
|
|
))
|
|
db_session.add(UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", key_name="K1", key_value="V1",
|
|
imported_website_id=w.id, imported_account_id="A1"))
|
|
db_session.commit()
|
|
|
|
class MockClient:
|
|
def __init__(self, **kwargs): pass
|
|
def __enter__(self): return self
|
|
def __exit__(self, *args): pass
|
|
def update_account(self, account_id, data): pass
|
|
|
|
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
|
|
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient)
|
|
|
|
sync_account_priorities_for_upstream(db_session, u1.id)
|
|
|
|
log = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first()
|
|
assert log is not None
|
|
assert log.algorithm == "priority_sync"
|
|
|
|
data = json.loads(log.source_rates_json)
|
|
# The first item should be the priority map metadata
|
|
assert data[0]["_meta"] == "priority_map"
|
|
assert f"{u1.id}:G1" in data[0]["data"]
|
|
# The second item should be the account result
|
|
assert data[1]["account_id"] == "A1"
|
|
|
|
def test_import_auto_priority_by_rate(db_session, monkeypatch):
|
|
from app.routers.websites import import_upstream_keys_as_accounts
|
|
from app.schemas.website import ImportAccountsRequest
|
|
|
|
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}",
|
|
groups_endpoint="/groups", group_update_endpoint="/groups/{id}", timeout_seconds=30)
|
|
u1 = Upstream(name="U1", base_url="http://u1")
|
|
u2 = Upstream(name="U2", base_url="http://u2")
|
|
db_session.add_all([w, u1, u2])
|
|
db_session.commit()
|
|
db_session.refresh(w)
|
|
db_session.refresh(u1)
|
|
db_session.refresh(u2)
|
|
|
|
db_session.add(UpstreamRateSnapshot(
|
|
upstream_id=u1.id,
|
|
snapshot_json=json.dumps({"groups": {"G1": {"rate": 2.0}}}),
|
|
captured_at=datetime.now(timezone.utc)
|
|
))
|
|
db_session.add(UpstreamRateSnapshot(
|
|
upstream_id=u2.id,
|
|
snapshot_json=json.dumps({"groups": {"G2": {"rate": 1.0}}}),
|
|
captured_at=datetime.now(timezone.utc)
|
|
))
|
|
|
|
k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", group_name="G1", key_name="K1", key_value="V1")
|
|
k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", group_name="G2", key_name="K2", key_value="V2")
|
|
db_session.add_all([k1, k2])
|
|
db_session.commit()
|
|
|
|
created_accounts = []
|
|
class MockClient:
|
|
def __init__(self, **kwargs): pass
|
|
def __enter__(self): return self
|
|
def __exit__(self, *args): pass
|
|
def create_account(self, body):
|
|
created_accounts.append(body)
|
|
return {"id": f"remote-{len(created_accounts)}", "name": body["name"]}
|
|
def extract_id(self, data): return data["id"]
|
|
def account_exists(self, aid): return False
|
|
|
|
monkeypatch.setattr("app.routers.websites._client", lambda website: MockClient())
|
|
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient)
|
|
|
|
req = ImportAccountsRequest(
|
|
upstream_key_ids=[k1.id, k2.id],
|
|
target_group_map={},
|
|
auto_priority_by_rate=True,
|
|
priority=10,
|
|
account_name_prefix="test",
|
|
default_platform="openai"
|
|
)
|
|
|
|
import_upstream_keys_as_accounts(w.id, req, db_session)
|
|
|
|
assert len(created_accounts) == 2
|
|
# G2 has rate 1.0 -> priority 1
|
|
# G1 has rate 2.0 -> priority 2
|
|
p1 = next(a["priority"] for a in created_accounts if "G1" in a["name"])
|
|
p2 = next(a["priority"] for a in created_accounts if "G2" in a["name"])
|
|
assert p2 == 1
|
|
assert p1 == 2
|