feat: 上游 Key 唯一化、分组导入跳过、账号导入平台识别&远端校验&base_url 注入

- 上游 Key 命名改为 {prefix}-{upstream.id}-{safe_group_name}-{group_id}
- 唯一约束 (upstream_id, group_id, managed_prefix) 加 managed_prefix 列
- 上游检测成功时同步 Key 状态,远端已删/分组已删自动清理
- 重复分组导入跳过,目标网站已存在同名分组返回 exists
- 账号导入平台自动识别(auto/manual 模式)
- 全选可导入 Key 按钮 + 目标分组自动匹配
- 导入幂等:已导入过的 Key 校验远端账号,不存在则重建
- 新增同步接口 POST /sync-imported-upstream-keys
- account_exists() 通过拉取账号列表判断,避免 404 误判
- credentials.base_url 注入来源上游地址,避免 401
- 前端导入弹窗自动同步+刷新按钮+并发/优先级设置
- 新增 12 个测试覆盖同步、幂等、远端删除、校验失败路径
This commit is contained in:
liumangmang
2026-05-21 01:16:39 +08:00
parent 0a27bba296
commit 6044b00685
18 changed files with 3112 additions and 50 deletions
+420
View File
@@ -9,11 +9,21 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.snapshot import UpstreamRateSnapshot
from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.schemas.website import (
BindingCreate,
BindingResponse,
BindingUpdate,
ImportAccountItem,
ImportAccountsRequest,
ImportAccountsResponse,
ImportGroupItem,
ImportGroupsRequest,
ImportGroupsResponse,
SyncImportStatusRequest,
TestResult,
WebsiteCreate,
WebsiteGroupResponse,
@@ -118,6 +128,47 @@ def _client(row: Website) -> Sub2ApiWebsiteClient:
)
def _latest_upstream_groups(db: Session, upstream_id: int) -> list[dict]:
row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.first()
)
if not row:
raise HTTPException(404, "no upstream snapshot found; run upstream check first")
snapshot = json.loads(row.snapshot_json or "{}")
groups = snapshot.get("groups") or {}
if not isinstance(groups, dict):
return []
return [item for item in groups.values() if isinstance(item, dict)]
def _source_group_id(group: dict) -> str:
return str(group.get("group_id") or group.get("id") or group.get("name") or "")
def _source_group_name(group: dict, gid: str) -> str:
return str(group.get("group_name") or group.get("name") or gid)
def _source_group_rate(group: dict) -> float:
raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1
try:
return float(raw)
except (TypeError, ValueError):
return 1.0
def _numeric_group_id(value: str | None) -> int | None:
if value is None or value == "":
return None
try:
return int(value)
except ValueError:
return None
@router.get("/api/websites", response_model=List[WebsiteResponse])
def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)):
return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()]
@@ -215,6 +266,375 @@ def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_c
raise HTTPException(502, str(exc))
@router.post("/api/websites/{wid}/groups/import-from-upstream/{upstream_id}", response_model=ImportGroupsResponse)
def import_groups_from_upstream(
wid: int,
upstream_id: int,
body: ImportGroupsRequest,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
website = db.query(Website).filter(Website.id == wid).first()
if not website:
raise HTTPException(404, "website not found")
if website.site_type != "sub2api":
raise HTTPException(400, "目前只支持 sub2api")
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
if not upstream:
raise HTTPException(404, "upstream not found")
selected = set(body.group_ids)
groups = _latest_upstream_groups(db, upstream_id)
# 拉取目标网站已有分组,同名则跳过
try:
existing_names = set()
with _client(website) as c:
for eg in c.get_groups(website.groups_endpoint):
gname = eg.get("name") or eg.get("group_name") or ""
if gname:
existing_names.add(gname)
except Exception:
existing_names = set()
items: list[ImportGroupItem] = []
with _client(website) as c:
for group in groups:
source_gid = _source_group_id(group)
if not source_gid or (selected and source_gid not in selected):
continue
source_name = _source_group_name(group, source_gid)
target_name = f"{body.name_prefix}{source_name}" if body.name_prefix else source_name
# 检查是否已存在同名分组
if target_name in existing_names:
items.append(ImportGroupItem(
source_group_id=source_gid,
source_group_name=source_name,
target_group_name=target_name,
status="exists",
message="目标分组已存在,已跳过",
))
continue
create_body = {
"name": target_name,
"description": group.get("description") or f"Imported from {upstream.name} / {source_name}",
"platform": group.get("platform") or "openai",
"rate_multiplier": _source_group_rate(group),
}
if group.get("rpm_limit") is not None:
create_body["rpm_limit"] = group.get("rpm_limit")
try:
created = c.create_group(create_body)
target_id = c.extract_id(created)
items.append(ImportGroupItem(
source_group_id=source_gid,
source_group_name=source_name,
target_group_id=target_id or None,
target_group_name=str(created.get("name") or target_name),
status="created",
message="已创建",
raw=created,
))
except Exception as exc:
msg = str(exc)
# 捕获 409 等已存在错误
if "已存在" in msg or "already exists" in msg.lower() or "409" in msg or "Conflict" in msg:
items.append(ImportGroupItem(
source_group_id=source_gid,
source_group_name=source_name,
target_group_name=target_name,
status="exists",
message="目标分组已存在(接口返回冲突)",
))
else:
logger.exception("import website group failed website=%s upstream=%s group=%s", wid, upstream_id, source_gid)
items.append(ImportGroupItem(
source_group_id=source_gid,
source_group_name=source_name,
target_group_name=target_name,
status="failed",
message=msg,
))
created_count = len([item for item in items if item.status == "created"])
exists_count = len([item for item in items if item.status == "exists"])
failed_count = len([item for item in items if item.status == "failed"])
msg_parts = []
if created_count:
msg_parts.append(f"新建 {created_count}")
if exists_count:
msg_parts.append(f"已存在 {exists_count}")
if failed_count:
msg_parts.append(f"失败 {failed_count}")
return ImportGroupsResponse(
success=failed_count == 0,
message="".join(msg_parts) + f" / 共 {len(items)}" if msg_parts else f"共处理 {len(items)} 个分组",
items=items,
)
def _detect_platform(text: str, fallback: str = "openai") -> str:
"""根据 Key 名或分组名关键词判断平台类型。"""
lower = text.lower()
if "claude" in lower or "anthropic" in lower:
return "anthropic"
if "gemini" in lower:
return "gemini"
if "antigravity" in lower:
return "antigravity"
return fallback
@router.post("/api/websites/{wid}/accounts/sync-imported-upstream-keys", response_model=ImportAccountsResponse)
def sync_imported_upstream_keys(
wid: int,
body: SyncImportStatusRequest,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
"""校验已导入的上游 Key 在目标 Sub2API 账号管理中是否仍存在。"""
website = db.query(Website).filter(Website.id == wid).first()
if not website:
raise HTTPException(404, "website not found")
rows = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == body.upstream_id,
UpstreamGeneratedKey.imported_website_id == wid,
UpstreamGeneratedKey.imported_account_id.isnot(None),
)
.all()
)
items: list[ImportAccountItem] = []
with _client(website) as c:
for row in rows:
platform = _detect_platform(f"{row.group_name} {row.group_id} {row.key_name}", "openai")
if not row.imported_account_id:
continue
old_account_id = row.imported_account_id
exists = c.account_exists(row.imported_account_id)
if exists is False:
row.imported_website_id = None
row.imported_account_id = None
row.imported_at = None
row.status = "created"
items.append(ImportAccountItem(
upstream_key_id=row.id, source_group_id=row.group_id,
source_group_name=row.group_name, account_id=old_account_id,
platform=platform, status="stale_cleared",
message="目标账号已删除,已清除导入标记",
))
elif exists is True:
items.append(ImportAccountItem(
upstream_key_id=row.id, source_group_id=row.group_id,
source_group_name=row.group_name, account_id=old_account_id,
platform=platform, status="exists", message="目标账号仍存在",
))
else:
items.append(ImportAccountItem(
upstream_key_id=row.id, source_group_id=row.group_id,
source_group_name=row.group_name, account_id=old_account_id,
platform=platform, status="check_failed",
message="无法校验目标账号存在性(目标网站认证/网络问题)",
))
db.commit()
cleared_count = len([i for i in items if i.status == "stale_cleared"])
check_failed_count = len([i for i in items if i.status == "check_failed"])
msg_parts = []
if cleared_count:
msg_parts.append(f"清除 {cleared_count}")
if check_failed_count:
msg_parts.append(f"校验失败 {check_failed_count}")
return ImportAccountsResponse(
success=check_failed_count == 0,
message="".join(msg_parts) + f" / 共 {len(items)}" if msg_parts else f"共校验 {len(items)} 个,无变化",
items=items,
)
@router.post("/api/websites/{wid}/accounts/import-upstream-keys", response_model=ImportAccountsResponse)
def import_upstream_keys_as_accounts(
wid: int,
body: ImportAccountsRequest,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
website = db.query(Website).filter(Website.id == wid).first()
if not website:
raise HTTPException(404, "website not found")
if website.site_type != "sub2api":
raise HTTPException(400, "目前只支持 sub2api")
if not body.upstream_key_ids:
raise HTTPException(400, "请选择要导入的 Key")
rows = (
db.query(UpstreamGeneratedKey)
.filter(UpstreamGeneratedKey.id.in_(body.upstream_key_ids))
.order_by(UpstreamGeneratedKey.id)
.all()
)
found_ids = {row.id for row in rows}
missing_ids = [kid for kid in body.upstream_key_ids if kid not in found_ids]
items: list[ImportAccountItem] = [
ImportAccountItem(
upstream_key_id=kid,
source_group_id="",
source_group_name="",
platform=body.default_platform,
status="failed",
message="key not found",
)
for kid in missing_ids
]
# 查出来源上游的 Base URL
upstream_base_url = ""
if body.upstream_key_ids:
first_row = rows[0] if rows else None
if first_row:
from app.models.upstream import Upstream as _Up
_u = db.query(_Up).filter(_Up.id == first_row.upstream_id).first()
if _u:
upstream_base_url = _u.base_url
with _client(website) as c:
for row in rows:
# 先确定平台(失败项也需要记录)
if body.platform_mode == "auto":
platform = _detect_platform(
f"{row.group_name} {row.group_id} {row.key_name}",
body.default_platform,
)
else:
platform = body.default_platform
# 幂等校验:已导入过则检查远端账号是否仍存在
if row.imported_website_id == wid and row.imported_account_id:
old_account_id = row.imported_account_id
exists = c.account_exists(row.imported_account_id)
if exists is True:
items.append(ImportAccountItem(
upstream_key_id=row.id,
source_group_id=row.group_id,
source_group_name=row.group_name,
target_group_id=body.target_group_map.get(row.group_id),
account_id=old_account_id,
account_name=f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}",
platform=platform,
upstream_base_url=upstream_base_url,
status="exists",
message="已导入过,已跳过",
))
continue
elif exists is False:
# 远端已删除,清空标记后继续创建
row.imported_website_id = None
row.imported_account_id = None
row.imported_at = None
row.status = "created"
# 继续往下走(不 continue
else:
# 校验失败,保守跳过
items.append(ImportAccountItem(
upstream_key_id=row.id,
source_group_id=row.group_id,
source_group_name=row.group_name,
target_group_id=body.target_group_map.get(row.group_id),
account_id=old_account_id,
platform=platform,
status="check_failed",
message="无法校验目标账号状态,已保守跳过",
))
continue
if not row.key_value:
items.append(ImportAccountItem(
upstream_key_id=row.id,
source_group_id=row.group_id,
source_group_name=row.group_name,
platform=platform,
status="failed",
message="该 Key 无明文值,无法导入(远端已存在 Key 不会保留明文,请重新创建或手动填入)",
))
continue
target_group_id = body.target_group_map.get(row.group_id)
group_ids = []
numeric_target = _numeric_group_id(target_group_id)
if numeric_target is not None:
group_ids.append(numeric_target)
account_name = f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}"
account_body = {
"name": account_name,
"platform": platform,
"type": "apikey",
"credentials": {
"api_key": row.key_value,
"base_url": upstream_base_url,
},
"group_ids": group_ids,
"rate_multiplier": 1,
"concurrency": body.concurrency,
"priority": body.priority,
"notes": f"Imported by SmartUp from upstream key #{row.id}",
}
try:
created = c.create_account(account_body)
account_id = c.extract_id(created)
row.imported_website_id = wid
row.imported_account_id = account_id or None
row.imported_at = datetime.now(timezone.utc)
row.status = "imported"
row.error = None
db.commit()
items.append(ImportAccountItem(
upstream_key_id=row.id,
source_group_id=row.group_id,
source_group_name=row.group_name,
target_group_id=target_group_id,
account_id=account_id or None,
account_name=str(created.get("name") or account_name),
platform=platform,
upstream_base_url=upstream_base_url,
status="created",
message="已创建账号",
raw=created,
))
except Exception as exc:
logger.exception("import upstream key as account failed website=%s key=%s", wid, row.id)
row.status = "import_failed"
row.error = str(exc)
db.commit()
items.append(ImportAccountItem(
upstream_key_id=row.id,
source_group_id=row.group_id,
source_group_name=row.group_name,
target_group_id=target_group_id,
account_name=account_name,
platform=platform,
upstream_base_url=upstream_base_url,
status="failed",
message=str(exc),
))
created_count = len([item for item in items if item.status == "created"])
exists_count = len([item for item in items if item.status == "exists"])
failed_count = len([item for item in items if item.status == "failed"])
check_failed_count = len([item for item in items if item.status == "check_failed"])
msg_parts = []
if created_count:
msg_parts.append(f"新建 {created_count}")
if exists_count:
msg_parts.append(f"已存在 {exists_count}")
if check_failed_count:
msg_parts.append(f"校验失败 {check_failed_count}")
if failed_count:
msg_parts.append(f"失败 {failed_count}")
return ImportAccountsResponse(
success=failed_count == 0 and check_failed_count == 0,
message="".join(msg_parts) + f" / 共 {len(items)}" if msg_parts else f"共处理 {len(items)}",
items=items,
)
@router.get("/api/group-bindings", response_model=List[BindingResponse])
def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)):
rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all()