from __future__ import annotations import json import logging from decimal import Decimal from typing import Any from sqlalchemy.orm import Session from app.models.snapshot import UpstreamRateSnapshot from app.models.upstream import Upstream from app.models.upstream_key import UpstreamGeneratedKey from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string from app.services.upstream_client import UpstreamClient from app.services import webhook_service logger = logging.getLogger(__name__) PRIORITY_BASE = 1 PRIORITY_STEP = 10 def priority_for_rate_rank(rank: int) -> int: """Convert a zero-based sorted rate rank to an account priority.""" return PRIORITY_BASE + rank * PRIORITY_STEP def binding_sources(binding: WebsiteGroupBinding) -> list[dict[str, Any]]: try: data = json.loads(binding.source_groups_json or "[]") except Exception: return [] return data if isinstance(data, list) else [] def latest_rate_map(db: Session, upstream_id: int) -> dict[str, Any]: row = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == upstream_id) .order_by(UpstreamRateSnapshot.captured_at.desc()) .first() ) if not row: return {} snapshot = json.loads(row.snapshot_json or "{}") groups = snapshot.get("groups") or {} return groups if isinstance(groups, dict) else {} def get_affected_bindings(db: Session, changes: list[dict[str, Any]], upstream_id: int) -> list[WebsiteGroupBinding]: changed_ids = {str(change.get("group_id")) for change in changes if change.get("group_id") is not None} if not changed_ids: return [] result: list[WebsiteGroupBinding] = [] bindings = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.enabled == True).all() for binding in bindings: for source in binding_sources(binding): if int(source.get("upstream_id") or 0) == upstream_id and str(source.get("group_id")) in changed_ids: result.append(binding) break return result def _client_for(website: Website) -> Sub2ApiWebsiteClient: return Sub2ApiWebsiteClient( base_url=website.base_url, api_prefix=website.api_prefix, auth_type=website.auth_type, auth_config=json.loads(website.auth_config_json or "{}"), timeout=float(website.timeout_seconds), ) def _log( db: Session, binding: WebsiteGroupBinding, website: Website, source_rates: list[dict[str, Any]], status: str, message: str, old_rate: Any = None, new_rate: Any = None, ) -> WebsiteSyncLog: row = WebsiteSyncLog( website_id=website.id, binding_id=binding.id, target_group_id=binding.target_group_id, target_group_name=binding.target_group_name, algorithm=binding.algorithm, percent=binding.percent, source_rates_json=json.dumps(source_rates, ensure_ascii=False), old_rate=decimal_string(old_rate) if old_rate not in (None, "") else None, new_rate=decimal_string(new_rate) if new_rate not in (None, "") else None, status=status, message=message, ) db.add(row) db.commit() db.refresh(row) return row def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True) -> WebsiteSyncLog: website = db.query(Website).filter(Website.id == binding.website_id).first() if not website: raise WebsiteError("网站不存在") sources = binding_sources(binding) # ── 批量预查:收集所有上游 ID,一次查询上游名称 ── upstream_ids = {int(s.get("upstream_id") or 0) for s in sources if s.get("upstream_id")} upstreams = {} if upstream_ids: rows = db.query(Upstream).filter(Upstream.id.in_(upstream_ids)).all() upstreams = {u.id: u for u in rows} # ── 同一轮 sync 内的快照缓存(调用级,函数返回即释放)── _snap_cache: dict[int, dict[str, Any]] = {} def _get_snap(upstream_id: int) -> dict[str, Any]: if upstream_id not in _snap_cache: _snap_cache[upstream_id] = latest_rate_map(db, upstream_id) return _snap_cache[upstream_id] source_rates: list[dict[str, Any]] = [] for source in sources: upstream_id = int(source.get("upstream_id") or 0) group_id = str(source.get("group_id") or "") groups = _get_snap(upstream_id) group = groups.get(group_id) if group_id else None upstream = upstreams.get(upstream_id) source_rates.append({ "upstream_id": upstream_id, "upstream_name": source.get("upstream_name") or (upstream.name if upstream else ""), "group_id": group_id, "group_name": source.get("group_name") or (group.get("group_name", "") if isinstance(group, dict) else ""), "rate": group.get("rate") if isinstance(group, dict) else None, }) try: target_rate = calculate_target_rate([item.get("rate") for item in source_rates], binding.percent, binding.algorithm) except Exception as exc: return _log(db, binding, website, source_rates, "failed", str(exc)) old_rate = None if write and website.enabled and website.auto_sync_enabled and binding.enabled: try: with _client_for(website) as client: groups = client.get_groups(website.groups_endpoint) target = next((item for item in groups if item.get("id") == binding.target_group_id), None) old_rate = target.get("rate_multiplier") if target else None client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate) website.last_status = "healthy" website.last_error = None except Exception as exc: website.last_status = "unhealthy" website.last_error = str(exc) db.commit() return _log(db, binding, website, source_rates, "failed", f"写回失败:{exc}", old_rate, target_rate) db.commit() log = _log(db, binding, website, source_rates, "success", "同步成功", old_rate, target_rate) old_rate_str = decimal_string(old_rate) if old_rate not in (None, "") else None new_rate_str = decimal_string(target_rate) if old_rate_str != new_rate_str: webhook_service.send_website_rate_changed( db, website.id, website.name, website.base_url, binding.id, binding.target_group_id, binding.target_group_name, old_rate_str, new_rate_str, source_rates, ) return log message = "已计算建议倍率,未写回" if not website.enabled or not website.auto_sync_enabled: message = "网站未启用自动同步,未写回" elif not binding.enabled: message = "绑定未启用,未写回" return _log(db, binding, website, source_rates, "success", message, old_rate, target_rate) def _snapshot_group_rate(group: dict) -> float: """从快照分组数据中提取倍率(兼容多个字段名)。""" raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1 try: return float(raw) except (TypeError, ValueError): return 1.0 def build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, int]: """根据上游分组倍率构建 f"{upstream_id}:{group_id}" → priority 映射。 使用 (upstream_id, group_id) 复合键避免不同上游的同名分组互相覆盖。 遍历所有涉及的上游的最新快照,收集分组的倍率,按倍率升序排列后赋值 priority。 倍率最低的 priority=1,次低的 priority=11,以此类推。相同倍率的分组共享同一 priority。 """ group_rates: dict[str, float] = {} for uid in upstream_ids: groups = latest_rate_map(db, uid) for gid, g in groups.items(): if not isinstance(g, dict): continue rate = _snapshot_group_rate(g) key = f"{uid}:{gid}" group_rates[key] = rate unique_rates = sorted(set(group_rates.values())) rate_to_priority = {rate: priority_for_rate_rank(idx) for idx, rate in enumerate(unique_rates)} return {key: rate_to_priority[rate] for key, rate in group_rates.items()} def _priority_result(row, new_priority: int | None, status: str, message: str) -> dict: """构建统一的优先级同步结果 dict。""" return { "account_id": row.imported_account_id, "group_id": row.group_id, "upstream_id": row.upstream_id, "old_priority": None, "new_priority": new_priority, "status": status, "message": message, } def _write_priority_sync_log_with_map( db: Session, wid: int, upstream_name: str, results: list[dict], priority_map: dict[str, int], ) -> None: """写入 priority_sync 日志,同时保存账号明细和 priority_map 快照。 source_rates_json 格式:[{"_meta": "priority_map", "data": {...}}, {"account_id": ..., ...}, ...] 兼容 WebsiteSyncLogResponse.source_rates: list[dict] 类型约束。 """ log_results: list[dict] = [ {"_meta": "priority_map", "data": dict(priority_map)}, ] log_results.extend(results) success = sum(1 for r in results if r["status"] == "success") failed = sum(1 for r in results if r["status"] == "failed") skipped = sum(1 for r in results if r["status"] == "skipped") parts = [] if success: parts.append(f"{success} 个更新成功") if failed: parts.append(f"{failed} 个失败") if skipped: parts.append(f"{skipped} 个跳过") log = WebsiteSyncLog( website_id=wid, binding_id=None, target_group_id="", target_group_name="", algorithm="priority_sync", percent=0, source_rates_json=json.dumps(log_results, ensure_ascii=False, default=str), old_rate=None, new_rate=None, status="failed" if failed else "success", message=f"优先级同步(上游={upstream_name}):{'、'.join(parts)} / 共 {len(results)} 个", ) db.add(log) db.commit() def _try_send_priority_webhook( db: Session, wid: int, website_name: str, upstream_id: int, upstream_name: str, updates: list[dict], ) -> None: """发送 account_priority_changed webhook,失败不抛异常。""" if not updates: return # 如果没传入名称,尝试从 DB 查 resolved_name = website_name if not resolved_name: row = db.query(Website.name).filter(Website.id == wid).first() if row: resolved_name = row[0] else: resolved_name = f"网站#{wid}" try: webhook_service.send_account_priority_changed( db, website_id=wid, website_name=resolved_name, upstream_id=upstream_id, upstream_name=upstream_name, updates=updates, ) except Exception as exc: logger.warning("account_priority_changed webhook failed for website %s: %s", wid, exc) def sync_account_priorities_for_upstream( db: Session, upstream_id: int, website_id: int | None = None, ) -> list[dict]: """上游倍率变化后,自动更新已导入下游账号的 priority。 只处理同一目标分组内有多个账号(存在竞争)的情况: - 竞争分组键:imported_target_group_id(老数据 fallback 到 group_id) - 同一竞争分组内按倍率升序排序,priority 从 1 开始,每档间隔 10(相同倍率共享) - 单账号分组:完全跳过,不调用 update_account,不发通知 - 无竞争分组:直接返回,不写日志,不发通知 """ from collections import defaultdict key_query = db.query(UpstreamGeneratedKey).filter( UpstreamGeneratedKey.upstream_id == upstream_id, UpstreamGeneratedKey.imported_website_id.isnot(None), UpstreamGeneratedKey.imported_account_id.isnot(None), UpstreamGeneratedKey.status != "orphaned", ) if website_id is not None: key_query = key_query.filter(UpstreamGeneratedKey.imported_website_id == website_id) key_rows = key_query.all() if not key_rows: return [] upstream_name = db.query(Upstream.name).filter(Upstream.id == upstream_id).scalar() or f"#{upstream_id}" # 按 imported_website_id 分组 website_groups: dict[int, list[UpstreamGeneratedKey]] = {} for row in key_rows: wid = row.imported_website_id website_groups.setdefault(wid, []).append(row) all_results: list[dict] = [] for wid, rows in website_groups.items(): website = db.query(Website).filter(Website.id == wid).first() if not website or not website.enabled: logger.info("skip account priority sync: website %s not found or disabled", wid) # 不写日志、不发通知:网站不可用时账号无法更新,沉默跳过 continue # 查询该网站所有已导入 Key(跨上游),用于倍率查询 all_website_keys = ( db.query(UpstreamGeneratedKey) .filter( UpstreamGeneratedKey.imported_website_id == wid, UpstreamGeneratedKey.imported_account_id.isnot(None), UpstreamGeneratedKey.status != "orphaned", ) .all() ) # ── 按竞争分组分桶 ────────────────────────────────────────────────── # 竞争分组键:imported_target_group_id(老数据为 NULL 时 fallback 到 group_id) buckets: dict[str, list[UpstreamGeneratedKey]] = defaultdict(list) for row in all_website_keys: comp_key = row.imported_target_group_id or row.group_id buckets[comp_key].append(row) # 只保留账号数 > 1 的分组(有竞争才需要排序) competitive_buckets = {k: v for k, v in buckets.items() if len(v) > 1} if not competitive_buckets: logger.info( "skip account priority sync for website %s: no competitive groups (all single-account)", wid, ) continue # 不写日志,不发通知 # ── 预取快照倍率 ──────────────────────────────────────────────────── all_upstream_ids = {k.upstream_id for k in all_website_keys} try: # 构建 "{upstream_id}:{group_id}" → rate 查询表 raw_rate_map: dict[str, float] = {} for uid in all_upstream_ids: groups = latest_rate_map(db, uid) for gid, g in groups.items(): if isinstance(g, dict): raw_rate_map[f"{uid}:{gid}"] = _snapshot_group_rate(g) except Exception as exc: logger.warning("build rate map failed for website %s: %s", wid, exc) raw_rate_map = {} # ── 每个竞争分组内独立计算 priority ──────────────────────────────── # priority_assignment: account_id → new_priority priority_assignment: dict[str, int] = {} for comp_key, comp_rows in competitive_buckets.items(): # 只保留快照中能查到倍率的账号;无数据的账号不参与排序 rated = [ (row, raw_rate_map[f"{row.upstream_id}:{row.group_id}"]) for row in comp_rows if f"{row.upstream_id}:{row.group_id}" in raw_rate_map ] # 过滤后有效账号不足 2 个 → 此分组无竞争意义,整组跳过 if len(rated) < 2: logger.info( "skip competitive bucket %s for website %s: only %d account(s) have rate data", comp_key, wid, len(rated), ) continue # 组内按倍率升序排序(倍率低 → priority 小 → 优先) unique_rates = sorted(set(r for _, r in rated)) rate_to_prio = {rate: priority_for_rate_rank(idx) for idx, rate in enumerate(unique_rates)} for row, rate in rated: priority_assignment[row.imported_account_id] = rate_to_prio[rate] # ── 调用 update_account(仅竞争分组的账号)─────────────────────── site_results: list[dict] = [] try: with Sub2ApiWebsiteClient( base_url=website.base_url, api_prefix=website.api_prefix, auth_type=website.auth_type, auth_config=json.loads(website.auth_config_json or "{}"), timeout=float(website.timeout_seconds), ) as client: for comp_rows in competitive_buckets.values(): for row in comp_rows: account_id = row.imported_account_id new_priority = priority_assignment.get(account_id) if new_priority is None: continue try: client.update_account(account_id, {"priority": new_priority}) logger.info( "updated priority for account %s (website=%s, upstream=%s, group=%s" ", comp_group=%s): %s", account_id, wid, row.upstream_id, row.group_id, row.imported_target_group_id or row.group_id, new_priority, ) site_results.append( _priority_result(row, new_priority, "success", f"优先级已更新为 {new_priority}") ) except Exception as exc: logger.warning( "failed to update priority for account %s (website=%s): %s", account_id, wid, exc, ) site_results.append( _priority_result(row, new_priority, "failed", str(exc)) ) except Exception as exc: logger.warning("failed to connect website %s for account priority sync: %s", wid, exc) for comp_rows in competitive_buckets.values(): for row in comp_rows: site_results.append( _priority_result(row, None, "failed", f"连接网站失败: {exc}") ) # 只发送有实际成功/失败的通知(不包含单账号跳过项) notify_updates = [r for r in site_results if r["status"] in ("success", "failed")] if notify_updates: # 构建简化的 priority_map 供日志参考(只包含竞争分组的账号) priority_map_snapshot = { f"{row.upstream_id}:{row.group_id}": priority_assignment.get(row.imported_account_id) for comp_rows in competitive_buckets.values() for row in comp_rows } _write_priority_sync_log_with_map(db, wid, upstream_name, site_results, priority_map_snapshot) _try_send_priority_webhook(db, wid, website.name, upstream_id, upstream_name, notify_updates) all_results.extend(site_results) return all_results def _fetch_remote_managed_prefixes(db: Session, upstream_id: int) -> list[str]: """查询本地 distinct managed_prefix。 返回该上游所有已使用的 prefix 列表。空时回退 ["SmartUp"] 兼容旧数据。 """ prefixes = [ row[0] for row in db.query(UpstreamGeneratedKey.managed_prefix) .filter( UpstreamGeneratedKey.upstream_id == upstream_id, UpstreamGeneratedKey.managed_prefix.isnot(None), ) .distinct() .all() ] return prefixes if prefixes else ["SmartUp"] def _fetch_remote_managed_key_ids(db: Session, client, upstream_id: int) -> set[str]: """查询本地 distinct managed_prefix,分别拉远端活跃 Key ID 集合。 返回全部找到的远端 Key ID(合并多个 prefix 的结果)。 """ all_ids: set[str] = set() for prefix in _fetch_remote_managed_prefixes(db, upstream_id): remote_keys = client.list_api_keys(search=prefix, status="active") all_ids.update(str(k["id"]) for k in remote_keys if k.get("id")) return all_ids def reconcile_upstream_keys( db: Session, upstream_id: int, active_group_ids: set[str] | None, remote_key_ids: set[str] | None, captured_at: datetime, ) -> None: """对账上游 Key 的本地缓存与远端状态。 active_group_ids=None → 跳过分组级清理(避免登录失败时误删)。 remote_key_ids=None → 跳过远端 key_id 级清理(查询失败时安全)。 两者同时为 None 则完全跳过对账。 """ if active_group_ids is None and remote_key_ids is None: return key_rows = ( db.query(UpstreamGeneratedKey) .filter( UpstreamGeneratedKey.upstream_id == upstream_id, ) .all() ) for row in key_rows: if active_group_ids is not None and row.group_id not in active_group_ids: if row.imported_website_id and row.imported_account_id: row.status = "orphaned" row.error = "来源分组已不存在" row.updated_at = captured_at logger.info("marked key %s orphaned (group %s no longer in snapshot)", row.id, row.group_id) else: db.delete(row) logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id) continue if row.key_id and remote_key_ids is not None and row.key_id not in remote_key_ids: if row.imported_website_id and row.imported_account_id: row.status = "orphaned" row.error = "远端 Key 已不存在" row.updated_at = captured_at logger.info("marked key %s orphaned (key_id %s gone from remote)", row.id, row.key_id) else: db.delete(row) logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id) continue if remote_key_ids is not None and row.key_id in remote_key_ids: row.updated_at = captured_at def reconcile_upstream_keys_full(db: Session, upstream_id: int) -> bool: """完整的 Key 对账:拉取最新快照的分组 + 登录上游查远端 Key 列表 → 调用 reconcile_upstream_keys。 活跃分组 ID 从最新快照获取(与调度器一致),而非调用 live API 避免格式不一致。 安全规则: - 快照存在 → 才允许分组级清理。 - 远端 Key 列表拉取成功 → 才允许 key_id 级清理。 - 两者均失败 → 不做任何删除/标记。 支持自定义 managed_prefix:查询本地 distinct prefix,分别查远端。 """ from datetime import datetime, timezone upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first() if not upstream: return False auth_config = json.loads(upstream.auth_config_json or "{}") groups_fetched = False keys_fetched = False active_group_ids: set[str] | None = None remote_key_ids: set[str] | None = None now = datetime.now(timezone.utc) # 从最新快照获取活跃分组 ID(与调度器 _sync_upstream_keys 一致) groups = latest_rate_map(db, upstream_id) if groups: active_group_ids = set(groups.keys()) groups_fetched = True try: with UpstreamClient( base_url=upstream.base_url, api_prefix=upstream.api_prefix, auth_type=upstream.auth_type, auth_config=auth_config, timeout=float(upstream.timeout_seconds), ) as client: client.login() # 获取远端 Key 列表(支持自定义 managed_prefix) remote_key_ids = _fetch_remote_managed_key_ids(db, client, upstream_id) keys_fetched = True except Exception as exc: logger.warning("reconcile_upstream_keys_full: upstream %s failed: %s", upstream_id, exc) # 只传递成功获取的数据;失败的传 None 跳过对应检查 reconcile_upstream_keys( db, upstream_id, active_group_ids if groups_fetched else None, remote_key_ids if keys_fetched else None, now, ) db.commit() return keys_fetched def sync_affected_bindings(db: Session, upstream_id: int, changes: list[dict[str, Any]]) -> None: for binding in get_affected_bindings(db, changes, upstream_id): try: sync_binding(db, binding, write=True) except Exception as exc: logger.exception("website sync failed for binding %s: %s", binding.id, exc)