From 9600e4ceba45bc49682a788963415a597bc06ca7 Mon Sep 17 00:00:00 2001 From: SmartUp Developer Date: Wed, 3 Jun 2026 18:39:21 +0800 Subject: [PATCH] fix: backfill account groups before priority reorder --- backend/app/services/website_client.py | 11 +++- backend/app/services/website_sync.py | 91 +++++++++++++++++++++++++- backend/test_priority_sync.py | 44 +++++++++++++ 3 files changed, 143 insertions(+), 3 deletions(-) diff --git a/backend/app/services/website_client.py b/backend/app/services/website_client.py index 9c9cc28..1aaa106 100644 --- a/backend/app/services/website_client.py +++ b/backend/app/services/website_client.py @@ -250,8 +250,8 @@ class Sub2ApiWebsiteClient: return v return None - def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None: - """拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。""" + def list_accounts(self, endpoint: str = "/accounts") -> list[dict[str, Any]] | None: + """拉取远端账号列表。成功返回账号 dict 列表,失败返回 None。""" try: resp = self._request("GET", endpoint) except Exception: @@ -261,6 +261,13 @@ class Sub2ApiWebsiteClient: if items is None: logger.warning("account list unexpected format for %s", endpoint) return None + return [item for item in items if isinstance(item, dict)] + + def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None: + """拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。""" + items = self.list_accounts(endpoint) + if items is None: + return None ids: set[str] = set() for item in items: item_id = self.extract_id(item) diff --git a/backend/app/services/website_sync.py b/backend/app/services/website_sync.py index b955958..4656c9f 100644 --- a/backend/app/services/website_sync.py +++ b/backend/app/services/website_sync.py @@ -11,7 +11,7 @@ from app.models.snapshot import UpstreamRateSnapshot from app.models.upstream import Upstream from app.models.upstream_key import UpstreamGeneratedKey from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog -from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string +from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, _extract_id, calculate_target_rate, decimal_string from app.services.upstream_client import UpstreamClient from app.services import webhook_service @@ -225,6 +225,94 @@ def _priority_result(row, new_priority: int | None, status: str, message: str) - } +def _remote_account_single_group(account: dict[str, Any]) -> tuple[str | None, str | None]: + """Return the only remote group id/name if the account belongs to exactly one group.""" + group_ids: list[str] = [] + group_names: dict[str, str] = {} + + raw_group_ids = account.get("group_ids") + if isinstance(raw_group_ids, list): + for item in raw_group_ids: + group_id = item.get("id") if isinstance(item, dict) else item + if group_id is not None: + group_ids.append(str(group_id)) + elif raw_group_ids is not None: + group_ids.append(str(raw_group_ids)) + + raw_groups = account.get("groups") + if isinstance(raw_groups, list): + for group in raw_groups: + if not isinstance(group, dict): + continue + group_id = group.get("id") or group.get("group_id") + if group_id is None: + continue + gid = str(group_id) + group_ids.append(gid) + if group.get("name") or group.get("group_name"): + group_names[gid] = str(group.get("name") or group.get("group_name")) + + unique_ids = list(dict.fromkeys(gid for gid in group_ids if gid)) + if len(unique_ids) != 1: + return None, None + gid = unique_ids[0] + return gid, group_names.get(gid) + + +def _backfill_missing_target_groups_from_remote_accounts( + db: Session, + website: Website, + rows: list[UpstreamGeneratedKey], +) -> int: + """Backfill old imported rows whose target group is missing from the remote account list.""" + missing_rows = [ + row for row in rows + if row.imported_account_id and not row.imported_target_group_id + ] + if not missing_rows: + return 0 + + try: + with Sub2ApiWebsiteClient( + base_url=website.base_url, + api_prefix=website.api_prefix, + auth_type=website.auth_type, + auth_config=json.loads(website.auth_config_json or "{}"), + timeout=float(website.timeout_seconds), + ) as client: + accounts = client.list_accounts() or [] + except Exception as exc: + logger.info("skip imported target group backfill for website %s: %s", website.id, exc) + return 0 + + account_map: dict[str, dict[str, Any]] = {} + for account in accounts: + account_id = _extract_id(account) + if account_id: + account_map[str(account_id)] = account + + changed = 0 + for row in missing_rows: + account = account_map.get(str(row.imported_account_id)) + if not account: + continue + group_id, group_name = _remote_account_single_group(account) + if not group_id: + logger.info( + "skip imported target group backfill for account %s: remote account has zero or multiple groups", + row.imported_account_id, + ) + continue + row.imported_target_group_id = group_id + row.imported_target_group_name = group_name + changed += 1 + + if changed: + db.commit() + logger.info("backfilled %d imported target group(s) for website %s", changed, website.id) + return changed + + def _write_priority_sync_log_with_map( db: Session, wid: int, upstream_name: str, results: list[dict], priority_map: dict[str, int], @@ -348,6 +436,7 @@ def sync_account_priorities_for_upstream( ) .all() ) + _backfill_missing_target_groups_from_remote_accounts(db, website, all_website_keys) # ── 按竞争分组分桶 ────────────────────────────────────────────────── # 竞争分组键:imported_target_group_id(老数据为 NULL 时 fallback 到 group_id) diff --git a/backend/test_priority_sync.py b/backend/test_priority_sync.py index 0e5097b..490f0d2 100644 --- a/backend/test_priority_sync.py +++ b/backend/test_priority_sync.py @@ -249,6 +249,50 @@ def test_old_data_null_target_group_fallback(db_session, monkeypatch): assert len(update_calls) == 2 +def test_null_target_group_backfills_from_remote_account_group(db_session, monkeypatch): + """老导入数据 target_group_id=NULL,但远端账号有单一 group_ids → 回填后参与同组重排。""" + w = Website(name="W1", base_url="http://w1", enabled=True, + auth_config_json="{}", timeout_seconds=30) + u1 = Upstream(name="U1", base_url="http://u1") + u2 = Upstream(name="U2", base_url="http://u2") + u3 = Upstream(name="U3", base_url="http://u3") + db_session.add_all([w, u1, u2, u3]) + db_session.commit() + db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2); db_session.refresh(u3) + + _make_snapshot(db_session, u1.id, {"G2": 0.06}) + _make_snapshot(db_session, u2.id, {"G19": 0.07}) + _make_snapshot(db_session, u3.id, {"G21": 0.05}) + + legacy = _make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id=None) + _make_key(db_session, u2.id, "G19", "K19", "V19", w.id, "A19", imported_target_group_id="TG1") + _make_key(db_session, u3.id, "G21", "K21", "V21", w.id, "A21", imported_target_group_id="TG1") + + update_calls = [] + + class AccountListMockClient: + def __init__(self, **kwargs): pass + def __enter__(self): return self + def __exit__(self, *a): pass + def list_accounts(self): + return [ + {"id": "A2", "group_ids": ["TG1"]}, + {"id": "A19", "group_ids": ["TG1"]}, + {"id": "A21", "group_ids": ["TG1"]}, + ] + def update_account(self, account_id, data): + update_calls.append((account_id, data)) + + monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", AccountListMockClient) + + sync_account_priorities_for_upstream(db_session, u1.id, website_id=w.id) + + updated = {account_id: payload["priority"] for account_id, payload in update_calls} + assert updated == {"A21": 1, "A2": 11, "A19": 21} + db_session.refresh(legacy) + assert legacy.imported_target_group_id == "TG1" + + def test_single_account_in_mixed_website(db_session, monkeypatch): """同一网站:一个目标分组有 2 账号(竞争),另一个目标分组只有 1 账号(不参与)。