Files
SmartUp/backend/test_priority_sync.py
T
liumangmang e519d1804b fix(priority-sync): narrow account priority update to competitive groups only
Root cause: sync_account_priorities_for_upstream() was doing a global
priority re-rank across ALL imported accounts on a website whenever any
upstream rate changed, triggering spurious account_priority_changed
notifications for accounts in different target groups with no competition.

Fix:
- Add imported_target_group_id / imported_target_group_name to
  UpstreamGeneratedKey (nullable; old data falls back to group_id)
- Writ imported_target_group_id on account import in websites.py
- Rewrite sync_account_priorities_for_upstream():
  * bucket accounts by competition_group = imported_target_group_id or group_id
  * only process buckets with count > 1 (genuine competition)
  * each competitive bucket independently sorted by rate; priority starts at 1
  * single-account groups: completely skipped (no update_account, no notification)
  * no competitive groups at all: early return, no log, no notification
- Remove auto priority update in re-import idempotency path (was also
  incorrect; now fully delegated to sync_account_priorities_for_upstream)
- Fix Sub2ApiWebsiteClient local import in sync fn → use module-level name
  so monkeypatch works correctly in tests

Tests: rewrite test_priority_sync.py
- REMOVED: test_priority_sync_full_website_update (was asserting the buggy behavior)
- NEW: test_no_update_when_different_groups_single_account_each
- NEW: test_same_target_group_two_accounts_updated
- NEW: test_two_target_groups_independent_priority
- NEW: test_old_data_null_target_group_fallback
- NEW: test_single_account_in_mixed_website
- UPDATED: test_priority_sync_log_structure (now requires competitive group)
- KEPT: test_priority_sync_cross_upstream_group, test_import_auto_priority_by_rate

All 25 tests pass (8 priority_sync + 17 existing upstream tests).
2026-06-01 19:13:14 +08:00

