fix: backfill account groups before priority reorder

This commit is contained in:
SmartUp Developer
2026-06-03 18:39:21 +08:00
parent b866b387e0
commit 9600e4ceba
3 changed files with 143 additions and 3 deletions
+9 -2
View File
@@ -250,8 +250,8 @@ class Sub2ApiWebsiteClient:
return v return v
return None return None
def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None: def list_accounts(self, endpoint: str = "/accounts") -> list[dict[str, Any]] | None:
"""拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。""" """拉取远端账号列表。成功返回账号 dict 列表,失败返回 None。"""
try: try:
resp = self._request("GET", endpoint) resp = self._request("GET", endpoint)
except Exception: except Exception:
@@ -261,6 +261,13 @@ class Sub2ApiWebsiteClient:
if items is None: if items is None:
logger.warning("account list unexpected format for %s", endpoint) logger.warning("account list unexpected format for %s", endpoint)
return None return None
return [item for item in items if isinstance(item, dict)]
def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None:
"""拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。"""
items = self.list_accounts(endpoint)
if items is None:
return None
ids: set[str] = set() ids: set[str] = set()
for item in items: for item in items:
item_id = self.extract_id(item) item_id = self.extract_id(item)
+90 -1
View File
@@ -11,7 +11,7 @@ from app.models.snapshot import UpstreamRateSnapshot
from app.models.upstream import Upstream from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey from app.models.upstream_key import UpstreamGeneratedKey
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, _extract_id, calculate_target_rate, decimal_string
from app.services.upstream_client import UpstreamClient from app.services.upstream_client import UpstreamClient
from app.services import webhook_service from app.services import webhook_service
@@ -225,6 +225,94 @@ def _priority_result(row, new_priority: int | None, status: str, message: str) -
} }
def _remote_account_single_group(account: dict[str, Any]) -> tuple[str | None, str | None]:
"""Return the only remote group id/name if the account belongs to exactly one group."""
group_ids: list[str] = []
group_names: dict[str, str] = {}
raw_group_ids = account.get("group_ids")
if isinstance(raw_group_ids, list):
for item in raw_group_ids:
group_id = item.get("id") if isinstance(item, dict) else item
if group_id is not None:
group_ids.append(str(group_id))
elif raw_group_ids is not None:
group_ids.append(str(raw_group_ids))
raw_groups = account.get("groups")
if isinstance(raw_groups, list):
for group in raw_groups:
if not isinstance(group, dict):
continue
group_id = group.get("id") or group.get("group_id")
if group_id is None:
continue
gid = str(group_id)
group_ids.append(gid)
if group.get("name") or group.get("group_name"):
group_names[gid] = str(group.get("name") or group.get("group_name"))
unique_ids = list(dict.fromkeys(gid for gid in group_ids if gid))
if len(unique_ids) != 1:
return None, None
gid = unique_ids[0]
return gid, group_names.get(gid)
def _backfill_missing_target_groups_from_remote_accounts(
db: Session,
website: Website,
rows: list[UpstreamGeneratedKey],
) -> int:
"""Backfill old imported rows whose target group is missing from the remote account list."""
missing_rows = [
row for row in rows
if row.imported_account_id and not row.imported_target_group_id
]
if not missing_rows:
return 0
try:
with Sub2ApiWebsiteClient(
base_url=website.base_url,
api_prefix=website.api_prefix,
auth_type=website.auth_type,
auth_config=json.loads(website.auth_config_json or "{}"),
timeout=float(website.timeout_seconds),
) as client:
accounts = client.list_accounts() or []
except Exception as exc:
logger.info("skip imported target group backfill for website %s: %s", website.id, exc)
return 0
account_map: dict[str, dict[str, Any]] = {}
for account in accounts:
account_id = _extract_id(account)
if account_id:
account_map[str(account_id)] = account
changed = 0
for row in missing_rows:
account = account_map.get(str(row.imported_account_id))
if not account:
continue
group_id, group_name = _remote_account_single_group(account)
if not group_id:
logger.info(
"skip imported target group backfill for account %s: remote account has zero or multiple groups",
row.imported_account_id,
)
continue
row.imported_target_group_id = group_id
row.imported_target_group_name = group_name
changed += 1
if changed:
db.commit()
logger.info("backfilled %d imported target group(s) for website %s", changed, website.id)
return changed
def _write_priority_sync_log_with_map( def _write_priority_sync_log_with_map(
db: Session, wid: int, upstream_name: str, db: Session, wid: int, upstream_name: str,
results: list[dict], priority_map: dict[str, int], results: list[dict], priority_map: dict[str, int],
@@ -348,6 +436,7 @@ def sync_account_priorities_for_upstream(
) )
.all() .all()
) )
_backfill_missing_target_groups_from_remote_accounts(db, website, all_website_keys)
# ── 按竞争分组分桶 ────────────────────────────────────────────────── # ── 按竞争分组分桶 ──────────────────────────────────────────────────
# 竞争分组键:imported_target_group_id(老数据为 NULL 时 fallback 到 group_id # 竞争分组键:imported_target_group_id(老数据为 NULL 时 fallback 到 group_id
+44
View File
@@ -249,6 +249,50 @@ def test_old_data_null_target_group_fallback(db_session, monkeypatch):
assert len(update_calls) == 2 assert len(update_calls) == 2
def test_null_target_group_backfills_from_remote_account_group(db_session, monkeypatch):
"""老导入数据 target_group_id=NULL,但远端账号有单一 group_ids → 回填后参与同组重排。"""
w = Website(name="W1", base_url="http://w1", enabled=True,
auth_config_json="{}", timeout_seconds=30)
u1 = Upstream(name="U1", base_url="http://u1")
u2 = Upstream(name="U2", base_url="http://u2")
u3 = Upstream(name="U3", base_url="http://u3")
db_session.add_all([w, u1, u2, u3])
db_session.commit()
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2); db_session.refresh(u3)
_make_snapshot(db_session, u1.id, {"G2": 0.06})
_make_snapshot(db_session, u2.id, {"G19": 0.07})
_make_snapshot(db_session, u3.id, {"G21": 0.05})
legacy = _make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id=None)
_make_key(db_session, u2.id, "G19", "K19", "V19", w.id, "A19", imported_target_group_id="TG1")
_make_key(db_session, u3.id, "G21", "K21", "V21", w.id, "A21", imported_target_group_id="TG1")
update_calls = []
class AccountListMockClient:
def __init__(self, **kwargs): pass
def __enter__(self): return self
def __exit__(self, *a): pass
def list_accounts(self):
return [
{"id": "A2", "group_ids": ["TG1"]},
{"id": "A19", "group_ids": ["TG1"]},
{"id": "A21", "group_ids": ["TG1"]},
]
def update_account(self, account_id, data):
update_calls.append((account_id, data))
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", AccountListMockClient)
sync_account_priorities_for_upstream(db_session, u1.id, website_id=w.id)
updated = {account_id: payload["priority"] for account_id, payload in update_calls}
assert updated == {"A21": 1, "A2": 11, "A19": 21}
db_session.refresh(legacy)
assert legacy.imported_target_group_id == "TG1"
def test_single_account_in_mixed_website(db_session, monkeypatch): def test_single_account_in_mixed_website(db_session, monkeypatch):
"""同一网站:一个目标分组有 2 账号(竞争),另一个目标分组只有 1 账号(不参与)。 """同一网站:一个目标分组有 2 账号(竞争),另一个目标分组只有 1 账号(不参与)。