Files
SmartUp/backend/app/services/website_sync.py
T
liumangmang e519d1804b fix(priority-sync): narrow account priority update to competitive groups only
Root cause: sync_account_priorities_for_upstream() was doing a global
priority re-rank across ALL imported accounts on a website whenever any
upstream rate changed, triggering spurious account_priority_changed
notifications for accounts in different target groups with no competition.

Fix:
- Add imported_target_group_id / imported_target_group_name to
  UpstreamGeneratedKey (nullable; old data falls back to group_id)
- Writ imported_target_group_id on account import in websites.py
- Rewrite sync_account_priorities_for_upstream():
  * bucket accounts by competition_group = imported_target_group_id or group_id
  * only process buckets with count > 1 (genuine competition)
  * each competitive bucket independently sorted by rate; priority starts at 1
  * single-account groups: completely skipped (no update_account, no notification)
  * no competitive groups at all: early return, no log, no notification
- Remove auto priority update in re-import idempotency path (was also
  incorrect; now fully delegated to sync_account_priorities_for_upstream)
- Fix Sub2ApiWebsiteClient local import in sync fn → use module-level name
  so monkeypatch works correctly in tests

Tests: rewrite test_priority_sync.py
- REMOVED: test_priority_sync_full_website_update (was asserting the buggy behavior)
- NEW: test_no_update_when_different_groups_single_account_each
- NEW: test_same_target_group_two_accounts_updated
- NEW: test_two_target_groups_independent_priority
- NEW: test_old_data_null_target_group_fallback
- NEW: test_single_account_in_mixed_website
- UPDATED: test_priority_sync_log_structure (now requires competitive group)
- KEPT: test_priority_sync_cross_upstream_group, test_import_auto_priority_by_rate

All 25 tests pass (8 priority_sync + 17 existing upstream tests).
2026-06-01 19:13:14 +08:00

