Files
SmartUp/backend/app/services/website_sync.py
T
2026-05-29 17:51:12 +08:00

426 lines
17 KiB
Python

from __future__ import annotations
import json
import logging
from decimal import Decimal
from typing import Any
from sqlalchemy.orm import Session
from app.models.snapshot import UpstreamRateSnapshot
from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string
from app.services import webhook_service
logger = logging.getLogger(__name__)
def binding_sources(binding: WebsiteGroupBinding) -> list[dict[str, Any]]:
try:
data = json.loads(binding.source_groups_json or "[]")
except Exception:
return []
return data if isinstance(data, list) else []
def latest_rate_map(db: Session, upstream_id: int) -> dict[str, Any]:
row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.first()
)
if not row:
return {}
snapshot = json.loads(row.snapshot_json or "{}")
groups = snapshot.get("groups") or {}
return groups if isinstance(groups, dict) else {}
def get_affected_bindings(db: Session, changes: list[dict[str, Any]], upstream_id: int) -> list[WebsiteGroupBinding]:
changed_ids = {str(change.get("group_id")) for change in changes if change.get("group_id") is not None}
if not changed_ids:
return []
result: list[WebsiteGroupBinding] = []
bindings = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.enabled == True).all()
for binding in bindings:
for source in binding_sources(binding):
if int(source.get("upstream_id") or 0) == upstream_id and str(source.get("group_id")) in changed_ids:
result.append(binding)
break
return result
def _client_for(website: Website) -> Sub2ApiWebsiteClient:
return Sub2ApiWebsiteClient(
base_url=website.base_url,
api_prefix=website.api_prefix,
auth_type=website.auth_type,
auth_config=json.loads(website.auth_config_json or "{}"),
timeout=float(website.timeout_seconds),
)
def _log(
db: Session,
binding: WebsiteGroupBinding,
website: Website,
source_rates: list[dict[str, Any]],
status: str,
message: str,
old_rate: Any = None,
new_rate: Any = None,
) -> WebsiteSyncLog:
row = WebsiteSyncLog(
website_id=website.id,
binding_id=binding.id,
target_group_id=binding.target_group_id,
target_group_name=binding.target_group_name,
algorithm=binding.algorithm,
percent=binding.percent,
source_rates_json=json.dumps(source_rates, ensure_ascii=False),
old_rate=decimal_string(old_rate) if old_rate not in (None, "") else None,
new_rate=decimal_string(new_rate) if new_rate not in (None, "") else None,
status=status,
message=message,
)
db.add(row)
db.commit()
db.refresh(row)
return row
def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True) -> WebsiteSyncLog:
website = db.query(Website).filter(Website.id == binding.website_id).first()
if not website:
raise WebsiteError("网站不存在")
sources = binding_sources(binding)
# ── 批量预查:收集所有上游 ID,一次查询上游名称 ──
upstream_ids = {int(s.get("upstream_id") or 0) for s in sources if s.get("upstream_id")}
upstreams = {}
if upstream_ids:
rows = db.query(Upstream).filter(Upstream.id.in_(upstream_ids)).all()
upstreams = {u.id: u for u in rows}
# ── 同一轮 sync 内的快照缓存(调用级,函数返回即释放)──
_snap_cache: dict[int, dict[str, Any]] = {}
def _get_snap(upstream_id: int) -> dict[str, Any]:
if upstream_id not in _snap_cache:
_snap_cache[upstream_id] = latest_rate_map(db, upstream_id)
return _snap_cache[upstream_id]
source_rates: list[dict[str, Any]] = []
for source in sources:
upstream_id = int(source.get("upstream_id") or 0)
group_id = str(source.get("group_id") or "")
groups = _get_snap(upstream_id)
group = groups.get(group_id) if group_id else None
upstream = upstreams.get(upstream_id)
source_rates.append({
"upstream_id": upstream_id,
"upstream_name": source.get("upstream_name") or (upstream.name if upstream else ""),
"group_id": group_id,
"group_name": source.get("group_name") or (group.get("group_name", "") if isinstance(group, dict) else ""),
"rate": group.get("rate") if isinstance(group, dict) else None,
})
try:
target_rate = calculate_target_rate([item.get("rate") for item in source_rates], binding.percent, binding.algorithm)
except Exception as exc:
return _log(db, binding, website, source_rates, "failed", str(exc))
old_rate = None
if write and website.enabled and website.auto_sync_enabled and binding.enabled:
try:
with _client_for(website) as client:
groups = client.get_groups(website.groups_endpoint)
target = next((item for item in groups if item.get("id") == binding.target_group_id), None)
old_rate = target.get("rate_multiplier") if target else None
client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate)
website.last_status = "healthy"
website.last_error = None
except Exception as exc:
website.last_status = "unhealthy"
website.last_error = str(exc)
db.commit()
return _log(db, binding, website, source_rates, "failed", f"写回失败:{exc}", old_rate, target_rate)
db.commit()
log = _log(db, binding, website, source_rates, "success", "同步成功", old_rate, target_rate)
old_rate_str = decimal_string(old_rate) if old_rate not in (None, "") else None
new_rate_str = decimal_string(target_rate)
if old_rate_str != new_rate_str:
webhook_service.send_website_rate_changed(
db,
website.id,
website.name,
website.base_url,
binding.id,
binding.target_group_id,
binding.target_group_name,
old_rate_str,
new_rate_str,
source_rates,
)
return log
message = "已计算建议倍率,未写回"
if not website.enabled or not website.auto_sync_enabled:
message = "网站未启用自动同步,未写回"
elif not binding.enabled:
message = "绑定未启用,未写回"
return _log(db, binding, website, source_rates, "success", message, old_rate, target_rate)
def _snapshot_group_rate(group: dict) -> float:
"""从快照分组数据中提取倍率(兼容多个字段名)。"""
raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1
try:
return float(raw)
except (TypeError, ValueError):
return 1.0
def build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, int]:
"""根据上游分组倍率构建 f"{upstream_id}:{group_id}" → priority 映射。
使用 (upstream_id, group_id) 复合键避免不同上游的同名分组互相覆盖。
遍历所有涉及的上游的最新快照,收集分组的倍率,按倍率升序排列后赋值 priority。
倍率最低的 priority=1,次低的 priority=2,以此类推。相同倍率的分组共享同一 priority。
"""
group_rates: dict[str, float] = {}
for uid in upstream_ids:
groups = latest_rate_map(db, uid)
for gid, g in groups.items():
if not isinstance(g, dict):
continue
rate = _snapshot_group_rate(g)
key = f"{uid}:{gid}"
group_rates[key] = rate
unique_rates = sorted(set(group_rates.values()))
rate_to_priority = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
return {key: rate_to_priority[rate] for key, rate in group_rates.items()}
def _priority_result(row, new_priority: int | None, status: str, message: str) -> dict:
"""构建统一的优先级同步结果 dict。"""
return {
"account_id": row.imported_account_id,
"group_id": row.group_id,
"upstream_id": row.upstream_id,
"old_priority": None,
"new_priority": new_priority,
"status": status,
"message": message,
}
def _write_priority_sync_log_with_map(
db: Session, wid: int, upstream_name: str,
results: list[dict], priority_map: dict[str, int],
) -> None:
"""写入 priority_sync 日志,同时保存账号明细和 priority_map 快照。
source_rates_json 格式:[{"_meta": "priority_map", "data": {...}}, {"account_id": ..., ...}, ...]
兼容 WebsiteSyncLogResponse.source_rates: list[dict] 类型约束。
"""
log_results: list[dict] = [
{"_meta": "priority_map", "data": dict(priority_map)},
]
log_results.extend(results)
success = sum(1 for r in results if r["status"] == "success")
failed = sum(1 for r in results if r["status"] == "failed")
skipped = sum(1 for r in results if r["status"] == "skipped")
parts = []
if success:
parts.append(f"{success} 个更新成功")
if failed:
parts.append(f"{failed} 个失败")
if skipped:
parts.append(f"{skipped} 个跳过")
log = WebsiteSyncLog(
website_id=wid,
binding_id=None,
target_group_id="",
target_group_name="",
algorithm="priority_sync",
percent=0,
source_rates_json=json.dumps(log_results, ensure_ascii=False, default=str),
old_rate=None,
new_rate=None,
status="failed" if failed else "success",
message=f"优先级同步(上游={upstream_name}):{''.join(parts)} / 共 {len(results)}",
)
db.add(log)
db.commit()
def _try_send_priority_webhook(
db: Session, wid: int, website_name: str,
upstream_id: int, upstream_name: str,
updates: list[dict],
) -> None:
"""发送 account_priority_changed webhook,失败不抛异常。"""
if not updates:
return
# 如果没传入名称,尝试从 DB 查
resolved_name = website_name
if not resolved_name:
row = db.query(Website.name).filter(Website.id == wid).first()
if row:
resolved_name = row[0]
else:
resolved_name = f"网站#{wid}"
try:
webhook_service.send_account_priority_changed(
db,
website_id=wid,
website_name=resolved_name,
upstream_id=upstream_id,
upstream_name=upstream_name,
updates=updates,
)
except Exception as exc:
logger.warning("account_priority_changed webhook failed for website %s: %s", wid, exc)
def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[dict]:
"""上游倍率变化后,自动更新已导入下游账号的 priority。
查询该上游下所有已导入(非 orphaned)的 Key,按目标网站分组后重新计算全局优先级,
并通过 update_account API 推送到下游网站。返回详细结果列表。
同时写入 WebsiteSyncLog 持久化审计日志,并通过 webhook 发送通知。
"""
from app.services.website_client import Sub2ApiWebsiteClient as Client
key_rows = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == upstream_id,
UpstreamGeneratedKey.imported_website_id.isnot(None),
UpstreamGeneratedKey.imported_account_id.isnot(None),
UpstreamGeneratedKey.status != "orphaned",
)
.all()
)
if not key_rows:
return []
upstream_name = db.query(Upstream.name).filter(Upstream.id == upstream_id).scalar() or f"#{upstream_id}"
# 按 imported_website_id 分组
website_groups: dict[int, list[UpstreamGeneratedKey]] = {}
for row in key_rows:
wid = row.imported_website_id
if wid not in website_groups:
website_groups[wid] = []
website_groups[wid].append(row)
all_results: list[dict] = []
for wid, rows in website_groups.items():
website = db.query(Website).filter(Website.id == wid).first()
if not website or not website.enabled:
logger.info("skip account priority sync: website %s not found or disabled", wid)
site_results = []
for row in rows:
r = _priority_result(row, None, "failed", "网站不可用")
site_results.append(r)
all_results.append(r)
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {})
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results)
continue
# 查询该网站所有已导入 Key(跨上游),实现全局优先级排序
all_website_keys = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.imported_website_id == wid,
UpstreamGeneratedKey.imported_account_id.isnot(None),
UpstreamGeneratedKey.status != "orphaned",
)
.all()
)
all_upstream_ids = {k.upstream_id for k in all_website_keys}
try:
priority_map = build_rate_priority_map(db, all_upstream_ids)
except Exception as exc:
logger.warning("build_rate_priority_map failed for website %s: %s", wid, exc)
site_results = []
for row in all_website_keys:
r = _priority_result(row, None, "failed", f"构建优先级映射失败: {exc}")
site_results.append(r)
all_results.append(r)
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {})
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results)
continue
if not priority_map:
logger.info("skip account priority sync for website %s: empty priority map", wid)
site_results = []
for row in all_website_keys:
r = _priority_result(row, None, "skipped", "无上游倍率数据")
site_results.append(r)
all_results.append(r)
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {})
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results)
continue
site_results: list[dict] = []
try:
with Client(
base_url=website.base_url,
api_prefix=website.api_prefix,
auth_type=website.auth_type,
auth_config=json.loads(website.auth_config_json or "{}"),
timeout=float(website.timeout_seconds),
) as client:
for row in all_website_keys:
account_id = row.imported_account_id
if not account_id:
continue
new_priority = priority_map.get(f"{row.upstream_id}:{row.group_id}")
if new_priority is None:
site_results.append(
_priority_result(row, None, "skipped", "无倍率数据,跳过")
)
continue
try:
client.update_account(account_id, {"priority": new_priority})
logger.info(
"updated priority for account %s (website=%s, upstream=%s, group=%s): %s",
account_id, wid, row.upstream_id, row.group_id, new_priority,
)
site_results.append(
_priority_result(row, new_priority, "success", f"优先级已更新为 {new_priority}")
)
except Exception as exc:
logger.warning(
"failed to update priority for account %s (website=%s): %s",
account_id, wid, exc,
)
site_results.append(
_priority_result(row, new_priority, "failed", str(exc))
)
except Exception as exc:
logger.warning("failed to connect website %s for account priority sync: %s", wid, exc)
for row in all_website_keys:
site_results.append(
_priority_result(row, None, "failed", f"连接网站失败: {exc}")
)
all_results.extend(site_results)
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, priority_map)
_try_send_priority_webhook(db, wid, website.name, upstream_id, upstream_name, site_results)
return all_results
def sync_affected_bindings(db: Session, upstream_id: int, changes: list[dict[str, Any]]) -> None:
for binding in get_affected_bindings(db, changes, upstream_id):
try:
sync_binding(db, binding, write=True)
except Exception as exc:
logger.exception("website sync failed for binding %s: %s", binding.id, exc)