"""Upstream management CRUD + test + check-now + snapshots.""" from __future__ import annotations import json import logging from datetime import datetime, timezone from typing import List logger = logging.getLogger(__name__) from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from app.database import get_db from app.models.admin_user import AdminUser from app.models.upstream import Upstream from app.models.snapshot import UpstreamRateSnapshot from app.schemas.upstream import ( UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult ) from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot from app.services.snapshot_service import diff_snapshots from app.services import scheduler as sched_svc from app.services import webhook_service from app.services import website_sync from app.utils.auth import get_current_user router = APIRouter(prefix="/api/upstreams", tags=["upstreams"]) MASK = "***" SECRET_KEYS = {"password", "token", "key", "secret"} def _mask_auth_config(auth_type: str, cfg: dict) -> dict: masked = {} for k, v in cfg.items(): if k.lower() in SECRET_KEYS and v: masked[k] = MASK else: masked[k] = v return masked def _to_response(u: Upstream) -> UpstreamResponse: cfg = json.loads(u.auth_config_json or "{}") return UpstreamResponse( id=u.id, name=u.name, base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config_masked=_mask_auth_config(u.auth_type, cfg), rate_endpoint=u.rate_endpoint, groups_endpoint=u.groups_endpoint, enabled=u.enabled, check_interval_seconds=u.check_interval_seconds, timeout_seconds=u.timeout_seconds, last_status=u.last_status, last_checked_at=u.last_checked_at, last_error=u.last_error, balance=u.balance, balance_updated_at=u.balance_updated_at, balance_endpoint=u.balance_endpoint or "", balance_response_path=u.balance_response_path or "", balance_divisor=u.balance_divisor or 1.0, created_at=u.created_at, updated_at=u.updated_at, ) @router.get("", response_model=List[UpstreamResponse]) def list_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)): return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()] @router.post("", response_model=UpstreamResponse, status_code=201) def create_upstream( body: UpstreamCreate, db: Session = Depends(get_db), _=Depends(get_current_user), ): u = Upstream( name=body.name, base_url=body.base_url.rstrip("/"), api_prefix=body.api_prefix, auth_type=body.auth_type, auth_config_json=json.dumps(body.auth_config, ensure_ascii=False), rate_endpoint=body.rate_endpoint, groups_endpoint=body.groups_endpoint, enabled=body.enabled, check_interval_seconds=body.check_interval_seconds, timeout_seconds=body.timeout_seconds, balance_endpoint=body.balance_endpoint, balance_response_path=body.balance_response_path, balance_divisor=body.balance_divisor, ) db.add(u) db.commit() db.refresh(u) sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled) return _to_response(u) @router.get("/{uid}", response_model=UpstreamResponse) def get_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") return _to_response(u) @router.put("/{uid}", response_model=UpstreamResponse) def update_upstream( uid: int, body: UpstreamUpdate, db: Session = Depends(get_db), _=Depends(get_current_user), ): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") data = body.model_dump(exclude_none=True) if "auth_config" in data: # merge with existing config to avoid overwriting masked fields existing = json.loads(u.auth_config_json or "{}") incoming = data.pop("auth_config") for k, v in incoming.items(): if v != MASK: # don't overwrite with mask placeholder existing[k] = v u.auth_config_json = json.dumps(existing, ensure_ascii=False) if "base_url" in data: data["base_url"] = data["base_url"].rstrip("/") for k, v in data.items(): setattr(u, k, v) # Reset failure counter on any update — the user may have fixed the issue u.consecutive_failures = 0 u.updated_at = datetime.now(timezone.utc) db.commit() db.refresh(u) sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled) return _to_response(u) @router.delete("/{uid}", status_code=204) def delete_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") sched_svc.refresh_upstream(uid, 0, False) # remove job db.delete(u) db.commit() @router.post("/{uid}/test", response_model=TestResult) def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") auth_config = json.loads(u.auth_config_json or "{}") with UpstreamClient( base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config=auth_config, timeout=float(u.timeout_seconds), ) as client: try: client.login() groups = client.get_available_groups(u.groups_endpoint) # Also try balance if configured if u.balance_endpoint and u.balance_response_path: try: raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path) if raw_balance is not None: divisor = u.balance_divisor or 1.0 u.balance = raw_balance / divisor u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None except Exception as exc: logger.warning("upstream %s balance fetch failed during test: %s", u.name, exc) u.last_status = "healthy" u.last_error = None u.last_checked_at = datetime.now(timezone.utc) u.consecutive_failures = 0 db.commit() return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组") except Exception as exc: u.last_status = "unhealthy" u.last_error = str(exc) u.last_checked_at = datetime.now(timezone.utc) u.consecutive_failures = (u.consecutive_failures or 0) + 1 db.commit() return TestResult(success=False, message="连接失败", detail=str(exc)) @router.post("/{uid}/check-now", response_model=TestResult) def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): u = db.query(Upstream).filter(Upstream.id == uid).first() if not u: raise HTTPException(404, "upstream not found") auth_config = json.loads(u.auth_config_json or "{}") with UpstreamClient( base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config=auth_config, timeout=float(u.timeout_seconds), ) as client: try: client.login() groups = client.get_available_groups(u.groups_endpoint) raw_rates = client.get_group_rates(u.rate_endpoint) snapshot = build_snapshot(u.id, u.base_url, u.api_prefix, groups, raw_rates) # Also try balance if configured if u.balance_endpoint and u.balance_response_path: try: raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path) if raw_balance is not None: divisor = u.balance_divisor or 1.0 u.balance = raw_balance / divisor u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None except Exception as exc: logger.warning("upstream %s balance fetch failed during check-now: %s", u.name, exc) except Exception as exc: u.consecutive_failures = (u.consecutive_failures or 0) + 1 u.last_error = str(exc) u.last_checked_at = datetime.now(timezone.utc) db.commit() return TestResult(success=False, message="检测失败", detail=str(exc)) prev_row = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == uid) .order_by(UpstreamRateSnapshot.captured_at.desc()) .first() ) previous = json.loads(prev_row.snapshot_json) if prev_row else None changes = diff_snapshots(previous, snapshot) new_row = UpstreamRateSnapshot( upstream_id=uid, snapshot_json=json.dumps(snapshot, ensure_ascii=False), captured_at=datetime.now(timezone.utc), ) db.add(new_row) was_unhealthy = u.last_status == "unhealthy" u.last_status = "healthy" u.last_checked_at = datetime.now(timezone.utc) u.last_error = None u.consecutive_failures = 0 db.commit() if was_unhealthy: webhook_service.send_status_event(db, u.id, u.name, u.base_url, "upstream_recovered") if changes: webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes) website_sync.sync_affected_bindings(db, u.id, changes) msg = f"检测成功,{len(groups)} 个分组" if changes: msg += f",发现 {len(changes)} 处倍率变化" elif previous is None: msg += ",初始化快照完成" else: msg += ",无变化" return TestResult(success=True, message=msg) @router.get("/{uid}/snapshots/latest", response_model=SnapshotResponse) def latest_snapshot(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == uid) .order_by(UpstreamRateSnapshot.captured_at.desc()) .first() ) if not row: raise HTTPException(404, "no snapshot found") return SnapshotResponse( id=row.id, upstream_id=row.upstream_id, snapshot=json.loads(row.snapshot_json), captured_at=row.captured_at, ) from fastapi import Query as QueryParam @router.get("/{uid}/snapshots", response_model=List[SnapshotResponse]) def list_snapshots( uid: int, limit: int = QueryParam(20, le=100), offset: int = QueryParam(0), db: Session = Depends(get_db), _=Depends(get_current_user), ): """Return paginated snapshot history with diff vs previous snapshot embedded.""" rows = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == uid) .order_by(UpstreamRateSnapshot.captured_at.desc()) .offset(offset) .limit(limit + 1) # fetch one extra to get the "previous" for diffing .all() ) # We need the snapshot just before each one to compute changes count. # rows are desc order; rows[i+1] is older than rows[i] results = [] for i, row in enumerate(rows[:limit]): snap = json.loads(row.snapshot_json) # try to diff against the next row (which is older) changes_count: int | None = None if i + 1 < len(rows): older = json.loads(rows[i + 1].snapshot_json) from app.services.snapshot_service import diff_snapshots ch = diff_snapshots(older, snap) changes_count = len(ch) groups_count = len(snap.get("groups", {})) # embed lightweight summary into snapshot dict so frontend can display it snap["_groups_count"] = groups_count snap["_changes_count"] = changes_count # None means first ever snapshot results.append(SnapshotResponse( id=row.id, upstream_id=row.upstream_id, snapshot=snap, captured_at=row.captured_at, )) return results