Files
SmartUp/backend/app/routers/upstreams.py
T
2026-05-29 17:51:12 +08:00

594 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Upstream management CRUD + test + check-now + snapshots."""
from __future__ import annotations
import json
import logging
import re
from datetime import datetime, timezone
from typing import Any, List
logger = logging.getLogger(__name__)
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.admin_user import AdminUser
from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.snapshot import UpstreamRateSnapshot
from app.schemas.upstream import (
GenerateKeysByGroupsRequest,
GenerateKeysByGroupsResponse,
GeneratedUpstreamKeyResponse,
UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult
)
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret, _extract_key_value
from app.services.snapshot_service import diff_snapshots
from app.services import scheduler as sched_svc
from app.services import webhook_service
from app.services import website_sync
from app.utils.auth import get_current_user
router = APIRouter(prefix="/api/upstreams", tags=["upstreams"])
MASK = "***"
SECRET_KEYS = {"password", "token", "key", "secret"}
def _group_id(group: dict) -> str:
for key in ("id", "group_id", "groupId"):
value = group.get(key)
if value is not None:
return str(value)
return str(group.get("name") or group.get("group_name") or "")
def _group_name(group: dict, gid: str) -> str:
return str(group.get("name") or group.get("group_name") or gid)
def _key_response(row: UpstreamGeneratedKey, include_value: bool = False) -> GeneratedUpstreamKeyResponse:
return GeneratedUpstreamKeyResponse(
id=row.id,
upstream_id=row.upstream_id,
group_id=row.group_id,
group_name=row.group_name,
key_id=row.key_id,
key_name=row.key_name,
key_value=row.key_value if include_value else None,
masked_key=row.masked_key,
status=row.status,
error=row.error,
imported_website_id=row.imported_website_id,
imported_account_id=row.imported_account_id,
imported_at=row.imported_at,
created_at=row.created_at,
updated_at=row.updated_at,
has_key_value=bool(row.key_value),
)
def _mask_auth_config(auth_type: str, cfg: dict) -> dict:
masked = {}
for k, v in cfg.items():
if k.lower() in SECRET_KEYS and v:
masked[k] = MASK
else:
masked[k] = v
return masked
def _extract_plaintext_key(payload: dict[str, Any] | None) -> str:
if not isinstance(payload, dict):
return ""
key_value = _extract_key_value(payload)
if not key_value:
return ""
if "*" in key_value:
return ""
return key_value
def _to_response(u: Upstream) -> UpstreamResponse:
cfg = json.loads(u.auth_config_json or "{}")
return UpstreamResponse(
id=u.id,
name=u.name,
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config_masked=_mask_auth_config(u.auth_type, cfg),
rate_endpoint=u.rate_endpoint,
groups_endpoint=u.groups_endpoint,
enabled=u.enabled,
check_interval_seconds=u.check_interval_seconds,
timeout_seconds=u.timeout_seconds,
last_status=u.last_status,
last_checked_at=u.last_checked_at,
last_error=u.last_error,
balance=u.balance,
balance_updated_at=u.balance_updated_at,
balance_endpoint=u.balance_endpoint or "",
balance_response_path=u.balance_response_path or "",
balance_divisor=u.balance_divisor or 1.0,
balance_alert_threshold=u.balance_alert_threshold,
created_at=u.created_at,
updated_at=u.updated_at,
)
@router.get("", response_model=List[UpstreamResponse])
def list_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)):
return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()]
@router.get("/{uid}/generated-keys", response_model=List[GeneratedUpstreamKeyResponse])
def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
if not db.query(Upstream.id).filter(Upstream.id == uid).first():
raise HTTPException(404, "upstream not found")
rows = (
db.query(UpstreamGeneratedKey)
.filter(UpstreamGeneratedKey.upstream_id == uid)
.order_by(UpstreamGeneratedKey.id.desc())
.limit(200)
.all()
)
return [_key_response(row) for row in rows]
_generate_key_lock = __import__("threading").Lock()
def _ensure_group_key(
db: Session,
client: UpstreamClient,
upstream: Upstream,
group: dict[str, Any],
prefix: str,
body: GenerateKeysByGroupsRequest,
) -> GeneratedUpstreamKeyResponse:
"""确保一个上游分组有一个 SmartUp 前缀 Key:存在则 upsert,不存在则创建。"""
gid = _group_id(group)
gname = _group_name(group, gid)
# 使用稳定的 upstream_id + group_id 而非可变名称,避免因改名产生重复
# 可读 Key 名:{prefix}-{upstream.id}-{安全的分组名}-{group_id}
safe_group_name = re.sub(r"[^a-zA-Z0-9\u4e00-\u9fff_-]", "", gname)[:30] if gname else gid
stable_name = f"{prefix}-{upstream.id}-{safe_group_name}-{gid}"
with _generate_key_lock:
try:
# 1. 先查本地是否已有该分组的托管 Key(兼容迁移前无 managed_prefix 的记录)
row = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == gid,
(UpstreamGeneratedKey.managed_prefix == prefix)
| ((UpstreamGeneratedKey.managed_prefix.is_(None))
& UpstreamGeneratedKey.key_name.like(f"{prefix}-%")),
)
.first()
)
if row and row.key_id:
# 本地已有记录,检查远端是否仍存在
try:
existing = client.find_smartup_group_key(gid, stable_name, prefix)
except Exception:
existing = None
if existing:
key_id = str(existing.get("id") or "")
key_value = _extract_plaintext_key(existing)
masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "")
row.key_id = key_id or row.key_id
if key_value:
row.key_value = key_value
row.masked_key = masked
elif masked:
row.masked_key = str(masked)
row.raw_json = json.dumps(existing, ensure_ascii=False)
row.status = "exists"
row.updated_at = datetime.now(timezone.utc)
db.add(row)
db.commit()
db.refresh(row)
return _key_response(row, include_value=False)
# 远端不存在,需要重新创建
row.status = "replaced"
# 2. 查远端是否有同名 Key(防止并发时另一个请求已创建)
existing = client.find_smartup_group_key(gid, stable_name, prefix)
if existing:
key_id = str(existing.get("id") or "")
key_value = _extract_plaintext_key(existing)
masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "")
if row:
row.key_id = key_id or row.key_id
if key_value:
row.key_value = key_value
row.masked_key = masked
elif masked:
row.masked_key = str(masked)
row.raw_json = json.dumps(existing, ensure_ascii=False)
row.status = "exists"
row.updated_at = datetime.now(timezone.utc)
else:
row = UpstreamGeneratedKey(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_id=key_id or None,
key_name=stable_name,
key_value=key_value or "",
masked_key=masked,
raw_json=json.dumps(existing, ensure_ascii=False),
managed_prefix=prefix,
status="exists",
)
db.add(row)
db.commit()
db.refresh(row)
return _key_response(row, include_value=False)
# 3. 远端不存在,创建新 Key
created = client.create_api_key(
stable_name,
gid,
quota=body.quota,
expires_in_days=body.expires_in_days,
rate_limit_5h=body.rate_limit_5h,
rate_limit_1d=body.rate_limit_1d,
rate_limit_7d=body.rate_limit_7d,
endpoint=body.endpoint,
)
if row:
# 复用旧行
row.key_id = created.get("id") or None
row.key_name = stable_name
row.key_value = created["key"]
row.masked_key = created.get("masked_key") or mask_secret(created["key"])
row.raw_json = json.dumps(created.get("raw") or {}, ensure_ascii=False)
row.managed_prefix = prefix
row.status = "created"
row.error = None
else:
row = UpstreamGeneratedKey(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_id=created.get("id") or None,
key_name=stable_name,
key_value=created["key"],
masked_key=created.get("masked_key") or mask_secret(created["key"]),
raw_json=json.dumps(created.get("raw") or {}, ensure_ascii=False),
managed_prefix=prefix,
status="created",
)
db.add(row)
db.commit()
db.refresh(row)
return _key_response(row, include_value=True)
except Exception as exc:
logger.exception("ensure group key failed for upstream=%s group=%s", upstream.id, gid)
return GeneratedUpstreamKeyResponse(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_name=stable_name,
status="failed",
error=str(exc),
)
@router.post("/{uid}/keys/generate-by-groups", response_model=GenerateKeysByGroupsResponse)
def generate_keys_by_groups(
uid: int,
body: GenerateKeysByGroupsRequest,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
if u.api_prefix.strip("/") != "api/v1":
raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1")
auth_config = json.loads(u.auth_config_json or "{}")
selected = set(body.group_ids)
prefix = body.name_prefix
results: list[GeneratedUpstreamKeyResponse] = []
with UpstreamClient(
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config=auth_config,
timeout=float(u.timeout_seconds),
) as client:
try:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
except Exception as exc:
raise HTTPException(502, str(exc))
for group in groups:
gid = _group_id(group)
if not gid or (selected and gid not in selected):
continue
result = _ensure_group_key(db, client, u, group, prefix, body)
results.append(result)
created = len([item for item in results if item.status == "created"])
existed = len([item for item in results if item.status == "exists"])
total = len(results)
msg_parts = []
if created:
msg_parts.append(f"新创建 {created}")
if existed:
msg_parts.append(f"已存在 {existed}")
msg = "".join(msg_parts) + f" / 共 {total} 个分组" if msg_parts else f"共处理 {total} 个分组"
return GenerateKeysByGroupsResponse(
success=total > 0 and all(item.status != "failed" for item in results),
message=msg,
items=results,
)
@router.post("", response_model=UpstreamResponse, status_code=201)
def create_upstream(
body: UpstreamCreate,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
u = Upstream(
name=body.name,
base_url=body.base_url.rstrip("/"),
api_prefix=body.api_prefix,
auth_type=body.auth_type,
auth_config_json=json.dumps(body.auth_config, ensure_ascii=False),
rate_endpoint=body.rate_endpoint,
groups_endpoint=body.groups_endpoint,
enabled=body.enabled,
check_interval_seconds=body.check_interval_seconds,
timeout_seconds=body.timeout_seconds,
balance_endpoint=body.balance_endpoint,
balance_response_path=body.balance_response_path,
balance_divisor=body.balance_divisor,
balance_alert_threshold=body.balance_alert_threshold,
)
db.add(u)
db.commit()
db.refresh(u)
sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled)
return _to_response(u)
@router.get("/{uid}", response_model=UpstreamResponse)
def get_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
return _to_response(u)
@router.put("/{uid}", response_model=UpstreamResponse)
def update_upstream(
uid: int,
body: UpstreamUpdate,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
data = body.model_dump(exclude_none=True)
if "auth_config" in data:
# merge with existing config to avoid overwriting masked fields
existing = json.loads(u.auth_config_json or "{}")
incoming = data.pop("auth_config")
for k, v in incoming.items():
if v != MASK: # don't overwrite with mask placeholder
existing[k] = v
u.auth_config_json = json.dumps(existing, ensure_ascii=False)
if "base_url" in data:
data["base_url"] = data["base_url"].rstrip("/")
for k, v in data.items():
setattr(u, k, v)
# Reset failure counter on any update — the user may have fixed the issue
u.consecutive_failures = 0
u.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(u)
sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled)
return _to_response(u)
@router.delete("/{uid}", status_code=204)
def delete_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
sched_svc.refresh_upstream(uid, 0, False) # remove job
db.delete(u)
db.commit()
@router.post("/{uid}/test", response_model=TestResult)
def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
auth_config = json.loads(u.auth_config_json or "{}")
with UpstreamClient(
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config=auth_config,
timeout=float(u.timeout_seconds),
) as client:
try:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
# Also try balance if configured
if u.balance_endpoint and u.balance_response_path:
try:
raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path)
if raw_balance is not None:
divisor = u.balance_divisor or 1.0
u.balance = raw_balance / divisor
u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None
except Exception as exc:
logger.warning("upstream %s balance fetch failed during test: %s", u.name, exc)
u.last_status = "healthy"
u.last_error = None
u.last_checked_at = datetime.now(timezone.utc)
u.consecutive_failures = 0
db.commit()
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
except Exception as exc:
u.last_status = "unhealthy"
u.last_error = str(exc)
u.last_checked_at = datetime.now(timezone.utc)
u.consecutive_failures = (u.consecutive_failures or 0) + 1
db.commit()
return TestResult(success=False, message="连接失败", detail=str(exc))
@router.post("/{uid}/check-now", response_model=TestResult)
def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
auth_config = json.loads(u.auth_config_json or "{}")
with UpstreamClient(
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config=auth_config,
timeout=float(u.timeout_seconds),
) as client:
try:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
raw_rates = client.get_group_rates(u.rate_endpoint)
snapshot = build_snapshot(u.id, u.base_url, u.api_prefix, groups, raw_rates)
# Also try balance if configured
if u.balance_endpoint and u.balance_response_path:
try:
raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path)
if raw_balance is not None:
divisor = u.balance_divisor or 1.0
u.balance = raw_balance / divisor
u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None
except Exception as exc:
logger.warning("upstream %s balance fetch failed during check-now: %s", u.name, exc)
except Exception as exc:
u.consecutive_failures = (u.consecutive_failures or 0) + 1
u.last_error = str(exc)
u.last_checked_at = datetime.now(timezone.utc)
db.commit()
return TestResult(success=False, message="检测失败", detail=str(exc))
prev_row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == uid)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.first()
)
previous = json.loads(prev_row.snapshot_json) if prev_row else None
changes = diff_snapshots(previous, snapshot)
new_row = UpstreamRateSnapshot(
upstream_id=uid,
snapshot_json=json.dumps(snapshot, ensure_ascii=False),
captured_at=datetime.now(timezone.utc),
)
db.add(new_row)
was_unhealthy = u.last_status == "unhealthy"
u.last_status = "healthy"
u.last_checked_at = datetime.now(timezone.utc)
u.last_error = None
u.consecutive_failures = 0
db.commit()
if was_unhealthy:
webhook_service.send_status_event(db, u.id, u.name, u.base_url, "upstream_recovered")
# 先同步 Key 状态(标记 orphaned),再执行优先级同步(避免未标记的 key 参与计算)
from app.services.scheduler import _sync_upstream_keys as _synck
_synck(uid, snapshot, new_row.captured_at)
if changes:
webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes)
website_sync.sync_affected_bindings(db, u.id, changes)
website_sync.sync_account_priorities_for_upstream(db, u.id)
msg = f"检测成功,{len(groups)} 个分组"
if changes:
msg += f",发现 {len(changes)} 处倍率变化"
elif previous is None:
msg += ",初始化快照完成"
else:
msg += ",无变化"
return TestResult(success=True, message=msg)
@router.get("/{uid}/snapshots/latest", response_model=SnapshotResponse)
def latest_snapshot(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == uid)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.first()
)
if not row:
raise HTTPException(404, "no snapshot found")
return SnapshotResponse(
id=row.id,
upstream_id=row.upstream_id,
snapshot=json.loads(row.snapshot_json),
captured_at=row.captured_at,
)
from fastapi import Query as QueryParam
@router.get("/{uid}/snapshots", response_model=List[SnapshotResponse])
def list_snapshots(
uid: int,
limit: int = QueryParam(20, le=100),
offset: int = QueryParam(0),
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
"""Return paginated snapshot history with diff vs previous snapshot embedded."""
rows = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == uid)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.offset(offset)
.limit(limit + 1) # fetch one extra to get the "previous" for diffing
.all()
)
# We need the snapshot just before each one to compute changes count.
# rows are desc order; rows[i+1] is older than rows[i]
results = []
for i, row in enumerate(rows[:limit]):
snap = json.loads(row.snapshot_json)
# try to diff against the next row (which is older)
changes_count: int | None = None
if i + 1 < len(rows):
older = json.loads(rows[i + 1].snapshot_json)
from app.services.snapshot_service import diff_snapshots
ch = diff_snapshots(older, snap)
changes_count = len(ch)
groups_count = len(snap.get("groups", {}))
# embed lightweight summary into snapshot dict so frontend can display it
snap["_groups_count"] = groups_count
snap["_changes_count"] = changes_count # None means first ever snapshot
results.append(SnapshotResponse(
id=row.id,
upstream_id=row.upstream_id,
snapshot=snap,
captured_at=row.captured_at,
))
return results