fix(priority-sync): narrow account priority update to competitive groups only

Root cause: sync_account_priorities_for_upstream() was doing a global
priority re-rank across ALL imported accounts on a website whenever any
upstream rate changed, triggering spurious account_priority_changed
notifications for accounts in different target groups with no competition.

Fix:
- Add imported_target_group_id / imported_target_group_name to
  UpstreamGeneratedKey (nullable; old data falls back to group_id)
- Writ imported_target_group_id on account import in websites.py
- Rewrite sync_account_priorities_for_upstream():
  * bucket accounts by competition_group = imported_target_group_id or group_id
  * only process buckets with count > 1 (genuine competition)
  * each competitive bucket independently sorted by rate; priority starts at 1
  * single-account groups: completely skipped (no update_account, no notification)
  * no competitive groups at all: early return, no log, no notification
- Remove auto priority update in re-import idempotency path (was also
  incorrect; now fully delegated to sync_account_priorities_for_upstream)
- Fix Sub2ApiWebsiteClient local import in sync fn → use module-level name
  so monkeypatch works correctly in tests

Tests: rewrite test_priority_sync.py
- REMOVED: test_priority_sync_full_website_update (was asserting the buggy behavior)
- NEW: test_no_update_when_different_groups_single_account_each
- NEW: test_same_target_group_two_accounts_updated
- NEW: test_two_target_groups_independent_priority
- NEW: test_old_data_null_target_group_fallback
- NEW: test_single_account_in_mixed_website
- UPDATED: test_priority_sync_log_structure (now requires competitive group)
- KEPT: test_priority_sync_cross_upstream_group, test_import_auto_priority_by_rate

