Files
SmartUp/backend/test_priority_sync.py
T
2026-06-03 17:03:11 +08:00

492 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
优先级同步测试套件 — 分组内竞争逻辑
核心规则(测试覆盖):
- 竞争分组键 = imported_target_group_id or group_id(老数据 fallback
- 只有同一竞争分组内账号数 > 1 时才更新 priority / 发通知
- 不同分组各 1 个账号:不调用 update_account,不发通知
- 同一目标分组多账号:组内按倍率升序独立排序,priority 从 1 开始,每档间隔 10
- 两个目标分组各有多账号:彼此独立,每组内 priority 都从 1 开始,每档间隔 10
- 老数据 imported_target_group_id=NULLfallback group_id,不报错
"""
import json
from datetime import datetime, timezone
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.upstream import Upstream
from app.models.website import Website, WebsiteSyncLog
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.snapshot import UpstreamRateSnapshot
from app.services.website_sync import (
build_rate_priority_map,
sync_account_priorities_for_upstream,
)
from app.services.website_client import Sub2ApiWebsiteClient
@pytest.fixture()
def db_session():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
# ── 辅助 ─────────────────────────────────────────────────────────────────────
def _make_snapshot(db, upstream_id: int, groups: dict):
"""插入快照 groups = {group_id: rate}。"""
snap = {
"groups": {
gid: {"group_name": gid, "rate": rate}
for gid, rate in groups.items()
}
}
db.add(UpstreamRateSnapshot(
upstream_id=upstream_id,
snapshot_json=json.dumps(snap),
captured_at=datetime.now(timezone.utc),
))
db.commit()
def _make_key(db, upstream_id, group_id, key_name, key_value,
website_id, account_id, imported_target_group_id=None):
k = UpstreamGeneratedKey(
upstream_id=upstream_id,
group_id=group_id,
group_name=group_id,
key_name=key_name,
key_value=key_value,
imported_website_id=website_id,
imported_account_id=account_id,
imported_target_group_id=imported_target_group_id,
)
db.add(k)
db.commit()
return k
class MockClient:
"""可注入的 update_account 记录器。"""
_calls: list # 由子类/工厂绑定
def __init__(self, **kwargs):
pass # calls 由工厂方法注入
def __enter__(self): return self
def __exit__(self, *a): pass
def update_account(self, account_id, data):
type(self)._shared_calls.append((account_id, data))
def make_mock_client(calls: list):
"""返回一个与 calls 列表绑定的 MockClient 类。"""
class _Client(MockClient):
_shared_calls = calls
return _Client
# ── 用例 ─────────────────────────────────────────────────────────────────────
def test_no_update_when_different_groups_single_account_each(db_session, monkeypatch):
"""不同分组各 1 个账号 → 无竞争 → 不调用 update_account,不发通知。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"G1": 1.0})
_make_snapshot(db_session, u2.id, {"G2": 2.0})
# 不同分组,各 1 个账号,imported_target_group_id=None (fallback group_id)
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1")
_make_key(db_session, u2.id, "G2", "K2", "V2", w.id, "A2")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
results = sync_account_priorities_for_upstream(db_session, u1.id)
# 无竞争分组 → 直接返回空列表,不调用 update_account
assert update_calls == [], f"不应有更新调用,实际:{update_calls}"
assert results == []
# 不写日志
logs = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).all()
assert logs == []
def test_same_target_group_two_accounts_updated(db_session, monkeypatch):
"""同一目标分组 2 个账号 → 有竞争 → 按倍率更新 priority。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"G1": 1.0})
_make_snapshot(db_session, u2.id, {"G2": 2.0})
# 两个账号都属于目标分组 "TG1"
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u2.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
results = sync_account_priorities_for_upstream(db_session, u1.id)
assert len(update_calls) == 2
priority_map = {aid: data["priority"] for aid, data in update_calls}
# G1 rate=1.0 → priority=1(低倍率优先);G2 rate=2.0 → priority=11
assert priority_map["A1"] == 1
assert priority_map["A2"] == 11
# 写了日志
log = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first()
assert log is not None
assert log.algorithm == "priority_sync"
def test_two_target_groups_independent_priority(db_session, monkeypatch):
"""两个目标分组各有多账号 → 每组内独立从 1 开始排序,不互相影响。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w, u1])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1)
_make_snapshot(db_session, u1.id, {
"G1": 1.0, # → TG1 中排 priority=1
"G2": 2.0, # → TG1 中排 priority=11
"G3": 0.5, # → TG2 中排 priority=1
"G4": 3.0, # → TG2 中排 priority=11
})
# TG1: G1(1.0), G2(2.0)
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
# TG2: G3(0.5), G4(3.0)
_make_key(db_session, u1.id, "G3", "K3", "V3", w.id, "A3", imported_target_group_id="TG2")
_make_key(db_session, u1.id, "G4", "K4", "V4", w.id, "A4", imported_target_group_id="TG2")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
sync_account_priorities_for_upstream(db_session, u1.id)
priority_map = {aid: data["priority"] for aid, data in update_calls}
assert len(update_calls) == 4
# TG1 内部:G1(1.0)→p1, G2(2.0)→p11
assert priority_map["A1"] == 1
assert priority_map["A2"] == 11
# TG2 内部:G3(0.5)→p1, G4(3.0)→p11(独立从 1 开始)
assert priority_map["A3"] == 1
assert priority_map["A4"] == 11
def test_old_data_null_target_group_fallback(db_session, monkeypatch):
"""老数据 imported_target_group_id=NULL → fallback group_id,不报错。
两个账号同 group_id(极端边界),视为同一竞争分组 → 有竞争 → 更新。
"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"GSHARED": 1.0})
_make_snapshot(db_session, u2.id, {"GSHARED": 2.0})
# 老数据:imported_target_group_id=None,两账号 group_id 相同
_make_key(db_session, u1.id, "GSHARED", "K1", "V1", w.id, "A1", imported_target_group_id=None)
_make_key(db_session, u2.id, "GSHARED", "K2", "V2", w.id, "A2", imported_target_group_id=None)
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
# 不应抛异常
results = sync_account_priorities_for_upstream(db_session, u1.id)
# 同 group_id="GSHARED" → 竞争分组,两账号都更新
assert len(update_calls) == 2
def test_single_account_in_mixed_website(db_session, monkeypatch):
"""同一网站:一个目标分组有 2 账号(竞争),另一个目标分组只有 1 账号(不参与)。
只有竞争分组的 2 个账号被更新,单账号分组不调用 update_account。
"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w, u1])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1)
_make_snapshot(db_session, u1.id, {"G1": 1.0, "G2": 2.0, "G3": 3.0})
# TG1: G1 + G2 → 竞争
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
# TG2: 只有 G3 → 单账号,跳过
_make_key(db_session, u1.id, "G3", "K3", "V3", w.id, "A3", imported_target_group_id="TG2")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
sync_account_priorities_for_upstream(db_session, u1.id)
updated_ids = {c[0] for c in update_calls}
assert "A3" not in updated_ids, "单账号分组不应被更新"
assert updated_ids == {"A1", "A2"}
def test_missing_rate_skips_entire_competitive_group(db_session, monkeypatch):
"""竞争分组内所有账号均无快照倍率 → 有效账号 < 2 → 整组跳过,不调用 update_account。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w, u1])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1)
# 故意不插快照 → raw_rate_map 为空
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
results = sync_account_priorities_for_upstream(db_session, u1.id)
assert update_calls == [], "无快照倍率时不应调用 update_account"
assert results == []
def test_partial_missing_rate_sufficient_accounts_still_updates(db_session, monkeypatch):
"""竞争分组内 3 账号,1 个无快照倍率,剩余 2 个有倍率 → 仍有竞争,2 个有倍率的账号正常更新。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w, u1])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1)
# G1、G2 有快照,G3 没有快照
_make_snapshot(db_session, u1.id, {"G1": 1.0, "G2": 2.0})
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G3", "K3", "V3", w.id, "A3", imported_target_group_id="TG1")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
sync_account_priorities_for_upstream(db_session, u1.id)
updated = {c[0]: c[1]["priority"] for c in update_calls}
# A3 无快照 → 不参与排序,不被更新
assert "A3" not in updated
# A1(G1, rate=1.0) → priority=1A2(G2, rate=2.0) → priority=11
assert updated["A1"] == 1
assert updated["A2"] == 11
def test_priority_sync_log_structure(db_session, monkeypatch):
"""日志写入格式验证(需要竞争分组才会写日志)。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w, u1])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1)
_make_snapshot(db_session, u1.id, {"G1": 1.0, "G2": 2.0})
# 同一目标分组 2 账号 → 有日志
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
class LoggingMockClient:
def __init__(self, **kwargs): pass
def __enter__(self): return self
def __exit__(self, *a): pass
def update_account(self, account_id, data): pass
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", LoggingMockClient)
sync_account_priorities_for_upstream(db_session, u1.id)
log = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first()
assert log is not None
assert log.algorithm == "priority_sync"
data = json.loads(log.source_rates_json)
# 第一项是 priority_map 元数据
assert data[0]["_meta"] == "priority_map"
# 后续项是账号结果
account_ids = {item["account_id"] for item in data[1:] if "account_id" in item}
assert {"A1", "A2"} == account_ids
# ── 保留:build_rate_priority_map 单元测试(供初始导入使用)────────────────
def test_priority_sync_cross_upstream_group(db_session):
"""build_rate_priority_map:相同 group_id 不同上游不互相覆盖。"""
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([u1, u2])
db_session.commit()
db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"VIP": 1.0})
_make_snapshot(db_session, u2.id, {"VIP": 2.0})
priority_map = build_rate_priority_map(db_session, {u1.id, u2.id})
assert priority_map[f"{u1.id}:VIP"] == 1
assert priority_map[f"{u2.id}:VIP"] == 11
assert len(priority_map) == 2
def test_import_auto_priority_by_rate(db_session, monkeypatch):
"""初始导入时 auto_priority_by_rate=True 按全局倍率分配 priority。"""
from app.routers.websites import import_upstream_keys_as_accounts
from app.schemas.website import ImportAccountsRequest
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}",
groups_endpoint="/groups", group_update_endpoint="/groups/{id}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"G1": 2.0})
_make_snapshot(db_session, u2.id, {"G2": 1.0})
k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", group_name="G1",
key_name="K1", key_value="V1")
k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", group_name="G2",
key_name="K2", key_value="V2")
db_session.add_all([k1, k2])
db_session.commit()
created_accounts = []
class MockImportClient:
def __init__(self, **kwargs): pass
def __enter__(self): return self
def __exit__(self, *a): pass
def create_account(self, body):
created_accounts.append(body)
return {"id": f"remote-{len(created_accounts)}", "name": body["name"]}
def extract_id(self, data): return data["id"]
def account_exists(self, aid): return False
monkeypatch.setattr("app.routers.websites._client", lambda website: MockImportClient())
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockImportClient)
req = ImportAccountsRequest(
upstream_key_ids=[k1.id, k2.id],
target_group_map={},
auto_priority_by_rate=True,
priority=10,
account_name_prefix="test",
default_platform="openai",
)
import_upstream_keys_as_accounts(w.id, req, db_session)
assert len(created_accounts) == 2
# G2 rate=1.0 → priority 1G1 rate=2.0 → priority 11
p1 = next(a["priority"] for a in created_accounts if "G1" in a["name"])
p2 = next(a["priority"] for a in created_accounts if "G2" in a["name"])
assert p2 == 1
assert p1 == 11
def test_reorder_priority_endpoint_scopes_to_current_website(db_session, monkeypatch):
"""手动重排只影响当前网站,不误改同一上游导入到其他网站的账号。"""
from app.routers.websites import reorder_account_priorities
from app.schemas.website import ReorderPriorityRequest
w1 = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
w2 = Website(name="W2", base_url="http://w2", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w1, w2, u1])
db_session.commit()
db_session.refresh(w1); db_session.refresh(w2); db_session.refresh(u1)
_make_snapshot(db_session, u1.id, {"G1": 1.0, "G2": 2.0})
_make_key(db_session, u1.id, "G1", "K1", "V1", w1.id, "W1A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w1.id, "W1A2", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G1", "K3", "V3", w2.id, "W2A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K4", "V4", w2.id, "W2A2", imported_target_group_id="TG1")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
response = reorder_account_priorities(
w1.id,
ReorderPriorityRequest(upstream_id=u1.id),
db_session,
)
updated = {account_id: payload["priority"] for account_id, payload in update_calls}
assert response.success is True
assert updated == {"W1A1": 1, "W1A2": 11}