"""Upstream management CRUD + test + check-now + snapshots.""" from __future__ import annotations import json import logging import re from datetime import datetime, timezone from typing import List logger = logging.getLogger(__name__) from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from app.database import get_db from app.models.admin_user import AdminUser from app.models.upstream import Upstream from app.models.upstream_key import UpstreamGeneratedKey from app.models.snapshot import UpstreamRateSnapshot from app.schemas.upstream import ( GenerateKeysByGroupsRequest, GenerateKeysByGroupsResponse, GeneratedUpstreamKeyResponse, UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult ) from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret from app.services.snapshot_service import diff_snapshots from app.services import scheduler as sched_svc from app.services import webhook_service from app.services import website_sync from app.utils.auth import get_current_user router = APIRouter(prefix="/api/upstreams", tags=["upstreams"]) MASK = "***" SECRET_KEYS = {"password", "token", "key", "secret"} def _group_id(group: dict) -> str: for key in ("id", "group_id", "groupId"): value = group.get(key) if value is not None: return str(value) return str(group.get("name") or group.get("group_name") or "") def _group_name(group: dict, gid: str) -> str: return str(group.get("name") or group.get("group_name") or gid) def _key_response(row: UpstreamGeneratedKey, include_value: bool = False) -> GeneratedUpstreamKeyResponse: return GeneratedUpstreamKeyResponse( id=row.id, upstream_id=row.upstream_id, group_id=row.group_id, group_name=row.group_name, key_id=row.key_id, key_name=row.key_name, key_value=row.key_value if include_value else None, masked_key=row.masked_key, status=row.status, error=row.error, imported_website_id=row.imported_website_id, imported_account_id=row.imported_account_id, imported_at=row.imported_at, created_at=row.created_at, updated_at=row.updated_at, ) def _mask_auth_config(auth_type: str, cfg: dict) -> dict: masked = {} for k, v in cfg.items(): if k.lower() in SECRET_KEYS and v: masked[k] = MASK else: masked[k] = v return masked def _to_response(u: Upstream) -> UpstreamResponse: cfg = json.loads(u.auth_config_json or "{}") return UpstreamResponse( id=u.id, name=u.name, base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config_masked=_mask_auth_config(u.auth_type, cfg), rate_endpoint=u.rate_endpoint, groups_endpoint=u.groups_endpoint, enabled=u.enabled, check_interval_seconds=u.check_interval_seconds, timeout_seconds=u.timeout_seconds, last_status=u.last_status, last_checked_at=u.last_checked_at, last_error=u.last_error, balance=u.balance, balance_updated_at=u.balance_updated_at, balance_endpoint=u.balance_endpoint or "", balance_response_path=u.balance_response_path or "", balance_divisor=u.balance_divisor or 1.0, created_at=u.created_at, updated_at=u.updated_at, ) @router.get("", response_model=List[UpstreamResponse]) def list_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)): return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()] @router.get("/{uid}/generated-keys", response_model=List[GeneratedUpstreamKeyResponse]) def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): if not db.query(Upstream.id).filter(Upstream.id == uid).first(): raise HTTPException(404, "upstream not found") rows = ( db.query(UpstreamGeneratedKey) .filter(UpstreamGeneratedKey.upstream_id == uid) .order_by(UpstreamGeneratedKey.id.desc()) .limit(200) .all() ) return [_key_response(row) for row in rows] _generate_key_lock = __import__("threading").Lock() def _ensure_group_key( db: Session, client: UpstreamClient, upstream: Upstream, group: dict[str, Any], prefix: str, body: GenerateKeysByGroupsRequest, ) -> GeneratedUpstreamKeyResponse: """确保一个上游分组有一个 SmartUp 前缀 Key:存在则 upsert,不存在则创建。""" gid = _group_id(group) gname = _group_name(group, gid) # 使用稳定的 upstream_id + group_id 而非可变名称,避免因改名产生重复 # 可读 Key 名:{prefix}-{upstream.id}-{安全的分组名}-{group_id} safe_group_name = re.sub(r"[^a-zA-Z0-9\u4e00-\u9fff_-]", "", gname)[:30] if gname else gid stable_name = f"{prefix}-{upstream.id}-{safe_group_name}-{gid}" with _generate_key_lock: try: # 1. 先查本地是否已有该分组的托管 Key(兼容迁移前无 managed_prefix 的记录) row = ( db.query(UpstreamGeneratedKey) .filter( UpstreamGeneratedKey.upstream_id == upstream.id, UpstreamGeneratedKey.group_id == gid, (UpstreamGeneratedKey.managed_prefix == prefix) | ((UpstreamGeneratedKey.managed_prefix.is_(None)) & UpstreamGeneratedKey.key_name.like(f"{prefix}-%")), ) .first() ) if row and row.key_id: # 本地已有记录,检查远端是否仍存在 try: existing = client.find_smartup_group_key(gid, stable_name, prefix) except Exception: existing = None if existing: row.status = "exists" row.updated_at = datetime.now(timezone.utc) db.commit() return _key_response(row, include_value=False) # 远端不存在,需要重新创建 row.status = "replaced" # 2. 查远端是否有同名 Key(防止并发时另一个请求已创建) existing = client.find_smartup_group_key(gid, stable_name, prefix) if existing: key_id = str(existing.get("id") or "") masked = existing.get("masked_key") or existing.get("key") or "" if row: row.key_id = key_id or row.key_id row.masked_key = masked or row.masked_key row.status = "exists" row.updated_at = datetime.now(timezone.utc) else: row = UpstreamGeneratedKey( upstream_id=upstream.id, group_id=gid, group_name=gname, key_id=key_id or None, key_name=stable_name, key_value="", masked_key=masked, raw_json=json.dumps(existing, ensure_ascii=False), managed_prefix=prefix, status="exists", ) db.add(row) db.commit() db.refresh(row) return _key_response(row, include_value=False) # 3. 远端不存在,创建新 Key created = client.create_api_key( stable_name, gid, quota=body.quota, expires_in_days=body.expires_in_days, rate_limit_5h=body.rate_limit_5h, rate_limit_1d=body.rate_limit_1d, rate_limit_7d=body.rate_limit_7d, endpoint=body.endpoint, ) if row: # 复用旧行 row.key_id = created.get("id") or None row.key_name = stable_name row.key_value = created["key"] row.masked_key = created.get("masked_key") or mask_secret(created["key"]) row.raw_json = json.dumps(created.get("raw") or {}, ensure_ascii=False) row.managed_prefix = prefix row.status = "created" row.error = None else: row = UpstreamGeneratedKey( upstream_id=upstream.id, group_id=gid, group_name=gname, key_id=created.get("id") or None, key_name=stable_name, key_value=created["key"], masked_key=created.get("masked_key") or mask_secret(created["key"]), raw_json=json.dumps(created.get("raw") or {}, ensure_ascii=False), managed_prefix=prefix, status="created", ) db.add(row) db.commit() db.refresh(row) return _key_response(row, include_value=True) except Exception as exc: logger.exception("ensure group key failed for upstream=%s group=%s", upstream.id, gid) return GeneratedUpstreamKeyResponse( upstream_id=upstream.id, group_id=gid, group_name=gname, key_name=stable_name, status="failed", error=str(exc), ) @router.post("/{uid}/keys/generate-by-groups", response_model=GenerateKeysByGroupsResponse) def generate_keys_by_groups( uid: int, body: GenerateKeysByGroupsRequest, db: Session = Depends(get_db), _=Depends(get_current_user), ): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") if u.api_prefix.strip("/") != "api/v1": raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1)") auth_config = json.loads(u.auth_config_json or "{}") selected = set(body.group_ids) prefix = body.name_prefix results: list[GeneratedUpstreamKeyResponse] = [] with UpstreamClient( base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config=auth_config, timeout=float(u.timeout_seconds), ) as client: try: client.login() groups = client.get_available_groups(u.groups_endpoint) except Exception as exc: raise HTTPException(502, str(exc)) for group in groups: gid = _group_id(group) if not gid or (selected and gid not in selected): continue result = _ensure_group_key(db, client, u, group, prefix, body) results.append(result) created = len([item for item in results if item.status == "created"]) existed = len([item for item in results if item.status == "exists"]) total = len(results) msg_parts = [] if created: msg_parts.append(f"新创建 {created}") if existed: msg_parts.append(f"已存在 {existed}") msg = "、".join(msg_parts) + f" / 共 {total} 个分组" if msg_parts else f"共处理 {total} 个分组" return GenerateKeysByGroupsResponse( success=total > 0 and all(item.status != "failed" for item in results), message=msg, items=results, ) @router.post("", response_model=UpstreamResponse, status_code=201) def create_upstream( body: UpstreamCreate, db: Session = Depends(get_db), _=Depends(get_current_user), ): u = Upstream( name=body.name, base_url=body.base_url.rstrip("/"), api_prefix=body.api_prefix, auth_type=body.auth_type, auth_config_json=json.dumps(body.auth_config, ensure_ascii=False), rate_endpoint=body.rate_endpoint, groups_endpoint=body.groups_endpoint, enabled=body.enabled, check_interval_seconds=body.check_interval_seconds, timeout_seconds=body.timeout_seconds, balance_endpoint=body.balance_endpoint, balance_response_path=body.balance_response_path, balance_divisor=body.balance_divisor, ) db.add(u) db.commit() db.refresh(u) sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled) return _to_response(u) @router.get("/{uid}", response_model=UpstreamResponse) def get_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") return _to_response(u) @router.put("/{uid}", response_model=UpstreamResponse) def update_upstream( uid: int, body: UpstreamUpdate, db: Session = Depends(get_db), _=Depends(get_current_user), ): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") data = body.model_dump(exclude_none=True) if "auth_config" in data: # merge with existing config to avoid overwriting masked fields existing = json.loads(u.auth_config_json or "{}") incoming = data.pop("auth_config") for k, v in incoming.items(): if v != MASK: # don't overwrite with mask placeholder existing[k] = v u.auth_config_json = json.dumps(existing, ensure_ascii=False) if "base_url" in data: data["base_url"] = data["base_url"].rstrip("/") for k, v in data.items(): setattr(u, k, v) # Reset failure counter on any update — the user may have fixed the issue u.consecutive_failures = 0 u.updated_at = datetime.now(timezone.utc) db.commit() db.refresh(u) sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled) return _to_response(u) @router.delete("/{uid}", status_code=204) def delete_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") sched_svc.refresh_upstream(uid, 0, False) # remove job db.delete(u) db.commit() @router.post("/{uid}/test", response_model=TestResult) def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") auth_config = json.loads(u.auth_config_json or "{}") with UpstreamClient( base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config=auth_config, timeout=float(u.timeout_seconds), ) as client: try: client.login() groups = client.get_available_groups(u.groups_endpoint) # Also try balance if configured if u.balance_endpoint and u.balance_response_path: try: raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path) if raw_balance is not None: divisor = u.balance_divisor or 1.0 u.balance = raw_balance / divisor u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None except Exception as exc: logger.warning("upstream %s balance fetch failed during test: %s", u.name, exc) u.last_status = "healthy" u.last_error = None u.last_checked_at = datetime.now(timezone.utc) u.consecutive_failures = 0 db.commit() return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组") except Exception as exc: u.last_status = "unhealthy" u.last_error = str(exc) u.last_checked_at = datetime.now(timezone.utc) u.consecutive_failures = (u.consecutive_failures or 0) + 1 db.commit() return TestResult(success=False, message="连接失败", detail=str(exc)) @router.post("/{uid}/check-now", response_model=TestResult) def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") auth_config = json.loads(u.auth_config_json or "{}") with UpstreamClient( base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config=auth_config, timeout=float(u.timeout_seconds), ) as client: try: client.login() groups = client.get_available_groups(u.groups_endpoint) raw_rates = client.get_group_rates(u.rate_endpoint) snapshot = build_snapshot(u.id, u.base_url, u.api_prefix, groups, raw_rates) # Also try balance if configured if u.balance_endpoint and u.balance_response_path: try: raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path) if raw_balance is not None: divisor = u.balance_divisor or 1.0 u.balance = raw_balance / divisor u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None except Exception as exc: logger.warning("upstream %s balance fetch failed during check-now: %s", u.name, exc) except Exception as exc: u.consecutive_failures = (u.consecutive_failures or 0) + 1 u.last_error = str(exc) u.last_checked_at = datetime.now(timezone.utc) db.commit() return TestResult(success=False, message="检测失败", detail=str(exc)) prev_row = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == uid) .order_by(UpstreamRateSnapshot.captured_at.desc()) .first() ) previous = json.loads(prev_row.snapshot_json) if prev_row else None changes = diff_snapshots(previous, snapshot) new_row = UpstreamRateSnapshot( upstream_id=uid, snapshot_json=json.dumps(snapshot, ensure_ascii=False), captured_at=datetime.now(timezone.utc), ) db.add(new_row) was_unhealthy = u.last_status == "unhealthy" u.last_status = "healthy" u.last_checked_at = datetime.now(timezone.utc) u.last_error = None u.consecutive_failures = 0 db.commit() if was_unhealthy: webhook_service.send_status_event(db, u.id, u.name, u.base_url, "upstream_recovered") if changes: webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes) website_sync.sync_affected_bindings(db, u.id, changes) # 同步 SmartUp Key 状态(使用实际快照入库时间,与定时任务一致) from app.services.scheduler import _sync_upstream_keys as _synck _synck(uid, snapshot, new_row.captured_at) msg = f"检测成功,{len(groups)} 个分组" if changes: msg += f",发现 {len(changes)} 处倍率变化" elif previous is None: msg += ",初始化快照完成" else: msg += ",无变化" return TestResult(success=True, message=msg) @router.get("/{uid}/snapshots/latest", response_model=SnapshotResponse) def latest_snapshot(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == uid) .order_by(UpstreamRateSnapshot.captured_at.desc()) .first() ) if not row: raise HTTPException(404, "no snapshot found") return SnapshotResponse( id=row.id, upstream_id=row.upstream_id, snapshot=json.loads(row.snapshot_json), captured_at=row.captured_at, ) from fastapi import Query as QueryParam @router.get("/{uid}/snapshots", response_model=List[SnapshotResponse]) def list_snapshots( uid: int, limit: int = QueryParam(20, le=100), offset: int = QueryParam(0), db: Session = Depends(get_db), _=Depends(get_current_user), ): """Return paginated snapshot history with diff vs previous snapshot embedded.""" rows = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == uid) .order_by(UpstreamRateSnapshot.captured_at.desc()) .offset(offset) .limit(limit + 1) # fetch one extra to get the "previous" for diffing .all() ) # We need the snapshot just before each one to compute changes count. # rows are desc order; rows[i+1] is older than rows[i] results = [] for i, row in enumerate(rows[:limit]): snap = json.loads(row.snapshot_json) # try to diff against the next row (which is older) changes_count: int | None = None if i + 1 < len(rows): older = json.loads(rows[i + 1].snapshot_json) from app.services.snapshot_service import diff_snapshots ch = diff_snapshots(older, snap) changes_count = len(ch) groups_count = len(snap.get("groups", {})) # embed lightweight summary into snapshot dict so frontend can display it snap["_groups_count"] = groups_count snap["_changes_count"] = changes_count # None means first ever snapshot results.append(SnapshotResponse( id=row.id, upstream_id=row.upstream_id, snapshot=snap, captured_at=row.captured_at, )) return results