All 25 tests pass (8 priority_sync + 17 existing upstream tests).
This commit is contained in:
liumangmang
2026-06-01 19:13:14 +08:00
parent 871557e4ae
commit e519d1804b
5 changed files with 415 additions and 216 deletions
+4
View File
@@ -150,6 +150,10 @@ def _migrate_upstream_generated_keys():
conn.execute(text("UPDATE upstream_generated_keys SET updated_at = created_at WHERE updated_at IS NULL")) conn.execute(text("UPDATE upstream_generated_keys SET updated_at = created_at WHERE updated_at IS NULL"))
if "managed_prefix" not in columns: if "managed_prefix" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN managed_prefix VARCHAR(64)")) conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN managed_prefix VARCHAR(64)"))
if "imported_target_group_id" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_target_group_id VARCHAR(255)"))
if "imported_target_group_name" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_target_group_name VARCHAR(255)"))
# ——— 历史数据迁移:回填 managed_prefix + 清理重复 ——— # ——— 历史数据迁移:回填 managed_prefix + 清理重复 ———
with engine.begin() as conn: with engine.begin() as conn:
+2
View File
@@ -25,6 +25,8 @@ class UpstreamGeneratedKey(Base):
imported_website_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("websites.id", ondelete="SET NULL"), nullable=True, index=True) imported_website_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("websites.id", ondelete="SET NULL"), nullable=True, index=True)
imported_account_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) imported_account_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
imported_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) imported_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
imported_target_group_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
imported_target_group_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc)) created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) updated_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
+3 -11
View File
@@ -558,16 +558,6 @@ def import_upstream_keys_as_accounts(
old_account_id = row.imported_account_id old_account_id = row.imported_account_id
exists = c.account_exists(row.imported_account_id) exists = c.account_exists(row.imported_account_id)
if exists is True: if exists is True:
# 自动更新已有账号的 priority(分步导入时全局倍率排序可能已变)
new_priority = rate_priority_map.get(f"{row.upstream_id}:{row.group_id}") if body.auto_priority_by_rate else None
priority_msg = "已导入过,已跳过"
if new_priority is not None:
try:
c.update_account(old_account_id, {"priority": new_priority})
priority_msg = f"已导入过,优先级已更新为 {new_priority}"
except Exception as exc:
logger.warning("update priority failed account=%s: %s", old_account_id, exc)
priority_msg = f"已导入过,优先级更新失败: {exc}"
items.append(ImportAccountItem( items.append(ImportAccountItem(
upstream_key_id=row.id, upstream_key_id=row.id,
source_group_id=row.group_id, source_group_id=row.group_id,
@@ -578,7 +568,7 @@ def import_upstream_keys_as_accounts(
platform=platform, platform=platform,
upstream_base_url=upstream_base_url, upstream_base_url=upstream_base_url,
status="exists", status="exists",
message=priority_msg, message="已导入过,已跳过",
)) ))
continue continue
elif exists is False: elif exists is False:
@@ -639,6 +629,8 @@ def import_upstream_keys_as_accounts(
row.imported_website_id = wid row.imported_website_id = wid
row.imported_account_id = account_id or None row.imported_account_id = account_id or None
row.imported_at = datetime.now(timezone.utc) row.imported_at = datetime.now(timezone.utc)
row.imported_target_group_id = target_group_id or None
row.imported_target_group_name = None # target_group_map 只存 IDname 展示用可留 NULL
row.status = "imported" row.status = "imported"
row.error = None row.error = None
db.commit() db.commit()
+75 -49
View File
@@ -288,12 +288,13 @@ def _try_send_priority_webhook(
def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[dict]: def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[dict]:
"""上游倍率变化后,自动更新已导入下游账号的 priority。 """上游倍率变化后,自动更新已导入下游账号的 priority。
查询该上游下所有已导入(非 orphaned)的 Key,按目标网站分组后重新计算全局优先级, 只处理同一目标分组内有多个账号(存在竞争)的情况:
并通过 update_account API 推送到下游网站。返回详细结果列表。 - 竞争分组键:imported_target_group_id(老数据 fallback 到 group_id
- 同一竞争分组内按倍率升序排序,priority 从 1 开始(相同倍率共享)
同时写入 WebsiteSyncLog 持久化审计日志,并通过 webhook 发送通知 - 单账号分组:完全跳过,不调用 update_account,不发通知
- 无竞争分组:直接返回,不写日志,不发通知
""" """
from app.services.website_client import Sub2ApiWebsiteClient as Client from collections import defaultdict
key_rows = ( key_rows = (
db.query(UpstreamGeneratedKey) db.query(UpstreamGeneratedKey)
@@ -314,9 +315,7 @@ def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[
website_groups: dict[int, list[UpstreamGeneratedKey]] = {} website_groups: dict[int, list[UpstreamGeneratedKey]] = {}
for row in key_rows: for row in key_rows:
wid = row.imported_website_id wid = row.imported_website_id
if wid not in website_groups: website_groups.setdefault(wid, []).append(row)
website_groups[wid] = []
website_groups[wid].append(row)
all_results: list[dict] = [] all_results: list[dict] = []
@@ -324,16 +323,10 @@ def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[
website = db.query(Website).filter(Website.id == wid).first() website = db.query(Website).filter(Website.id == wid).first()
if not website or not website.enabled: if not website or not website.enabled:
logger.info("skip account priority sync: website %s not found or disabled", wid) logger.info("skip account priority sync: website %s not found or disabled", wid)
site_results = [] # 不写日志、不发通知:网站不可用时账号无法更新,沉默跳过
for row in rows:
r = _priority_result(row, None, "failed", "网站不可用")
site_results.append(r)
all_results.append(r)
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {})
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results)
continue continue
# 查询该网站所有已导入 Key(跨上游),实现全局优先级排序 # 查询该网站所有已导入 Key(跨上游),用于倍率查询
all_website_keys = ( all_website_keys = (
db.query(UpstreamGeneratedKey) db.query(UpstreamGeneratedKey)
.filter( .filter(
@@ -343,55 +336,76 @@ def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[
) )
.all() .all()
) )
# ── 按竞争分组分桶 ──────────────────────────────────────────────────
# 竞争分组键:imported_target_group_id(老数据为 NULL 时 fallback 到 group_id
buckets: dict[str, list[UpstreamGeneratedKey]] = defaultdict(list)
for row in all_website_keys:
comp_key = row.imported_target_group_id or row.group_id
buckets[comp_key].append(row)
# 只保留账号数 > 1 的分组(有竞争才需要排序)
competitive_buckets = {k: v for k, v in buckets.items() if len(v) > 1}
if not competitive_buckets:
logger.info(
"skip account priority sync for website %s: no competitive groups (all single-account)",
wid,
)
continue # 不写日志,不发通知
# ── 预取快照倍率 ────────────────────────────────────────────────────
all_upstream_ids = {k.upstream_id for k in all_website_keys} all_upstream_ids = {k.upstream_id for k in all_website_keys}
try: try:
priority_map = build_rate_priority_map(db, all_upstream_ids) # 构建 "{upstream_id}:{group_id}" → rate 查询表
raw_rate_map: dict[str, float] = {}
for uid in all_upstream_ids:
groups = latest_rate_map(db, uid)
for gid, g in groups.items():
if isinstance(g, dict):
raw_rate_map[f"{uid}:{gid}"] = _snapshot_group_rate(g)
except Exception as exc: except Exception as exc:
logger.warning("build_rate_priority_map failed for website %s: %s", wid, exc) logger.warning("build rate map failed for website %s: %s", wid, exc)
site_results = [] raw_rate_map = {}
for row in all_website_keys:
r = _priority_result(row, None, "failed", f"构建优先级映射失败: {exc}")
site_results.append(r)
all_results.append(r)
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {})
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results)
continue
if not priority_map: # ── 每个竞争分组内独立计算 priority ────────────────────────────────
logger.info("skip account priority sync for website %s: empty priority map", wid) # priority_assignment: account_id → new_priority
site_results = [] priority_assignment: dict[str, int] = {}
for row in all_website_keys: for comp_key, comp_rows in competitive_buckets.items():
r = _priority_result(row, None, "skipped", "无上游倍率数据") # 取每行的倍率(查不到则 fallback 1.0
site_results.append(r) rated = [
all_results.append(r) (row, raw_rate_map.get(f"{row.upstream_id}:{row.group_id}", 1.0))
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {}) for row in comp_rows
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results) ]
continue # 组内按倍率升序排序(倍率低 → priority 小 → 优先)
unique_rates = sorted(set(r for _, r in rated))
rate_to_prio = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
for row, rate in rated:
priority_assignment[row.imported_account_id] = rate_to_prio[rate]
# ── 调用 update_account(仅竞争分组的账号)───────────────────────
site_results: list[dict] = [] site_results: list[dict] = []
try: try:
with Client( with Sub2ApiWebsiteClient(
base_url=website.base_url, base_url=website.base_url,
api_prefix=website.api_prefix, api_prefix=website.api_prefix,
auth_type=website.auth_type, auth_type=website.auth_type,
auth_config=json.loads(website.auth_config_json or "{}"), auth_config=json.loads(website.auth_config_json or "{}"),
timeout=float(website.timeout_seconds), timeout=float(website.timeout_seconds),
) as client: ) as client:
for row in all_website_keys: for comp_rows in competitive_buckets.values():
for row in comp_rows:
account_id = row.imported_account_id account_id = row.imported_account_id
if not account_id: new_priority = priority_assignment.get(account_id)
continue
new_priority = priority_map.get(f"{row.upstream_id}:{row.group_id}")
if new_priority is None: if new_priority is None:
site_results.append(
_priority_result(row, None, "skipped", "无倍率数据,跳过")
)
continue continue
try: try:
client.update_account(account_id, {"priority": new_priority}) client.update_account(account_id, {"priority": new_priority})
logger.info( logger.info(
"updated priority for account %s (website=%s, upstream=%s, group=%s): %s", "updated priority for account %s (website=%s, upstream=%s, group=%s"
account_id, wid, row.upstream_id, row.group_id, new_priority, ", comp_group=%s): %s",
account_id, wid, row.upstream_id, row.group_id,
row.imported_target_group_id or row.group_id, new_priority,
) )
site_results.append( site_results.append(
_priority_result(row, new_priority, "success", f"优先级已更新为 {new_priority}") _priority_result(row, new_priority, "success", f"优先级已更新为 {new_priority}")
@@ -406,18 +420,30 @@ def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[
) )
except Exception as exc: except Exception as exc:
logger.warning("failed to connect website %s for account priority sync: %s", wid, exc) logger.warning("failed to connect website %s for account priority sync: %s", wid, exc)
for row in all_website_keys: for comp_rows in competitive_buckets.values():
for row in comp_rows:
site_results.append( site_results.append(
_priority_result(row, None, "failed", f"连接网站失败: {exc}") _priority_result(row, None, "failed", f"连接网站失败: {exc}")
) )
# 只发送有实际成功/失败的通知(不包含单账号跳过项)
notify_updates = [r for r in site_results if r["status"] in ("success", "failed")]
if notify_updates:
# 构建简化的 priority_map 供日志参考(只包含竞争分组的账号)
priority_map_snapshot = {
f"{row.upstream_id}:{row.group_id}": priority_assignment.get(row.imported_account_id)
for comp_rows in competitive_buckets.values()
for row in comp_rows
}
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, priority_map_snapshot)
_try_send_priority_webhook(db, wid, website.name, upstream_id, upstream_name, notify_updates)
all_results.extend(site_results) all_results.extend(site_results)
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, priority_map)
_try_send_priority_webhook(db, wid, website.name, upstream_id, upstream_name, site_results)
return all_results return all_results
def _fetch_remote_managed_prefixes(db: Session, upstream_id: int) -> list[str]: def _fetch_remote_managed_prefixes(db: Session, upstream_id: int) -> list[str]:
"""查询本地 distinct managed_prefix。 """查询本地 distinct managed_prefix。
+284 -109
View File
@@ -1,3 +1,14 @@
"""
优先级同步测试套件 — 分组内竞争逻辑
核心规则(测试覆盖):
- 竞争分组键 = imported_target_group_id or group_id(老数据 fallback
- 只有同一竞争分组内账号数 > 1 时才更新 priority / 发通知
- 不同分组各 1 个账号:不调用 update_account,不发通知
- 同一目标分组多账号:组内按倍率升序独立排序,priority 从 1 开始
- 两个目标分组各有多账号:彼此独立,每组内 priority 都从 1 开始
- 老数据 imported_target_group_id=NULLfallback group_id,不报错
"""
import json import json
from datetime import datetime, timezone from datetime import datetime, timezone
import pytest import pytest
@@ -12,10 +23,11 @@ from app.models.upstream_key import UpstreamGeneratedKey
from app.models.snapshot import UpstreamRateSnapshot from app.models.snapshot import UpstreamRateSnapshot
from app.services.website_sync import ( from app.services.website_sync import (
build_rate_priority_map, build_rate_priority_map,
sync_account_priorities_for_upstream sync_account_priorities_for_upstream,
) )
from app.services.website_client import Sub2ApiWebsiteClient from app.services.website_client import Sub2ApiWebsiteClient
@pytest.fixture() @pytest.fixture()
def db_session(): def db_session():
engine = create_engine( engine = create_engine(
@@ -32,117 +44,266 @@ def db_session():
db.close() db.close()
Base.metadata.drop_all(bind=engine) Base.metadata.drop_all(bind=engine)
def test_priority_sync_cross_upstream_group(db_session):
# Setup 2 upstreams
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([u1, u2])
db_session.commit()
db_session.refresh(u1)
db_session.refresh(u2)
# Setup snapshots for both with same group ID "VIP" but different rates # ── 辅助 ─────────────────────────────────────────────────────────────────────
s1 = UpstreamRateSnapshot(
upstream_id=u1.id, def _make_snapshot(db, upstream_id: int, groups: dict):
snapshot_json=json.dumps({"groups": {"VIP": {"group_name": "VIP", "rate": 1.0}}}), """插入快照 groups = {group_id: rate}。"""
captured_at=datetime.now(timezone.utc) snap = {
"groups": {
gid: {"group_name": gid, "rate": rate}
for gid, rate in groups.items()
}
}
db.add(UpstreamRateSnapshot(
upstream_id=upstream_id,
snapshot_json=json.dumps(snap),
captured_at=datetime.now(timezone.utc),
))
db.commit()
def _make_key(db, upstream_id, group_id, key_name, key_value,
website_id, account_id, imported_target_group_id=None):
k = UpstreamGeneratedKey(
upstream_id=upstream_id,
group_id=group_id,
group_name=group_id,
key_name=key_name,
key_value=key_value,
imported_website_id=website_id,
imported_account_id=account_id,
imported_target_group_id=imported_target_group_id,
) )
s2 = UpstreamRateSnapshot( db.add(k)
upstream_id=u2.id, db.commit()
snapshot_json=json.dumps({"groups": {"VIP": {"group_name": "VIP", "rate": 2.0}}}), return k
captured_at=datetime.now(timezone.utc)
)
db_session.add_all([s1, s2])
db_session.commit()
priority_map = build_rate_priority_map(db_session, {u1.id, u2.id})
assert priority_map[f"{u1.id}:VIP"] == 1 class MockClient:
assert priority_map[f"{u2.id}:VIP"] == 2 """可注入的 update_account 记录器。"""
assert len(priority_map) == 2 _calls: list # 由子类/工厂绑定
def test_priority_sync_full_website_update(db_session, monkeypatch): def __init__(self, **kwargs):
# Setup website and upstreams pass # calls 由工厂方法注入
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
def __enter__(self): return self
def __exit__(self, *a): pass
def update_account(self, account_id, data):
type(self)._shared_calls.append((account_id, data))
def make_mock_client(calls: list):
"""返回一个与 calls 列表绑定的 MockClient 类。"""
class _Client(MockClient):
_shared_calls = calls
return _Client
# ── 用例 ─────────────────────────────────────────────────────────────────────
def test_no_update_when_different_groups_single_account_each(db_session, monkeypatch):
"""不同分组各 1 个账号 → 无竞争 → 不调用 update_account,不发通知。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1") u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2") u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2]) db_session.add_all([w, u1, u2])
db_session.commit() db_session.commit()
db_session.refresh(w) db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
db_session.refresh(u1)
db_session.refresh(u2)
# Setup snapshots _make_snapshot(db_session, u1.id, {"G1": 1.0})
db_session.add(UpstreamRateSnapshot( _make_snapshot(db_session, u2.id, {"G2": 2.0})
upstream_id=u1.id,
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}}}),
captured_at=datetime.now(timezone.utc)
))
db_session.add(UpstreamRateSnapshot(
upstream_id=u2.id,
snapshot_json=json.dumps({"groups": {"G2": {"rate": 2.0}}}),
captured_at=datetime.now(timezone.utc)
))
db_session.commit()
# Setup keys imported to website # 不同分组,各 1 个账号,imported_target_group_id=None (fallback group_id)
k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", key_name="K1", key_value="V1", _make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1")
imported_website_id=w.id, imported_account_id="A1") _make_key(db_session, u2.id, "G2", "K2", "V2", w.id, "A2")
k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", key_name="K2", key_value="V2",
imported_website_id=w.id, imported_account_id="A2")
db_session.add_all([k1, k2])
db_session.commit()
# Mock Sub2ApiWebsiteClient
update_calls = [] update_calls = []
class MockClient: monkeypatch.setattr(
def __init__(self, **kwargs): pass "app.services.website_sync.Sub2ApiWebsiteClient",
def __enter__(self): return self make_mock_client(update_calls),
def __exit__(self, *args): pass )
def update_account(self, account_id, data):
update_calls.append((account_id, data))
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient) results = sync_account_priorities_for_upstream(db_session, u1.id)
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient)
# Trigger sync for U1 # 无竞争分组 → 直接返回空列表,不调用 update_account
sync_account_priorities_for_upstream(db_session, u1.id) assert update_calls == [], f"不应有更新调用,实际:{update_calls}"
assert results == []
# 不写日志
logs = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).all()
assert logs == []
def test_same_target_group_two_accounts_updated(db_session, monkeypatch):
"""同一目标分组 2 个账号 → 有竞争 → 按倍率更新 priority。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"G1": 1.0})
_make_snapshot(db_session, u2.id, {"G2": 2.0})
# 两个账号都属于目标分组 "TG1"
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u2.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
results = sync_account_priorities_for_upstream(db_session, u1.id)
# Verify BOTH A1 and A2 were updated because they belong to the same website
assert len(update_calls) == 2 assert len(update_calls) == 2
account_ids = {c[0] for c in update_calls} priority_map = {aid: data["priority"] for aid, data in update_calls}
assert account_ids == {"A1", "A2"} # G1 rate=1.0 → priority=1(低倍率优先);G2 rate=2.0 → priority=2
assert priority_map["A1"] == 1
assert priority_map["A2"] == 2
# Priority check: G1(1.0) -> 1, G2(2.0) -> 2 # 写了日志
for aid, data in update_calls: log = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first()
if aid == "A1": assert data["priority"] == 1 assert log is not None
if aid == "A2": assert data["priority"] == 2 assert log.algorithm == "priority_sync"
def test_priority_sync_log_structure(db_session, monkeypatch):
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30) def test_two_target_groups_independent_priority(db_session, monkeypatch):
"""两个目标分组各有多账号 → 每组内独立从 1 开始排序,不互相影响。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1") u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w, u1]) db_session.add_all([w, u1])
db_session.commit() db_session.commit()
db_session.refresh(w) db_session.refresh(w); db_session.refresh(u1)
db_session.refresh(u1)
db_session.add(UpstreamRateSnapshot( _make_snapshot(db_session, u1.id, {
upstream_id=u1.id, "G1": 1.0, # → TG1 中排 priority=1
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}}}), "G2": 2.0, # → TG1 中排 priority=2
captured_at=datetime.now(timezone.utc) "G3": 0.5, # → TG2 中排 priority=1
)) "G4": 3.0, # → TG2 中排 priority=2
db_session.add(UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", key_name="K1", key_value="V1", })
imported_website_id=w.id, imported_account_id="A1"))
# TG1: G1(1.0), G2(2.0)
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
# TG2: G3(0.5), G4(3.0)
_make_key(db_session, u1.id, "G3", "K3", "V3", w.id, "A3", imported_target_group_id="TG2")
_make_key(db_session, u1.id, "G4", "K4", "V4", w.id, "A4", imported_target_group_id="TG2")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
sync_account_priorities_for_upstream(db_session, u1.id)
priority_map = {aid: data["priority"] for aid, data in update_calls}
assert len(update_calls) == 4
# TG1 内部:G1(1.0)→p1, G2(2.0)→p2
assert priority_map["A1"] == 1
assert priority_map["A2"] == 2
# TG2 内部:G3(0.5)→p1, G4(3.0)→p2(独立从 1 开始)
assert priority_map["A3"] == 1
assert priority_map["A4"] == 2
def test_old_data_null_target_group_fallback(db_session, monkeypatch):
"""老数据 imported_target_group_id=NULL → fallback group_id,不报错。
两个账号同 group_id(极端边界),视为同一竞争分组 → 有竞争 → 更新。
"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2])
db_session.commit() db_session.commit()
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
class MockClient: _make_snapshot(db_session, u1.id, {"GSHARED": 1.0})
_make_snapshot(db_session, u2.id, {"GSHARED": 2.0})
# 老数据:imported_target_group_id=None,两账号 group_id 相同
_make_key(db_session, u1.id, "GSHARED", "K1", "V1", w.id, "A1", imported_target_group_id=None)
_make_key(db_session, u2.id, "GSHARED", "K2", "V2", w.id, "A2", imported_target_group_id=None)
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
# 不应抛异常
results = sync_account_priorities_for_upstream(db_session, u1.id)
# 同 group_id="GSHARED" → 竞争分组,两账号都更新
assert len(update_calls) == 2
def test_single_account_in_mixed_website(db_session, monkeypatch):
"""同一网站:一个目标分组有 2 账号(竞争),另一个目标分组只有 1 账号(不参与)。
只有竞争分组的 2 个账号被更新,单账号分组不调用 update_account。
"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w, u1])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1)
_make_snapshot(db_session, u1.id, {"G1": 1.0, "G2": 2.0, "G3": 3.0})
# TG1: G1 + G2 → 竞争
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
# TG2: 只有 G3 → 单账号,跳过
_make_key(db_session, u1.id, "G3", "K3", "V3", w.id, "A3", imported_target_group_id="TG2")
update_calls = []
monkeypatch.setattr(
"app.services.website_sync.Sub2ApiWebsiteClient",
make_mock_client(update_calls),
)
sync_account_priorities_for_upstream(db_session, u1.id)
updated_ids = {c[0] for c in update_calls}
assert "A3" not in updated_ids, "单账号分组不应被更新"
assert updated_ids == {"A1", "A2"}
def test_priority_sync_log_structure(db_session, monkeypatch):
"""日志写入格式验证(需要竞争分组才会写日志)。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
db_session.add_all([w, u1])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1)
_make_snapshot(db_session, u1.id, {"G1": 1.0, "G2": 2.0})
# 同一目标分组 2 账号 → 有日志
_make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1")
_make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1")
class LoggingMockClient:
def __init__(self, **kwargs): pass def __init__(self, **kwargs): pass
def __enter__(self): return self def __enter__(self): return self
def __exit__(self, *args): pass def __exit__(self, *a): pass
def update_account(self, account_id, data): pass def update_account(self, account_id, data): pass
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient) monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", LoggingMockClient)
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient)
sync_account_priorities_for_upstream(db_session, u1.id) sync_account_priorities_for_upstream(db_session, u1.id)
@@ -151,13 +312,35 @@ def test_priority_sync_log_structure(db_session, monkeypatch):
assert log.algorithm == "priority_sync" assert log.algorithm == "priority_sync"
data = json.loads(log.source_rates_json) data = json.loads(log.source_rates_json)
# The first item should be the priority map metadata # 第一项是 priority_map 元数据
assert data[0]["_meta"] == "priority_map" assert data[0]["_meta"] == "priority_map"
assert f"{u1.id}:G1" in data[0]["data"] # 后续项是账号结果
# The second item should be the account result account_ids = {item["account_id"] for item in data[1:] if "account_id" in item}
assert data[1]["account_id"] == "A1" assert {"A1", "A2"} == account_ids
# ── 保留:build_rate_priority_map 单元测试(供初始导入使用)────────────────
def test_priority_sync_cross_upstream_group(db_session):
"""build_rate_priority_map:相同 group_id 不同上游不互相覆盖。"""
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([u1, u2])
db_session.commit()
db_session.refresh(u1); db_session.refresh(u2)
_make_snapshot(db_session, u1.id, {"VIP": 1.0})
_make_snapshot(db_session, u2.id, {"VIP": 2.0})
priority_map = build_rate_priority_map(db_session, {u1.id, u2.id})
assert priority_map[f"{u1.id}:VIP"] == 1
assert priority_map[f"{u2.id}:VIP"] == 2
assert len(priority_map) == 2
def test_import_auto_priority_by_rate(db_session, monkeypatch): def test_import_auto_priority_by_rate(db_session, monkeypatch):
"""初始导入时 auto_priority_by_rate=True 按全局倍率分配 priority。"""
from app.routers.websites import import_upstream_keys_as_accounts from app.routers.websites import import_upstream_keys_as_accounts
from app.schemas.website import ImportAccountsRequest from app.schemas.website import ImportAccountsRequest
@@ -167,39 +350,32 @@ def test_import_auto_priority_by_rate(db_session, monkeypatch):
u2 = Upstream(name="U2", base_url="http://u2") u2 = Upstream(name="U2", base_url="http://u2")
db_session.add_all([w, u1, u2]) db_session.add_all([w, u1, u2])
db_session.commit() db_session.commit()
db_session.refresh(w) db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2)
db_session.refresh(u1)
db_session.refresh(u2)
db_session.add(UpstreamRateSnapshot( _make_snapshot(db_session, u1.id, {"G1": 2.0})
upstream_id=u1.id, _make_snapshot(db_session, u2.id, {"G2": 1.0})
snapshot_json=json.dumps({"groups": {"G1": {"rate": 2.0}}}),
captured_at=datetime.now(timezone.utc)
))
db_session.add(UpstreamRateSnapshot(
upstream_id=u2.id,
snapshot_json=json.dumps({"groups": {"G2": {"rate": 1.0}}}),
captured_at=datetime.now(timezone.utc)
))
k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", group_name="G1", key_name="K1", key_value="V1") k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", group_name="G1",
k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", group_name="G2", key_name="K2", key_value="V2") key_name="K1", key_value="V1")
k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", group_name="G2",
key_name="K2", key_value="V2")
db_session.add_all([k1, k2]) db_session.add_all([k1, k2])
db_session.commit() db_session.commit()
created_accounts = [] created_accounts = []
class MockClient:
class MockImportClient:
def __init__(self, **kwargs): pass def __init__(self, **kwargs): pass
def __enter__(self): return self def __enter__(self): return self
def __exit__(self, *args): pass def __exit__(self, *a): pass
def create_account(self, body): def create_account(self, body):
created_accounts.append(body) created_accounts.append(body)
return {"id": f"remote-{len(created_accounts)}", "name": body["name"]} return {"id": f"remote-{len(created_accounts)}", "name": body["name"]}
def extract_id(self, data): return data["id"] def extract_id(self, data): return data["id"]
def account_exists(self, aid): return False def account_exists(self, aid): return False
monkeypatch.setattr("app.routers.websites._client", lambda website: MockClient()) monkeypatch.setattr("app.routers.websites._client", lambda website: MockImportClient())
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient) monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockImportClient)
req = ImportAccountsRequest( req = ImportAccountsRequest(
upstream_key_ids=[k1.id, k2.id], upstream_key_ids=[k1.id, k2.id],
@@ -207,14 +383,13 @@ def test_import_auto_priority_by_rate(db_session, monkeypatch):
auto_priority_by_rate=True, auto_priority_by_rate=True,
priority=10, priority=10,
account_name_prefix="test", account_name_prefix="test",
default_platform="openai" default_platform="openai",
) )
import_upstream_keys_as_accounts(w.id, req, db_session) import_upstream_keys_as_accounts(w.id, req, db_session)
assert len(created_accounts) == 2 assert len(created_accounts) == 2
# G2 has rate 1.0 -> priority 1 # G2 rate=1.0 priority 1G1 rate=2.0 → priority 2
# G1 has rate 2.0 -> priority 2
p1 = next(a["priority"] for a in created_accounts if "G1" in a["name"]) p1 = next(a["priority"] for a in created_accounts if "G1" in a["name"])
p2 = next(a["priority"] for a in created_accounts if "G2" in a["name"]) p2 = next(a["priority"] for a in created_accounts if "G2" in a["name"])
assert p2 == 1 assert p2 == 1