diff --git a/backend/app/database.py b/backend/app/database.py index ae18562..73d151d 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -150,6 +150,10 @@ def _migrate_upstream_generated_keys(): conn.execute(text("UPDATE upstream_generated_keys SET updated_at = created_at WHERE updated_at IS NULL")) if "managed_prefix" not in columns: conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN managed_prefix VARCHAR(64)")) + if "imported_target_group_id" not in columns: + conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_target_group_id VARCHAR(255)")) + if "imported_target_group_name" not in columns: + conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_target_group_name VARCHAR(255)")) # ——— 历史数据迁移:回填 managed_prefix + 清理重复 ——— with engine.begin() as conn: diff --git a/backend/app/models/upstream_key.py b/backend/app/models/upstream_key.py index 2d414a2..6e47cb9 100644 --- a/backend/app/models/upstream_key.py +++ b/backend/app/models/upstream_key.py @@ -25,6 +25,8 @@ class UpstreamGeneratedKey(Base): imported_website_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("websites.id", ondelete="SET NULL"), nullable=True, index=True) imported_account_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) imported_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + imported_target_group_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + imported_target_group_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc)) updated_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) diff --git a/backend/app/routers/websites.py b/backend/app/routers/websites.py index f186fe6..de5234e 100644 --- a/backend/app/routers/websites.py +++ b/backend/app/routers/websites.py @@ -558,16 +558,6 @@ def import_upstream_keys_as_accounts( old_account_id = row.imported_account_id exists = c.account_exists(row.imported_account_id) if exists is True: - # 自动更新已有账号的 priority(分步导入时全局倍率排序可能已变) - new_priority = rate_priority_map.get(f"{row.upstream_id}:{row.group_id}") if body.auto_priority_by_rate else None - priority_msg = "已导入过,已跳过" - if new_priority is not None: - try: - c.update_account(old_account_id, {"priority": new_priority}) - priority_msg = f"已导入过,优先级已更新为 {new_priority}" - except Exception as exc: - logger.warning("update priority failed account=%s: %s", old_account_id, exc) - priority_msg = f"已导入过,优先级更新失败: {exc}" items.append(ImportAccountItem( upstream_key_id=row.id, source_group_id=row.group_id, @@ -578,7 +568,7 @@ def import_upstream_keys_as_accounts( platform=platform, upstream_base_url=upstream_base_url, status="exists", - message=priority_msg, + message="已导入过,已跳过", )) continue elif exists is False: @@ -639,6 +629,8 @@ def import_upstream_keys_as_accounts( row.imported_website_id = wid row.imported_account_id = account_id or None row.imported_at = datetime.now(timezone.utc) + row.imported_target_group_id = target_group_id or None + row.imported_target_group_name = None # target_group_map 只存 ID,name 展示用可留 NULL row.status = "imported" row.error = None db.commit() diff --git a/backend/app/services/website_sync.py b/backend/app/services/website_sync.py index b066610..a02320f 100644 --- a/backend/app/services/website_sync.py +++ b/backend/app/services/website_sync.py @@ -288,12 +288,13 @@ def _try_send_priority_webhook( def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[dict]: """上游倍率变化后,自动更新已导入下游账号的 priority。 - 查询该上游下所有已导入(非 orphaned)的 Key,按目标网站分组后重新计算全局优先级, - 并通过 update_account API 推送到下游网站。返回详细结果列表。 - - 同时写入 WebsiteSyncLog 持久化审计日志,并通过 webhook 发送通知。 + 只处理同一目标分组内有多个账号(存在竞争)的情况: + - 竞争分组键:imported_target_group_id(老数据 fallback 到 group_id) + - 同一竞争分组内按倍率升序排序,priority 从 1 开始(相同倍率共享) + - 单账号分组:完全跳过,不调用 update_account,不发通知 + - 无竞争分组:直接返回,不写日志,不发通知 """ - from app.services.website_client import Sub2ApiWebsiteClient as Client + from collections import defaultdict key_rows = ( db.query(UpstreamGeneratedKey) @@ -314,9 +315,7 @@ def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[ website_groups: dict[int, list[UpstreamGeneratedKey]] = {} for row in key_rows: wid = row.imported_website_id - if wid not in website_groups: - website_groups[wid] = [] - website_groups[wid].append(row) + website_groups.setdefault(wid, []).append(row) all_results: list[dict] = [] @@ -324,16 +323,10 @@ def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[ website = db.query(Website).filter(Website.id == wid).first() if not website or not website.enabled: logger.info("skip account priority sync: website %s not found or disabled", wid) - site_results = [] - for row in rows: - r = _priority_result(row, None, "failed", "网站不可用") - site_results.append(r) - all_results.append(r) - _write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {}) - _try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results) + # 不写日志、不发通知:网站不可用时账号无法更新,沉默跳过 continue - # 查询该网站所有已导入 Key(跨上游),实现全局优先级排序 + # 查询该网站所有已导入 Key(跨上游),用于倍率查询 all_website_keys = ( db.query(UpstreamGeneratedKey) .filter( @@ -343,81 +336,114 @@ def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[ ) .all() ) + + # ── 按竞争分组分桶 ────────────────────────────────────────────────── + # 竞争分组键:imported_target_group_id(老数据为 NULL 时 fallback 到 group_id) + buckets: dict[str, list[UpstreamGeneratedKey]] = defaultdict(list) + for row in all_website_keys: + comp_key = row.imported_target_group_id or row.group_id + buckets[comp_key].append(row) + + # 只保留账号数 > 1 的分组(有竞争才需要排序) + competitive_buckets = {k: v for k, v in buckets.items() if len(v) > 1} + + if not competitive_buckets: + logger.info( + "skip account priority sync for website %s: no competitive groups (all single-account)", + wid, + ) + continue # 不写日志,不发通知 + + # ── 预取快照倍率 ──────────────────────────────────────────────────── all_upstream_ids = {k.upstream_id for k in all_website_keys} try: - priority_map = build_rate_priority_map(db, all_upstream_ids) + # 构建 "{upstream_id}:{group_id}" → rate 查询表 + raw_rate_map: dict[str, float] = {} + for uid in all_upstream_ids: + groups = latest_rate_map(db, uid) + for gid, g in groups.items(): + if isinstance(g, dict): + raw_rate_map[f"{uid}:{gid}"] = _snapshot_group_rate(g) except Exception as exc: - logger.warning("build_rate_priority_map failed for website %s: %s", wid, exc) - site_results = [] - for row in all_website_keys: - r = _priority_result(row, None, "failed", f"构建优先级映射失败: {exc}") - site_results.append(r) - all_results.append(r) - _write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {}) - _try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results) - continue + logger.warning("build rate map failed for website %s: %s", wid, exc) + raw_rate_map = {} - if not priority_map: - logger.info("skip account priority sync for website %s: empty priority map", wid) - site_results = [] - for row in all_website_keys: - r = _priority_result(row, None, "skipped", "无上游倍率数据") - site_results.append(r) - all_results.append(r) - _write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {}) - _try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results) - continue + # ── 每个竞争分组内独立计算 priority ──────────────────────────────── + # priority_assignment: account_id → new_priority + priority_assignment: dict[str, int] = {} + for comp_key, comp_rows in competitive_buckets.items(): + # 取每行的倍率(查不到则 fallback 1.0) + rated = [ + (row, raw_rate_map.get(f"{row.upstream_id}:{row.group_id}", 1.0)) + for row in comp_rows + ] + # 组内按倍率升序排序(倍率低 → priority 小 → 优先) + unique_rates = sorted(set(r for _, r in rated)) + rate_to_prio = {rate: idx + 1 for idx, rate in enumerate(unique_rates)} + for row, rate in rated: + priority_assignment[row.imported_account_id] = rate_to_prio[rate] + # ── 调用 update_account(仅竞争分组的账号)─────────────────────── site_results: list[dict] = [] try: - with Client( + with Sub2ApiWebsiteClient( base_url=website.base_url, api_prefix=website.api_prefix, auth_type=website.auth_type, auth_config=json.loads(website.auth_config_json or "{}"), timeout=float(website.timeout_seconds), ) as client: - for row in all_website_keys: - account_id = row.imported_account_id - if not account_id: - continue - new_priority = priority_map.get(f"{row.upstream_id}:{row.group_id}") - if new_priority is None: - site_results.append( - _priority_result(row, None, "skipped", "无倍率数据,跳过") - ) - continue - try: - client.update_account(account_id, {"priority": new_priority}) - logger.info( - "updated priority for account %s (website=%s, upstream=%s, group=%s): %s", - account_id, wid, row.upstream_id, row.group_id, new_priority, - ) - site_results.append( - _priority_result(row, new_priority, "success", f"优先级已更新为 {new_priority}") - ) - except Exception as exc: - logger.warning( - "failed to update priority for account %s (website=%s): %s", - account_id, wid, exc, - ) - site_results.append( - _priority_result(row, new_priority, "failed", str(exc)) - ) + for comp_rows in competitive_buckets.values(): + for row in comp_rows: + account_id = row.imported_account_id + new_priority = priority_assignment.get(account_id) + if new_priority is None: + continue + try: + client.update_account(account_id, {"priority": new_priority}) + logger.info( + "updated priority for account %s (website=%s, upstream=%s, group=%s" + ", comp_group=%s): %s", + account_id, wid, row.upstream_id, row.group_id, + row.imported_target_group_id or row.group_id, new_priority, + ) + site_results.append( + _priority_result(row, new_priority, "success", f"优先级已更新为 {new_priority}") + ) + except Exception as exc: + logger.warning( + "failed to update priority for account %s (website=%s): %s", + account_id, wid, exc, + ) + site_results.append( + _priority_result(row, new_priority, "failed", str(exc)) + ) except Exception as exc: logger.warning("failed to connect website %s for account priority sync: %s", wid, exc) - for row in all_website_keys: - site_results.append( - _priority_result(row, None, "failed", f"连接网站失败: {exc}") - ) + for comp_rows in competitive_buckets.values(): + for row in comp_rows: + site_results.append( + _priority_result(row, None, "failed", f"连接网站失败: {exc}") + ) + + # 只发送有实际成功/失败的通知(不包含单账号跳过项) + notify_updates = [r for r in site_results if r["status"] in ("success", "failed")] + if notify_updates: + # 构建简化的 priority_map 供日志参考(只包含竞争分组的账号) + priority_map_snapshot = { + f"{row.upstream_id}:{row.group_id}": priority_assignment.get(row.imported_account_id) + for comp_rows in competitive_buckets.values() + for row in comp_rows + } + _write_priority_sync_log_with_map(db, wid, upstream_name, site_results, priority_map_snapshot) + _try_send_priority_webhook(db, wid, website.name, upstream_id, upstream_name, notify_updates) all_results.extend(site_results) - _write_priority_sync_log_with_map(db, wid, upstream_name, site_results, priority_map) - _try_send_priority_webhook(db, wid, website.name, upstream_id, upstream_name, site_results) return all_results + def _fetch_remote_managed_prefixes(db: Session, upstream_id: int) -> list[str]: """查询本地 distinct managed_prefix。 diff --git a/backend/test_priority_sync.py b/backend/test_priority_sync.py index 833a573..931133f 100644 --- a/backend/test_priority_sync.py +++ b/backend/test_priority_sync.py @@ -1,3 +1,14 @@ +""" +优先级同步测试套件 — 分组内竞争逻辑 + +核心规则(测试覆盖): +- 竞争分组键 = imported_target_group_id or group_id(老数据 fallback) +- 只有同一竞争分组内账号数 > 1 时才更新 priority / 发通知 +- 不同分组各 1 个账号:不调用 update_account,不发通知 +- 同一目标分组多账号:组内按倍率升序独立排序,priority 从 1 开始 +- 两个目标分组各有多账号:彼此独立,每组内 priority 都从 1 开始 +- 老数据 imported_target_group_id=NULL:fallback group_id,不报错 +""" import json from datetime import datetime, timezone import pytest @@ -12,10 +23,11 @@ from app.models.upstream_key import UpstreamGeneratedKey from app.models.snapshot import UpstreamRateSnapshot from app.services.website_sync import ( build_rate_priority_map, - sync_account_priorities_for_upstream + sync_account_priorities_for_upstream, ) from app.services.website_client import Sub2ApiWebsiteClient + @pytest.fixture() def db_session(): engine = create_engine( @@ -32,28 +44,293 @@ def db_session(): db.close() Base.metadata.drop_all(bind=engine) + +# ── 辅助 ───────────────────────────────────────────────────────────────────── + +def _make_snapshot(db, upstream_id: int, groups: dict): + """插入快照 groups = {group_id: rate}。""" + snap = { + "groups": { + gid: {"group_name": gid, "rate": rate} + for gid, rate in groups.items() + } + } + db.add(UpstreamRateSnapshot( + upstream_id=upstream_id, + snapshot_json=json.dumps(snap), + captured_at=datetime.now(timezone.utc), + )) + db.commit() + + +def _make_key(db, upstream_id, group_id, key_name, key_value, + website_id, account_id, imported_target_group_id=None): + k = UpstreamGeneratedKey( + upstream_id=upstream_id, + group_id=group_id, + group_name=group_id, + key_name=key_name, + key_value=key_value, + imported_website_id=website_id, + imported_account_id=account_id, + imported_target_group_id=imported_target_group_id, + ) + db.add(k) + db.commit() + return k + + +class MockClient: + """可注入的 update_account 记录器。""" + _calls: list # 由子类/工厂绑定 + + def __init__(self, **kwargs): + pass # calls 由工厂方法注入 + + def __enter__(self): return self + def __exit__(self, *a): pass + + def update_account(self, account_id, data): + type(self)._shared_calls.append((account_id, data)) + + +def make_mock_client(calls: list): + """返回一个与 calls 列表绑定的 MockClient 类。""" + class _Client(MockClient): + _shared_calls = calls + return _Client + + +# ── 用例 ───────────────────────────────────────────────────────────────────── + +def test_no_update_when_different_groups_single_account_each(db_session, monkeypatch): + """不同分组各 1 个账号 → 无竞争 → 不调用 update_account,不发通知。""" + w = Website(name="W1", base_url="http://w1", enabled=True, + auth_config_json="{}", timeout_seconds=30) + u1 = Upstream(name="U1", base_url="http://u1") + u2 = Upstream(name="U2", base_url="http://u2") + db_session.add_all([w, u1, u2]) + db_session.commit() + db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2) + + _make_snapshot(db_session, u1.id, {"G1": 1.0}) + _make_snapshot(db_session, u2.id, {"G2": 2.0}) + + # 不同分组,各 1 个账号,imported_target_group_id=None (fallback group_id) + _make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1") + _make_key(db_session, u2.id, "G2", "K2", "V2", w.id, "A2") + + update_calls = [] + monkeypatch.setattr( + "app.services.website_sync.Sub2ApiWebsiteClient", + make_mock_client(update_calls), + ) + + results = sync_account_priorities_for_upstream(db_session, u1.id) + + # 无竞争分组 → 直接返回空列表,不调用 update_account + assert update_calls == [], f"不应有更新调用,实际:{update_calls}" + assert results == [] + + # 不写日志 + logs = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).all() + assert logs == [] + + +def test_same_target_group_two_accounts_updated(db_session, monkeypatch): + """同一目标分组 2 个账号 → 有竞争 → 按倍率更新 priority。""" + w = Website(name="W1", base_url="http://w1", enabled=True, + auth_config_json="{}", timeout_seconds=30) + u1 = Upstream(name="U1", base_url="http://u1") + u2 = Upstream(name="U2", base_url="http://u2") + db_session.add_all([w, u1, u2]) + db_session.commit() + db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2) + + _make_snapshot(db_session, u1.id, {"G1": 1.0}) + _make_snapshot(db_session, u2.id, {"G2": 2.0}) + + # 两个账号都属于目标分组 "TG1" + _make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1") + _make_key(db_session, u2.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1") + + update_calls = [] + monkeypatch.setattr( + "app.services.website_sync.Sub2ApiWebsiteClient", + make_mock_client(update_calls), + ) + + results = sync_account_priorities_for_upstream(db_session, u1.id) + + assert len(update_calls) == 2 + priority_map = {aid: data["priority"] for aid, data in update_calls} + # G1 rate=1.0 → priority=1(低倍率优先);G2 rate=2.0 → priority=2 + assert priority_map["A1"] == 1 + assert priority_map["A2"] == 2 + + # 写了日志 + log = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first() + assert log is not None + assert log.algorithm == "priority_sync" + + +def test_two_target_groups_independent_priority(db_session, monkeypatch): + """两个目标分组各有多账号 → 每组内独立从 1 开始排序,不互相影响。""" + w = Website(name="W1", base_url="http://w1", enabled=True, + auth_config_json="{}", timeout_seconds=30) + u1 = Upstream(name="U1", base_url="http://u1") + db_session.add_all([w, u1]) + db_session.commit() + db_session.refresh(w); db_session.refresh(u1) + + _make_snapshot(db_session, u1.id, { + "G1": 1.0, # → TG1 中排 priority=1 + "G2": 2.0, # → TG1 中排 priority=2 + "G3": 0.5, # → TG2 中排 priority=1 + "G4": 3.0, # → TG2 中排 priority=2 + }) + + # TG1: G1(1.0), G2(2.0) + _make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1") + _make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1") + # TG2: G3(0.5), G4(3.0) + _make_key(db_session, u1.id, "G3", "K3", "V3", w.id, "A3", imported_target_group_id="TG2") + _make_key(db_session, u1.id, "G4", "K4", "V4", w.id, "A4", imported_target_group_id="TG2") + + update_calls = [] + monkeypatch.setattr( + "app.services.website_sync.Sub2ApiWebsiteClient", + make_mock_client(update_calls), + ) + + sync_account_priorities_for_upstream(db_session, u1.id) + + priority_map = {aid: data["priority"] for aid, data in update_calls} + assert len(update_calls) == 4 + + # TG1 内部:G1(1.0)→p1, G2(2.0)→p2 + assert priority_map["A1"] == 1 + assert priority_map["A2"] == 2 + # TG2 内部:G3(0.5)→p1, G4(3.0)→p2(独立从 1 开始) + assert priority_map["A3"] == 1 + assert priority_map["A4"] == 2 + + +def test_old_data_null_target_group_fallback(db_session, monkeypatch): + """老数据 imported_target_group_id=NULL → fallback group_id,不报错。 + + 两个账号同 group_id(极端边界),视为同一竞争分组 → 有竞争 → 更新。 + """ + w = Website(name="W1", base_url="http://w1", enabled=True, + auth_config_json="{}", timeout_seconds=30) + u1 = Upstream(name="U1", base_url="http://u1") + u2 = Upstream(name="U2", base_url="http://u2") + db_session.add_all([w, u1, u2]) + db_session.commit() + db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2) + + _make_snapshot(db_session, u1.id, {"GSHARED": 1.0}) + _make_snapshot(db_session, u2.id, {"GSHARED": 2.0}) + + # 老数据:imported_target_group_id=None,两账号 group_id 相同 + _make_key(db_session, u1.id, "GSHARED", "K1", "V1", w.id, "A1", imported_target_group_id=None) + _make_key(db_session, u2.id, "GSHARED", "K2", "V2", w.id, "A2", imported_target_group_id=None) + + update_calls = [] + monkeypatch.setattr( + "app.services.website_sync.Sub2ApiWebsiteClient", + make_mock_client(update_calls), + ) + + # 不应抛异常 + results = sync_account_priorities_for_upstream(db_session, u1.id) + + # 同 group_id="GSHARED" → 竞争分组,两账号都更新 + assert len(update_calls) == 2 + + +def test_single_account_in_mixed_website(db_session, monkeypatch): + """同一网站:一个目标分组有 2 账号(竞争),另一个目标分组只有 1 账号(不参与)。 + + 只有竞争分组的 2 个账号被更新,单账号分组不调用 update_account。 + """ + w = Website(name="W1", base_url="http://w1", enabled=True, + auth_config_json="{}", timeout_seconds=30) + u1 = Upstream(name="U1", base_url="http://u1") + db_session.add_all([w, u1]) + db_session.commit() + db_session.refresh(w); db_session.refresh(u1) + + _make_snapshot(db_session, u1.id, {"G1": 1.0, "G2": 2.0, "G3": 3.0}) + + # TG1: G1 + G2 → 竞争 + _make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1") + _make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1") + # TG2: 只有 G3 → 单账号,跳过 + _make_key(db_session, u1.id, "G3", "K3", "V3", w.id, "A3", imported_target_group_id="TG2") + + update_calls = [] + monkeypatch.setattr( + "app.services.website_sync.Sub2ApiWebsiteClient", + make_mock_client(update_calls), + ) + + sync_account_priorities_for_upstream(db_session, u1.id) + + updated_ids = {c[0] for c in update_calls} + assert "A3" not in updated_ids, "单账号分组不应被更新" + assert updated_ids == {"A1", "A2"} + + +def test_priority_sync_log_structure(db_session, monkeypatch): + """日志写入格式验证(需要竞争分组才会写日志)。""" + w = Website(name="W1", base_url="http://w1", enabled=True, + auth_config_json="{}", timeout_seconds=30) + u1 = Upstream(name="U1", base_url="http://u1") + db_session.add_all([w, u1]) + db_session.commit() + db_session.refresh(w); db_session.refresh(u1) + + _make_snapshot(db_session, u1.id, {"G1": 1.0, "G2": 2.0}) + + # 同一目标分组 2 账号 → 有日志 + _make_key(db_session, u1.id, "G1", "K1", "V1", w.id, "A1", imported_target_group_id="TG1") + _make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id="TG1") + + class LoggingMockClient: + def __init__(self, **kwargs): pass + def __enter__(self): return self + def __exit__(self, *a): pass + def update_account(self, account_id, data): pass + + monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", LoggingMockClient) + + sync_account_priorities_for_upstream(db_session, u1.id) + + log = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first() + assert log is not None + assert log.algorithm == "priority_sync" + + data = json.loads(log.source_rates_json) + # 第一项是 priority_map 元数据 + assert data[0]["_meta"] == "priority_map" + # 后续项是账号结果 + account_ids = {item["account_id"] for item in data[1:] if "account_id" in item} + assert {"A1", "A2"} == account_ids + + +# ── 保留:build_rate_priority_map 单元测试(供初始导入使用)──────────────── + def test_priority_sync_cross_upstream_group(db_session): - # Setup 2 upstreams + """build_rate_priority_map:相同 group_id 不同上游不互相覆盖。""" u1 = Upstream(name="U1", base_url="http://u1") u2 = Upstream(name="U2", base_url="http://u2") db_session.add_all([u1, u2]) db_session.commit() - db_session.refresh(u1) - db_session.refresh(u2) + db_session.refresh(u1); db_session.refresh(u2) - # Setup snapshots for both with same group ID "VIP" but different rates - s1 = UpstreamRateSnapshot( - upstream_id=u1.id, - snapshot_json=json.dumps({"groups": {"VIP": {"group_name": "VIP", "rate": 1.0}}}), - captured_at=datetime.now(timezone.utc) - ) - s2 = UpstreamRateSnapshot( - upstream_id=u2.id, - snapshot_json=json.dumps({"groups": {"VIP": {"group_name": "VIP", "rate": 2.0}}}), - captured_at=datetime.now(timezone.utc) - ) - db_session.add_all([s1, s2]) - db_session.commit() + _make_snapshot(db_session, u1.id, {"VIP": 1.0}) + _make_snapshot(db_session, u2.id, {"VIP": 2.0}) priority_map = build_rate_priority_map(db_session, {u1.id, u2.id}) @@ -61,145 +338,44 @@ def test_priority_sync_cross_upstream_group(db_session): assert priority_map[f"{u2.id}:VIP"] == 2 assert len(priority_map) == 2 -def test_priority_sync_full_website_update(db_session, monkeypatch): - # Setup website and upstreams - w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30) - u1 = Upstream(name="U1", base_url="http://u1") - u2 = Upstream(name="U2", base_url="http://u2") - db_session.add_all([w, u1, u2]) - db_session.commit() - db_session.refresh(w) - db_session.refresh(u1) - db_session.refresh(u2) - - # Setup snapshots - db_session.add(UpstreamRateSnapshot( - upstream_id=u1.id, - snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}}}), - captured_at=datetime.now(timezone.utc) - )) - db_session.add(UpstreamRateSnapshot( - upstream_id=u2.id, - snapshot_json=json.dumps({"groups": {"G2": {"rate": 2.0}}}), - captured_at=datetime.now(timezone.utc) - )) - db_session.commit() - - # Setup keys imported to website - k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", key_name="K1", key_value="V1", - imported_website_id=w.id, imported_account_id="A1") - k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", key_name="K2", key_value="V2", - imported_website_id=w.id, imported_account_id="A2") - db_session.add_all([k1, k2]) - db_session.commit() - - # Mock Sub2ApiWebsiteClient - update_calls = [] - class MockClient: - def __init__(self, **kwargs): pass - def __enter__(self): return self - def __exit__(self, *args): pass - def update_account(self, account_id, data): - update_calls.append((account_id, data)) - - monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient) - monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient) - - # Trigger sync for U1 - sync_account_priorities_for_upstream(db_session, u1.id) - - # Verify BOTH A1 and A2 were updated because they belong to the same website - assert len(update_calls) == 2 - account_ids = {c[0] for c in update_calls} - assert account_ids == {"A1", "A2"} - - # Priority check: G1(1.0) -> 1, G2(2.0) -> 2 - for aid, data in update_calls: - if aid == "A1": assert data["priority"] == 1 - if aid == "A2": assert data["priority"] == 2 - -def test_priority_sync_log_structure(db_session, monkeypatch): - w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30) - u1 = Upstream(name="U1", base_url="http://u1") - db_session.add_all([w, u1]) - db_session.commit() - db_session.refresh(w) - db_session.refresh(u1) - - db_session.add(UpstreamRateSnapshot( - upstream_id=u1.id, - snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}}}), - captured_at=datetime.now(timezone.utc) - )) - db_session.add(UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", key_name="K1", key_value="V1", - imported_website_id=w.id, imported_account_id="A1")) - db_session.commit() - - class MockClient: - def __init__(self, **kwargs): pass - def __enter__(self): return self - def __exit__(self, *args): pass - def update_account(self, account_id, data): pass - - monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient) - monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient) - - sync_account_priorities_for_upstream(db_session, u1.id) - - log = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first() - assert log is not None - assert log.algorithm == "priority_sync" - - data = json.loads(log.source_rates_json) - # The first item should be the priority map metadata - assert data[0]["_meta"] == "priority_map" - assert f"{u1.id}:G1" in data[0]["data"] - # The second item should be the account result - assert data[1]["account_id"] == "A1" def test_import_auto_priority_by_rate(db_session, monkeypatch): + """初始导入时 auto_priority_by_rate=True 按全局倍率分配 priority。""" from app.routers.websites import import_upstream_keys_as_accounts from app.schemas.website import ImportAccountsRequest - w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", + w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", groups_endpoint="/groups", group_update_endpoint="/groups/{id}", timeout_seconds=30) u1 = Upstream(name="U1", base_url="http://u1") u2 = Upstream(name="U2", base_url="http://u2") db_session.add_all([w, u1, u2]) db_session.commit() - db_session.refresh(w) - db_session.refresh(u1) - db_session.refresh(u2) + db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2) - db_session.add(UpstreamRateSnapshot( - upstream_id=u1.id, - snapshot_json=json.dumps({"groups": {"G1": {"rate": 2.0}}}), - captured_at=datetime.now(timezone.utc) - )) - db_session.add(UpstreamRateSnapshot( - upstream_id=u2.id, - snapshot_json=json.dumps({"groups": {"G2": {"rate": 1.0}}}), - captured_at=datetime.now(timezone.utc) - )) + _make_snapshot(db_session, u1.id, {"G1": 2.0}) + _make_snapshot(db_session, u2.id, {"G2": 1.0}) - k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", group_name="G1", key_name="K1", key_value="V1") - k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", group_name="G2", key_name="K2", key_value="V2") + k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", group_name="G1", + key_name="K1", key_value="V1") + k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", group_name="G2", + key_name="K2", key_value="V2") db_session.add_all([k1, k2]) db_session.commit() created_accounts = [] - class MockClient: + + class MockImportClient: def __init__(self, **kwargs): pass def __enter__(self): return self - def __exit__(self, *args): pass + def __exit__(self, *a): pass def create_account(self, body): created_accounts.append(body) return {"id": f"remote-{len(created_accounts)}", "name": body["name"]} def extract_id(self, data): return data["id"] def account_exists(self, aid): return False - monkeypatch.setattr("app.routers.websites._client", lambda website: MockClient()) - monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient) + monkeypatch.setattr("app.routers.websites._client", lambda website: MockImportClient()) + monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockImportClient) req = ImportAccountsRequest( upstream_key_ids=[k1.id, k2.id], @@ -207,14 +383,13 @@ def test_import_auto_priority_by_rate(db_session, monkeypatch): auto_priority_by_rate=True, priority=10, account_name_prefix="test", - default_platform="openai" + default_platform="openai", ) import_upstream_keys_as_accounts(w.id, req, db_session) assert len(created_accounts) == 2 - # G2 has rate 1.0 -> priority 1 - # G1 has rate 2.0 -> priority 2 + # G2 rate=1.0 → priority 1;G1 rate=2.0 → priority 2 p1 = next(a["priority"] for a in created_accounts if "G1" in a["name"]) p2 = next(a["priority"] for a in created_accounts if "G2" in a["name"]) assert p2 == 1