feat: sync account priorities after rate changes
This commit is contained in:
@@ -0,0 +1,221 @@
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import Base
|
||||
from app.models.upstream import Upstream
|
||||
from app.models.website import Website, WebsiteSyncLog
|
||||
from app.models.upstream_key import UpstreamGeneratedKey
|
||||
from app.models.snapshot import UpstreamRateSnapshot
|
||||
from app.services.website_sync import (
|
||||
build_rate_priority_map,
|
||||
sync_account_priorities_for_upstream
|
||||
)
|
||||
from app.services.website_client import Sub2ApiWebsiteClient
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session():
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = TestingSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
def test_priority_sync_cross_upstream_group(db_session):
|
||||
# Setup 2 upstreams
|
||||
u1 = Upstream(name="U1", base_url="http://u1")
|
||||
u2 = Upstream(name="U2", base_url="http://u2")
|
||||
db_session.add_all([u1, u2])
|
||||
db_session.commit()
|
||||
db_session.refresh(u1)
|
||||
db_session.refresh(u2)
|
||||
|
||||
# Setup snapshots for both with same group ID "VIP" but different rates
|
||||
s1 = UpstreamRateSnapshot(
|
||||
upstream_id=u1.id,
|
||||
snapshot_json=json.dumps({"groups": {"VIP": {"group_name": "VIP", "rate": 1.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
)
|
||||
s2 = UpstreamRateSnapshot(
|
||||
upstream_id=u2.id,
|
||||
snapshot_json=json.dumps({"groups": {"VIP": {"group_name": "VIP", "rate": 2.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db_session.add_all([s1, s2])
|
||||
db_session.commit()
|
||||
|
||||
priority_map = build_rate_priority_map(db_session, {u1.id, u2.id})
|
||||
|
||||
assert priority_map[f"{u1.id}:VIP"] == 1
|
||||
assert priority_map[f"{u2.id}:VIP"] == 2
|
||||
assert len(priority_map) == 2
|
||||
|
||||
def test_priority_sync_full_website_update(db_session, monkeypatch):
|
||||
# Setup website and upstreams
|
||||
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
|
||||
u1 = Upstream(name="U1", base_url="http://u1")
|
||||
u2 = Upstream(name="U2", base_url="http://u2")
|
||||
db_session.add_all([w, u1, u2])
|
||||
db_session.commit()
|
||||
db_session.refresh(w)
|
||||
db_session.refresh(u1)
|
||||
db_session.refresh(u2)
|
||||
|
||||
# Setup snapshots
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u1.id,
|
||||
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u2.id,
|
||||
snapshot_json=json.dumps({"groups": {"G2": {"rate": 2.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
db_session.commit()
|
||||
|
||||
# Setup keys imported to website
|
||||
k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", key_name="K1", key_value="V1",
|
||||
imported_website_id=w.id, imported_account_id="A1")
|
||||
k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", key_name="K2", key_value="V2",
|
||||
imported_website_id=w.id, imported_account_id="A2")
|
||||
db_session.add_all([k1, k2])
|
||||
db_session.commit()
|
||||
|
||||
# Mock Sub2ApiWebsiteClient
|
||||
update_calls = []
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs): pass
|
||||
def __enter__(self): return self
|
||||
def __exit__(self, *args): pass
|
||||
def update_account(self, account_id, data):
|
||||
update_calls.append((account_id, data))
|
||||
|
||||
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
|
||||
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient)
|
||||
|
||||
# Trigger sync for U1
|
||||
sync_account_priorities_for_upstream(db_session, u1.id)
|
||||
|
||||
# Verify BOTH A1 and A2 were updated because they belong to the same website
|
||||
assert len(update_calls) == 2
|
||||
account_ids = {c[0] for c in update_calls}
|
||||
assert account_ids == {"A1", "A2"}
|
||||
|
||||
# Priority check: G1(1.0) -> 1, G2(2.0) -> 2
|
||||
for aid, data in update_calls:
|
||||
if aid == "A1": assert data["priority"] == 1
|
||||
if aid == "A2": assert data["priority"] == 2
|
||||
|
||||
def test_priority_sync_log_structure(db_session, monkeypatch):
|
||||
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
|
||||
u1 = Upstream(name="U1", base_url="http://u1")
|
||||
db_session.add_all([w, u1])
|
||||
db_session.commit()
|
||||
db_session.refresh(w)
|
||||
db_session.refresh(u1)
|
||||
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u1.id,
|
||||
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
db_session.add(UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", key_name="K1", key_value="V1",
|
||||
imported_website_id=w.id, imported_account_id="A1"))
|
||||
db_session.commit()
|
||||
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs): pass
|
||||
def __enter__(self): return self
|
||||
def __exit__(self, *args): pass
|
||||
def update_account(self, account_id, data): pass
|
||||
|
||||
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
|
||||
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient)
|
||||
|
||||
sync_account_priorities_for_upstream(db_session, u1.id)
|
||||
|
||||
log = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first()
|
||||
assert log is not None
|
||||
assert log.algorithm == "priority_sync"
|
||||
|
||||
data = json.loads(log.source_rates_json)
|
||||
# The first item should be the priority map metadata
|
||||
assert data[0]["_meta"] == "priority_map"
|
||||
assert f"{u1.id}:G1" in data[0]["data"]
|
||||
# The second item should be the account result
|
||||
assert data[1]["account_id"] == "A1"
|
||||
|
||||
def test_import_auto_priority_by_rate(db_session, monkeypatch):
|
||||
from app.routers.websites import import_upstream_keys_as_accounts
|
||||
from app.schemas.website import ImportAccountsRequest
|
||||
|
||||
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}",
|
||||
groups_endpoint="/groups", group_update_endpoint="/groups/{id}", timeout_seconds=30)
|
||||
u1 = Upstream(name="U1", base_url="http://u1")
|
||||
u2 = Upstream(name="U2", base_url="http://u2")
|
||||
db_session.add_all([w, u1, u2])
|
||||
db_session.commit()
|
||||
db_session.refresh(w)
|
||||
db_session.refresh(u1)
|
||||
db_session.refresh(u2)
|
||||
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u1.id,
|
||||
snapshot_json=json.dumps({"groups": {"G1": {"rate": 2.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u2.id,
|
||||
snapshot_json=json.dumps({"groups": {"G2": {"rate": 1.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
|
||||
k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", group_name="G1", key_name="K1", key_value="V1")
|
||||
k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", group_name="G2", key_name="K2", key_value="V2")
|
||||
db_session.add_all([k1, k2])
|
||||
db_session.commit()
|
||||
|
||||
created_accounts = []
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs): pass
|
||||
def __enter__(self): return self
|
||||
def __exit__(self, *args): pass
|
||||
def create_account(self, body):
|
||||
created_accounts.append(body)
|
||||
return {"id": f"remote-{len(created_accounts)}", "name": body["name"]}
|
||||
def extract_id(self, data): return data["id"]
|
||||
def account_exists(self, aid): return False
|
||||
|
||||
monkeypatch.setattr("app.routers.websites._client", lambda website: MockClient())
|
||||
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient)
|
||||
|
||||
req = ImportAccountsRequest(
|
||||
upstream_key_ids=[k1.id, k2.id],
|
||||
target_group_map={},
|
||||
auto_priority_by_rate=True,
|
||||
priority=10,
|
||||
account_name_prefix="test",
|
||||
default_platform="openai"
|
||||
)
|
||||
|
||||
import_upstream_keys_as_accounts(w.id, req, db_session)
|
||||
|
||||
assert len(created_accounts) == 2
|
||||
# G2 has rate 1.0 -> priority 1
|
||||
# G1 has rate 2.0 -> priority 2
|
||||
p1 = next(a["priority"] for a in created_accounts if "G1" in a["name"])
|
||||
p2 = next(a["priority"] for a in created_accounts if "G2" in a["name"])
|
||||
assert p2 == 1
|
||||
assert p1 == 2
|
||||
Reference in New Issue
Block a user