"""Upstream management CRUD + test + check-now + snapshots.""" from __future__ import annotations import json import logging import re from datetime import datetime, timezone from typing import Any, List logger = logging.getLogger(__name__) from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from app.database import get_db from app.models.admin_user import AdminUser from app.models.upstream import Upstream from app.models.upstream_key import UpstreamGeneratedKey from app.models.snapshot import UpstreamRateSnapshot from app.schemas.upstream import ( GenerateKeysByGroupsRequest, GenerateKeysByGroupsResponse, GeneratedUpstreamKeyResponse, UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult, UpstreamBatchActionItem, UpstreamBatchActionSummary, UpstreamBatchActionResponse, ) from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret, _extract_key_value from app.services.snapshot_service import diff_snapshots from app.services import scheduler as sched_svc from app.services import webhook_service from app.services import website_sync from app.utils.auth import get_current_user router = APIRouter(prefix="/api/upstreams", tags=["upstreams"]) MASK = "***" SECRET_KEYS = {"password", "token", "key", "secret"} def _group_id(group: dict) -> str: for key in ("id", "group_id", "groupId"): value = group.get(key) if value is not None: return str(value) return str(group.get("name") or group.get("group_name") or "") def _group_name(group: dict, gid: str) -> str: return str(group.get("name") or group.get("group_name") or gid) def _key_response(row: UpstreamGeneratedKey, include_value: bool = False) -> GeneratedUpstreamKeyResponse: return GeneratedUpstreamKeyResponse( id=row.id, upstream_id=row.upstream_id, group_id=row.group_id, group_name=row.group_name, key_id=row.key_id, key_name=row.key_name, key_value=row.key_value if include_value else None, masked_key=row.masked_key, status=row.status, error=row.error, imported_website_id=row.imported_website_id, imported_account_id=row.imported_account_id, imported_at=row.imported_at, created_at=row.created_at, updated_at=row.updated_at, has_key_value=bool(row.key_value), ) def _mask_auth_config(auth_type: str, cfg: dict) -> dict: masked = {} for k, v in cfg.items(): if k.lower() in SECRET_KEYS and v: masked[k] = MASK else: masked[k] = v return masked def _extract_plaintext_key(payload: dict[str, Any] | None) -> str: if not isinstance(payload, dict): return "" key_value = _extract_key_value(payload) if not key_value: return "" if "*" in key_value: return "" return key_value def _remote_group_name(rk: dict) -> str: """从远端 Key 响应中安全提取分组名(兼容 group 为 dict 的场景)。""" if rk.get("group_name") and isinstance(rk["group_name"], str): return rk["group_name"] group = rk.get("group") if isinstance(group, dict): return str(group.get("name") or group.get("group_name") or group.get("id") or "") if group: return str(group) return "" def _generate_masked_key(key_name: str, key_id: str) -> str: """为远端无明文 Key 生成脱敏展示名。""" suffix = key_name[-8:] if len(key_name) > 8 else key_name return f"remote:{suffix}:{key_id[-6:] if len(key_id) > 6 else key_id}" def _to_response(u: Upstream) -> UpstreamResponse: cfg = json.loads(u.auth_config_json or "{}") return UpstreamResponse( id=u.id, name=u.name, base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config_masked=_mask_auth_config(u.auth_type, cfg), rate_endpoint=u.rate_endpoint, groups_endpoint=u.groups_endpoint, enabled=u.enabled, check_interval_seconds=u.check_interval_seconds, timeout_seconds=u.timeout_seconds, last_status=u.last_status, last_checked_at=u.last_checked_at, last_error=u.last_error, balance=u.balance, balance_updated_at=u.balance_updated_at, balance_endpoint=u.balance_endpoint or "", balance_response_path=u.balance_response_path or "", balance_divisor=u.balance_divisor or 1.0, balance_alert_threshold=u.balance_alert_threshold, created_at=u.created_at, updated_at=u.updated_at, ) @router.get("", response_model=List[UpstreamResponse]) def list_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)): return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()] @router.get("/{uid}/generated-keys", response_model=List[GeneratedUpstreamKeyResponse]) def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): """实时拉取上游 Key 列表,与本地数据合并后返回。 远端 Key 按 key_name 与本地 upstream_generated_keys 匹配: - 远端存在 + 本地有明文 → has_key_value=true,可导入 - 远端存在 + 本地无明文 → has_key_value=false,不可导入(提示重新生成) - 本地存在但远端不存在 → 不返回(已导入的标记 orphaned 用于审计) """ from datetime import datetime, timezone upstream = db.query(Upstream).filter(Upstream.id == uid).first() if not upstream: raise HTTPException(404, "upstream not found") auth_config = json.loads(upstream.auth_config_json or "{}") now = datetime.now(timezone.utc) # 1. 获取本地记录索引(key_name → row) local_rows = ( db.query(UpstreamGeneratedKey) .filter(UpstreamGeneratedKey.upstream_id == uid) .all() ) local_by_keyname: dict[str, UpstreamGeneratedKey] = {} local_by_keyid: dict[str, UpstreamGeneratedKey] = {} local_by_groupprefix: dict[tuple[str, str], UpstreamGeneratedKey] = {} for row in local_rows: if row.key_name: local_by_keyname[row.key_name] = row if row.key_id: local_by_keyid[row.key_id] = row if row.group_id and row.managed_prefix: local_by_groupprefix[(row.group_id, row.managed_prefix)] = row local_remaining = {row.id for row in local_rows} # 2. 登录上游,拉取实时 Key 列表 remote_keys_list: list[dict] = [] try: with UpstreamClient( base_url=upstream.base_url, api_prefix=upstream.api_prefix, auth_type=upstream.auth_type, auth_config=auth_config, timeout=float(upstream.timeout_seconds), ) as client: client.login() for prefix in website_sync._fetch_remote_managed_prefixes(db, uid): # list_api_keys 的参数在 Sub2API 中通常是 search 参数 remote_keys_list.extend(client.list_api_keys(search=prefix, status="active")) except Exception as exc: logger.warning("list_generated_keys: upstream %s fetch failed: %s", uid, exc) # 远端不可达时回退到本地数据(已标记 orphaned 的也展示,不隐藏) rows = ( db.query(UpstreamGeneratedKey) .filter(UpstreamGeneratedKey.upstream_id == uid) .order_by(UpstreamGeneratedKey.id.desc()) .limit(200) .all() ) return [_key_response(row) for row in rows] results: list[GeneratedUpstreamKeyResponse] = [] # 按 (group_id, prefix) 分组,每组只保留最新的一条远端 Key from collections import defaultdict group_buckets: dict[tuple[str, str], list[dict]] = defaultdict(list) for rk in remote_keys_list: key_name = rk.get("name") or "" if not key_name: continue key_id = str(rk.get("id") or "") group_id = str(rk.get("group_id") or "") prefix = key_name.split("-")[0] if "-" in key_name else "SmartUp" group_buckets[(group_id, prefix)].append(rk) for (group_id, prefix), bucket in group_buckets.items(): # 选择最新的一条(id 最大) bucket.sort(key=lambda x: int(x.get("id") or 0), reverse=True) rk = bucket[0] key_name = rk.get("name") or "" key_id = str(rk.get("id") or "") group_name = _remote_group_name(rk) # 匹配本地记录 local = local_by_keyname.get(key_name) or local_by_keyid.get(key_id) or local_by_groupprefix.get((group_id, prefix)) if local: local_remaining.discard(local.id) # 远端存在 + 本地有记录 → 同步更新为选中的最新远端 Key key_value = _extract_plaintext_key(rk) local.key_id = key_id local.key_name = key_name local.group_name = group_name if key_value: local.key_value = key_value local.masked_key = mask_secret(key_value) local.raw_json = json.dumps(rk, ensure_ascii=False) local.updated_at = now if local.status in ("failed", "import_failed"): local.status = "exists" local.error = None results.append(_key_response(local)) else: # 远端存在但本地无记录 → 检查是否有明文 key_value = _extract_plaintext_key(rk) if key_value: # 有明文 → 新建本地行,并更新索引 local = UpstreamGeneratedKey( upstream_id=uid, group_id=group_id, group_name=group_name, key_id=key_id, key_name=key_name, key_value=key_value, masked_key=mask_secret(key_value), raw_json=json.dumps(rk, ensure_ascii=False), managed_prefix=prefix, status="created", created_at=now, updated_at=now, ) db.add(local) db.flush() db.refresh(local) local_by_keyname[key_name] = local local_by_groupprefix[(group_id, prefix)] = local if key_id: local_by_keyid[key_id] = local results.append(_key_response(local)) else: # 无明文 → 展示为不可导入 results.append(GeneratedUpstreamKeyResponse( id=None, upstream_id=uid, group_id=group_id, group_name=group_name, key_id=key_id, key_name=key_name, key_value=None, masked_key=_generate_masked_key(key_name, key_id), status="remote", error=None, imported_website_id=None, imported_account_id=None, imported_at=None, has_key_value=False, created_at=None, updated_at=None, )) # 3. 清理:本地有但远端已不存在的记录 for row_id in local_remaining: row = next((r for r in local_rows if r.id == row_id), None) if not row: continue if row.imported_website_id and row.imported_account_id: row.status = "orphaned" row.error = "远端 Key 已不存在" row.updated_at = now logger.info("marked key %s orphaned (not found remotely)", row.id) else: db.delete(row) logger.info("removed key %s (not found remotely)", row.id) db.commit() return results[:200] _generate_key_lock = __import__("threading").Lock() def _ensure_group_key( db: Session, client: UpstreamClient, upstream: Upstream, group: dict[str, Any], prefix: str, body: GenerateKeysByGroupsRequest, ) -> GeneratedUpstreamKeyResponse: """确保一个上游分组有一个 SmartUp 前缀 Key:存在则 upsert,不存在则创建。""" gid = _group_id(group) gname = _group_name(group, gid) # 使用稳定的 upstream_id + group_id 而非可变名称,避免因改名产生重复 # 可读 Key 名:{prefix}-{upstream.id}-{安全的分组名}-{group_id} safe_group_name = re.sub(r"[^a-zA-Z0-9\u4e00-\u9fff_-]", "", gname)[:30] if gname else gid stable_name = f"{prefix}-{upstream.id}-{safe_group_name}-{gid}" with _generate_key_lock: try: # 1. 先查本地是否已有该分组的托管 Key(兼容迁移前无 managed_prefix 的记录) row = ( db.query(UpstreamGeneratedKey) .filter( UpstreamGeneratedKey.upstream_id == upstream.id, UpstreamGeneratedKey.group_id == gid, (UpstreamGeneratedKey.managed_prefix == prefix) | ((UpstreamGeneratedKey.managed_prefix.is_(None)) & UpstreamGeneratedKey.key_name.like(f"{prefix}-%")), ) .first() ) if row and row.key_id: # 本地已有记录,检查远端是否仍存在 try: existing = client.find_smartup_group_key(gid, stable_name, prefix) except Exception: existing = None if existing: key_id = str(existing.get("id") or "") key_value = _extract_plaintext_key(existing) masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "") row.key_id = key_id or row.key_id if key_value: row.key_value = key_value row.masked_key = masked elif masked: row.masked_key = str(masked) row.raw_json = json.dumps(existing, ensure_ascii=False) row.status = "exists" row.updated_at = datetime.now(timezone.utc) db.add(row) db.commit() db.refresh(row) return _key_response(row, include_value=False) # 远端不存在,需要重新创建 row.status = "replaced" # 2. 查远端是否有同名 Key(防止并发时另一个请求已创建) existing = client.find_smartup_group_key(gid, stable_name, prefix) if existing: key_id = str(existing.get("id") or "") key_value = _extract_plaintext_key(existing) masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "") if row: row.key_id = key_id or row.key_id if key_value: row.key_value = key_value row.masked_key = masked elif masked: row.masked_key = str(masked) row.raw_json = json.dumps(existing, ensure_ascii=False) row.status = "exists" row.updated_at = datetime.now(timezone.utc) else: row = UpstreamGeneratedKey( upstream_id=upstream.id, group_id=gid, group_name=gname, key_id=key_id or None, key_name=stable_name, key_value=key_value or "", masked_key=masked, raw_json=json.dumps(existing, ensure_ascii=False), managed_prefix=prefix, status="exists", ) db.add(row) db.commit() db.refresh(row) return _key_response(row, include_value=False) # 3. 远端不存在,创建新 Key created = client.create_api_key( stable_name, gid, quota=body.quota, expires_in_days=body.expires_in_days, rate_limit_5h=body.rate_limit_5h, rate_limit_1d=body.rate_limit_1d, rate_limit_7d=body.rate_limit_7d, endpoint=body.endpoint, ) if row: # 复用旧行 row.key_id = created.get("id") or None row.key_name = stable_name row.key_value = created["key"] row.masked_key = created.get("masked_key") or mask_secret(created["key"]) row.raw_json = json.dumps(created.get("raw") or {}, ensure_ascii=False) row.managed_prefix = prefix row.status = "created" row.error = None else: row = UpstreamGeneratedKey( upstream_id=upstream.id, group_id=gid, group_name=gname, key_id=created.get("id") or None, key_name=stable_name, key_value=created["key"], masked_key=created.get("masked_key") or mask_secret(created["key"]), raw_json=json.dumps(created.get("raw") or {}, ensure_ascii=False), managed_prefix=prefix, status="created", ) db.add(row) db.commit() db.refresh(row) return _key_response(row, include_value=True) except Exception as exc: logger.exception("ensure group key failed for upstream=%s group=%s", upstream.id, gid) return GeneratedUpstreamKeyResponse( upstream_id=upstream.id, group_id=gid, group_name=gname, key_name=stable_name, status="failed", error=str(exc), ) @router.post("/{uid}/keys/generate-by-groups", response_model=GenerateKeysByGroupsResponse) def generate_keys_by_groups( uid: int, body: GenerateKeysByGroupsRequest, db: Session = Depends(get_db), _=Depends(get_current_user), ): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") if u.api_prefix.strip("/") != "api/v1": raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1)") # 生成前先对账,清理远端已删除的旧 Key try: website_sync.reconcile_upstream_keys_full(db, uid) except Exception as exc: logger.warning("generate_keys_by_groups reconcile failed for %s: %s", uid, exc) auth_config = json.loads(u.auth_config_json or "{}") selected = set(body.group_ids) prefix = body.name_prefix results: list[GeneratedUpstreamKeyResponse] = [] with UpstreamClient( base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config=auth_config, timeout=float(u.timeout_seconds), ) as client: try: client.login() groups = client.get_available_groups(u.groups_endpoint) except Exception as exc: raise HTTPException(502, str(exc)) for group in groups: gid = _group_id(group) if not gid or (selected and gid not in selected): continue result = _ensure_group_key(db, client, u, group, prefix, body) results.append(result) created = len([item for item in results if item.status == "created"]) existed = len([item for item in results if item.status == "exists"]) total = len(results) msg_parts = [] if created: msg_parts.append(f"新创建 {created}") if existed: msg_parts.append(f"已存在 {existed}") msg = "、".join(msg_parts) + f" / 共 {total} 个分组" if msg_parts else f"共处理 {total} 个分组" return GenerateKeysByGroupsResponse( success=total > 0 and all(item.status != "failed" for item in results), message=msg, items=results, ) # ─── Shared core helpers ────────────────────────────────────────────────────── def _test_upstream_core(db: Session, u: Upstream) -> UpstreamBatchActionItem: """连接测试核心逻辑,含余额拉取。与单行 test_upstream 行为一致。""" auth_config = json.loads(u.auth_config_json or "{}") with UpstreamClient( base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config=auth_config, timeout=float(u.timeout_seconds), ) as client: client.login() groups = client.get_available_groups(u.groups_endpoint) # 余额(与单行 test_upstream 保持一致) if u.balance_endpoint and u.balance_response_path: try: raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path) if raw_balance is not None: u.balance = raw_balance / (u.balance_divisor or 1.0) u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None except Exception as exc: logger.warning("test-all: upstream %s balance failed: %s", u.name, exc) u.last_status = "healthy" u.last_error = None u.last_checked_at = datetime.now(timezone.utc) u.consecutive_failures = 0 db.commit() return UpstreamBatchActionItem( upstream_id=u.id, upstream_name=u.name, status="success", message=f"连接成功,获取到 {len(groups)} 个分组", ) def _check_now_core(db: Session, u: Upstream) -> tuple[str, bool]: """完整同步核心逻辑:写快照、对比倍率、发 Webhook、同步 Key/优先级。 Returns: (message, was_changed) — message 供调用方组装返回体。 """ uid = u.id auth_config = json.loads(u.auth_config_json or "{}") with UpstreamClient( base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config=auth_config, timeout=float(u.timeout_seconds), ) as client: client.login() groups = client.get_available_groups(u.groups_endpoint) raw_rates = client.get_group_rates(u.rate_endpoint) snapshot = build_snapshot(uid, u.base_url, u.api_prefix, groups, raw_rates) # 余额(可选) if u.balance_endpoint and u.balance_response_path: try: raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path) if raw_balance is not None: u.balance = raw_balance / (u.balance_divisor or 1.0) u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None except Exception as exc: logger.warning("check-now: upstream %s balance failed: %s", u.name, exc) # 写快照 & diff prev_row = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == uid) .order_by(UpstreamRateSnapshot.captured_at.desc()) .first() ) previous = json.loads(prev_row.snapshot_json) if prev_row else None changes = diff_snapshots(previous, snapshot) new_row = UpstreamRateSnapshot( upstream_id=uid, snapshot_json=json.dumps(snapshot, ensure_ascii=False), captured_at=datetime.now(timezone.utc), ) db.add(new_row) was_unhealthy = u.last_status == "unhealthy" u.last_status = "healthy" u.last_checked_at = datetime.now(timezone.utc) u.last_error = None u.consecutive_failures = 0 db.commit() if was_unhealthy: webhook_service.send_status_event(db, uid, u.name, u.base_url, "upstream_recovered") # 先同步 Key 状态(标记 orphaned),再执行优先级同步 from app.services.scheduler import _sync_upstream_keys as _synck _synck(uid, snapshot, new_row.captured_at) if changes: webhook_service.send_rate_changed(db, uid, u.name, u.base_url, changes) website_sync.sync_affected_bindings(db, uid, changes) website_sync.sync_account_priorities_for_upstream(db, uid) msg = f"检测成功,{len(groups)} 个分组" if changes: msg += f",发现 {len(changes)} 处倍率变化" elif previous is None: msg += ",初始化快照完成" else: msg += ",无变化" return msg, bool(changes) @router.post("/test-all", response_model=UpstreamBatchActionResponse) def test_all_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)): """批量测试所有启用上游的连接(不写快照,不触发 Webhook)。""" upstreams = db.query(Upstream).order_by(Upstream.id).all() items: list[UpstreamBatchActionItem] = [] success_count = failed_count = skipped_count = 0 for u in upstreams: if not u.enabled: items.append(UpstreamBatchActionItem( upstream_id=u.id, upstream_name=u.name, status="skipped", message="上游已停用,跳过", )) skipped_count += 1 continue try: item = _test_upstream_core(db, u) items.append(item) success_count += 1 except Exception as exc: u.last_status = "unhealthy" u.last_error = str(exc) u.last_checked_at = datetime.now(timezone.utc) u.consecutive_failures = (u.consecutive_failures or 0) + 1 db.commit() items.append(UpstreamBatchActionItem( upstream_id=u.id, upstream_name=u.name, status="failed", message="连接失败", detail=str(exc), )) failed_count += 1 total = len(upstreams) overall_ok = failed_count == 0 msg = f"完成:{success_count} 成功 / {failed_count} 失败 / {skipped_count} 跳过" return UpstreamBatchActionResponse( success=overall_ok, message=msg, summary=UpstreamBatchActionSummary( total=total, success=success_count, failed=failed_count, skipped=skipped_count, ), items=items, ) @router.post("/check-now-all", response_model=UpstreamBatchActionResponse) def check_now_all_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)): """批量全量同步所有启用上游:拉取倍率 → 写快照 → 对比变化 → Webhook → 同步 Key。""" upstreams = db.query(Upstream).order_by(Upstream.id).all() items: list[UpstreamBatchActionItem] = [] success_count = failed_count = skipped_count = 0 for u in upstreams: if not u.enabled: items.append(UpstreamBatchActionItem( upstream_id=u.id, upstream_name=u.name, status="skipped", message="上游已停用,跳过", )) skipped_count += 1 continue try: detail_msg, _ = _check_now_core(db, u) items.append(UpstreamBatchActionItem( upstream_id=u.id, upstream_name=u.name, status="success", message=detail_msg, )) success_count += 1 except Exception as exc: u.consecutive_failures = (u.consecutive_failures or 0) + 1 u.last_error = str(exc) u.last_checked_at = datetime.now(timezone.utc) db.commit() items.append(UpstreamBatchActionItem( upstream_id=u.id, upstream_name=u.name, status="failed", message="同步失败", detail=str(exc), )) failed_count += 1 total = len(upstreams) overall_ok = failed_count == 0 msg = f"完成:{success_count} 成功 / {failed_count} 失败 / {skipped_count} 跳过" return UpstreamBatchActionResponse( success=overall_ok, message=msg, summary=UpstreamBatchActionSummary( total=total, success=success_count, failed=failed_count, skipped=skipped_count, ), items=items, ) # ─── CRUD ───────────────────────────────────────────────────────────────────── @router.post("", response_model=UpstreamResponse, status_code=201) def create_upstream( body: UpstreamCreate, db: Session = Depends(get_db), _=Depends(get_current_user), ): u = Upstream( name=body.name, base_url=body.base_url.rstrip("/"), api_prefix=body.api_prefix, auth_type=body.auth_type, auth_config_json=json.dumps(body.auth_config, ensure_ascii=False), rate_endpoint=body.rate_endpoint, groups_endpoint=body.groups_endpoint, enabled=body.enabled, check_interval_seconds=body.check_interval_seconds, timeout_seconds=body.timeout_seconds, balance_endpoint=body.balance_endpoint, balance_response_path=body.balance_response_path, balance_divisor=body.balance_divisor, balance_alert_threshold=body.balance_alert_threshold, ) db.add(u) db.commit() db.refresh(u) sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled) return _to_response(u) @router.get("/{uid}", response_model=UpstreamResponse) def get_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") return _to_response(u) @router.put("/{uid}", response_model=UpstreamResponse) def update_upstream( uid: int, body: UpstreamUpdate, db: Session = Depends(get_db), _=Depends(get_current_user), ): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") data = body.model_dump(exclude_none=True) if "auth_config" in data: # merge with existing config to avoid overwriting masked fields existing = json.loads(u.auth_config_json or "{}") incoming = data.pop("auth_config") for k, v in incoming.items(): if v != MASK: # don't overwrite with mask placeholder existing[k] = v u.auth_config_json = json.dumps(existing, ensure_ascii=False) if "base_url" in data: data["base_url"] = data["base_url"].rstrip("/") for k, v in data.items(): setattr(u, k, v) # Reset failure counter on any update — the user may have fixed the issue u.consecutive_failures = 0 u.updated_at = datetime.now(timezone.utc) db.commit() db.refresh(u) sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled) return _to_response(u) @router.delete("/{uid}", status_code=204) def delete_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") sched_svc.refresh_upstream(uid, 0, False) # remove job db.delete(u) db.commit() @router.post("/{uid}/test", response_model=TestResult) def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") try: item = _test_upstream_core(db, u) return TestResult(success=True, message=item.message) except Exception as exc: u.last_status = "unhealthy" u.last_error = str(exc) u.last_checked_at = datetime.now(timezone.utc) u.consecutive_failures = (u.consecutive_failures or 0) + 1 db.commit() return TestResult(success=False, message="连接失败", detail=str(exc)) @router.post("/{uid}/check-now", response_model=TestResult) def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") try: msg, _ = _check_now_core(db, u) return TestResult(success=True, message=msg) except Exception as exc: u.consecutive_failures = (u.consecutive_failures or 0) + 1 u.last_error = str(exc) u.last_checked_at = datetime.now(timezone.utc) db.commit() return TestResult(success=False, message="检测失败", detail=str(exc)) @router.get("/{uid}/snapshots/latest", response_model=SnapshotResponse) def latest_snapshot(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == uid) .order_by(UpstreamRateSnapshot.captured_at.desc()) .first() ) if not row: raise HTTPException(404, "no snapshot found") return SnapshotResponse( id=row.id, upstream_id=row.upstream_id, snapshot=json.loads(row.snapshot_json), captured_at=row.captured_at, ) from fastapi import Query as QueryParam @router.get("/{uid}/snapshots", response_model=List[SnapshotResponse]) def list_snapshots( uid: int, limit: int = QueryParam(20, le=100), offset: int = QueryParam(0), db: Session = Depends(get_db), _=Depends(get_current_user), ): """Return paginated snapshot history with diff vs previous snapshot embedded.""" rows = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == uid) .order_by(UpstreamRateSnapshot.captured_at.desc()) .offset(offset) .limit(limit + 1) # fetch one extra to get the "previous" for diffing .all() ) # We need the snapshot just before each one to compute changes count. # rows are desc order; rows[i+1] is older than rows[i] results = [] for i, row in enumerate(rows[:limit]): snap = json.loads(row.snapshot_json) # try to diff against the next row (which is older) changes_count: int | None = None if i + 1 < len(rows): older = json.loads(rows[i + 1].snapshot_json) from app.services.snapshot_service import diff_snapshots ch = diff_snapshots(older, snap) changes_count = len(ch) groups_count = len(snap.get("groups", {})) # embed lightweight summary into snapshot dict so frontend can display it snap["_groups_count"] = groups_count snap["_changes_count"] = changes_count # None means first ever snapshot results.append(SnapshotResponse( id=row.id, upstream_id=row.upstream_id, snapshot=snap, captured_at=row.captured_at, )) return results