f17317b13c
P1 - Missing rate data now skips account instead of falling back to 1.0:
In sync_account_priorities_for_upstream(), the rated list now filters
out accounts whose upstream snapshot has no rate entry for their group_id.
If after filtering a competitive bucket has fewer than 2 accounts with
valid rate data, the entire bucket is silently skipped (no update_account
call, no webhook) rather than treating missing rates as 1.0 and
potentially triggering spurious notifications.
P2 - Re-importing an existing account now backfills imported_target_group_id:
In the exists-is-True idempotency branch of import_upstream_keys_as_accounts(),
if the current request supplies a target_group_id for the account's source group
and it differs from what is stored, the field is written back and committed.
This lets operators fix old data by simply re-running the import dialog.
Tests added:
- test_missing_rate_skips_entire_competitive_group: all accounts in
competitive group lack snapshot → bucket skipped, no update called
- test_partial_missing_rate_sufficient_accounts_still_updates: 3 accounts
in same bucket, 1 missing rate → the 2 with rates still compete normally
All 27 tests pass.
595 lines
25 KiB
Python
595 lines
25 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from decimal import Decimal
|
||
from typing import Any
|
||
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.models.snapshot import UpstreamRateSnapshot
|
||
from app.models.upstream import Upstream
|
||
from app.models.upstream_key import UpstreamGeneratedKey
|
||
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
|
||
from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string
|
||
from app.services.upstream_client import UpstreamClient
|
||
from app.services import webhook_service
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def binding_sources(binding: WebsiteGroupBinding) -> list[dict[str, Any]]:
|
||
try:
|
||
data = json.loads(binding.source_groups_json or "[]")
|
||
except Exception:
|
||
return []
|
||
return data if isinstance(data, list) else []
|
||
|
||
|
||
def latest_rate_map(db: Session, upstream_id: int) -> dict[str, Any]:
|
||
row = (
|
||
db.query(UpstreamRateSnapshot)
|
||
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
|
||
.order_by(UpstreamRateSnapshot.captured_at.desc())
|
||
.first()
|
||
)
|
||
if not row:
|
||
return {}
|
||
snapshot = json.loads(row.snapshot_json or "{}")
|
||
groups = snapshot.get("groups") or {}
|
||
return groups if isinstance(groups, dict) else {}
|
||
|
||
|
||
def get_affected_bindings(db: Session, changes: list[dict[str, Any]], upstream_id: int) -> list[WebsiteGroupBinding]:
|
||
changed_ids = {str(change.get("group_id")) for change in changes if change.get("group_id") is not None}
|
||
if not changed_ids:
|
||
return []
|
||
result: list[WebsiteGroupBinding] = []
|
||
bindings = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.enabled == True).all()
|
||
for binding in bindings:
|
||
for source in binding_sources(binding):
|
||
if int(source.get("upstream_id") or 0) == upstream_id and str(source.get("group_id")) in changed_ids:
|
||
result.append(binding)
|
||
break
|
||
return result
|
||
|
||
|
||
def _client_for(website: Website) -> Sub2ApiWebsiteClient:
|
||
return Sub2ApiWebsiteClient(
|
||
base_url=website.base_url,
|
||
api_prefix=website.api_prefix,
|
||
auth_type=website.auth_type,
|
||
auth_config=json.loads(website.auth_config_json or "{}"),
|
||
timeout=float(website.timeout_seconds),
|
||
)
|
||
|
||
|
||
def _log(
|
||
db: Session,
|
||
binding: WebsiteGroupBinding,
|
||
website: Website,
|
||
source_rates: list[dict[str, Any]],
|
||
status: str,
|
||
message: str,
|
||
old_rate: Any = None,
|
||
new_rate: Any = None,
|
||
) -> WebsiteSyncLog:
|
||
row = WebsiteSyncLog(
|
||
website_id=website.id,
|
||
binding_id=binding.id,
|
||
target_group_id=binding.target_group_id,
|
||
target_group_name=binding.target_group_name,
|
||
algorithm=binding.algorithm,
|
||
percent=binding.percent,
|
||
source_rates_json=json.dumps(source_rates, ensure_ascii=False),
|
||
old_rate=decimal_string(old_rate) if old_rate not in (None, "") else None,
|
||
new_rate=decimal_string(new_rate) if new_rate not in (None, "") else None,
|
||
status=status,
|
||
message=message,
|
||
)
|
||
db.add(row)
|
||
db.commit()
|
||
db.refresh(row)
|
||
return row
|
||
|
||
|
||
def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True) -> WebsiteSyncLog:
|
||
website = db.query(Website).filter(Website.id == binding.website_id).first()
|
||
if not website:
|
||
raise WebsiteError("网站不存在")
|
||
sources = binding_sources(binding)
|
||
# ── 批量预查:收集所有上游 ID,一次查询上游名称 ──
|
||
upstream_ids = {int(s.get("upstream_id") or 0) for s in sources if s.get("upstream_id")}
|
||
upstreams = {}
|
||
if upstream_ids:
|
||
rows = db.query(Upstream).filter(Upstream.id.in_(upstream_ids)).all()
|
||
upstreams = {u.id: u for u in rows}
|
||
# ── 同一轮 sync 内的快照缓存(调用级,函数返回即释放)──
|
||
_snap_cache: dict[int, dict[str, Any]] = {}
|
||
|
||
def _get_snap(upstream_id: int) -> dict[str, Any]:
|
||
if upstream_id not in _snap_cache:
|
||
_snap_cache[upstream_id] = latest_rate_map(db, upstream_id)
|
||
return _snap_cache[upstream_id]
|
||
|
||
source_rates: list[dict[str, Any]] = []
|
||
for source in sources:
|
||
upstream_id = int(source.get("upstream_id") or 0)
|
||
group_id = str(source.get("group_id") or "")
|
||
groups = _get_snap(upstream_id)
|
||
group = groups.get(group_id) if group_id else None
|
||
upstream = upstreams.get(upstream_id)
|
||
source_rates.append({
|
||
"upstream_id": upstream_id,
|
||
"upstream_name": source.get("upstream_name") or (upstream.name if upstream else ""),
|
||
"group_id": group_id,
|
||
"group_name": source.get("group_name") or (group.get("group_name", "") if isinstance(group, dict) else ""),
|
||
"rate": group.get("rate") if isinstance(group, dict) else None,
|
||
})
|
||
try:
|
||
target_rate = calculate_target_rate([item.get("rate") for item in source_rates], binding.percent, binding.algorithm)
|
||
except Exception as exc:
|
||
return _log(db, binding, website, source_rates, "failed", str(exc))
|
||
|
||
old_rate = None
|
||
if write and website.enabled and website.auto_sync_enabled and binding.enabled:
|
||
try:
|
||
with _client_for(website) as client:
|
||
groups = client.get_groups(website.groups_endpoint)
|
||
target = next((item for item in groups if item.get("id") == binding.target_group_id), None)
|
||
old_rate = target.get("rate_multiplier") if target else None
|
||
client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate)
|
||
website.last_status = "healthy"
|
||
website.last_error = None
|
||
except Exception as exc:
|
||
website.last_status = "unhealthy"
|
||
website.last_error = str(exc)
|
||
db.commit()
|
||
return _log(db, binding, website, source_rates, "failed", f"写回失败:{exc}", old_rate, target_rate)
|
||
db.commit()
|
||
log = _log(db, binding, website, source_rates, "success", "同步成功", old_rate, target_rate)
|
||
old_rate_str = decimal_string(old_rate) if old_rate not in (None, "") else None
|
||
new_rate_str = decimal_string(target_rate)
|
||
if old_rate_str != new_rate_str:
|
||
webhook_service.send_website_rate_changed(
|
||
db,
|
||
website.id,
|
||
website.name,
|
||
website.base_url,
|
||
binding.id,
|
||
binding.target_group_id,
|
||
binding.target_group_name,
|
||
old_rate_str,
|
||
new_rate_str,
|
||
source_rates,
|
||
)
|
||
return log
|
||
|
||
message = "已计算建议倍率,未写回"
|
||
if not website.enabled or not website.auto_sync_enabled:
|
||
message = "网站未启用自动同步,未写回"
|
||
elif not binding.enabled:
|
||
message = "绑定未启用,未写回"
|
||
return _log(db, binding, website, source_rates, "success", message, old_rate, target_rate)
|
||
|
||
|
||
def _snapshot_group_rate(group: dict) -> float:
|
||
"""从快照分组数据中提取倍率(兼容多个字段名)。"""
|
||
raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1
|
||
try:
|
||
return float(raw)
|
||
except (TypeError, ValueError):
|
||
return 1.0
|
||
|
||
|
||
def build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, int]:
|
||
"""根据上游分组倍率构建 f"{upstream_id}:{group_id}" → priority 映射。
|
||
|
||
使用 (upstream_id, group_id) 复合键避免不同上游的同名分组互相覆盖。
|
||
遍历所有涉及的上游的最新快照,收集分组的倍率,按倍率升序排列后赋值 priority。
|
||
倍率最低的 priority=1,次低的 priority=2,以此类推。相同倍率的分组共享同一 priority。
|
||
"""
|
||
group_rates: dict[str, float] = {}
|
||
for uid in upstream_ids:
|
||
groups = latest_rate_map(db, uid)
|
||
for gid, g in groups.items():
|
||
if not isinstance(g, dict):
|
||
continue
|
||
rate = _snapshot_group_rate(g)
|
||
key = f"{uid}:{gid}"
|
||
group_rates[key] = rate
|
||
unique_rates = sorted(set(group_rates.values()))
|
||
rate_to_priority = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
|
||
return {key: rate_to_priority[rate] for key, rate in group_rates.items()}
|
||
|
||
|
||
def _priority_result(row, new_priority: int | None, status: str, message: str) -> dict:
|
||
"""构建统一的优先级同步结果 dict。"""
|
||
return {
|
||
"account_id": row.imported_account_id,
|
||
"group_id": row.group_id,
|
||
"upstream_id": row.upstream_id,
|
||
"old_priority": None,
|
||
"new_priority": new_priority,
|
||
"status": status,
|
||
"message": message,
|
||
}
|
||
|
||
|
||
def _write_priority_sync_log_with_map(
|
||
db: Session, wid: int, upstream_name: str,
|
||
results: list[dict], priority_map: dict[str, int],
|
||
) -> None:
|
||
"""写入 priority_sync 日志,同时保存账号明细和 priority_map 快照。
|
||
|
||
source_rates_json 格式:[{"_meta": "priority_map", "data": {...}}, {"account_id": ..., ...}, ...]
|
||
兼容 WebsiteSyncLogResponse.source_rates: list[dict] 类型约束。
|
||
"""
|
||
log_results: list[dict] = [
|
||
{"_meta": "priority_map", "data": dict(priority_map)},
|
||
]
|
||
log_results.extend(results)
|
||
success = sum(1 for r in results if r["status"] == "success")
|
||
failed = sum(1 for r in results if r["status"] == "failed")
|
||
skipped = sum(1 for r in results if r["status"] == "skipped")
|
||
parts = []
|
||
if success:
|
||
parts.append(f"{success} 个更新成功")
|
||
if failed:
|
||
parts.append(f"{failed} 个失败")
|
||
if skipped:
|
||
parts.append(f"{skipped} 个跳过")
|
||
log = WebsiteSyncLog(
|
||
website_id=wid,
|
||
binding_id=None,
|
||
target_group_id="",
|
||
target_group_name="",
|
||
algorithm="priority_sync",
|
||
percent=0,
|
||
source_rates_json=json.dumps(log_results, ensure_ascii=False, default=str),
|
||
old_rate=None,
|
||
new_rate=None,
|
||
status="failed" if failed else "success",
|
||
message=f"优先级同步(上游={upstream_name}):{'、'.join(parts)} / 共 {len(results)} 个",
|
||
)
|
||
db.add(log)
|
||
db.commit()
|
||
|
||
|
||
def _try_send_priority_webhook(
|
||
db: Session, wid: int, website_name: str,
|
||
upstream_id: int, upstream_name: str,
|
||
updates: list[dict],
|
||
) -> None:
|
||
"""发送 account_priority_changed webhook,失败不抛异常。"""
|
||
if not updates:
|
||
return
|
||
# 如果没传入名称,尝试从 DB 查
|
||
resolved_name = website_name
|
||
if not resolved_name:
|
||
row = db.query(Website.name).filter(Website.id == wid).first()
|
||
if row:
|
||
resolved_name = row[0]
|
||
else:
|
||
resolved_name = f"网站#{wid}"
|
||
try:
|
||
webhook_service.send_account_priority_changed(
|
||
db,
|
||
website_id=wid,
|
||
website_name=resolved_name,
|
||
upstream_id=upstream_id,
|
||
upstream_name=upstream_name,
|
||
updates=updates,
|
||
)
|
||
except Exception as exc:
|
||
logger.warning("account_priority_changed webhook failed for website %s: %s", wid, exc)
|
||
|
||
|
||
def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[dict]:
|
||
"""上游倍率变化后,自动更新已导入下游账号的 priority。
|
||
|
||
只处理同一目标分组内有多个账号(存在竞争)的情况:
|
||
- 竞争分组键:imported_target_group_id(老数据 fallback 到 group_id)
|
||
- 同一竞争分组内按倍率升序排序,priority 从 1 开始(相同倍率共享)
|
||
- 单账号分组:完全跳过,不调用 update_account,不发通知
|
||
- 无竞争分组:直接返回,不写日志,不发通知
|
||
"""
|
||
from collections import defaultdict
|
||
|
||
key_rows = (
|
||
db.query(UpstreamGeneratedKey)
|
||
.filter(
|
||
UpstreamGeneratedKey.upstream_id == upstream_id,
|
||
UpstreamGeneratedKey.imported_website_id.isnot(None),
|
||
UpstreamGeneratedKey.imported_account_id.isnot(None),
|
||
UpstreamGeneratedKey.status != "orphaned",
|
||
)
|
||
.all()
|
||
)
|
||
if not key_rows:
|
||
return []
|
||
|
||
upstream_name = db.query(Upstream.name).filter(Upstream.id == upstream_id).scalar() or f"#{upstream_id}"
|
||
|
||
# 按 imported_website_id 分组
|
||
website_groups: dict[int, list[UpstreamGeneratedKey]] = {}
|
||
for row in key_rows:
|
||
wid = row.imported_website_id
|
||
website_groups.setdefault(wid, []).append(row)
|
||
|
||
all_results: list[dict] = []
|
||
|
||
for wid, rows in website_groups.items():
|
||
website = db.query(Website).filter(Website.id == wid).first()
|
||
if not website or not website.enabled:
|
||
logger.info("skip account priority sync: website %s not found or disabled", wid)
|
||
# 不写日志、不发通知:网站不可用时账号无法更新,沉默跳过
|
||
continue
|
||
|
||
# 查询该网站所有已导入 Key(跨上游),用于倍率查询
|
||
all_website_keys = (
|
||
db.query(UpstreamGeneratedKey)
|
||
.filter(
|
||
UpstreamGeneratedKey.imported_website_id == wid,
|
||
UpstreamGeneratedKey.imported_account_id.isnot(None),
|
||
UpstreamGeneratedKey.status != "orphaned",
|
||
)
|
||
.all()
|
||
)
|
||
|
||
# ── 按竞争分组分桶 ──────────────────────────────────────────────────
|
||
# 竞争分组键:imported_target_group_id(老数据为 NULL 时 fallback 到 group_id)
|
||
buckets: dict[str, list[UpstreamGeneratedKey]] = defaultdict(list)
|
||
for row in all_website_keys:
|
||
comp_key = row.imported_target_group_id or row.group_id
|
||
buckets[comp_key].append(row)
|
||
|
||
# 只保留账号数 > 1 的分组(有竞争才需要排序)
|
||
competitive_buckets = {k: v for k, v in buckets.items() if len(v) > 1}
|
||
|
||
if not competitive_buckets:
|
||
logger.info(
|
||
"skip account priority sync for website %s: no competitive groups (all single-account)",
|
||
wid,
|
||
)
|
||
continue # 不写日志,不发通知
|
||
|
||
# ── 预取快照倍率 ────────────────────────────────────────────────────
|
||
all_upstream_ids = {k.upstream_id for k in all_website_keys}
|
||
try:
|
||
# 构建 "{upstream_id}:{group_id}" → rate 查询表
|
||
raw_rate_map: dict[str, float] = {}
|
||
for uid in all_upstream_ids:
|
||
groups = latest_rate_map(db, uid)
|
||
for gid, g in groups.items():
|
||
if isinstance(g, dict):
|
||
raw_rate_map[f"{uid}:{gid}"] = _snapshot_group_rate(g)
|
||
except Exception as exc:
|
||
logger.warning("build rate map failed for website %s: %s", wid, exc)
|
||
raw_rate_map = {}
|
||
|
||
# ── 每个竞争分组内独立计算 priority ────────────────────────────────
|
||
# priority_assignment: account_id → new_priority
|
||
priority_assignment: dict[str, int] = {}
|
||
for comp_key, comp_rows in competitive_buckets.items():
|
||
# 只保留快照中能查到倍率的账号;无数据的账号不参与排序
|
||
rated = [
|
||
(row, raw_rate_map[f"{row.upstream_id}:{row.group_id}"])
|
||
for row in comp_rows
|
||
if f"{row.upstream_id}:{row.group_id}" in raw_rate_map
|
||
]
|
||
# 过滤后有效账号不足 2 个 → 此分组无竞争意义,整组跳过
|
||
if len(rated) < 2:
|
||
logger.info(
|
||
"skip competitive bucket %s for website %s: only %d account(s) have rate data",
|
||
comp_key, wid, len(rated),
|
||
)
|
||
continue
|
||
# 组内按倍率升序排序(倍率低 → priority 小 → 优先)
|
||
unique_rates = sorted(set(r for _, r in rated))
|
||
rate_to_prio = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
|
||
for row, rate in rated:
|
||
priority_assignment[row.imported_account_id] = rate_to_prio[rate]
|
||
|
||
# ── 调用 update_account(仅竞争分组的账号)───────────────────────
|
||
site_results: list[dict] = []
|
||
try:
|
||
with Sub2ApiWebsiteClient(
|
||
base_url=website.base_url,
|
||
api_prefix=website.api_prefix,
|
||
auth_type=website.auth_type,
|
||
auth_config=json.loads(website.auth_config_json or "{}"),
|
||
timeout=float(website.timeout_seconds),
|
||
) as client:
|
||
for comp_rows in competitive_buckets.values():
|
||
for row in comp_rows:
|
||
account_id = row.imported_account_id
|
||
new_priority = priority_assignment.get(account_id)
|
||
if new_priority is None:
|
||
continue
|
||
try:
|
||
client.update_account(account_id, {"priority": new_priority})
|
||
logger.info(
|
||
"updated priority for account %s (website=%s, upstream=%s, group=%s"
|
||
", comp_group=%s): %s",
|
||
account_id, wid, row.upstream_id, row.group_id,
|
||
row.imported_target_group_id or row.group_id, new_priority,
|
||
)
|
||
site_results.append(
|
||
_priority_result(row, new_priority, "success", f"优先级已更新为 {new_priority}")
|
||
)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"failed to update priority for account %s (website=%s): %s",
|
||
account_id, wid, exc,
|
||
)
|
||
site_results.append(
|
||
_priority_result(row, new_priority, "failed", str(exc))
|
||
)
|
||
except Exception as exc:
|
||
logger.warning("failed to connect website %s for account priority sync: %s", wid, exc)
|
||
for comp_rows in competitive_buckets.values():
|
||
for row in comp_rows:
|
||
site_results.append(
|
||
_priority_result(row, None, "failed", f"连接网站失败: {exc}")
|
||
)
|
||
|
||
# 只发送有实际成功/失败的通知(不包含单账号跳过项)
|
||
notify_updates = [r for r in site_results if r["status"] in ("success", "failed")]
|
||
if notify_updates:
|
||
# 构建简化的 priority_map 供日志参考(只包含竞争分组的账号)
|
||
priority_map_snapshot = {
|
||
f"{row.upstream_id}:{row.group_id}": priority_assignment.get(row.imported_account_id)
|
||
for comp_rows in competitive_buckets.values()
|
||
for row in comp_rows
|
||
}
|
||
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, priority_map_snapshot)
|
||
_try_send_priority_webhook(db, wid, website.name, upstream_id, upstream_name, notify_updates)
|
||
|
||
all_results.extend(site_results)
|
||
|
||
return all_results
|
||
|
||
|
||
|
||
def _fetch_remote_managed_prefixes(db: Session, upstream_id: int) -> list[str]:
|
||
"""查询本地 distinct managed_prefix。
|
||
|
||
返回该上游所有已使用的 prefix 列表。空时回退 ["SmartUp"] 兼容旧数据。
|
||
"""
|
||
prefixes = [
|
||
row[0] for row in
|
||
db.query(UpstreamGeneratedKey.managed_prefix)
|
||
.filter(
|
||
UpstreamGeneratedKey.upstream_id == upstream_id,
|
||
UpstreamGeneratedKey.managed_prefix.isnot(None),
|
||
)
|
||
.distinct()
|
||
.all()
|
||
]
|
||
return prefixes if prefixes else ["SmartUp"]
|
||
|
||
|
||
def _fetch_remote_managed_key_ids(db: Session, client, upstream_id: int) -> set[str]:
|
||
"""查询本地 distinct managed_prefix,分别拉远端活跃 Key ID 集合。
|
||
|
||
返回全部找到的远端 Key ID(合并多个 prefix 的结果)。
|
||
"""
|
||
all_ids: set[str] = set()
|
||
for prefix in _fetch_remote_managed_prefixes(db, upstream_id):
|
||
remote_keys = client.list_api_keys(search=prefix, status="active")
|
||
all_ids.update(str(k["id"]) for k in remote_keys if k.get("id"))
|
||
return all_ids
|
||
|
||
|
||
def reconcile_upstream_keys(
|
||
db: Session,
|
||
upstream_id: int,
|
||
active_group_ids: set[str] | None,
|
||
remote_key_ids: set[str] | None,
|
||
captured_at: datetime,
|
||
) -> None:
|
||
"""对账上游 Key 的本地缓存与远端状态。
|
||
|
||
active_group_ids=None → 跳过分组级清理(避免登录失败时误删)。
|
||
remote_key_ids=None → 跳过远端 key_id 级清理(查询失败时安全)。
|
||
两者同时为 None 则完全跳过对账。
|
||
"""
|
||
if active_group_ids is None and remote_key_ids is None:
|
||
return
|
||
key_rows = (
|
||
db.query(UpstreamGeneratedKey)
|
||
.filter(
|
||
UpstreamGeneratedKey.upstream_id == upstream_id,
|
||
)
|
||
.all()
|
||
)
|
||
for row in key_rows:
|
||
if active_group_ids is not None and row.group_id not in active_group_ids:
|
||
if row.imported_website_id and row.imported_account_id:
|
||
row.status = "orphaned"
|
||
row.error = "来源分组已不存在"
|
||
row.updated_at = captured_at
|
||
logger.info("marked key %s orphaned (group %s no longer in snapshot)", row.id, row.group_id)
|
||
else:
|
||
db.delete(row)
|
||
logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id)
|
||
continue
|
||
if row.key_id and remote_key_ids is not None and row.key_id not in remote_key_ids:
|
||
if row.imported_website_id and row.imported_account_id:
|
||
row.status = "orphaned"
|
||
row.error = "远端 Key 已不存在"
|
||
row.updated_at = captured_at
|
||
logger.info("marked key %s orphaned (key_id %s gone from remote)", row.id, row.key_id)
|
||
else:
|
||
db.delete(row)
|
||
logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id)
|
||
continue
|
||
if remote_key_ids is not None and row.key_id in remote_key_ids:
|
||
row.updated_at = captured_at
|
||
|
||
|
||
def reconcile_upstream_keys_full(db: Session, upstream_id: int) -> bool:
|
||
"""完整的 Key 对账:拉取最新快照的分组 + 登录上游查远端 Key 列表 → 调用 reconcile_upstream_keys。
|
||
|
||
活跃分组 ID 从最新快照获取(与调度器一致),而非调用 live API 避免格式不一致。
|
||
安全规则:
|
||
- 快照存在 → 才允许分组级清理。
|
||
- 远端 Key 列表拉取成功 → 才允许 key_id 级清理。
|
||
- 两者均失败 → 不做任何删除/标记。
|
||
|
||
支持自定义 managed_prefix:查询本地 distinct prefix,分别查远端。
|
||
"""
|
||
from datetime import datetime, timezone
|
||
|
||
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
|
||
if not upstream:
|
||
return False
|
||
|
||
auth_config = json.loads(upstream.auth_config_json or "{}")
|
||
groups_fetched = False
|
||
keys_fetched = False
|
||
active_group_ids: set[str] | None = None
|
||
remote_key_ids: set[str] | None = None
|
||
now = datetime.now(timezone.utc)
|
||
|
||
# 从最新快照获取活跃分组 ID(与调度器 _sync_upstream_keys 一致)
|
||
groups = latest_rate_map(db, upstream_id)
|
||
if groups:
|
||
active_group_ids = set(groups.keys())
|
||
groups_fetched = True
|
||
|
||
try:
|
||
with UpstreamClient(
|
||
base_url=upstream.base_url,
|
||
api_prefix=upstream.api_prefix,
|
||
auth_type=upstream.auth_type,
|
||
auth_config=auth_config,
|
||
timeout=float(upstream.timeout_seconds),
|
||
) as client:
|
||
client.login()
|
||
# 获取远端 Key 列表(支持自定义 managed_prefix)
|
||
remote_key_ids = _fetch_remote_managed_key_ids(db, client, upstream_id)
|
||
keys_fetched = True
|
||
except Exception as exc:
|
||
logger.warning("reconcile_upstream_keys_full: upstream %s failed: %s", upstream_id, exc)
|
||
|
||
# 只传递成功获取的数据;失败的传 None 跳过对应检查
|
||
reconcile_upstream_keys(
|
||
db,
|
||
upstream_id,
|
||
active_group_ids if groups_fetched else None,
|
||
remote_key_ids if keys_fetched else None,
|
||
now,
|
||
)
|
||
db.commit()
|
||
return keys_fetched
|
||
|
||
|
||
def sync_affected_bindings(db: Session, upstream_id: int, changes: list[dict[str, Any]]) -> None:
|
||
for binding in get_affected_bindings(db, changes, upstream_id):
|
||
try:
|
||
sync_binding(db, binding, write=True)
|
||
except Exception as exc:
|
||
logger.exception("website sync failed for binding %s: %s", binding.id, exc)
|