feat: 上游 Key 唯一化、分组导入跳过、账号导入平台识别&远端校验&base_url 注入

- 上游 Key 命名改为 {prefix}-{upstream.id}-{safe_group_name}-{group_id}
- 唯一约束 (upstream_id, group_id, managed_prefix) 加 managed_prefix 列
- 上游检测成功时同步 Key 状态,远端已删/分组已删自动清理
- 重复分组导入跳过,目标网站已存在同名分组返回 exists
- 账号导入平台自动识别(auto/manual 模式)
- 全选可导入 Key 按钮 + 目标分组自动匹配
- 导入幂等:已导入过的 Key 校验远端账号,不存在则重建
- 新增同步接口 POST /sync-imported-upstream-keys
- account_exists() 通过拉取账号列表判断,避免 404 误判
- credentials.base_url 注入来源上游地址,避免 401
- 前端导入弹窗自动同步+刷新按钮+并发/优先级设置
- 新增 12 个测试覆盖同步、幂等、远端删除、校验失败路径
This commit is contained in:
liumangmang
2026-05-21 01:16:39 +08:00
parent 0a27bba296
commit 6044b00685
18 changed files with 3112 additions and 50 deletions
+234 -1
View File
@@ -3,6 +3,7 @@ from __future__ import annotations
import json
import logging
import re
from datetime import datetime, timezone
from typing import List
@@ -14,11 +15,15 @@ from sqlalchemy.orm import Session
from app.database import get_db
from app.models.admin_user import AdminUser
from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.snapshot import UpstreamRateSnapshot
from app.schemas.upstream import (
GenerateKeysByGroupsRequest,
GenerateKeysByGroupsResponse,
GeneratedUpstreamKeyResponse,
UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult
)
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret
from app.services.snapshot_service import diff_snapshots
from app.services import scheduler as sched_svc
from app.services import webhook_service
@@ -31,6 +36,38 @@ MASK = "***"
SECRET_KEYS = {"password", "token", "key", "secret"}
def _group_id(group: dict) -> str:
for key in ("id", "group_id", "groupId"):
value = group.get(key)
if value is not None:
return str(value)
return str(group.get("name") or group.get("group_name") or "")
def _group_name(group: dict, gid: str) -> str:
return str(group.get("name") or group.get("group_name") or gid)
def _key_response(row: UpstreamGeneratedKey, include_value: bool = False) -> GeneratedUpstreamKeyResponse:
return GeneratedUpstreamKeyResponse(
id=row.id,
upstream_id=row.upstream_id,
group_id=row.group_id,
group_name=row.group_name,
key_id=row.key_id,
key_name=row.key_name,
key_value=row.key_value if include_value else None,
masked_key=row.masked_key,
status=row.status,
error=row.error,
imported_website_id=row.imported_website_id,
imported_account_id=row.imported_account_id,
imported_at=row.imported_at,
created_at=row.created_at,
updated_at=row.updated_at,
)
def _mask_auth_config(auth_type: str, cfg: dict) -> dict:
masked = {}
for k, v in cfg.items():
@@ -73,6 +110,198 @@ def list_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)):
return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()]
@router.get("/{uid}/generated-keys", response_model=List[GeneratedUpstreamKeyResponse])
def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
if not db.query(Upstream.id).filter(Upstream.id == uid).first():
raise HTTPException(404, "upstream not found")
rows = (
db.query(UpstreamGeneratedKey)
.filter(UpstreamGeneratedKey.upstream_id == uid)
.order_by(UpstreamGeneratedKey.id.desc())
.limit(200)
.all()
)
return [_key_response(row) for row in rows]
_generate_key_lock = __import__("threading").Lock()
def _ensure_group_key(
db: Session,
client: UpstreamClient,
upstream: Upstream,
group: dict[str, Any],
prefix: str,
body: GenerateKeysByGroupsRequest,
) -> GeneratedUpstreamKeyResponse:
"""确保一个上游分组有一个 SmartUp 前缀 Key:存在则 upsert,不存在则创建。"""
gid = _group_id(group)
gname = _group_name(group, gid)
# 使用稳定的 upstream_id + group_id 而非可变名称,避免因改名产生重复
# 可读 Key 名:{prefix}-{upstream.id}-{安全的分组名}-{group_id}
safe_group_name = re.sub(r"[^a-zA-Z0-9\u4e00-\u9fff_-]", "", gname)[:30] if gname else gid
stable_name = f"{prefix}-{upstream.id}-{safe_group_name}-{gid}"
with _generate_key_lock:
try:
# 1. 先查本地是否已有该分组的托管 Key(兼容迁移前无 managed_prefix 的记录)
row = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == gid,
(UpstreamGeneratedKey.managed_prefix == prefix)
| ((UpstreamGeneratedKey.managed_prefix.is_(None))
& UpstreamGeneratedKey.key_name.like(f"{prefix}-%")),
)
.first()
)
if row and row.key_id:
# 本地已有记录,检查远端是否仍存在
try:
existing = client.find_smartup_group_key(gid, stable_name, prefix)
except Exception:
existing = None
if existing:
row.status = "exists"
row.updated_at = datetime.now(timezone.utc)
db.commit()
return _key_response(row, include_value=False)
# 远端不存在,需要重新创建
row.status = "replaced"
# 2. 查远端是否有同名 Key(防止并发时另一个请求已创建)
existing = client.find_smartup_group_key(gid, stable_name, prefix)
if existing:
key_id = str(existing.get("id") or "")
masked = existing.get("masked_key") or existing.get("key") or ""
if row:
row.key_id = key_id or row.key_id
row.masked_key = masked or row.masked_key
row.status = "exists"
row.updated_at = datetime.now(timezone.utc)
else:
row = UpstreamGeneratedKey(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_id=key_id or None,
key_name=stable_name,
key_value="",
masked_key=masked,
raw_json=json.dumps(existing, ensure_ascii=False),
managed_prefix=prefix,
status="exists",
)
db.add(row)
db.commit()
db.refresh(row)
return _key_response(row, include_value=False)
# 3. 远端不存在,创建新 Key
created = client.create_api_key(
stable_name,
gid,
quota=body.quota,
expires_in_days=body.expires_in_days,
rate_limit_5h=body.rate_limit_5h,
rate_limit_1d=body.rate_limit_1d,
rate_limit_7d=body.rate_limit_7d,
endpoint=body.endpoint,
)
if row:
# 复用旧行
row.key_id = created.get("id") or None
row.key_name = stable_name
row.key_value = created["key"]
row.masked_key = created.get("masked_key") or mask_secret(created["key"])
row.raw_json = json.dumps(created.get("raw") or {}, ensure_ascii=False)
row.managed_prefix = prefix
row.status = "created"
row.error = None
else:
row = UpstreamGeneratedKey(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_id=created.get("id") or None,
key_name=stable_name,
key_value=created["key"],
masked_key=created.get("masked_key") or mask_secret(created["key"]),
raw_json=json.dumps(created.get("raw") or {}, ensure_ascii=False),
managed_prefix=prefix,
status="created",
)
db.add(row)
db.commit()
db.refresh(row)
return _key_response(row, include_value=True)
except Exception as exc:
logger.exception("ensure group key failed for upstream=%s group=%s", upstream.id, gid)
return GeneratedUpstreamKeyResponse(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_name=stable_name,
status="failed",
error=str(exc),
)
@router.post("/{uid}/keys/generate-by-groups", response_model=GenerateKeysByGroupsResponse)
def generate_keys_by_groups(
uid: int,
body: GenerateKeysByGroupsRequest,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
if u.api_prefix.strip("/") != "api/v1":
raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1")
auth_config = json.loads(u.auth_config_json or "{}")
selected = set(body.group_ids)
prefix = body.name_prefix
results: list[GeneratedUpstreamKeyResponse] = []
with UpstreamClient(
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config=auth_config,
timeout=float(u.timeout_seconds),
) as client:
try:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
except Exception as exc:
raise HTTPException(502, str(exc))
for group in groups:
gid = _group_id(group)
if not gid or (selected and gid not in selected):
continue
result = _ensure_group_key(db, client, u, group, prefix, body)
results.append(result)
created = len([item for item in results if item.status == "created"])
existed = len([item for item in results if item.status == "exists"])
total = len(results)
msg_parts = []
if created:
msg_parts.append(f"新创建 {created}")
if existed:
msg_parts.append(f"已存在 {existed}")
msg = "".join(msg_parts) + f" / 共 {total} 个分组" if msg_parts else f"共处理 {total} 个分组"
return GenerateKeysByGroupsResponse(
success=total > 0 and all(item.status != "failed" for item in results),
message=msg,
items=results,
)
@router.post("", response_model=UpstreamResponse, status_code=201)
def create_upstream(
body: UpstreamCreate,
@@ -255,6 +484,10 @@ def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_use
webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes)
website_sync.sync_affected_bindings(db, u.id, changes)
# 同步 SmartUp Key 状态(使用实际快照入库时间,与定时任务一致)
from app.services.scheduler import _sync_upstream_keys as _synck
_synck(uid, snapshot, new_row.captured_at)
msg = f"检测成功,{len(groups)} 个分组"
if changes:
msg += f",发现 {len(changes)} 处倍率变化"