fix: backfill account groups before priority reorder
This commit is contained in:
@@ -250,8 +250,8 @@ class Sub2ApiWebsiteClient:
|
||||
return v
|
||||
return None
|
||||
|
||||
def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None:
|
||||
"""拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。"""
|
||||
def list_accounts(self, endpoint: str = "/accounts") -> list[dict[str, Any]] | None:
|
||||
"""拉取远端账号列表。成功返回账号 dict 列表,失败返回 None。"""
|
||||
try:
|
||||
resp = self._request("GET", endpoint)
|
||||
except Exception:
|
||||
@@ -261,6 +261,13 @@ class Sub2ApiWebsiteClient:
|
||||
if items is None:
|
||||
logger.warning("account list unexpected format for %s", endpoint)
|
||||
return None
|
||||
return [item for item in items if isinstance(item, dict)]
|
||||
|
||||
def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None:
|
||||
"""拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。"""
|
||||
items = self.list_accounts(endpoint)
|
||||
if items is None:
|
||||
return None
|
||||
ids: set[str] = set()
|
||||
for item in items:
|
||||
item_id = self.extract_id(item)
|
||||
|
||||
@@ -11,7 +11,7 @@ from app.models.snapshot import UpstreamRateSnapshot
|
||||
from app.models.upstream import Upstream
|
||||
from app.models.upstream_key import UpstreamGeneratedKey
|
||||
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
|
||||
from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string
|
||||
from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, _extract_id, calculate_target_rate, decimal_string
|
||||
from app.services.upstream_client import UpstreamClient
|
||||
from app.services import webhook_service
|
||||
|
||||
@@ -225,6 +225,94 @@ def _priority_result(row, new_priority: int | None, status: str, message: str) -
|
||||
}
|
||||
|
||||
|
||||
def _remote_account_single_group(account: dict[str, Any]) -> tuple[str | None, str | None]:
|
||||
"""Return the only remote group id/name if the account belongs to exactly one group."""
|
||||
group_ids: list[str] = []
|
||||
group_names: dict[str, str] = {}
|
||||
|
||||
raw_group_ids = account.get("group_ids")
|
||||
if isinstance(raw_group_ids, list):
|
||||
for item in raw_group_ids:
|
||||
group_id = item.get("id") if isinstance(item, dict) else item
|
||||
if group_id is not None:
|
||||
group_ids.append(str(group_id))
|
||||
elif raw_group_ids is not None:
|
||||
group_ids.append(str(raw_group_ids))
|
||||
|
||||
raw_groups = account.get("groups")
|
||||
if isinstance(raw_groups, list):
|
||||
for group in raw_groups:
|
||||
if not isinstance(group, dict):
|
||||
continue
|
||||
group_id = group.get("id") or group.get("group_id")
|
||||
if group_id is None:
|
||||
continue
|
||||
gid = str(group_id)
|
||||
group_ids.append(gid)
|
||||
if group.get("name") or group.get("group_name"):
|
||||
group_names[gid] = str(group.get("name") or group.get("group_name"))
|
||||
|
||||
unique_ids = list(dict.fromkeys(gid for gid in group_ids if gid))
|
||||
if len(unique_ids) != 1:
|
||||
return None, None
|
||||
gid = unique_ids[0]
|
||||
return gid, group_names.get(gid)
|
||||
|
||||
|
||||
def _backfill_missing_target_groups_from_remote_accounts(
|
||||
db: Session,
|
||||
website: Website,
|
||||
rows: list[UpstreamGeneratedKey],
|
||||
) -> int:
|
||||
"""Backfill old imported rows whose target group is missing from the remote account list."""
|
||||
missing_rows = [
|
||||
row for row in rows
|
||||
if row.imported_account_id and not row.imported_target_group_id
|
||||
]
|
||||
if not missing_rows:
|
||||
return 0
|
||||
|
||||
try:
|
||||
with Sub2ApiWebsiteClient(
|
||||
base_url=website.base_url,
|
||||
api_prefix=website.api_prefix,
|
||||
auth_type=website.auth_type,
|
||||
auth_config=json.loads(website.auth_config_json or "{}"),
|
||||
timeout=float(website.timeout_seconds),
|
||||
) as client:
|
||||
accounts = client.list_accounts() or []
|
||||
except Exception as exc:
|
||||
logger.info("skip imported target group backfill for website %s: %s", website.id, exc)
|
||||
return 0
|
||||
|
||||
account_map: dict[str, dict[str, Any]] = {}
|
||||
for account in accounts:
|
||||
account_id = _extract_id(account)
|
||||
if account_id:
|
||||
account_map[str(account_id)] = account
|
||||
|
||||
changed = 0
|
||||
for row in missing_rows:
|
||||
account = account_map.get(str(row.imported_account_id))
|
||||
if not account:
|
||||
continue
|
||||
group_id, group_name = _remote_account_single_group(account)
|
||||
if not group_id:
|
||||
logger.info(
|
||||
"skip imported target group backfill for account %s: remote account has zero or multiple groups",
|
||||
row.imported_account_id,
|
||||
)
|
||||
continue
|
||||
row.imported_target_group_id = group_id
|
||||
row.imported_target_group_name = group_name
|
||||
changed += 1
|
||||
|
||||
if changed:
|
||||
db.commit()
|
||||
logger.info("backfilled %d imported target group(s) for website %s", changed, website.id)
|
||||
return changed
|
||||
|
||||
|
||||
def _write_priority_sync_log_with_map(
|
||||
db: Session, wid: int, upstream_name: str,
|
||||
results: list[dict], priority_map: dict[str, int],
|
||||
@@ -348,6 +436,7 @@ def sync_account_priorities_for_upstream(
|
||||
)
|
||||
.all()
|
||||
)
|
||||
_backfill_missing_target_groups_from_remote_accounts(db, website, all_website_keys)
|
||||
|
||||
# ── 按竞争分组分桶 ──────────────────────────────────────────────────
|
||||
# 竞争分组键:imported_target_group_id(老数据为 NULL 时 fallback 到 group_id)
|
||||
|
||||
@@ -249,6 +249,50 @@ def test_old_data_null_target_group_fallback(db_session, monkeypatch):
|
||||
assert len(update_calls) == 2
|
||||
|
||||
|
||||
def test_null_target_group_backfills_from_remote_account_group(db_session, monkeypatch):
|
||||
"""老导入数据 target_group_id=NULL,但远端账号有单一 group_ids → 回填后参与同组重排。"""
|
||||
w = Website(name="W1", base_url="http://w1", enabled=True,
|
||||
auth_config_json="{}", timeout_seconds=30)
|
||||
u1 = Upstream(name="U1", base_url="http://u1")
|
||||
u2 = Upstream(name="U2", base_url="http://u2")
|
||||
u3 = Upstream(name="U3", base_url="http://u3")
|
||||
db_session.add_all([w, u1, u2, u3])
|
||||
db_session.commit()
|
||||
db_session.refresh(w); db_session.refresh(u1); db_session.refresh(u2); db_session.refresh(u3)
|
||||
|
||||
_make_snapshot(db_session, u1.id, {"G2": 0.06})
|
||||
_make_snapshot(db_session, u2.id, {"G19": 0.07})
|
||||
_make_snapshot(db_session, u3.id, {"G21": 0.05})
|
||||
|
||||
legacy = _make_key(db_session, u1.id, "G2", "K2", "V2", w.id, "A2", imported_target_group_id=None)
|
||||
_make_key(db_session, u2.id, "G19", "K19", "V19", w.id, "A19", imported_target_group_id="TG1")
|
||||
_make_key(db_session, u3.id, "G21", "K21", "V21", w.id, "A21", imported_target_group_id="TG1")
|
||||
|
||||
update_calls = []
|
||||
|
||||
class AccountListMockClient:
|
||||
def __init__(self, **kwargs): pass
|
||||
def __enter__(self): return self
|
||||
def __exit__(self, *a): pass
|
||||
def list_accounts(self):
|
||||
return [
|
||||
{"id": "A2", "group_ids": ["TG1"]},
|
||||
{"id": "A19", "group_ids": ["TG1"]},
|
||||
{"id": "A21", "group_ids": ["TG1"]},
|
||||
]
|
||||
def update_account(self, account_id, data):
|
||||
update_calls.append((account_id, data))
|
||||
|
||||
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", AccountListMockClient)
|
||||
|
||||
sync_account_priorities_for_upstream(db_session, u1.id, website_id=w.id)
|
||||
|
||||
updated = {account_id: payload["priority"] for account_id, payload in update_calls}
|
||||
assert updated == {"A21": 1, "A2": 11, "A19": 21}
|
||||
db_session.refresh(legacy)
|
||||
assert legacy.imported_target_group_id == "TG1"
|
||||
|
||||
|
||||
def test_single_account_in_mixed_website(db_session, monkeypatch):
|
||||
"""同一网站:一个目标分组有 2 账号(竞争),另一个目标分组只有 1 账号(不参与)。
|
||||
|
||||
|
||||
Reference in New Issue
Block a user