Files
2026-06-02 13:51:29 +08:00

930 lines
36 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Upstream management CRUD + test + check-now + snapshots."""
from __future__ import annotations
import json
import logging
import re
from datetime import datetime, timezone
from typing import Any, List
logger = logging.getLogger(__name__)
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.admin_user import AdminUser
from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.snapshot import UpstreamRateSnapshot
from app.schemas.upstream import (
GenerateKeysByGroupsRequest,
GenerateKeysByGroupsResponse,
GeneratedUpstreamKeyResponse,
UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult,
UpstreamBatchActionItem, UpstreamBatchActionSummary, UpstreamBatchActionResponse,
)
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret, _extract_key_value
from app.services.snapshot_service import diff_snapshots
from app.services import scheduler as sched_svc
from app.services import webhook_service
from app.services import website_sync
from app.utils.auth import get_current_user
router = APIRouter(prefix="/api/upstreams", tags=["upstreams"])
MASK = "***"
SECRET_KEYS = {"password", "token", "key", "secret"}
def _group_id(group: dict) -> str:
for key in ("id", "group_id", "groupId"):
value = group.get(key)
if value is not None:
return str(value)
return str(group.get("name") or group.get("group_name") or "")
def _group_name(group: dict, gid: str) -> str:
return str(group.get("name") or group.get("group_name") or gid)
def _key_response(row: UpstreamGeneratedKey, include_value: bool = False) -> GeneratedUpstreamKeyResponse:
return GeneratedUpstreamKeyResponse(
id=row.id,
upstream_id=row.upstream_id,
group_id=row.group_id,
group_name=row.group_name,
key_id=row.key_id,
key_name=row.key_name,
key_value=row.key_value if include_value else None,
masked_key=row.masked_key,
status=row.status,
error=row.error,
imported_website_id=row.imported_website_id,
imported_account_id=row.imported_account_id,
imported_at=row.imported_at,
created_at=row.created_at,
updated_at=row.updated_at,
has_key_value=bool(row.key_value),
)
def _mask_auth_config(auth_type: str, cfg: dict) -> dict:
masked = {}
for k, v in cfg.items():
if k.lower() in SECRET_KEYS and v:
masked[k] = MASK
else:
masked[k] = v
return masked
def _extract_plaintext_key(payload: dict[str, Any] | None) -> str:
if not isinstance(payload, dict):
return ""
key_value = _extract_key_value(payload)
if not key_value:
return ""
if "*" in key_value:
return ""
return key_value
def _remote_group_name(rk: dict) -> str:
"""从远端 Key 响应中安全提取分组名(兼容 group 为 dict 的场景)。"""
if rk.get("group_name") and isinstance(rk["group_name"], str):
return rk["group_name"]
group = rk.get("group")
if isinstance(group, dict):
return str(group.get("name") or group.get("group_name") or group.get("id") or "")
if group:
return str(group)
return ""
def _generate_masked_key(key_name: str, key_id: str) -> str:
"""为远端无明文 Key 生成脱敏展示名。"""
suffix = key_name[-8:] if len(key_name) > 8 else key_name
return f"remote:{suffix}:{key_id[-6:] if len(key_id) > 6 else key_id}"
def _to_response(u: Upstream) -> UpstreamResponse:
cfg = json.loads(u.auth_config_json or "{}")
return UpstreamResponse(
id=u.id,
name=u.name,
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config_masked=_mask_auth_config(u.auth_type, cfg),
rate_endpoint=u.rate_endpoint,
groups_endpoint=u.groups_endpoint,
enabled=u.enabled,
check_interval_seconds=u.check_interval_seconds,
timeout_seconds=u.timeout_seconds,
last_status=u.last_status,
last_checked_at=u.last_checked_at,
last_error=u.last_error,
balance=u.balance,
balance_updated_at=u.balance_updated_at,
balance_endpoint=u.balance_endpoint or "",
balance_response_path=u.balance_response_path or "",
balance_divisor=u.balance_divisor or 1.0,
balance_alert_threshold=u.balance_alert_threshold,
created_at=u.created_at,
updated_at=u.updated_at,
)
@router.get("", response_model=List[UpstreamResponse])
def list_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)):
return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()]
@router.get("/{uid}/generated-keys", response_model=List[GeneratedUpstreamKeyResponse])
def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
"""实时拉取上游 Key 列表,与本地数据合并后返回。
远端 Key 按 key_name 与本地 upstream_generated_keys 匹配:
- 远端存在 + 本地有明文 → has_key_value=true,可导入
- 远端存在 + 本地无明文 → has_key_value=false,不可导入(提示重新生成)
- 本地存在但远端不存在 → 不返回(已导入的标记 orphaned 用于审计)
"""
from datetime import datetime, timezone
upstream = db.query(Upstream).filter(Upstream.id == uid).first()
if not upstream:
raise HTTPException(404, "upstream not found")
auth_config = json.loads(upstream.auth_config_json or "{}")
now = datetime.now(timezone.utc)
# 1. 获取本地记录索引(key_name → row
local_rows = (
db.query(UpstreamGeneratedKey)
.filter(UpstreamGeneratedKey.upstream_id == uid)
.all()
)
local_by_keyname: dict[str, UpstreamGeneratedKey] = {}
local_by_keyid: dict[str, UpstreamGeneratedKey] = {}
local_by_groupprefix: dict[tuple[str, str], UpstreamGeneratedKey] = {}
for row in local_rows:
if row.key_name:
local_by_keyname[row.key_name] = row
if row.key_id:
local_by_keyid[row.key_id] = row
if row.group_id and row.managed_prefix:
local_by_groupprefix[(row.group_id, row.managed_prefix)] = row
local_remaining = {row.id for row in local_rows}
# 2. 登录上游,拉取实时 Key 列表
remote_keys_list: list[dict] = []
try:
with UpstreamClient(
base_url=upstream.base_url,
api_prefix=upstream.api_prefix,
auth_type=upstream.auth_type,
auth_config=auth_config,
timeout=float(upstream.timeout_seconds),
) as client:
client.login()
for prefix in website_sync._fetch_remote_managed_prefixes(db, uid):
# list_api_keys 的参数在 Sub2API 中通常是 search 参数
remote_keys_list.extend(client.list_api_keys(search=prefix, status="active"))
except Exception as exc:
logger.warning("list_generated_keys: upstream %s fetch failed: %s", uid, exc)
# 远端不可达时回退到本地数据(已标记 orphaned 的也展示,不隐藏)
rows = (
db.query(UpstreamGeneratedKey)
.filter(UpstreamGeneratedKey.upstream_id == uid)
.order_by(UpstreamGeneratedKey.id.desc())
.limit(200)
.all()
)
return [_key_response(row) for row in rows]
results: list[GeneratedUpstreamKeyResponse] = []
# 按 (group_id, prefix) 分组,每组只保留最新的一条远端 Key
from collections import defaultdict
group_buckets: dict[tuple[str, str], list[dict]] = defaultdict(list)
for rk in remote_keys_list:
key_name = rk.get("name") or ""
if not key_name:
continue
key_id = str(rk.get("id") or "")
group_id = str(rk.get("group_id") or "")
prefix = key_name.split("-")[0] if "-" in key_name else "SmartUp"
group_buckets[(group_id, prefix)].append(rk)
for (group_id, prefix), bucket in group_buckets.items():
# 选择最新的一条(id 最大)
bucket.sort(key=lambda x: int(x.get("id") or 0), reverse=True)
rk = bucket[0]
key_name = rk.get("name") or ""
key_id = str(rk.get("id") or "")
group_name = _remote_group_name(rk)
# 匹配本地记录
local = local_by_keyname.get(key_name) or local_by_keyid.get(key_id) or local_by_groupprefix.get((group_id, prefix))
if local:
local_remaining.discard(local.id)
# 远端存在 + 本地有记录 → 同步更新为选中的最新远端 Key
key_value = _extract_plaintext_key(rk)
local.key_id = key_id
local.key_name = key_name
local.group_name = group_name
if key_value:
local.key_value = key_value
local.masked_key = mask_secret(key_value)
local.raw_json = json.dumps(rk, ensure_ascii=False)
local.updated_at = now
if local.status in ("failed", "import_failed"):
local.status = "exists"
local.error = None
results.append(_key_response(local))
else:
# 远端存在但本地无记录 → 检查是否有明文
key_value = _extract_plaintext_key(rk)
if key_value:
# 有明文 → 新建本地行,并更新索引
local = UpstreamGeneratedKey(
upstream_id=uid,
group_id=group_id,
group_name=group_name,
key_id=key_id,
key_name=key_name,
key_value=key_value,
masked_key=mask_secret(key_value),
raw_json=json.dumps(rk, ensure_ascii=False),
managed_prefix=prefix,
status="created",
created_at=now,
updated_at=now,
)
db.add(local)
db.flush()
db.refresh(local)
local_by_keyname[key_name] = local
local_by_groupprefix[(group_id, prefix)] = local
if key_id:
local_by_keyid[key_id] = local
results.append(_key_response(local))
else:
# 无明文 → 展示为不可导入
results.append(GeneratedUpstreamKeyResponse(
id=None,
upstream_id=uid,
group_id=group_id,
group_name=group_name,
key_id=key_id,
key_name=key_name,
key_value=None,
masked_key=_generate_masked_key(key_name, key_id),
status="remote",
error=None,
imported_website_id=None,
imported_account_id=None,
imported_at=None,
has_key_value=False,
created_at=None,
updated_at=None,
))
# 3. 清理:本地有但远端已不存在的记录
for row_id in local_remaining:
row = next((r for r in local_rows if r.id == row_id), None)
if not row:
continue
if row.imported_website_id and row.imported_account_id:
row.status = "orphaned"
row.error = "远端 Key 已不存在"
row.updated_at = now
logger.info("marked key %s orphaned (not found remotely)", row.id)
else:
db.delete(row)
logger.info("removed key %s (not found remotely)", row.id)
db.commit()
return results[:200]
_generate_key_lock = __import__("threading").Lock()
def _is_sub2api_upstream(upstream: Upstream) -> bool:
return upstream.api_prefix.strip("/") == "api/v1"
def _is_new_api_user_upstream(upstream: Upstream) -> bool:
auth_config = json.loads(upstream.auth_config_json or "{}")
return (
upstream.api_prefix.strip("/") == ""
and (
upstream.groups_endpoint == "/api/user/self/groups"
or auth_config.get("login_path") == "/api/user/login"
or bool(auth_config.get("new_api_user"))
)
)
def _supports_key_generation(upstream: Upstream) -> bool:
return _is_sub2api_upstream(upstream) or _is_new_api_user_upstream(upstream)
def _ensure_group_key(
db: Session,
client: UpstreamClient,
upstream: Upstream,
group: dict[str, Any],
prefix: str,
body: GenerateKeysByGroupsRequest,
) -> GeneratedUpstreamKeyResponse:
"""确保一个上游分组有一个 SmartUp 前缀 Key:存在则 upsert,不存在则创建。"""
gid = _group_id(group)
gname = _group_name(group, gid)
# 使用稳定的 upstream_id + group_id 而非可变名称,避免因改名产生重复
# 可读 Key 名:{prefix}-{upstream.id}-{安全的分组名}-{group_id}
safe_group_name = re.sub(r"[^a-zA-Z0-9\u4e00-\u9fff_-]", "", gname)[:30] if gname else gid
stable_name = f"{prefix}-{upstream.id}-{safe_group_name}-{gid}"
with _generate_key_lock:
try:
# 1. 先查本地是否已有该分组的托管 Key(兼容迁移前无 managed_prefix 的记录)
row = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == gid,
(UpstreamGeneratedKey.managed_prefix == prefix)
| ((UpstreamGeneratedKey.managed_prefix.is_(None))
& UpstreamGeneratedKey.key_name.like(f"{prefix}-%")),
)
.first()
)
if row and row.key_id:
# 本地已有记录,检查远端是否仍存在
try:
existing = client.find_smartup_group_key(gid, stable_name, prefix)
except Exception:
existing = None
if existing:
key_id = str(existing.get("id") or "")
key_value = _extract_plaintext_key(existing)
masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "")
row.key_id = key_id or row.key_id
if key_value:
row.key_value = key_value
row.masked_key = masked
elif masked:
row.masked_key = str(masked)
row.raw_json = json.dumps(existing, ensure_ascii=False)
row.status = "exists"
row.updated_at = datetime.now(timezone.utc)
db.add(row)
db.commit()
db.refresh(row)
return _key_response(row, include_value=False)
# 远端不存在,需要重新创建
row.status = "replaced"
# 2. 查远端是否有同名 Key(防止并发时另一个请求已创建)
existing = client.find_smartup_group_key(gid, stable_name, prefix)
if existing:
key_id = str(existing.get("id") or "")
key_value = _extract_plaintext_key(existing)
masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "")
if row:
row.key_id = key_id or row.key_id
if key_value:
row.key_value = key_value
row.masked_key = masked
elif masked:
row.masked_key = str(masked)
row.raw_json = json.dumps(existing, ensure_ascii=False)
row.status = "exists"
row.updated_at = datetime.now(timezone.utc)
else:
row = UpstreamGeneratedKey(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_id=key_id or None,
key_name=stable_name,
key_value=key_value or "",
masked_key=masked,
raw_json=json.dumps(existing, ensure_ascii=False),
managed_prefix=prefix,
status="exists",
)
db.add(row)
db.commit()
db.refresh(row)
return _key_response(row, include_value=False)
# 3. 远端不存在,创建新 Key
created = client.create_api_key(
stable_name,
gid,
quota=body.quota,
expires_in_days=body.expires_in_days,
rate_limit_5h=body.rate_limit_5h,
rate_limit_1d=body.rate_limit_1d,
rate_limit_7d=body.rate_limit_7d,
endpoint=body.endpoint,
)
if row:
# 复用旧行
row.key_id = created.get("id") or None
row.key_name = stable_name
row.key_value = created["key"]
row.masked_key = created.get("masked_key") or mask_secret(created["key"])
row.raw_json = json.dumps(created.get("raw") or {}, ensure_ascii=False)
row.managed_prefix = prefix
row.status = "created"
row.error = None
else:
row = UpstreamGeneratedKey(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_id=created.get("id") or None,
key_name=stable_name,
key_value=created["key"],
masked_key=created.get("masked_key") or mask_secret(created["key"]),
raw_json=json.dumps(created.get("raw") or {}, ensure_ascii=False),
managed_prefix=prefix,
status="created",
)
db.add(row)
db.commit()
db.refresh(row)
return _key_response(row, include_value=True)
except Exception as exc:
logger.exception("ensure group key failed for upstream=%s group=%s", upstream.id, gid)
return GeneratedUpstreamKeyResponse(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_name=stable_name,
status="failed",
error=str(exc),
)
@router.post("/{uid}/keys/generate-by-groups", response_model=GenerateKeysByGroupsResponse)
def generate_keys_by_groups(
uid: int,
body: GenerateKeysByGroupsRequest,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
if not _supports_key_generation(u):
raise HTTPException(400, "仅支持 Sub2API 或 New-API 普通账号上游生成 Key")
# 生成前先对账,清理远端已删除的旧 Key
try:
website_sync.reconcile_upstream_keys_full(db, uid)
except Exception as exc:
logger.warning("generate_keys_by_groups reconcile failed for %s: %s", uid, exc)
auth_config = json.loads(u.auth_config_json or "{}")
selected = set(body.group_ids)
prefix = body.name_prefix
results: list[GeneratedUpstreamKeyResponse] = []
with UpstreamClient(
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config=auth_config,
timeout=float(u.timeout_seconds),
) as client:
try:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
except Exception as exc:
raise HTTPException(502, str(exc))
for group in groups:
gid = _group_id(group)
if not gid or (selected and gid not in selected):
continue
result = _ensure_group_key(db, client, u, group, prefix, body)
results.append(result)
created = len([item for item in results if item.status == "created"])
existed = len([item for item in results if item.status == "exists"])
total = len(results)
msg_parts = []
if created:
msg_parts.append(f"新创建 {created}")
if existed:
msg_parts.append(f"已存在 {existed}")
msg = "".join(msg_parts) + f" / 共 {total} 个分组" if msg_parts else f"共处理 {total} 个分组"
return GenerateKeysByGroupsResponse(
success=total > 0 and all(item.status != "failed" for item in results),
message=msg,
items=results,
)
# ─── Shared core helpers ──────────────────────────────────────────────────────
def _test_upstream_core(db: Session, u: Upstream) -> UpstreamBatchActionItem:
"""连接测试核心逻辑,含余额拉取。与单行 test_upstream 行为一致。"""
auth_config = json.loads(u.auth_config_json or "{}")
with UpstreamClient(
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config=auth_config,
timeout=float(u.timeout_seconds),
) as client:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
# 余额(与单行 test_upstream 保持一致)
if u.balance_endpoint and u.balance_response_path:
try:
raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path)
if raw_balance is not None:
u.balance = raw_balance / (u.balance_divisor or 1.0)
u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None
except Exception as exc:
logger.warning("test-all: upstream %s balance failed: %s", u.name, exc)
u.last_status = "healthy"
u.last_error = None
u.last_checked_at = datetime.now(timezone.utc)
u.consecutive_failures = 0
db.commit()
return UpstreamBatchActionItem(
upstream_id=u.id,
upstream_name=u.name,
status="success",
message=f"连接成功,获取到 {len(groups)} 个分组",
)
def _check_now_core(db: Session, u: Upstream) -> tuple[str, bool]:
"""完整同步核心逻辑:写快照、对比倍率、发 Webhook、同步 Key/优先级。
Returns:
(message, was_changed) — message 供调用方组装返回体。
"""
uid = u.id
auth_config = json.loads(u.auth_config_json or "{}")
with UpstreamClient(
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config=auth_config,
timeout=float(u.timeout_seconds),
) as client:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
raw_rates = client.get_group_rates(u.rate_endpoint)
snapshot = build_snapshot(uid, u.base_url, u.api_prefix, groups, raw_rates)
# 余额(可选)
if u.balance_endpoint and u.balance_response_path:
try:
raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path)
if raw_balance is not None:
u.balance = raw_balance / (u.balance_divisor or 1.0)
u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None
except Exception as exc:
logger.warning("check-now: upstream %s balance failed: %s", u.name, exc)
# 写快照 & diff
prev_row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == uid)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.first()
)
previous = json.loads(prev_row.snapshot_json) if prev_row else None
changes = diff_snapshots(previous, snapshot)
new_row = UpstreamRateSnapshot(
upstream_id=uid,
snapshot_json=json.dumps(snapshot, ensure_ascii=False),
captured_at=datetime.now(timezone.utc),
)
db.add(new_row)
was_unhealthy = u.last_status == "unhealthy"
u.last_status = "healthy"
u.last_checked_at = datetime.now(timezone.utc)
u.last_error = None
u.consecutive_failures = 0
db.commit()
if was_unhealthy:
webhook_service.send_status_event(db, uid, u.name, u.base_url, "upstream_recovered")
# 先同步 Key 状态(标记 orphaned),再执行优先级同步
from app.services.scheduler import _sync_upstream_keys as _synck
_synck(uid, snapshot, new_row.captured_at)
if changes:
webhook_service.send_rate_changed(db, uid, u.name, u.base_url, changes)
website_sync.sync_affected_bindings(db, uid, changes)
website_sync.sync_account_priorities_for_upstream(db, uid)
msg = f"检测成功,{len(groups)} 个分组"
if changes:
msg += f",发现 {len(changes)} 处倍率变化"
elif previous is None:
msg += ",初始化快照完成"
else:
msg += ",无变化"
return msg, bool(changes)
@router.post("/test-all", response_model=UpstreamBatchActionResponse)
def test_all_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)):
"""批量测试所有启用上游的连接(不写快照,不触发 Webhook)。"""
upstreams = db.query(Upstream).order_by(Upstream.id).all()
items: list[UpstreamBatchActionItem] = []
success_count = failed_count = skipped_count = 0
for u in upstreams:
if not u.enabled:
items.append(UpstreamBatchActionItem(
upstream_id=u.id,
upstream_name=u.name,
status="skipped",
message="上游已停用,跳过",
))
skipped_count += 1
continue
try:
item = _test_upstream_core(db, u)
items.append(item)
success_count += 1
except Exception as exc:
u.last_status = "unhealthy"
u.last_error = str(exc)
u.last_checked_at = datetime.now(timezone.utc)
u.consecutive_failures = (u.consecutive_failures or 0) + 1
db.commit()
items.append(UpstreamBatchActionItem(
upstream_id=u.id,
upstream_name=u.name,
status="failed",
message="连接失败",
detail=str(exc),
))
failed_count += 1
total = len(upstreams)
overall_ok = failed_count == 0
msg = f"完成:{success_count} 成功 / {failed_count} 失败 / {skipped_count} 跳过"
return UpstreamBatchActionResponse(
success=overall_ok,
message=msg,
summary=UpstreamBatchActionSummary(
total=total,
success=success_count,
failed=failed_count,
skipped=skipped_count,
),
items=items,
)
@router.post("/check-now-all", response_model=UpstreamBatchActionResponse)
def check_now_all_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)):
"""批量全量同步所有启用上游:拉取倍率 → 写快照 → 对比变化 → Webhook → 同步 Key。"""
upstreams = db.query(Upstream).order_by(Upstream.id).all()
items: list[UpstreamBatchActionItem] = []
success_count = failed_count = skipped_count = 0
for u in upstreams:
if not u.enabled:
items.append(UpstreamBatchActionItem(
upstream_id=u.id,
upstream_name=u.name,
status="skipped",
message="上游已停用,跳过",
))
skipped_count += 1
continue
try:
detail_msg, _ = _check_now_core(db, u)
items.append(UpstreamBatchActionItem(
upstream_id=u.id,
upstream_name=u.name,
status="success",
message=detail_msg,
))
success_count += 1
except Exception as exc:
u.consecutive_failures = (u.consecutive_failures or 0) + 1
u.last_error = str(exc)
u.last_checked_at = datetime.now(timezone.utc)
db.commit()
items.append(UpstreamBatchActionItem(
upstream_id=u.id,
upstream_name=u.name,
status="failed",
message="同步失败",
detail=str(exc),
))
failed_count += 1
total = len(upstreams)
overall_ok = failed_count == 0
msg = f"完成:{success_count} 成功 / {failed_count} 失败 / {skipped_count} 跳过"
return UpstreamBatchActionResponse(
success=overall_ok,
message=msg,
summary=UpstreamBatchActionSummary(
total=total,
success=success_count,
failed=failed_count,
skipped=skipped_count,
),
items=items,
)
# ─── CRUD ─────────────────────────────────────────────────────────────────────
@router.post("", response_model=UpstreamResponse, status_code=201)
def create_upstream(
body: UpstreamCreate,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
u = Upstream(
name=body.name,
base_url=body.base_url.rstrip("/"),
api_prefix=body.api_prefix,
auth_type=body.auth_type,
auth_config_json=json.dumps(body.auth_config, ensure_ascii=False),
rate_endpoint=body.rate_endpoint,
groups_endpoint=body.groups_endpoint,
enabled=body.enabled,
check_interval_seconds=body.check_interval_seconds,
timeout_seconds=body.timeout_seconds,
balance_endpoint=body.balance_endpoint,
balance_response_path=body.balance_response_path,
balance_divisor=body.balance_divisor,
balance_alert_threshold=body.balance_alert_threshold,
)
db.add(u)
db.commit()
db.refresh(u)
sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled)
return _to_response(u)
@router.get("/{uid}", response_model=UpstreamResponse)
def get_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
return _to_response(u)
@router.put("/{uid}", response_model=UpstreamResponse)
def update_upstream(
uid: int,
body: UpstreamUpdate,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
data = body.model_dump(exclude_none=True)
if "auth_config" in data:
# merge with existing config to avoid overwriting masked fields
existing = json.loads(u.auth_config_json or "{}")
incoming = data.pop("auth_config")
for k, v in incoming.items():
if v != MASK: # don't overwrite with mask placeholder
existing[k] = v
u.auth_config_json = json.dumps(existing, ensure_ascii=False)
if "base_url" in data:
data["base_url"] = data["base_url"].rstrip("/")
for k, v in data.items():
setattr(u, k, v)
# Reset failure counter on any update — the user may have fixed the issue
u.consecutive_failures = 0
u.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(u)
sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled)
return _to_response(u)
@router.delete("/{uid}", status_code=204)
def delete_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
sched_svc.refresh_upstream(uid, 0, False) # remove job
db.delete(u)
db.commit()
@router.post("/{uid}/test", response_model=TestResult)
def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
try:
item = _test_upstream_core(db, u)
return TestResult(success=True, message=item.message)
except Exception as exc:
u.last_status = "unhealthy"
u.last_error = str(exc)
u.last_checked_at = datetime.now(timezone.utc)
u.consecutive_failures = (u.consecutive_failures or 0) + 1
db.commit()
return TestResult(success=False, message="连接失败", detail=str(exc))
@router.post("/{uid}/check-now", response_model=TestResult)
def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
try:
msg, _ = _check_now_core(db, u)
return TestResult(success=True, message=msg)
except Exception as exc:
u.consecutive_failures = (u.consecutive_failures or 0) + 1
u.last_error = str(exc)
u.last_checked_at = datetime.now(timezone.utc)
db.commit()
return TestResult(success=False, message="检测失败", detail=str(exc))
@router.get("/{uid}/snapshots/latest", response_model=SnapshotResponse)
def latest_snapshot(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == uid)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.first()
)
if not row:
raise HTTPException(404, "no snapshot found")
return SnapshotResponse(
id=row.id,
upstream_id=row.upstream_id,
snapshot=json.loads(row.snapshot_json),
captured_at=row.captured_at,
)
from fastapi import Query as QueryParam
@router.get("/{uid}/snapshots", response_model=List[SnapshotResponse])
def list_snapshots(
uid: int,
limit: int = QueryParam(20, le=100),
offset: int = QueryParam(0),
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
"""Return paginated snapshot history with diff vs previous snapshot embedded."""
rows = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == uid)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.offset(offset)
.limit(limit + 1) # fetch one extra to get the "previous" for diffing
.all()
)
# We need the snapshot just before each one to compute changes count.
# rows are desc order; rows[i+1] is older than rows[i]
results = []
for i, row in enumerate(rows[:limit]):
snap = json.loads(row.snapshot_json)
# try to diff against the next row (which is older)
changes_count: int | None = None
if i + 1 < len(rows):
older = json.loads(rows[i + 1].snapshot_json)
from app.services.snapshot_service import diff_snapshots
ch = diff_snapshots(older, snap)
changes_count = len(ch)
groups_count = len(snap.get("groups", {}))
# embed lightweight summary into snapshot dict so frontend can display it
snap["_groups_count"] = groups_count
snap["_changes_count"] = changes_count # None means first ever snapshot
results.append(SnapshotResponse(
id=row.id,
upstream_id=row.upstream_id,
snapshot=snap,
captured_at=row.captured_at,
))
return results