587 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import logging
from decimal import Decimal
from typing import Any
from sqlalchemy.orm import Session
from app.models.snapshot import UpstreamRateSnapshot
from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string
from app.services.upstream_client import UpstreamClient
from app.services import webhook_service
logger = logging.getLogger(__name__)
def binding_sources(binding: WebsiteGroupBinding) -> list[dict[str, Any]]:
try:
data = json.loads(binding.source_groups_json or "[]")
except Exception:
return []
return data if isinstance(data, list) else []
def latest_rate_map(db: Session, upstream_id: int) -> dict[str, Any]:
row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.first()
)
if not row:
return {}
snapshot = json.loads(row.snapshot_json or "{}")
groups = snapshot.get("groups") or {}
return groups if isinstance(groups, dict) else {}
def get_affected_bindings(db: Session, changes: list[dict[str, Any]], upstream_id: int) -> list[WebsiteGroupBinding]:
changed_ids = {str(change.get("group_id")) for change in changes if change.get("group_id") is not None}
if not changed_ids:
return []
result: list[WebsiteGroupBinding] = []
bindings = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.enabled == True).all()
for binding in bindings:
for source in binding_sources(binding):
if int(source.get("upstream_id") or 0) == upstream_id and str(source.get("group_id")) in changed_ids:
result.append(binding)
break
return result
def _client_for(website: Website) -> Sub2ApiWebsiteClient:
return Sub2ApiWebsiteClient(
base_url=website.base_url,
api_prefix=website.api_prefix,
auth_type=website.auth_type,
auth_config=json.loads(website.auth_config_json or "{}"),
timeout=float(website.timeout_seconds),
)
def _log(
db: Session,
binding: WebsiteGroupBinding,
website: Website,
source_rates: list[dict[str, Any]],
status: str,
message: str,
old_rate: Any = None,
new_rate: Any = None,
) -> WebsiteSyncLog:
row = WebsiteSyncLog(
website_id=website.id,
binding_id=binding.id,
target_group_id=binding.target_group_id,
target_group_name=binding.target_group_name,
algorithm=binding.algorithm,
percent=binding.percent,
source_rates_json=json.dumps(source_rates, ensure_ascii=False),
old_rate=decimal_string(old_rate) if old_rate not in (None, "") else None,
new_rate=decimal_string(new_rate) if new_rate not in (None, "") else None,
status=status,
message=message,
)
db.add(row)
db.commit()
db.refresh(row)
return row
def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True) -> WebsiteSyncLog:
website = db.query(Website).filter(Website.id == binding.website_id).first()
if not website:
raise WebsiteError("网站不存在")
sources = binding_sources(binding)
# ── 批量预查:收集所有上游 ID,一次查询上游名称 ──
upstream_ids = {int(s.get("upstream_id") or 0) for s in sources if s.get("upstream_id")}
upstreams = {}
if upstream_ids:
rows = db.query(Upstream).filter(Upstream.id.in_(upstream_ids)).all()
upstreams = {u.id: u for u in rows}
# ── 同一轮 sync 内的快照缓存(调用级,函数返回即释放)──
_snap_cache: dict[int, dict[str, Any]] = {}
def _get_snap(upstream_id: int) -> dict[str, Any]:
if upstream_id not in _snap_cache:
_snap_cache[upstream_id] = latest_rate_map(db, upstream_id)
return _snap_cache[upstream_id]
source_rates: list[dict[str, Any]] = []
for source in sources:
upstream_id = int(source.get("upstream_id") or 0)
group_id = str(source.get("group_id") or "")
groups = _get_snap(upstream_id)
group = groups.get(group_id) if group_id else None
upstream = upstreams.get(upstream_id)
source_rates.append({
"upstream_id": upstream_id,
"upstream_name": source.get("upstream_name") or (upstream.name if upstream else ""),
"group_id": group_id,
"group_name": source.get("group_name") or (group.get("group_name", "") if isinstance(group, dict) else ""),
"rate": group.get("rate") if isinstance(group, dict) else None,
})
try:
target_rate = calculate_target_rate([item.get("rate") for item in source_rates], binding.percent, binding.algorithm)
except Exception as exc:
return _log(db, binding, website, source_rates, "failed", str(exc))
old_rate = None
if write and website.enabled and website.auto_sync_enabled and binding.enabled:
try:
with _client_for(website) as client:
groups = client.get_groups(website.groups_endpoint)
target = next((item for item in groups if item.get("id") == binding.target_group_id), None)
old_rate = target.get("rate_multiplier") if target else None
client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate)
website.last_status = "healthy"
website.last_error = None
except Exception as exc:
website.last_status = "unhealthy"
website.last_error = str(exc)
db.commit()
return _log(db, binding, website, source_rates, "failed", f"写回失败:{exc}", old_rate, target_rate)
db.commit()
log = _log(db, binding, website, source_rates, "success", "同步成功", old_rate, target_rate)
old_rate_str = decimal_string(old_rate) if old_rate not in (None, "") else None
new_rate_str = decimal_string(target_rate)
if old_rate_str != new_rate_str:
webhook_service.send_website_rate_changed(
db,
website.id,
website.name,
website.base_url,
binding.id,
binding.target_group_id,
binding.target_group_name,
old_rate_str,
new_rate_str,
source_rates,
)
return log
message = "已计算建议倍率,未写回"
if not website.enabled or not website.auto_sync_enabled:
message = "网站未启用自动同步,未写回"
elif not binding.enabled:
message = "绑定未启用,未写回"
return _log(db, binding, website, source_rates, "success", message, old_rate, target_rate)
def _snapshot_group_rate(group: dict) -> float:
"""从快照分组数据中提取倍率(兼容多个字段名)。"""
raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1
try:
return float(raw)
except (TypeError, ValueError):
return 1.0
def build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, int]:
"""根据上游分组倍率构建 f"{upstream_id}:{group_id}" → priority 映射。
使用 (upstream_id, group_id) 复合键避免不同上游的同名分组互相覆盖。
遍历所有涉及的上游的最新快照,收集分组的倍率,按倍率升序排列后赋值 priority。
倍率最低的 priority=1,次低的 priority=2,以此类推。相同倍率的分组共享同一 priority。
"""
group_rates: dict[str, float] = {}
for uid in upstream_ids:
groups = latest_rate_map(db, uid)
for gid, g in groups.items():
if not isinstance(g, dict):
continue
rate = _snapshot_group_rate(g)
key = f"{uid}:{gid}"
group_rates[key] = rate
unique_rates = sorted(set(group_rates.values()))
rate_to_priority = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
return {key: rate_to_priority[rate] for key, rate in group_rates.items()}
def _priority_result(row, new_priority: int | None, status: str, message: str) -> dict:
"""构建统一的优先级同步结果 dict。"""
return {
"account_id": row.imported_account_id,
"group_id": row.group_id,
"upstream_id": row.upstream_id,
"old_priority": None,
"new_priority": new_priority,
"status": status,
"message": message,
}
def _write_priority_sync_log_with_map(
db: Session, wid: int, upstream_name: str,
results: list[dict], priority_map: dict[str, int],
) -> None:
"""写入 priority_sync 日志,同时保存账号明细和 priority_map 快照。
source_rates_json 格式:[{"_meta": "priority_map", "data": {...}}, {"account_id": ..., ...}, ...]
兼容 WebsiteSyncLogResponse.source_rates: list[dict] 类型约束。
"""
log_results: list[dict] = [
{"_meta": "priority_map", "data": dict(priority_map)},
]
log_results.extend(results)
success = sum(1 for r in results if r["status"] == "success")
failed = sum(1 for r in results if r["status"] == "failed")
skipped = sum(1 for r in results if r["status"] == "skipped")
parts = []
if success:
parts.append(f"{success} 个更新成功")
if failed:
parts.append(f"{failed} 个失败")
if skipped:
parts.append(f"{skipped} 个跳过")
log = WebsiteSyncLog(
website_id=wid,
binding_id=None,
target_group_id="",
target_group_name="",
algorithm="priority_sync",
percent=0,
source_rates_json=json.dumps(log_results, ensure_ascii=False, default=str),
old_rate=None,
new_rate=None,
status="failed" if failed else "success",
message=f"优先级同步(上游={upstream_name}):{''.join(parts)} / 共 {len(results)}",
)
db.add(log)
db.commit()
def _try_send_priority_webhook(
db: Session, wid: int, website_name: str,
upstream_id: int, upstream_name: str,
updates: list[dict],
) -> None:
"""发送 account_priority_changed webhook,失败不抛异常。"""
if not updates:
return
# 如果没传入名称,尝试从 DB 查
resolved_name = website_name
if not resolved_name:
row = db.query(Website.name).filter(Website.id == wid).first()
if row:
resolved_name = row[0]
else:
resolved_name = f"网站#{wid}"
try:
webhook_service.send_account_priority_changed(
db,
website_id=wid,
website_name=resolved_name,
upstream_id=upstream_id,
upstream_name=upstream_name,
updates=updates,
)
except Exception as exc:
logger.warning("account_priority_changed webhook failed for website %s: %s", wid, exc)
def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[dict]:
"""上游倍率变化后,自动更新已导入下游账号的 priority。
只处理同一目标分组内有多个账号(存在竞争)的情况:
- 竞争分组键:imported_target_group_id(老数据 fallback 到 group_id
- 同一竞争分组内按倍率升序排序,priority 从 1 开始(相同倍率共享)
- 单账号分组:完全跳过,不调用 update_account,不发通知
- 无竞争分组:直接返回,不写日志,不发通知
"""
from collections import defaultdict
key_rows = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == upstream_id,
UpstreamGeneratedKey.imported_website_id.isnot(None),
UpstreamGeneratedKey.imported_account_id.isnot(None),
UpstreamGeneratedKey.status != "orphaned",
)
.all()
)
if not key_rows:
return []
upstream_name = db.query(Upstream.name).filter(Upstream.id == upstream_id).scalar() or f"#{upstream_id}"
# 按 imported_website_id 分组
website_groups: dict[int, list[UpstreamGeneratedKey]] = {}
for row in key_rows:
wid = row.imported_website_id
website_groups.setdefault(wid, []).append(row)
all_results: list[dict] = []
for wid, rows in website_groups.items():
website = db.query(Website).filter(Website.id == wid).first()
if not website or not website.enabled:
logger.info("skip account priority sync: website %s not found or disabled", wid)
# 不写日志、不发通知:网站不可用时账号无法更新,沉默跳过
continue
# 查询该网站所有已导入 Key(跨上游),用于倍率查询
all_website_keys = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.imported_website_id == wid,
UpstreamGeneratedKey.imported_account_id.isnot(None),
UpstreamGeneratedKey.status != "orphaned",
)
.all()
)
# ── 按竞争分组分桶 ──────────────────────────────────────────────────
# 竞争分组键:imported_target_group_id(老数据为 NULL 时 fallback 到 group_id
buckets: dict[str, list[UpstreamGeneratedKey]] = defaultdict(list)
for row in all_website_keys:
comp_key = row.imported_target_group_id or row.group_id
buckets[comp_key].append(row)
# 只保留账号数 > 1 的分组(有竞争才需要排序)
competitive_buckets = {k: v for k, v in buckets.items() if len(v) > 1}
if not competitive_buckets:
logger.info(
"skip account priority sync for website %s: no competitive groups (all single-account)",
wid,
)
continue # 不写日志,不发通知
# ── 预取快照倍率 ────────────────────────────────────────────────────
all_upstream_ids = {k.upstream_id for k in all_website_keys}
try:
# 构建 "{upstream_id}:{group_id}" → rate 查询表
raw_rate_map: dict[str, float] = {}
for uid in all_upstream_ids:
groups = latest_rate_map(db, uid)
for gid, g in groups.items():
if isinstance(g, dict):
raw_rate_map[f"{uid}:{gid}"] = _snapshot_group_rate(g)
except Exception as exc:
logger.warning("build rate map failed for website %s: %s", wid, exc)
raw_rate_map = {}
# ── 每个竞争分组内独立计算 priority ────────────────────────────────
# priority_assignment: account_id → new_priority
priority_assignment: dict[str, int] = {}
for comp_key, comp_rows in competitive_buckets.items():
# 取每行的倍率(查不到则 fallback 1.0
rated = [
(row, raw_rate_map.get(f"{row.upstream_id}:{row.group_id}", 1.0))
for row in comp_rows
]
# 组内按倍率升序排序(倍率低 → priority 小 → 优先)
unique_rates = sorted(set(r for _, r in rated))
rate_to_prio = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
for row, rate in rated:
priority_assignment[row.imported_account_id] = rate_to_prio[rate]
# ── 调用 update_account(仅竞争分组的账号)───────────────────────
site_results: list[dict] = []
try:
with Sub2ApiWebsiteClient(
base_url=website.base_url,
api_prefix=website.api_prefix,
auth_type=website.auth_type,
auth_config=json.loads(website.auth_config_json or "{}"),
timeout=float(website.timeout_seconds),
) as client:
for comp_rows in competitive_buckets.values():
for row in comp_rows:
account_id = row.imported_account_id
new_priority = priority_assignment.get(account_id)
if new_priority is None:
continue
try:
client.update_account(account_id, {"priority": new_priority})
logger.info(
"updated priority for account %s (website=%s, upstream=%s, group=%s"
", comp_group=%s): %s",
account_id, wid, row.upstream_id, row.group_id,
row.imported_target_group_id or row.group_id, new_priority,
)
site_results.append(
_priority_result(row, new_priority, "success", f"优先级已更新为 {new_priority}")
)
except Exception as exc:
logger.warning(
"failed to update priority for account %s (website=%s): %s",
account_id, wid, exc,
)
site_results.append(
_priority_result(row, new_priority, "failed", str(exc))
)
except Exception as exc:
logger.warning("failed to connect website %s for account priority sync: %s", wid, exc)
for comp_rows in competitive_buckets.values():
for row in comp_rows:
site_results.append(
_priority_result(row, None, "failed", f"连接网站失败: {exc}")
)
# 只发送有实际成功/失败的通知(不包含单账号跳过项)
notify_updates = [r for r in site_results if r["status"] in ("success", "failed")]
if notify_updates:
# 构建简化的 priority_map 供日志参考(只包含竞争分组的账号)
priority_map_snapshot = {
f"{row.upstream_id}:{row.group_id}": priority_assignment.get(row.imported_account_id)
for comp_rows in competitive_buckets.values()
for row in comp_rows
}
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, priority_map_snapshot)
_try_send_priority_webhook(db, wid, website.name, upstream_id, upstream_name, notify_updates)
all_results.extend(site_results)
return all_results
def _fetch_remote_managed_prefixes(db: Session, upstream_id: int) -> list[str]:
"""查询本地 distinct managed_prefix。
返回该上游所有已使用的 prefix 列表。空时回退 ["SmartUp"] 兼容旧数据。
"""
prefixes = [
row[0] for row in
db.query(UpstreamGeneratedKey.managed_prefix)
.filter(
UpstreamGeneratedKey.upstream_id == upstream_id,
UpstreamGeneratedKey.managed_prefix.isnot(None),
)
.distinct()
.all()
]
return prefixes if prefixes else ["SmartUp"]
def _fetch_remote_managed_key_ids(db: Session, client, upstream_id: int) -> set[str]:
"""查询本地 distinct managed_prefix,分别拉远端活跃 Key ID 集合。
返回全部找到的远端 Key ID(合并多个 prefix 的结果)。
"""
all_ids: set[str] = set()
for prefix in _fetch_remote_managed_prefixes(db, upstream_id):
remote_keys = client.list_api_keys(search=prefix, status="active")
all_ids.update(str(k["id"]) for k in remote_keys if k.get("id"))
return all_ids
def reconcile_upstream_keys(
db: Session,
upstream_id: int,
active_group_ids: set[str] | None,
remote_key_ids: set[str] | None,
captured_at: datetime,
) -> None:
"""对账上游 Key 的本地缓存与远端状态。
active_group_ids=None → 跳过分组级清理(避免登录失败时误删)。
remote_key_ids=None → 跳过远端 key_id 级清理(查询失败时安全)。
两者同时为 None 则完全跳过对账。
"""
if active_group_ids is None and remote_key_ids is None:
return
key_rows = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == upstream_id,
)
.all()
)
for row in key_rows:
if active_group_ids is not None and row.group_id not in active_group_ids:
if row.imported_website_id and row.imported_account_id:
row.status = "orphaned"
row.error = "来源分组已不存在"
row.updated_at = captured_at
logger.info("marked key %s orphaned (group %s no longer in snapshot)", row.id, row.group_id)
else:
db.delete(row)
logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id)
continue
if row.key_id and remote_key_ids is not None and row.key_id not in remote_key_ids:
if row.imported_website_id and row.imported_account_id:
row.status = "orphaned"
row.error = "远端 Key 已不存在"
row.updated_at = captured_at
logger.info("marked key %s orphaned (key_id %s gone from remote)", row.id, row.key_id)
else:
db.delete(row)
logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id)
continue
if remote_key_ids is not None and row.key_id in remote_key_ids:
row.updated_at = captured_at
def reconcile_upstream_keys_full(db: Session, upstream_id: int) -> bool:
"""完整的 Key 对账:拉取最新快照的分组 + 登录上游查远端 Key 列表 → 调用 reconcile_upstream_keys。
活跃分组 ID 从最新快照获取(与调度器一致),而非调用 live API 避免格式不一致。
安全规则:
- 快照存在 → 才允许分组级清理。
- 远端 Key 列表拉取成功 → 才允许 key_id 级清理。
- 两者均失败 → 不做任何删除/标记。
支持自定义 managed_prefix:查询本地 distinct prefix,分别查远端。
"""
from datetime import datetime, timezone
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
if not upstream:
return False
auth_config = json.loads(upstream.auth_config_json or "{}")
groups_fetched = False
keys_fetched = False
active_group_ids: set[str] | None = None
remote_key_ids: set[str] | None = None
now = datetime.now(timezone.utc)
# 从最新快照获取活跃分组 ID(与调度器 _sync_upstream_keys 一致)
groups = latest_rate_map(db, upstream_id)
if groups:
active_group_ids = set(groups.keys())
groups_fetched = True
try:
with UpstreamClient(
base_url=upstream.base_url,
api_prefix=upstream.api_prefix,
auth_type=upstream.auth_type,
auth_config=auth_config,
timeout=float(upstream.timeout_seconds),
) as client:
client.login()
# 获取远端 Key 列表(支持自定义 managed_prefix
remote_key_ids = _fetch_remote_managed_key_ids(db, client, upstream_id)
keys_fetched = True
except Exception as exc:
logger.warning("reconcile_upstream_keys_full: upstream %s failed: %s", upstream_id, exc)
# 只传递成功获取的数据;失败的传 None 跳过对应检查
reconcile_upstream_keys(
db,
upstream_id,
active_group_ids if groups_fetched else None,
remote_key_ids if keys_fetched else None,
now,
)
db.commit()
return keys_fetched
def sync_affected_bindings(db: Session, upstream_id: int, changes: list[dict[str, Any]]) -> None:
for binding in get_affected_bindings(db, changes, upstream_id):
try:
sync_binding(db, binding, write=True)
except Exception as exc:
logger.exception("website sync failed for binding %s: %s", binding.id, exc)