from __future__ import annotations import json import logging from decimal import Decimal from typing import Any from sqlalchemy.orm import Session from app.models.snapshot import UpstreamRateSnapshot from app.models.upstream import Upstream from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string from app.services import webhook_service logger = logging.getLogger(__name__) def binding_sources(binding: WebsiteGroupBinding) -> list[dict[str, Any]]: try: data = json.loads(binding.source_groups_json or "[]") except Exception: return [] return data if isinstance(data, list) else [] def latest_rate_map(db: Session, upstream_id: int) -> dict[str, Any]: row = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == upstream_id) .order_by(UpstreamRateSnapshot.captured_at.desc()) .first() ) if not row: return {} snapshot = json.loads(row.snapshot_json or "{}") groups = snapshot.get("groups") or {} return groups if isinstance(groups, dict) else {} def get_affected_bindings(db: Session, changes: list[dict[str, Any]], upstream_id: int) -> list[WebsiteGroupBinding]: changed_ids = {str(change.get("group_id")) for change in changes if change.get("group_id") is not None} if not changed_ids: return [] result: list[WebsiteGroupBinding] = [] bindings = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.enabled == True).all() for binding in bindings: for source in binding_sources(binding): if int(source.get("upstream_id") or 0) == upstream_id and str(source.get("group_id")) in changed_ids: result.append(binding) break return result def _client_for(website: Website) -> Sub2ApiWebsiteClient: return Sub2ApiWebsiteClient( base_url=website.base_url, api_prefix=website.api_prefix, auth_type=website.auth_type, auth_config=json.loads(website.auth_config_json or "{}"), timeout=float(website.timeout_seconds), ) def _log( db: Session, binding: WebsiteGroupBinding, website: Website, source_rates: list[dict[str, Any]], status: str, message: str, old_rate: Any = None, new_rate: Any = None, ) -> WebsiteSyncLog: row = WebsiteSyncLog( website_id=website.id, binding_id=binding.id, target_group_id=binding.target_group_id, target_group_name=binding.target_group_name, algorithm=binding.algorithm, percent=binding.percent, source_rates_json=json.dumps(source_rates, ensure_ascii=False), old_rate=decimal_string(old_rate) if old_rate not in (None, "") else None, new_rate=decimal_string(new_rate) if new_rate not in (None, "") else None, status=status, message=message, ) db.add(row) db.commit() db.refresh(row) return row def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True) -> WebsiteSyncLog: website = db.query(Website).filter(Website.id == binding.website_id).first() if not website: raise WebsiteError("网站不存在") sources = binding_sources(binding) # ── 批量预查:收集所有上游 ID,一次查询上游名称 ── upstream_ids = {int(s.get("upstream_id") or 0) for s in sources if s.get("upstream_id")} upstreams = {} if upstream_ids: rows = db.query(Upstream).filter(Upstream.id.in_(upstream_ids)).all() upstreams = {u.id: u for u in rows} # ── 同一轮 sync 内的快照缓存(调用级,函数返回即释放)── _snap_cache: dict[int, dict[str, Any]] = {} def _get_snap(upstream_id: int) -> dict[str, Any]: if upstream_id not in _snap_cache: _snap_cache[upstream_id] = latest_rate_map(db, upstream_id) return _snap_cache[upstream_id] source_rates: list[dict[str, Any]] = [] for source in sources: upstream_id = int(source.get("upstream_id") or 0) group_id = str(source.get("group_id") or "") groups = _get_snap(upstream_id) group = groups.get(group_id) if group_id else None upstream = upstreams.get(upstream_id) source_rates.append({ "upstream_id": upstream_id, "upstream_name": source.get("upstream_name") or (upstream.name if upstream else ""), "group_id": group_id, "group_name": source.get("group_name") or (group.get("group_name", "") if isinstance(group, dict) else ""), "rate": group.get("rate") if isinstance(group, dict) else None, }) try: target_rate = calculate_target_rate([item.get("rate") for item in source_rates], binding.percent, binding.algorithm) except Exception as exc: return _log(db, binding, website, source_rates, "failed", str(exc)) old_rate = None if write and website.enabled and website.auto_sync_enabled and binding.enabled: try: with _client_for(website) as client: groups = client.get_groups(website.groups_endpoint) target = next((item for item in groups if item.get("id") == binding.target_group_id), None) old_rate = target.get("rate_multiplier") if target else None client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate) website.last_status = "healthy" website.last_error = None except Exception as exc: website.last_status = "unhealthy" website.last_error = str(exc) db.commit() return _log(db, binding, website, source_rates, "failed", f"写回失败:{exc}", old_rate, target_rate) db.commit() log = _log(db, binding, website, source_rates, "success", "同步成功", old_rate, target_rate) old_rate_str = decimal_string(old_rate) if old_rate not in (None, "") else None new_rate_str = decimal_string(target_rate) if old_rate_str != new_rate_str: webhook_service.send_website_rate_changed( db, website.id, website.name, website.base_url, binding.id, binding.target_group_id, binding.target_group_name, old_rate_str, new_rate_str, source_rates, ) return log message = "已计算建议倍率,未写回" if not website.enabled or not website.auto_sync_enabled: message = "网站未启用自动同步,未写回" elif not binding.enabled: message = "绑定未启用,未写回" return _log(db, binding, website, source_rates, "success", message, old_rate, target_rate) def sync_affected_bindings(db: Session, upstream_id: int, changes: list[dict[str, Any]]) -> None: for binding in get_affected_bindings(db, changes, upstream_id): try: sync_binding(db, binding, write=True) except Exception as exc: logger.exception("website sync failed for binding %s: %s", binding.id, exc)