Files
SmartUp/backend/app/services/website_sync.py
T
liumangmang c8ba25f08e feat: live remote key list with auto-upsert and safe group name extraction
- list_generated_keys now fetches live keys from upstream API, merges with
  local DB: remote keys with plaintext values are auto-upserted (by
  group_id+managed_prefix), remote-only keys shown as unimportable
- Use _fetch_remote_managed_prefixes to support custom key prefixes
- Group remote keys by (group_id, prefix), pick latest by key_id
- Extract _remote_group_name helper for safe group name parsing
  (handles dict group field from Meow upstream)
- Frontend excludes orphaned keys from importable list
- Backend import endpoint reconciles upstream before importing
2026-06-01 14:53:40 +08:00

561 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import logging
from decimal import Decimal
from typing import Any
from sqlalchemy.orm import Session
from app.models.snapshot import UpstreamRateSnapshot
from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string
from app.services.upstream_client import UpstreamClient
from app.services import webhook_service
logger = logging.getLogger(__name__)
def binding_sources(binding: WebsiteGroupBinding) -> list[dict[str, Any]]:
try:
data = json.loads(binding.source_groups_json or "[]")
except Exception:
return []
return data if isinstance(data, list) else []
def latest_rate_map(db: Session, upstream_id: int) -> dict[str, Any]:
row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.first()
)
if not row:
return {}
snapshot = json.loads(row.snapshot_json or "{}")
groups = snapshot.get("groups") or {}
return groups if isinstance(groups, dict) else {}
def get_affected_bindings(db: Session, changes: list[dict[str, Any]], upstream_id: int) -> list[WebsiteGroupBinding]:
changed_ids = {str(change.get("group_id")) for change in changes if change.get("group_id") is not None}
if not changed_ids:
return []
result: list[WebsiteGroupBinding] = []
bindings = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.enabled == True).all()
for binding in bindings:
for source in binding_sources(binding):
if int(source.get("upstream_id") or 0) == upstream_id and str(source.get("group_id")) in changed_ids:
result.append(binding)
break
return result
def _client_for(website: Website) -> Sub2ApiWebsiteClient:
return Sub2ApiWebsiteClient(
base_url=website.base_url,
api_prefix=website.api_prefix,
auth_type=website.auth_type,
auth_config=json.loads(website.auth_config_json or "{}"),
timeout=float(website.timeout_seconds),
)
def _log(
db: Session,
binding: WebsiteGroupBinding,
website: Website,
source_rates: list[dict[str, Any]],
status: str,
message: str,
old_rate: Any = None,
new_rate: Any = None,
) -> WebsiteSyncLog:
row = WebsiteSyncLog(
website_id=website.id,
binding_id=binding.id,
target_group_id=binding.target_group_id,
target_group_name=binding.target_group_name,
algorithm=binding.algorithm,
percent=binding.percent,
source_rates_json=json.dumps(source_rates, ensure_ascii=False),
old_rate=decimal_string(old_rate) if old_rate not in (None, "") else None,
new_rate=decimal_string(new_rate) if new_rate not in (None, "") else None,
status=status,
message=message,
)
db.add(row)
db.commit()
db.refresh(row)
return row
def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True) -> WebsiteSyncLog:
website = db.query(Website).filter(Website.id == binding.website_id).first()
if not website:
raise WebsiteError("网站不存在")
sources = binding_sources(binding)
# ── 批量预查:收集所有上游 ID,一次查询上游名称 ──
upstream_ids = {int(s.get("upstream_id") or 0) for s in sources if s.get("upstream_id")}
upstreams = {}
if upstream_ids:
rows = db.query(Upstream).filter(Upstream.id.in_(upstream_ids)).all()
upstreams = {u.id: u for u in rows}
# ── 同一轮 sync 内的快照缓存(调用级,函数返回即释放)──
_snap_cache: dict[int, dict[str, Any]] = {}
def _get_snap(upstream_id: int) -> dict[str, Any]:
if upstream_id not in _snap_cache:
_snap_cache[upstream_id] = latest_rate_map(db, upstream_id)
return _snap_cache[upstream_id]
source_rates: list[dict[str, Any]] = []
for source in sources:
upstream_id = int(source.get("upstream_id") or 0)
group_id = str(source.get("group_id") or "")
groups = _get_snap(upstream_id)
group = groups.get(group_id) if group_id else None
upstream = upstreams.get(upstream_id)
source_rates.append({
"upstream_id": upstream_id,
"upstream_name": source.get("upstream_name") or (upstream.name if upstream else ""),
"group_id": group_id,
"group_name": source.get("group_name") or (group.get("group_name", "") if isinstance(group, dict) else ""),
"rate": group.get("rate") if isinstance(group, dict) else None,
})
try:
target_rate = calculate_target_rate([item.get("rate") for item in source_rates], binding.percent, binding.algorithm)
except Exception as exc:
return _log(db, binding, website, source_rates, "failed", str(exc))
old_rate = None
if write and website.enabled and website.auto_sync_enabled and binding.enabled:
try:
with _client_for(website) as client:
groups = client.get_groups(website.groups_endpoint)
target = next((item for item in groups if item.get("id") == binding.target_group_id), None)
old_rate = target.get("rate_multiplier") if target else None
client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate)
website.last_status = "healthy"
website.last_error = None
except Exception as exc:
website.last_status = "unhealthy"
website.last_error = str(exc)
db.commit()
return _log(db, binding, website, source_rates, "failed", f"写回失败:{exc}", old_rate, target_rate)
db.commit()
log = _log(db, binding, website, source_rates, "success", "同步成功", old_rate, target_rate)
old_rate_str = decimal_string(old_rate) if old_rate not in (None, "") else None
new_rate_str = decimal_string(target_rate)
if old_rate_str != new_rate_str:
webhook_service.send_website_rate_changed(
db,
website.id,
website.name,
website.base_url,
binding.id,
binding.target_group_id,
binding.target_group_name,
old_rate_str,
new_rate_str,
source_rates,
)
return log
message = "已计算建议倍率,未写回"
if not website.enabled or not website.auto_sync_enabled:
message = "网站未启用自动同步,未写回"
elif not binding.enabled:
message = "绑定未启用,未写回"
return _log(db, binding, website, source_rates, "success", message, old_rate, target_rate)
def _snapshot_group_rate(group: dict) -> float:
"""从快照分组数据中提取倍率(兼容多个字段名)。"""
raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1
try:
return float(raw)
except (TypeError, ValueError):
return 1.0
def build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, int]:
"""根据上游分组倍率构建 f"{upstream_id}:{group_id}" → priority 映射。
使用 (upstream_id, group_id) 复合键避免不同上游的同名分组互相覆盖。
遍历所有涉及的上游的最新快照,收集分组的倍率,按倍率升序排列后赋值 priority。
倍率最低的 priority=1,次低的 priority=2,以此类推。相同倍率的分组共享同一 priority。
"""
group_rates: dict[str, float] = {}
for uid in upstream_ids:
groups = latest_rate_map(db, uid)
for gid, g in groups.items():
if not isinstance(g, dict):
continue
rate = _snapshot_group_rate(g)
key = f"{uid}:{gid}"
group_rates[key] = rate
unique_rates = sorted(set(group_rates.values()))
rate_to_priority = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
return {key: rate_to_priority[rate] for key, rate in group_rates.items()}
def _priority_result(row, new_priority: int | None, status: str, message: str) -> dict:
"""构建统一的优先级同步结果 dict。"""
return {
"account_id": row.imported_account_id,
"group_id": row.group_id,
"upstream_id": row.upstream_id,
"old_priority": None,
"new_priority": new_priority,
"status": status,
"message": message,
}
def _write_priority_sync_log_with_map(
db: Session, wid: int, upstream_name: str,
results: list[dict], priority_map: dict[str, int],
) -> None:
"""写入 priority_sync 日志,同时保存账号明细和 priority_map 快照。
source_rates_json 格式:[{"_meta": "priority_map", "data": {...}}, {"account_id": ..., ...}, ...]
兼容 WebsiteSyncLogResponse.source_rates: list[dict] 类型约束。
"""
log_results: list[dict] = [
{"_meta": "priority_map", "data": dict(priority_map)},
]
log_results.extend(results)
success = sum(1 for r in results if r["status"] == "success")
failed = sum(1 for r in results if r["status"] == "failed")
skipped = sum(1 for r in results if r["status"] == "skipped")
parts = []
if success:
parts.append(f"{success} 个更新成功")
if failed:
parts.append(f"{failed} 个失败")
if skipped:
parts.append(f"{skipped} 个跳过")
log = WebsiteSyncLog(
website_id=wid,
binding_id=None,
target_group_id="",
target_group_name="",
algorithm="priority_sync",
percent=0,
source_rates_json=json.dumps(log_results, ensure_ascii=False, default=str),
old_rate=None,
new_rate=None,
status="failed" if failed else "success",
message=f"优先级同步(上游={upstream_name}):{''.join(parts)} / 共 {len(results)}",
)
db.add(log)
db.commit()
def _try_send_priority_webhook(
db: Session, wid: int, website_name: str,
upstream_id: int, upstream_name: str,
updates: list[dict],
) -> None:
"""发送 account_priority_changed webhook,失败不抛异常。"""
if not updates:
return
# 如果没传入名称,尝试从 DB 查
resolved_name = website_name
if not resolved_name:
row = db.query(Website.name).filter(Website.id == wid).first()
if row:
resolved_name = row[0]
else:
resolved_name = f"网站#{wid}"
try:
webhook_service.send_account_priority_changed(
db,
website_id=wid,
website_name=resolved_name,
upstream_id=upstream_id,
upstream_name=upstream_name,
updates=updates,
)
except Exception as exc:
logger.warning("account_priority_changed webhook failed for website %s: %s", wid, exc)
def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[dict]:
"""上游倍率变化后,自动更新已导入下游账号的 priority。
查询该上游下所有已导入(非 orphaned)的 Key,按目标网站分组后重新计算全局优先级,
并通过 update_account API 推送到下游网站。返回详细结果列表。
同时写入 WebsiteSyncLog 持久化审计日志,并通过 webhook 发送通知。
"""
from app.services.website_client import Sub2ApiWebsiteClient as Client
key_rows = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == upstream_id,
UpstreamGeneratedKey.imported_website_id.isnot(None),
UpstreamGeneratedKey.imported_account_id.isnot(None),
UpstreamGeneratedKey.status != "orphaned",
)
.all()
)
if not key_rows:
return []
upstream_name = db.query(Upstream.name).filter(Upstream.id == upstream_id).scalar() or f"#{upstream_id}"
# 按 imported_website_id 分组
website_groups: dict[int, list[UpstreamGeneratedKey]] = {}
for row in key_rows:
wid = row.imported_website_id
if wid not in website_groups:
website_groups[wid] = []
website_groups[wid].append(row)
all_results: list[dict] = []
for wid, rows in website_groups.items():
website = db.query(Website).filter(Website.id == wid).first()
if not website or not website.enabled:
logger.info("skip account priority sync: website %s not found or disabled", wid)
site_results = []
for row in rows:
r = _priority_result(row, None, "failed", "网站不可用")
site_results.append(r)
all_results.append(r)
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {})
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results)
continue
# 查询该网站所有已导入 Key(跨上游),实现全局优先级排序
all_website_keys = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.imported_website_id == wid,
UpstreamGeneratedKey.imported_account_id.isnot(None),
UpstreamGeneratedKey.status != "orphaned",
)
.all()
)
all_upstream_ids = {k.upstream_id for k in all_website_keys}
try:
priority_map = build_rate_priority_map(db, all_upstream_ids)
except Exception as exc:
logger.warning("build_rate_priority_map failed for website %s: %s", wid, exc)
site_results = []
for row in all_website_keys:
r = _priority_result(row, None, "failed", f"构建优先级映射失败: {exc}")
site_results.append(r)
all_results.append(r)
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {})
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results)
continue
if not priority_map:
logger.info("skip account priority sync for website %s: empty priority map", wid)
site_results = []
for row in all_website_keys:
r = _priority_result(row, None, "skipped", "无上游倍率数据")
site_results.append(r)
all_results.append(r)
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {})
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results)
continue
site_results: list[dict] = []
try:
with Client(
base_url=website.base_url,
api_prefix=website.api_prefix,
auth_type=website.auth_type,
auth_config=json.loads(website.auth_config_json or "{}"),
timeout=float(website.timeout_seconds),
) as client:
for row in all_website_keys:
account_id = row.imported_account_id
if not account_id:
continue
new_priority = priority_map.get(f"{row.upstream_id}:{row.group_id}")
if new_priority is None:
site_results.append(
_priority_result(row, None, "skipped", "无倍率数据,跳过")
)
continue
try:
client.update_account(account_id, {"priority": new_priority})
logger.info(
"updated priority for account %s (website=%s, upstream=%s, group=%s): %s",
account_id, wid, row.upstream_id, row.group_id, new_priority,
)
site_results.append(
_priority_result(row, new_priority, "success", f"优先级已更新为 {new_priority}")
)
except Exception as exc:
logger.warning(
"failed to update priority for account %s (website=%s): %s",
account_id, wid, exc,
)
site_results.append(
_priority_result(row, new_priority, "failed", str(exc))
)
except Exception as exc:
logger.warning("failed to connect website %s for account priority sync: %s", wid, exc)
for row in all_website_keys:
site_results.append(
_priority_result(row, None, "failed", f"连接网站失败: {exc}")
)
all_results.extend(site_results)
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, priority_map)
_try_send_priority_webhook(db, wid, website.name, upstream_id, upstream_name, site_results)
return all_results
def _fetch_remote_managed_prefixes(db: Session, upstream_id: int) -> list[str]:
"""查询本地 distinct managed_prefix。
返回该上游所有已使用的 prefix 列表。空时回退 ["SmartUp"] 兼容旧数据。
"""
prefixes = [
row[0] for row in
db.query(UpstreamGeneratedKey.managed_prefix)
.filter(
UpstreamGeneratedKey.upstream_id == upstream_id,
UpstreamGeneratedKey.managed_prefix.isnot(None),
)
.distinct()
.all()
]
return prefixes if prefixes else ["SmartUp"]
def _fetch_remote_managed_key_ids(db: Session, client, upstream_id: int) -> set[str]:
"""查询本地 distinct managed_prefix,分别拉远端活跃 Key ID 集合。
返回全部找到的远端 Key ID(合并多个 prefix 的结果)。
"""
all_ids: set[str] = set()
for prefix in _fetch_remote_managed_prefixes(db, upstream_id):
remote_keys = client.list_api_keys(search=prefix, status="active")
all_ids.update(str(k["id"]) for k in remote_keys if k.get("id"))
return all_ids
def reconcile_upstream_keys(
db: Session,
upstream_id: int,
active_group_ids: set[str] | None,
remote_key_ids: set[str] | None,
captured_at: datetime,
) -> None:
"""对账上游 Key 的本地缓存与远端状态。
active_group_ids=None → 跳过分组级清理(避免登录失败时误删)。
remote_key_ids=None → 跳过远端 key_id 级清理(查询失败时安全)。
两者同时为 None 则完全跳过对账。
"""
if active_group_ids is None and remote_key_ids is None:
return
key_rows = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == upstream_id,
)
.all()
)
for row in key_rows:
if active_group_ids is not None and row.group_id not in active_group_ids:
if row.imported_website_id and row.imported_account_id:
row.status = "orphaned"
row.error = "来源分组已不存在"
row.updated_at = captured_at
logger.info("marked key %s orphaned (group %s no longer in snapshot)", row.id, row.group_id)
else:
db.delete(row)
logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id)
continue
if row.key_id and remote_key_ids is not None and row.key_id not in remote_key_ids:
if row.imported_website_id and row.imported_account_id:
row.status = "orphaned"
row.error = "远端 Key 已不存在"
row.updated_at = captured_at
logger.info("marked key %s orphaned (key_id %s gone from remote)", row.id, row.key_id)
else:
db.delete(row)
logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id)
continue
if remote_key_ids is not None and row.key_id in remote_key_ids:
row.updated_at = captured_at
def reconcile_upstream_keys_full(db: Session, upstream_id: int) -> bool:
"""完整的 Key 对账:拉取最新快照的分组 + 登录上游查远端 Key 列表 → 调用 reconcile_upstream_keys。
活跃分组 ID 从最新快照获取(与调度器一致),而非调用 live API 避免格式不一致。
安全规则:
- 快照存在 → 才允许分组级清理。
- 远端 Key 列表拉取成功 → 才允许 key_id 级清理。
- 两者均失败 → 不做任何删除/标记。
支持自定义 managed_prefix:查询本地 distinct prefix,分别查远端。
"""
from datetime import datetime, timezone
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
if not upstream:
return False
auth_config = json.loads(upstream.auth_config_json or "{}")
groups_fetched = False
keys_fetched = False
active_group_ids: set[str] | None = None
remote_key_ids: set[str] | None = None
now = datetime.now(timezone.utc)
# 从最新快照获取活跃分组 ID(与调度器 _sync_upstream_keys 一致)
groups = latest_rate_map(db, upstream_id)
if groups:
active_group_ids = set(groups.keys())
groups_fetched = True
try:
with UpstreamClient(
base_url=upstream.base_url,
api_prefix=upstream.api_prefix,
auth_type=upstream.auth_type,
auth_config=auth_config,
timeout=float(upstream.timeout_seconds),
) as client:
client.login()
# 获取远端 Key 列表(支持自定义 managed_prefix
remote_key_ids = _fetch_remote_managed_key_ids(db, client, upstream_id)
keys_fetched = True
except Exception as exc:
logger.warning("reconcile_upstream_keys_full: upstream %s failed: %s", upstream_id, exc)
# 只传递成功获取的数据;失败的传 None 跳过对应检查
reconcile_upstream_keys(
db,
upstream_id,
active_group_ids if groups_fetched else None,
remote_key_ids if keys_fetched else None,
now,
)
db.commit()
return keys_fetched
def sync_affected_bindings(db: Session, upstream_id: int, changes: list[dict[str, Any]]) -> None:
for binding in get_affected_bindings(db, changes, upstream_id):
try:
sync_binding(db, binding, write=True)
except Exception as exc:
logger.exception("website sync failed for binding %s: %s", binding.id, exc)