397 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
优先级同步测试套件 — 分组内竞争逻辑
核心规则(测试覆盖):
- 竞争分组键 = imported_target_group_id or group_id(老数据 fallback
- 只有同一竞争分组内账号数 > 1 时才更新 priority / 发通知
- 不同分组各 1 个账号:不调用 update_account,不发通知
- 同一目标分组多账号:组内按倍率升序独立排序,priority 从 1 开始
- 两个目标分组各有多账号:彼此独立,每组内 priority 都从 1 开始
- 老数据 imported_target_group_id=NULLfallback group_id,不报错
"""
import json
from datetime import datetime, timezone
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.upstream import Upstream
from app.models.website import Website, WebsiteSyncLog
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.snapshot import UpstreamRateSnapshot
from app.services.website_sync import (
build_rate_priority_map,
sync_account_priorities_for_upstream,
)
from app.services.website_client import Sub2ApiWebsiteClient
@pytest.fixture()
def db_session():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
# ── 辅助 ─────────────────────────────────────────────────────────────────────
def _make_snapshot(db, upstream_id: int, groups: dict):
"""插入快照 groups = {group_id: rate}。"""
snap = {
"groups": {
gid: {"group_name": gid, "rate": rate}
for gid, rate in groups.items()
}
}
db.add(UpstreamRateSnapshot(
upstream_id=upstream_id,
snapshot_json=json.dumps(snap),
captured_at=datetime.now(timezone.utc),
))
db.commit()
def _make_key(db, upstream_id, group_id, key_name, key_value,
website_id, account_id, imported_target_group_id=None):
k = UpstreamGeneratedKey(
upstream_id=upstream_id,
group_id=group_id,
group_name=group_id,
key_name=key_name,
key_value=key_value,
imported_website_id=website_id,
imported_account_id=account_id,
imported_target_group_id=imported_target_group_id,
)
db.add(k)
db.commit()
return k
class MockClient:
"""可注入的 update_account 记录器。"""
_calls: list # 由子类/工厂绑定
def __init__(self, **kwargs):
pass # calls 由工厂方法注入
def __enter__(self): return self
def __exit__(self, *a): pass
def update_account(self, account_id, data):
type(self)._shared_calls.append((account_id, data))
def make_mock_client(calls: list):
"""返回一个与 calls 列表绑定的 MockClient 类。"""
class _Client(MockClient):
_shared_calls = calls
return _Client
# ── 用例 ─────────────────────────────────────────────────────────────────────
def test_no_update_when_different_groups_single_account_each(db_session, monkeypatch):
"""不同分组各 1 个账号 → 无竞争 → 不调用 update_account,不发通知。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"G1": 1.0})
_make_snapshot(db_session, u2.id, {"G2": 2.0})
# 不同分组,各 1 个账号,imported_target_group_id=None (fallback group_id)
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1")
_make_key(db_session, u2.id, "G2", "K2", "V2", w.id, "A2")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
results = sync_account_priorities_for_upstream(db_session, u1.id)
# 无竞争分组 → 直接返回空列表,不调用 update_account
assert update_calls == [], f"不应有更新调用,实际:{update_calls}"
assert results == []
# 不写日志
logs = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).all()
assert logs == []
def test_same_target_group_two_accounts_updated(db_session, monkeypatch):
"""同一目标分组 2 个账号 → 有竞争 → 按倍率更新 priority。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"G1": 1.0})
_make_snapshot(db_session, u2.id, {"G2": 2.0})
# 两个账号都属于目标分组 "TG1"
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u2.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
results = sync_account_priorities_for_upstream(db_session, u1.id)
assert len(update_calls) == 2
priority_map = {aid: data["priority"] for aid, data in update_calls}
# G1 rate=1.0 → priority=1(低倍率优先);G2 rate=2.0 → priority=2
assert priority_map["A1"] == 1
assert priority_map["A2"] == 2
# 写了日志
log = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first()
assert log is not None
assert log.algorithm == "priority_sync"
def test_two_target_groups_independent_priority(db_session, monkeypatch):
"""两个目标分组各有多账号 → 每组内独立从 1 开始排序,不互相影响。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w, u1])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1)
_make_snapshot(db_session, u1.id, {
"G1": 1.0, # → TG1 中排 priority=1
"G2": 2.0, # → TG1 中排 priority=2
"G3": 0.5, # → TG2 中排 priority=1
"G4": 3.0, # → TG2 中排 priority=2
})
# TG1: G1(1.0), G2(2.0)
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
# TG2: G3(0.5), G4(3.0)
_make_key(db_session, u1.id, "G3", "K3", "V3", w.id, "A3", imported_target_group_id="TG2")
_make_key(db_session, u1.id, "G4", "K4", "V4", w.id, "A4", imported_target_group_id="TG2")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
sync_account_priorities_for_upstream(db_session, u1.id)
priority_map = {aid: data["priority"] for aid, data in update_calls}
assert len(update_calls) == 4
# TG1 内部:G1(1.0)→p1, G2(2.0)→p2
assert priority_map["A1"] == 1
assert priority_map["A2"] == 2
# TG2 内部:G3(0.5)→p1, G4(3.0)→p2(独立从 1 开始)
assert priority_map["A3"] == 1
assert priority_map["A4"] == 2
def test_old_data_null_target_group_fallback(db_session, monkeypatch):
"""老数据 imported_target_group_id=NULL → fallback group_id,不报错。
两个账号同 group_id(极端边界),视为同一竞争分组 → 有竞争 → 更新。
"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"GSHARED": 1.0})
_make_snapshot(db_session, u2.id, {"GSHARED": 2.0})
# 老数据:imported_target_group_id=None,两账号 group_id 相同
_make_key(db_session, u1.id, "GSHARED", "K1", "V1", w.id, "A1", imported_target_group_id=None)
_make_key(db_session, u2.id, "GSHARED", "K2", "V2", w.id, "A2", imported_target_group_id=None)
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
# 不应抛异常
results = sync_account_priorities_for_upstream(db_session, u1.id)
# 同 group_id="GSHARED" → 竞争分组,两账号都更新
assert len(update_calls) == 2
def test_single_account_in_mixed_website(db_session, monkeypatch):
"""同一网站:一个目标分组有 2 账号(竞争),另一个目标分组只有 1 账号(不参与)。
只有竞争分组的 2 个账号被更新,单账号分组不调用 update_account。
"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w, u1])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1)
_make_snapshot(db_session, u1.id, {"G1": 1.0, "G2": 2.0, "G3": 3.0})
# TG1: G1 + G2 → 竞争
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
# TG2: 只有 G3 → 单账号,跳过
_make_key(db_session, u1.id, "G3", "K3", "V3", w.id, "A3", imported_target_group_id="TG2")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
sync_account_priorities_for_upstream(db_session, u1.id)
updated_ids = {c[0] for c in update_calls}
assert "A3" not in updated_ids, "单账号分组不应被更新"
assert updated_ids == {"A1", "A2"}
def test_priority_sync_log_structure(db_session, monkeypatch):
"""日志写入格式验证(需要竞争分组才会写日志)。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w, u1])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1)
_make_snapshot(db_session, u1.id, {"G1": 1.0, "G2": 2.0})
# 同一目标分组 2 账号 → 有日志
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
class LoggingMockClient:
def __init__(self, **kwargs): pass
def __enter__(self): return self
def __exit__(self, *a): pass
def update_account(self, account_id, data): pass
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", LoggingMockClient)
sync_account_priorities_for_upstream(db_session, u1.id)
log = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first()
assert log is not None
assert log.algorithm == "priority_sync"
data = json.loads(log.source_rates_json)
# 第一项是 priority_map 元数据
assert data[0]["_meta"] == "priority_map"
# 后续项是账号结果
account_ids = {item["account_id"] for item in data[1:] if "account_id" in item}
assert {"A1", "A2"} == account_ids
# ── 保留:build_rate_priority_map 单元测试(供初始导入使用)────────────────
def test_priority_sync_cross_upstream_group(db_session):
"""build_rate_priority_map:相同 group_id 不同上游不互相覆盖。"""
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([u1, u2])
db_session.commit()
db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"VIP": 1.0})
_make_snapshot(db_session, u2.id, {"VIP": 2.0})
priority_map = build_rate_priority_map(db_session, {u1.id, u2.id})
assert priority_map[f"{u1.id}:VIP"] == 1
assert priority_map[f"{u2.id}:VIP"] == 2
assert len(priority_map) == 2
def test_import_auto_priority_by_rate(db_session, monkeypatch):
"""初始导入时 auto_priority_by_rate=True 按全局倍率分配 priority。"""
from app.routers.websites import import_upstream_keys_as_accounts
from app.schemas.website import ImportAccountsRequest
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}",
groups_endpoint="/groups", group_update_endpoint="/groups/{id}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"G1": 2.0})
_make_snapshot(db_session, u2.id, {"G2": 1.0})
k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", group_name="G1",
key_name="K1", key_value="V1")
k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", group_name="G2",
key_name="K2", key_value="V2")
db_session.add_all([k1, k2])
db_session.commit()
created_accounts = []
class MockImportClient:
def __init__(self, **kwargs): pass
def __enter__(self): return self
def __exit__(self, *a): pass
def create_account(self, body):
created_accounts.append(body)
return {"id": f"remote-{len(created_accounts)}", "name": body["name"]}
def extract_id(self, data): return data["id"]
def account_exists(self, aid): return False
monkeypatch.setattr("app.routers.websites._client", lambda website: MockImportClient())
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockImportClient)
req = ImportAccountsRequest(
upstream_key_ids=[k1.id, k2.id],
target_group_map={},
auto_priority_by_rate=True,
priority=10,
account_name_prefix="test",
default_platform="openai",
)
import_upstream_keys_as_accounts(w.id, req, db_session)
assert len(created_accounts) == 2
# G2 rate=1.0 → priority 1G1 rate=2.0 → priority 2
p1 = next(a["priority"] for a in created_accounts if "G1" in a["name"])
p2 = next(a["priority"] for a in created_accounts if "G2" in a["name"])
assert p2 == 1
assert p1 == 2