Compare commits

...

2 Commits

Author SHA1 Message Date
liumangmang 6044b00685 feat: 上游 Key 唯一化、分组导入跳过、账号导入平台识别&远端校验&base_url 注入
- 上游 Key 命名改为 {prefix}-{upstream.id}-{safe_group_name}-{group_id}
- 唯一约束 (upstream_id, group_id, managed_prefix) 加 managed_prefix 列
- 上游检测成功时同步 Key 状态,远端已删/分组已删自动清理
- 重复分组导入跳过,目标网站已存在同名分组返回 exists
- 账号导入平台自动识别(auto/manual 模式)
- 全选可导入 Key 按钮 + 目标分组自动匹配
- 导入幂等:已导入过的 Key 校验远端账号,不存在则重建
- 新增同步接口 POST /sync-imported-upstream-keys
- account_exists() 通过拉取账号列表判断,避免 404 误判
- credentials.base_url 注入来源上游地址,避免 401
- 前端导入弹窗自动同步+刷新按钮+并发/优先级设置
- 新增 12 个测试覆盖同步、幂等、远端删除、校验失败路径
2026-05-21 01:16:39 +08:00
liumangmang 0a27bba296 fix: 修复远程浏览器登录态保留 & 剪贴板同步问题 2026-05-20 10:13:13 +08:00
18 changed files with 3117 additions and 52 deletions
+3
View File
@@ -18,6 +18,9 @@ build/
backend/static/ backend/static/
backend/data/ backend/data/
# 运行时数据(数据库、远程浏览器 profile、缓存等)
data/
*.log *.log
.DS_Store .DS_Store
.git-real/ .git-real/
+22 -8
View File
@@ -1,8 +1,14 @@
# syntax=docker/dockerfile:1
# ---- Stage 1: Build frontend ---- # ---- Stage 1: Build frontend ----
FROM node:20-alpine AS frontend-build FROM node:20-alpine AS frontend-build
WORKDIR /frontend WORKDIR /frontend
# 依赖层:package*.json 不变则复用 npm 缓存
COPY frontend/package*.json ./ COPY frontend/package*.json ./
RUN npm ci --registry=https://registry.npmmirror.com RUN --mount=type=cache,target=/root/.npm \
npm ci --registry=https://registry.npmmirror.com
# 源码层:业务代码变更不影响上层依赖
COPY frontend/ . COPY frontend/ .
RUN npm run build RUN npm run build
@@ -11,13 +17,13 @@ FROM python:3.12-slim
WORKDIR /app WORKDIR /app
ENV PLAYWRIGHT_DOWNLOAD_HOST=https://npmmirror.com/mirrors/playwright ENV PLAYWRIGHT_DOWNLOAD_HOST=https://npmmirror.com/mirrors/playwright
ENV PLAYWRIGHT_BROWSERS_PATH=/ms-playwright
RUN sed -i 's|http://deb.debian.org|https://mirrors.aliyun.com|g; s|http://security.debian.org|https://mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources RUN sed -i 's|http://deb.debian.org|https://mirrors.aliyun.com|g; s|http://security.debian.org|https://mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources
# Install deps # 系统依赖层:apt 包安装,缓存 deb 包避免重复下载
COPY backend/requirements.txt . RUN --mount=type=cache,target=/var/cache/apt \
RUN pip install --no-cache-dir --index-url https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn -r requirements.txt apt-get update \
RUN apt-get update \
&& apt-get install -y --no-install-recommends \ && apt-get install -y --no-install-recommends \
fonts-liberation fonts-unifont fonts-wqy-zenhei \ fonts-liberation fonts-unifont fonts-wqy-zenhei \
libasound2t64 libatk-bridge2.0-0 libatk1.0-0 libatspi2.0-0 \ libasound2t64 libatk-bridge2.0-0 libatk1.0-0 libatspi2.0-0 \
@@ -27,15 +33,23 @@ RUN apt-get update \
libxdamage1 libxext6 libxfixes3 libxrandr2 libxshmfence1 xvfb \ libxdamage1 libxext6 libxfixes3 libxrandr2 libxshmfence1 xvfb \
curl \ curl \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Python 依赖层:requirements.txt 不变则复用 pip 缓存
COPY backend/requirements.txt .
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple \
--trusted-host pypi.tuna.tsinghua.edu.cn \
-r requirements.txt
# Playwright Chromium:安装在镜像层中,业务代码变更不会触发重下
RUN playwright install chromium RUN playwright install chromium
# Copy backend source # 源码层:业务代码变更不影响上面所有依赖层
COPY backend/ . COPY backend/ .
# Copy built frontend into backend/static # 前端构建产物
COPY --from=frontend-build /frontend/dist ./static COPY --from=frontend-build /frontend/dist ./static
# Data directory for SQLite
RUN mkdir -p /app/data RUN mkdir -p /app/data
ENV PYTHONPATH=/app ENV PYTHONPATH=/app
+17 -2
View File
@@ -1,10 +1,25 @@
COMPOSE ?= docker compose COMPOSE ?= docker compose
SERVICE ?= smartup SERVICE ?= smartup
.PHONY: up down log restart ps .PHONY: up down build build-nc up-build log restart ps
# 日常启动(不重新构建镜像,启动已有容器)
up: up:
$(COMPOSE) up -d --build $(COMPOSE) up -d
@port=$$(grep -E '^SERVER_PORT=' .env 2>/dev/null | tail -n 1 | cut -d= -f2-); \
printf '访问地址:http://localhost:%s\n' "$${port:-8899}"
# 构建镜像(带 BuildKit 缓存)
build:
DOCKER_BUILDKIT=1 $(COMPOSE) build
# 强制重新构建(忽略 Docker 层缓存,npm/pip/apt 下载缓存仍可能复用)
build-nc:
DOCKER_BUILDKIT=1 $(COMPOSE) build --no-cache
# 构建并启动(依赖变更后使用)
up-build:
DOCKER_BUILDKIT=1 $(COMPOSE) up -d --build
@port=$$(grep -E '^SERVER_PORT=' .env 2>/dev/null | tail -n 1 | cut -d= -f2-); \ @port=$$(grep -E '^SERVER_PORT=' .env 2>/dev/null | tail -n 1 | cut -d= -f2-); \
printf '访问地址:http://localhost:%s\n' "$${port:-8899}" printf '访问地址:http://localhost:%s\n' "$${port:-8899}"
+58 -1
View File
@@ -26,10 +26,11 @@ def get_db():
def init_db(): def init_db():
"""Create all tables.""" """Create all tables."""
# import models so SQLAlchemy registers them # import models so SQLAlchemy registers them
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token # noqa: F401 from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key # noqa: F401
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
_migrate_custom_pages() _migrate_custom_pages()
_migrate_upstreams() _migrate_upstreams()
_migrate_upstream_generated_keys()
def _migrate_custom_pages(): def _migrate_custom_pages():
@@ -87,3 +88,59 @@ def _migrate_upstreams():
if "balance_divisor" not in columns: if "balance_divisor" not in columns:
conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_divisor FLOAT NOT NULL DEFAULT 1.0")) conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_divisor FLOAT NOT NULL DEFAULT 1.0"))
def _migrate_upstream_generated_keys():
"""Apply SQLite-safe migrations to the generated upstream keys table."""
inspector = inspect(engine)
if "upstream_generated_keys" not in inspector.get_table_names():
return
columns = {col["name"] for col in inspector.get_columns("upstream_generated_keys")}
with engine.begin() as conn:
if "imported_website_id" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_website_id INTEGER"))
if "imported_account_id" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_account_id VARCHAR(255)"))
if "imported_at" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_at DATETIME"))
if "updated_at" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN updated_at DATETIME"))
conn.execute(text("UPDATE upstream_generated_keys SET updated_at = created_at WHERE updated_at IS NULL"))
if "managed_prefix" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN managed_prefix VARCHAR(64)"))
# ——— 历史数据迁移:回填 managed_prefix + 清理重复 ———
with engine.begin() as conn:
# 1. 回填:key_name 以 SmartUp- 开头的旧记录设置 managed_prefix = 'SmartUp'
conn.execute(text(
"UPDATE upstream_generated_keys SET managed_prefix = 'SmartUp' "
"WHERE managed_prefix IS NULL AND key_name LIKE 'SmartUp-%'"
))
# 2. 清理:同一 (upstream_id, group_id, managed_prefix) 只保留最新一条
# SQLite 不支持子查询直接 DELETE,用两步
to_delete = conn.execute(text("""
SELECT id FROM upstream_generated_keys
WHERE managed_prefix IS NOT NULL
AND id NOT IN (
SELECT MAX(id) FROM upstream_generated_keys
WHERE managed_prefix IS NOT NULL
GROUP BY upstream_id, group_id, managed_prefix
)
""")).fetchall()
for (row_id,) in to_delete:
conn.execute(text("DELETE FROM upstream_generated_keys WHERE id = :id"), {"id": row_id})
# ——— 创建唯一索引 ———
try:
with engine.begin() as conn:
conn.execute(
text("CREATE UNIQUE INDEX IF NOT EXISTS uq_upstream_group_key "
"ON upstream_generated_keys(upstream_id, group_id, key_name)")
)
conn.execute(
text("CREATE UNIQUE INDEX IF NOT EXISTS uq_upstream_group_managed "
"ON upstream_generated_keys(upstream_id, group_id, managed_prefix) "
"WHERE managed_prefix IS NOT NULL")
)
except Exception:
logger = __import__("logging").getLogger(__name__)
logger.warning("could not create unique indexes on upstream_generated_keys (non-fatal)")
+33
View File
@@ -0,0 +1,33 @@
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class UpstreamGeneratedKey(Base):
__tablename__ = "upstream_generated_keys"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
upstream_id: Mapped[int] = mapped_column(Integer, ForeignKey("upstreams.id", ondelete="CASCADE"), index=True)
group_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
group_name: Mapped[str] = mapped_column(String(255), default="")
key_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
key_name: Mapped[str] = mapped_column(String(255), nullable=False)
key_value: Mapped[str] = mapped_column(Text, nullable=False)
masked_key: Mapped[str] = mapped_column(String(255), default="")
raw_json: Mapped[str] = mapped_column(Text, default="{}")
managed_prefix: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
status: Mapped[str] = mapped_column(String(32), default="created")
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
imported_website_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("websites.id", ondelete="SET NULL"), nullable=True, index=True)
imported_account_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
imported_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
__table_args__ = (
UniqueConstraint("upstream_id", "group_id", "key_name", name="uq_upstream_group_key"),
)
+234 -1
View File
@@ -3,6 +3,7 @@ from __future__ import annotations
import json import json
import logging import logging
import re
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List from typing import List
@@ -14,11 +15,15 @@ from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.models.admin_user import AdminUser from app.models.admin_user import AdminUser
from app.models.upstream import Upstream from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.snapshot import UpstreamRateSnapshot from app.models.snapshot import UpstreamRateSnapshot
from app.schemas.upstream import ( from app.schemas.upstream import (
GenerateKeysByGroupsRequest,
GenerateKeysByGroupsResponse,
GeneratedUpstreamKeyResponse,
UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult
) )
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret
from app.services.snapshot_service import diff_snapshots from app.services.snapshot_service import diff_snapshots
from app.services import scheduler as sched_svc from app.services import scheduler as sched_svc
from app.services import webhook_service from app.services import webhook_service
@@ -31,6 +36,38 @@ MASK = "***"
SECRET_KEYS = {"password", "token", "key", "secret"} SECRET_KEYS = {"password", "token", "key", "secret"}
def _group_id(group: dict) -> str:
for key in ("id", "group_id", "groupId"):
value = group.get(key)
if value is not None:
return str(value)
return str(group.get("name") or group.get("group_name") or "")
def _group_name(group: dict, gid: str) -> str:
return str(group.get("name") or group.get("group_name") or gid)
def _key_response(row: UpstreamGeneratedKey, include_value: bool = False) -> GeneratedUpstreamKeyResponse:
return GeneratedUpstreamKeyResponse(
id=row.id,
upstream_id=row.upstream_id,
group_id=row.group_id,
group_name=row.group_name,
key_id=row.key_id,
key_name=row.key_name,
key_value=row.key_value if include_value else None,
masked_key=row.masked_key,
status=row.status,
error=row.error,
imported_website_id=row.imported_website_id,
imported_account_id=row.imported_account_id,
imported_at=row.imported_at,
created_at=row.created_at,
updated_at=row.updated_at,
)
def _mask_auth_config(auth_type: str, cfg: dict) -> dict: def _mask_auth_config(auth_type: str, cfg: dict) -> dict:
masked = {} masked = {}
for k, v in cfg.items(): for k, v in cfg.items():
@@ -73,6 +110,198 @@ def list_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)):
return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()] return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()]
@router.get("/{uid}/generated-keys", response_model=List[GeneratedUpstreamKeyResponse])
def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
if not db.query(Upstream.id).filter(Upstream.id == uid).first():
raise HTTPException(404, "upstream not found")
rows = (
db.query(UpstreamGeneratedKey)
.filter(UpstreamGeneratedKey.upstream_id == uid)
.order_by(UpstreamGeneratedKey.id.desc())
.limit(200)
.all()
)
return [_key_response(row) for row in rows]
_generate_key_lock = __import__("threading").Lock()
def _ensure_group_key(
db: Session,
client: UpstreamClient,
upstream: Upstream,
group: dict[str, Any],
prefix: str,
body: GenerateKeysByGroupsRequest,
) -> GeneratedUpstreamKeyResponse:
"""确保一个上游分组有一个 SmartUp 前缀 Key:存在则 upsert,不存在则创建。"""
gid = _group_id(group)
gname = _group_name(group, gid)
# 使用稳定的 upstream_id + group_id 而非可变名称,避免因改名产生重复
# 可读 Key 名:{prefix}-{upstream.id}-{安全的分组名}-{group_id}
safe_group_name = re.sub(r"[^a-zA-Z0-9\u4e00-\u9fff_-]", "", gname)[:30] if gname else gid
stable_name = f"{prefix}-{upstream.id}-{safe_group_name}-{gid}"
with _generate_key_lock:
try:
# 1. 先查本地是否已有该分组的托管 Key(兼容迁移前无 managed_prefix 的记录)
row = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == gid,
(UpstreamGeneratedKey.managed_prefix == prefix)
| ((UpstreamGeneratedKey.managed_prefix.is_(None))
& UpstreamGeneratedKey.key_name.like(f"{prefix}-%")),
)
.first()
)
if row and row.key_id:
# 本地已有记录,检查远端是否仍存在
try:
existing = client.find_smartup_group_key(gid, stable_name, prefix)
except Exception:
existing = None
if existing:
row.status = "exists"
row.updated_at = datetime.now(timezone.utc)
db.commit()
return _key_response(row, include_value=False)
# 远端不存在,需要重新创建
row.status = "replaced"
# 2. 查远端是否有同名 Key(防止并发时另一个请求已创建)
existing = client.find_smartup_group_key(gid, stable_name, prefix)
if existing:
key_id = str(existing.get("id") or "")
masked = existing.get("masked_key") or existing.get("key") or ""
if row:
row.key_id = key_id or row.key_id
row.masked_key = masked or row.masked_key
row.status = "exists"
row.updated_at = datetime.now(timezone.utc)
else:
row = UpstreamGeneratedKey(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_id=key_id or None,
key_name=stable_name,
key_value="",
masked_key=masked,
raw_json=json.dumps(existing, ensure_ascii=False),
managed_prefix=prefix,
status="exists",
)
db.add(row)
db.commit()
db.refresh(row)
return _key_response(row, include_value=False)
# 3. 远端不存在,创建新 Key
created = client.create_api_key(
stable_name,
gid,
quota=body.quota,
expires_in_days=body.expires_in_days,
rate_limit_5h=body.rate_limit_5h,
rate_limit_1d=body.rate_limit_1d,
rate_limit_7d=body.rate_limit_7d,
endpoint=body.endpoint,
)
if row:
# 复用旧行
row.key_id = created.get("id") or None
row.key_name = stable_name
row.key_value = created["key"]
row.masked_key = created.get("masked_key") or mask_secret(created["key"])
row.raw_json = json.dumps(created.get("raw") or {}, ensure_ascii=False)
row.managed_prefix = prefix
row.status = "created"
row.error = None
else:
row = UpstreamGeneratedKey(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_id=created.get("id") or None,
key_name=stable_name,
key_value=created["key"],
masked_key=created.get("masked_key") or mask_secret(created["key"]),
raw_json=json.dumps(created.get("raw") or {}, ensure_ascii=False),
managed_prefix=prefix,
status="created",
)
db.add(row)
db.commit()
db.refresh(row)
return _key_response(row, include_value=True)
except Exception as exc:
logger.exception("ensure group key failed for upstream=%s group=%s", upstream.id, gid)
return GeneratedUpstreamKeyResponse(
upstream_id=upstream.id,
group_id=gid,
group_name=gname,
key_name=stable_name,
status="failed",
error=str(exc),
)
@router.post("/{uid}/keys/generate-by-groups", response_model=GenerateKeysByGroupsResponse)
def generate_keys_by_groups(
uid: int,
body: GenerateKeysByGroupsRequest,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
if u.api_prefix.strip("/") != "api/v1":
raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1")
auth_config = json.loads(u.auth_config_json or "{}")
selected = set(body.group_ids)
prefix = body.name_prefix
results: list[GeneratedUpstreamKeyResponse] = []
with UpstreamClient(
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config=auth_config,
timeout=float(u.timeout_seconds),
) as client:
try:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
except Exception as exc:
raise HTTPException(502, str(exc))
for group in groups:
gid = _group_id(group)
if not gid or (selected and gid not in selected):
continue
result = _ensure_group_key(db, client, u, group, prefix, body)
results.append(result)
created = len([item for item in results if item.status == "created"])
existed = len([item for item in results if item.status == "exists"])
total = len(results)
msg_parts = []
if created:
msg_parts.append(f"新创建 {created}")
if existed:
msg_parts.append(f"已存在 {existed}")
msg = "".join(msg_parts) + f" / 共 {total} 个分组" if msg_parts else f"共处理 {total} 个分组"
return GenerateKeysByGroupsResponse(
success=total > 0 and all(item.status != "failed" for item in results),
message=msg,
items=results,
)
@router.post("", response_model=UpstreamResponse, status_code=201) @router.post("", response_model=UpstreamResponse, status_code=201)
def create_upstream( def create_upstream(
body: UpstreamCreate, body: UpstreamCreate,
@@ -255,6 +484,10 @@ def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_use
webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes) webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes)
website_sync.sync_affected_bindings(db, u.id, changes) website_sync.sync_affected_bindings(db, u.id, changes)
# 同步 SmartUp Key 状态(使用实际快照入库时间,与定时任务一致)
from app.services.scheduler import _sync_upstream_keys as _synck
_synck(uid, snapshot, new_row.captured_at)
msg = f"检测成功,{len(groups)} 个分组" msg = f"检测成功,{len(groups)} 个分组"
if changes: if changes:
msg += f",发现 {len(changes)} 处倍率变化" msg += f",发现 {len(changes)} 处倍率变化"
+420
View File
@@ -9,11 +9,21 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.models.snapshot import UpstreamRateSnapshot
from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.schemas.website import ( from app.schemas.website import (
BindingCreate, BindingCreate,
BindingResponse, BindingResponse,
BindingUpdate, BindingUpdate,
ImportAccountItem,
ImportAccountsRequest,
ImportAccountsResponse,
ImportGroupItem,
ImportGroupsRequest,
ImportGroupsResponse,
SyncImportStatusRequest,
TestResult, TestResult,
WebsiteCreate, WebsiteCreate,
WebsiteGroupResponse, WebsiteGroupResponse,
@@ -118,6 +128,47 @@ def _client(row: Website) -> Sub2ApiWebsiteClient:
) )
def _latest_upstream_groups(db: Session, upstream_id: int) -> list[dict]:
row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.first()
)
if not row:
raise HTTPException(404, "no upstream snapshot found; run upstream check first")
snapshot = json.loads(row.snapshot_json or "{}")
groups = snapshot.get("groups") or {}
if not isinstance(groups, dict):
return []
return [item for item in groups.values() if isinstance(item, dict)]
def _source_group_id(group: dict) -> str:
return str(group.get("group_id") or group.get("id") or group.get("name") or "")
def _source_group_name(group: dict, gid: str) -> str:
return str(group.get("group_name") or group.get("name") or gid)
def _source_group_rate(group: dict) -> float:
raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1
try:
return float(raw)
except (TypeError, ValueError):
return 1.0
def _numeric_group_id(value: str | None) -> int | None:
if value is None or value == "":
return None
try:
return int(value)
except ValueError:
return None
@router.get("/api/websites", response_model=List[WebsiteResponse]) @router.get("/api/websites", response_model=List[WebsiteResponse])
def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)): def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)):
return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()] return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()]
@@ -215,6 +266,375 @@ def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_c
raise HTTPException(502, str(exc)) raise HTTPException(502, str(exc))
@router.post("/api/websites/{wid}/groups/import-from-upstream/{upstream_id}", response_model=ImportGroupsResponse)
def import_groups_from_upstream(
wid: int,
upstream_id: int,
body: ImportGroupsRequest,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
website = db.query(Website).filter(Website.id == wid).first()
if not website:
raise HTTPException(404, "website not found")
if website.site_type != "sub2api":
raise HTTPException(400, "目前只支持 sub2api")
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
if not upstream:
raise HTTPException(404, "upstream not found")
selected = set(body.group_ids)
groups = _latest_upstream_groups(db, upstream_id)
# 拉取目标网站已有分组,同名则跳过
try:
existing_names = set()
with _client(website) as c:
for eg in c.get_groups(website.groups_endpoint):
gname = eg.get("name") or eg.get("group_name") or ""
if gname:
existing_names.add(gname)
except Exception:
existing_names = set()
items: list[ImportGroupItem] = []
with _client(website) as c:
for group in groups:
source_gid = _source_group_id(group)
if not source_gid or (selected and source_gid not in selected):
continue
source_name = _source_group_name(group, source_gid)
target_name = f"{body.name_prefix}{source_name}" if body.name_prefix else source_name
# 检查是否已存在同名分组
if target_name in existing_names:
items.append(ImportGroupItem(
source_group_id=source_gid,
source_group_name=source_name,
target_group_name=target_name,
status="exists",
message="目标分组已存在,已跳过",
))
continue
create_body = {
"name": target_name,
"description": group.get("description") or f"Imported from {upstream.name} / {source_name}",
"platform": group.get("platform") or "openai",
"rate_multiplier": _source_group_rate(group),
}
if group.get("rpm_limit") is not None:
create_body["rpm_limit"] = group.get("rpm_limit")
try:
created = c.create_group(create_body)
target_id = c.extract_id(created)
items.append(ImportGroupItem(
source_group_id=source_gid,
source_group_name=source_name,
target_group_id=target_id or None,
target_group_name=str(created.get("name") or target_name),
status="created",
message="已创建",
raw=created,
))
except Exception as exc:
msg = str(exc)
# 捕获 409 等已存在错误
if "已存在" in msg or "already exists" in msg.lower() or "409" in msg or "Conflict" in msg:
items.append(ImportGroupItem(
source_group_id=source_gid,
source_group_name=source_name,
target_group_name=target_name,
status="exists",
message="目标分组已存在(接口返回冲突)",
))
else:
logger.exception("import website group failed website=%s upstream=%s group=%s", wid, upstream_id, source_gid)
items.append(ImportGroupItem(
source_group_id=source_gid,
source_group_name=source_name,
target_group_name=target_name,
status="failed",
message=msg,
))
created_count = len([item for item in items if item.status == "created"])
exists_count = len([item for item in items if item.status == "exists"])
failed_count = len([item for item in items if item.status == "failed"])
msg_parts = []
if created_count:
msg_parts.append(f"新建 {created_count}")
if exists_count:
msg_parts.append(f"已存在 {exists_count}")
if failed_count:
msg_parts.append(f"失败 {failed_count}")
return ImportGroupsResponse(
success=failed_count == 0,
message="".join(msg_parts) + f" / 共 {len(items)}" if msg_parts else f"共处理 {len(items)} 个分组",
items=items,
)
def _detect_platform(text: str, fallback: str = "openai") -> str:
"""根据 Key 名或分组名关键词判断平台类型。"""
lower = text.lower()
if "claude" in lower or "anthropic" in lower:
return "anthropic"
if "gemini" in lower:
return "gemini"
if "antigravity" in lower:
return "antigravity"
return fallback
@router.post("/api/websites/{wid}/accounts/sync-imported-upstream-keys", response_model=ImportAccountsResponse)
def sync_imported_upstream_keys(
wid: int,
body: SyncImportStatusRequest,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
"""校验已导入的上游 Key 在目标 Sub2API 账号管理中是否仍存在。"""
website = db.query(Website).filter(Website.id == wid).first()
if not website:
raise HTTPException(404, "website not found")
rows = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == body.upstream_id,
UpstreamGeneratedKey.imported_website_id == wid,
UpstreamGeneratedKey.imported_account_id.isnot(None),
)
.all()
)
items: list[ImportAccountItem] = []
with _client(website) as c:
for row in rows:
platform = _detect_platform(f"{row.group_name} {row.group_id} {row.key_name}", "openai")
if not row.imported_account_id:
continue
old_account_id = row.imported_account_id
exists = c.account_exists(row.imported_account_id)
if exists is False:
row.imported_website_id = None
row.imported_account_id = None
row.imported_at = None
row.status = "created"
items.append(ImportAccountItem(
upstream_key_id=row.id, source_group_id=row.group_id,
source_group_name=row.group_name, account_id=old_account_id,
platform=platform, status="stale_cleared",
message="目标账号已删除,已清除导入标记",
))
elif exists is True:
items.append(ImportAccountItem(
upstream_key_id=row.id, source_group_id=row.group_id,
source_group_name=row.group_name, account_id=old_account_id,
platform=platform, status="exists", message="目标账号仍存在",
))
else:
items.append(ImportAccountItem(
upstream_key_id=row.id, source_group_id=row.group_id,
source_group_name=row.group_name, account_id=old_account_id,
platform=platform, status="check_failed",
message="无法校验目标账号存在性(目标网站认证/网络问题)",
))
db.commit()
cleared_count = len([i for i in items if i.status == "stale_cleared"])
check_failed_count = len([i for i in items if i.status == "check_failed"])
msg_parts = []
if cleared_count:
msg_parts.append(f"清除 {cleared_count}")
if check_failed_count:
msg_parts.append(f"校验失败 {check_failed_count}")
return ImportAccountsResponse(
success=check_failed_count == 0,
message="".join(msg_parts) + f" / 共 {len(items)}" if msg_parts else f"共校验 {len(items)} 个,无变化",
items=items,
)
@router.post("/api/websites/{wid}/accounts/import-upstream-keys", response_model=ImportAccountsResponse)
def import_upstream_keys_as_accounts(
wid: int,
body: ImportAccountsRequest,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
website = db.query(Website).filter(Website.id == wid).first()
if not website:
raise HTTPException(404, "website not found")
if website.site_type != "sub2api":
raise HTTPException(400, "目前只支持 sub2api")
if not body.upstream_key_ids:
raise HTTPException(400, "请选择要导入的 Key")
rows = (
db.query(UpstreamGeneratedKey)
.filter(UpstreamGeneratedKey.id.in_(body.upstream_key_ids))
.order_by(UpstreamGeneratedKey.id)
.all()
)
found_ids = {row.id for row in rows}
missing_ids = [kid for kid in body.upstream_key_ids if kid not in found_ids]
items: list[ImportAccountItem] = [
ImportAccountItem(
upstream_key_id=kid,
source_group_id="",
source_group_name="",
platform=body.default_platform,
status="failed",
message="key not found",
)
for kid in missing_ids
]
# 查出来源上游的 Base URL
upstream_base_url = ""
if body.upstream_key_ids:
first_row = rows[0] if rows else None
if first_row:
from app.models.upstream import Upstream as _Up
_u = db.query(_Up).filter(_Up.id == first_row.upstream_id).first()
if _u:
upstream_base_url = _u.base_url
with _client(website) as c:
for row in rows:
# 先确定平台(失败项也需要记录)
if body.platform_mode == "auto":
platform = _detect_platform(
f"{row.group_name} {row.group_id} {row.key_name}",
body.default_platform,
)
else:
platform = body.default_platform
# 幂等校验:已导入过则检查远端账号是否仍存在
if row.imported_website_id == wid and row.imported_account_id:
old_account_id = row.imported_account_id
exists = c.account_exists(row.imported_account_id)
if exists is True:
items.append(ImportAccountItem(
upstream_key_id=row.id,
source_group_id=row.group_id,
source_group_name=row.group_name,
target_group_id=body.target_group_map.get(row.group_id),
account_id=old_account_id,
account_name=f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}",
platform=platform,
upstream_base_url=upstream_base_url,
status="exists",
message="已导入过,已跳过",
))
continue
elif exists is False:
# 远端已删除,清空标记后继续创建
row.imported_website_id = None
row.imported_account_id = None
row.imported_at = None
row.status = "created"
# 继续往下走(不 continue
else:
# 校验失败,保守跳过
items.append(ImportAccountItem(
upstream_key_id=row.id,
source_group_id=row.group_id,
source_group_name=row.group_name,
target_group_id=body.target_group_map.get(row.group_id),
account_id=old_account_id,
platform=platform,
status="check_failed",
message="无法校验目标账号状态,已保守跳过",
))
continue
if not row.key_value:
items.append(ImportAccountItem(
upstream_key_id=row.id,
source_group_id=row.group_id,
source_group_name=row.group_name,
platform=platform,
status="failed",
message="该 Key 无明文值,无法导入(远端已存在 Key 不会保留明文,请重新创建或手动填入)",
))
continue
target_group_id = body.target_group_map.get(row.group_id)
group_ids = []
numeric_target = _numeric_group_id(target_group_id)
if numeric_target is not None:
group_ids.append(numeric_target)
account_name = f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}"
account_body = {
"name": account_name,
"platform": platform,
"type": "apikey",
"credentials": {
"api_key": row.key_value,
"base_url": upstream_base_url,
},
"group_ids": group_ids,
"rate_multiplier": 1,
"concurrency": body.concurrency,
"priority": body.priority,
"notes": f"Imported by SmartUp from upstream key #{row.id}",
}
try:
created = c.create_account(account_body)
account_id = c.extract_id(created)
row.imported_website_id = wid
row.imported_account_id = account_id or None
row.imported_at = datetime.now(timezone.utc)
row.status = "imported"
row.error = None
db.commit()
items.append(ImportAccountItem(
upstream_key_id=row.id,
source_group_id=row.group_id,
source_group_name=row.group_name,
target_group_id=target_group_id,
account_id=account_id or None,
account_name=str(created.get("name") or account_name),
platform=platform,
upstream_base_url=upstream_base_url,
status="created",
message="已创建账号",
raw=created,
))
except Exception as exc:
logger.exception("import upstream key as account failed website=%s key=%s", wid, row.id)
row.status = "import_failed"
row.error = str(exc)
db.commit()
items.append(ImportAccountItem(
upstream_key_id=row.id,
source_group_id=row.group_id,
source_group_name=row.group_name,
target_group_id=target_group_id,
account_name=account_name,
platform=platform,
upstream_base_url=upstream_base_url,
status="failed",
message=str(exc),
))
created_count = len([item for item in items if item.status == "created"])
exists_count = len([item for item in items if item.status == "exists"])
failed_count = len([item for item in items if item.status == "failed"])
check_failed_count = len([item for item in items if item.status == "check_failed"])
msg_parts = []
if created_count:
msg_parts.append(f"新建 {created_count}")
if exists_count:
msg_parts.append(f"已存在 {exists_count}")
if check_failed_count:
msg_parts.append(f"校验失败 {check_failed_count}")
if failed_count:
msg_parts.append(f"失败 {failed_count}")
return ImportAccountsResponse(
success=failed_count == 0 and check_failed_count == 0,
message="".join(msg_parts) + f" / 共 {len(items)}" if msg_parts else f"共处理 {len(items)}",
items=items,
)
@router.get("/api/group-bindings", response_model=List[BindingResponse]) @router.get("/api/group-bindings", response_model=List[BindingResponse])
def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)): def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)):
rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all() rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all()
+36 -1
View File
@@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from typing import Optional, Any from typing import Optional, Any
from pydantic import BaseModel from pydantic import BaseModel, Field
class AuthConfigBearer(BaseModel): class AuthConfigBearer(BaseModel):
@@ -89,3 +89,38 @@ class TestResult(BaseModel):
success: bool success: bool
message: str message: str
detail: Optional[str] = None detail: Optional[str] = None
class GenerateKeysByGroupsRequest(BaseModel):
group_ids: list[str] = Field(default_factory=list)
name_prefix: str = "SmartUp"
quota: float = Field(default=0, ge=0)
expires_in_days: Optional[int] = Field(default=None, ge=1)
rate_limit_5h: float = Field(default=0, ge=0)
rate_limit_1d: float = Field(default=0, ge=0)
rate_limit_7d: float = Field(default=0, ge=0)
endpoint: str = "/keys"
class GeneratedUpstreamKeyResponse(BaseModel):
id: Optional[int] = None
upstream_id: int
group_id: str
group_name: str = ""
key_id: Optional[str] = None
key_name: str
key_value: Optional[str] = None
masked_key: str = ""
status: str
error: Optional[str] = None
imported_website_id: Optional[int] = None
imported_account_id: Optional[str] = None
imported_at: Optional[datetime] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class GenerateKeysByGroupsResponse(BaseModel):
success: bool
message: str
items: list[GeneratedUpstreamKeyResponse]
+55
View File
@@ -122,3 +122,58 @@ class WebsiteSyncLogResponse(BaseModel):
status: str status: str
message: str message: str
created_at: datetime created_at: datetime
class ImportGroupsRequest(BaseModel):
group_ids: list[str] = Field(default_factory=list)
name_prefix: str = ""
class ImportGroupItem(BaseModel):
source_group_id: str
source_group_name: str
target_group_id: Optional[str] = None
target_group_name: str = ""
status: str
message: str = ""
raw: dict[str, Any] = {}
class ImportGroupsResponse(BaseModel):
success: bool
message: str
items: list[ImportGroupItem]
class SyncImportStatusRequest(BaseModel):
upstream_id: int = Field(default=0)
class ImportAccountsRequest(BaseModel):
upstream_key_ids: list[int] = Field(default_factory=list)
target_group_map: dict[str, str] = Field(default_factory=dict)
account_name_prefix: str = "SmartUp"
default_platform: str = "openai"
platform_mode: str = "auto" # "auto" | "manual"
concurrency: int = Field(default=10, ge=1)
priority: int = Field(default=1, ge=0)
class ImportAccountItem(BaseModel):
upstream_key_id: int
source_group_id: str
source_group_name: str
target_group_id: Optional[str] = None
account_id: Optional[str] = None
account_name: str = ""
platform: str = ""
upstream_base_url: str = ""
status: str
message: str = ""
raw: dict[str, Any] = {}
class ImportAccountsResponse(BaseModel):
success: bool
message: str
items: list[ImportAccountItem]
+75 -3
View File
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from app.database import SessionLocal from app.database import SessionLocal
from app.models.upstream import Upstream from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.snapshot import UpstreamRateSnapshot from app.models.snapshot import UpstreamRateSnapshot
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot
from app.services.snapshot_service import diff_snapshots, prune_snapshots from app.services.snapshot_service import diff_snapshots, prune_snapshots
@@ -72,8 +73,9 @@ def _check_upstream(upstream_id: int) -> None:
balance = raw_balance / divisor balance = raw_balance / divisor
except Exception as exc: except Exception as exc:
logger.warning("upstream %s balance fetch failed: %s", upstream.name, exc) logger.warning("upstream %s balance fetch failed: %s", upstream.name, exc)
upstream.balance = balance if balance is not None:
upstream.balance_updated_at = datetime.now(timezone.utc) if balance is not None else None upstream.balance = balance
upstream.balance_updated_at = datetime.now(timezone.utc)
except Exception as exc: except Exception as exc:
# failure path # failure path
upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1 upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1
@@ -129,7 +131,20 @@ def _check_upstream(upstream_id: int) -> None:
finally: finally:
db.close() db.close()
# ── Phase 2: notifications (independent sessions) ────────────── # ── Phase 2: key sync (independent session) ───────────────────
if snapshot:
captured_at = snapshot.get("captured_at")
if isinstance(captured_at, str):
from datetime import datetime as dt
try:
captured_at = dt.fromisoformat(captured_at)
except Exception:
captured_at = datetime.now(timezone.utc)
elif captured_at is None:
captured_at = datetime.now(timezone.utc)
_sync_upstream_keys(upstream_id, snapshot, captured_at)
# ── Phase 3: notifications (independent sessions) ──────────────
if was_unhealthy: if was_unhealthy:
_notify_status(upstream_id, upstream.name, upstream.base_url, "upstream_recovered") _notify_status(upstream_id, upstream.name, upstream.base_url, "upstream_recovered")
@@ -169,6 +184,63 @@ def _notify_rate_changed(
db.close() db.close()
def _sync_upstream_keys(upstream_id: int, snapshot: dict[str, Any], captured_at: datetime) -> None:
"""上游检测成功后同步 SmartUp Key 状态(远端删除/分组删除)。"""
db = SessionLocal()
try:
active_group_ids = set(snapshot.get("groups", {}).keys())
key_rows = (
db.query(UpstreamGeneratedKey)
.filter(
UpstreamGeneratedKey.upstream_id == upstream_id,
UpstreamGeneratedKey.key_name.like("SmartUp-%"),
)
.all()
)
auth_config = json.loads(
db.query(Upstream).filter(Upstream.id == upstream_id).first().auth_config_json or "{}"
)
# 用 UpstreamClient 查询远端活跃 Key ID 集合
remote_key_ids: set[str] | None = None # None=查询失败,set()=查询成功但为空
try:
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
if upstream:
with UpstreamClient(
base_url=upstream.base_url,
api_prefix=upstream.api_prefix,
auth_type=upstream.auth_type,
auth_config=auth_config,
timeout=float(upstream.timeout_seconds),
) as client:
client.login()
remote_keys = client.list_api_keys(search="SmartUp", status="active")
remote_key_ids = {
str(k["id"]) for k in remote_keys if k.get("id")
}
except Exception as exc:
logger.warning("sync upstream keys list failed for %s: %s", upstream_id, exc)
for row in key_rows:
# 1. 分组已不在当前快照中 → 删除本地记录
if row.group_id not in active_group_ids:
db.delete(row)
logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id)
continue
# 2. 远端查询成功但 key_id 不在列表中 → 删除本地记录
if row.key_id and remote_key_ids is not None and row.key_id not in remote_key_ids:
db.delete(row)
logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id)
continue
# 3. 更新同步时间戳(仅当查询成功且 Key 仍在远端时)
if remote_key_ids is not None and row.key_id in remote_key_ids:
row.updated_at = captured_at
db.commit()
except Exception:
logger.exception("key sync failed for upstream %s", upstream_id)
finally:
db.close()
def _sync_website_bindings(upstream_id: int, changes: list[dict[str, Any]]) -> None: def _sync_website_bindings(upstream_id: int, changes: list[dict[str, Any]]) -> None:
db = SessionLocal() db = SessionLocal()
try: try:
+147
View File
@@ -62,6 +62,49 @@ def _find_user_id(value: Any) -> str:
return "" return ""
def mask_secret(value: Any) -> str:
text = str(value or "")
if not text:
return ""
if len(text) <= 8:
return text[:2] + "****" + text[-2:] if len(text) > 4 else "****"
return text[:4] + "**********" + text[-4:]
def _unwrap_data(value: Any) -> Any:
if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value):
return value.get("data")
return value
def _extract_id(value: Any) -> str:
if isinstance(value, dict):
for key in ("id", "key_id", "keyId"):
candidate = value.get(key)
if candidate is not None:
return str(candidate)
for key in ("data", "result", "key", "api_key"):
found = _extract_id(value.get(key))
if found:
return found
return ""
def _extract_key_value(value: Any) -> str:
if isinstance(value, str):
return value
if isinstance(value, dict):
for key in ("key", "api_key", "apiKey", "token", "value"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
for key in ("data", "result", "api_key", "key"):
found = _extract_key_value(value.get(key))
if found:
return found
return ""
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]: def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
def _normalize(lst: list) -> list[dict[str, Any]]: def _normalize(lst: list) -> list[dict[str, Any]]:
out = [] out = []
@@ -360,3 +403,107 @@ class UpstreamClient:
return float(value) return float(value)
except (ValueError, TypeError): except (ValueError, TypeError):
return None return None
def list_api_keys(
self,
search: str = "",
group_id: str | int | None = None,
status: str = "active",
endpoint: str = "/keys",
) -> list[dict[str, Any]]:
"""查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。"""
params: dict[str, Any] = {}
if search:
params["search"] = search
if group_id is not None:
params["group_id"] = int(group_id) if str(group_id).isdigit() else group_id
if status:
params["status"] = status
url = self._url(endpoint)
resp = self._client.request(
"GET",
url,
params=params if params else None,
headers=self._headers(),
cookies=self._cookies,
)
resp.raise_for_status()
data = resp.json()
if isinstance(data, list):
return data
if isinstance(data, dict):
# 尝试展开常见的包装结构
for top_key in ("data", "result", "response"):
val = data.get(top_key)
if isinstance(val, list):
return val
if isinstance(val, dict):
for inner_key in ("items", "keys", "list", "records", "data"):
inner = val.get(inner_key)
if isinstance(inner, list):
return inner
# 顶层本身就是 list-like wrapper
for key in ("items", "keys", "list", "records"):
val = data.get(key)
if isinstance(val, list):
return val
raise UpstreamError(f"unexpected keys response type: {type(data).__name__}")
def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None:
"""删除远端上游上的一个 Key。"""
self._request("DELETE", f"{endpoint}/{key_id}")
def find_smartup_group_key(
self,
group_id: str | int,
expected_name: str,
prefix: str = "SmartUp",
) -> dict[str, Any] | None:
"""查找同一上游分组下是否已存在 SmartUp 前缀的 Key。
匹配规则:key_name 等于 expected_name,且以 prefix 开头。
返回匹配到的第一个 Key,或 None。
"""
gid = int(group_id) if str(group_id).isdigit() else group_id
keys = self.list_api_keys(search=prefix, group_id=gid, status="active")
for k in keys:
name = k.get("name") or k.get("key_name") or ""
if name == expected_name:
return k
# 部分后端返回的 name 可能带空格或 trimming
if name.strip() == expected_name.strip():
return k
return None
def create_api_key(
self,
name: str,
group_id: str | int,
quota: float = 0,
expires_in_days: int | None = None,
rate_limit_5h: float = 0,
rate_limit_1d: float = 0,
rate_limit_7d: float = 0,
endpoint: str = "/keys",
) -> dict[str, Any]:
body: dict[str, Any] = {
"name": name,
"group_id": int(group_id) if str(group_id).isdigit() else group_id,
"quota": quota,
"rate_limit_5h": rate_limit_5h,
"rate_limit_1d": rate_limit_1d,
"rate_limit_7d": rate_limit_7d,
}
if expires_in_days:
body["expires_in_days"] = expires_in_days
resp = self._request("POST", endpoint, body)
data = _unwrap_data(resp)
key_value = _extract_key_value(data)
if not key_value:
raise UpstreamError("key create response did not include key")
return {
"id": _extract_id(data),
"key": key_value,
"masked_key": mask_secret(key_value),
"raw": data if isinstance(data, dict) else {"value": data},
}
+135 -6
View File
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import logging
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
from typing import Any from typing import Any
from urllib.parse import quote from urllib.parse import quote
@@ -8,11 +9,39 @@ import httpx
from app.utils.number import decimal_string from app.utils.number import decimal_string
logger = logging.getLogger(__name__)
class WebsiteError(RuntimeError): class WebsiteError(RuntimeError):
pass pass
def _friendly_http_error(exc: httpx.HTTPStatusError) -> str:
"""将常见 HTTP 错误转换为中文友好提示,原始信息保留在日志中。"""
status = exc.response.status_code
url = exc.request.url if exc.request else "?"
logger.warning("website_client HTTP %s from %s: %s", status, url, exc)
if status == 401:
return "目标网站认证失败,请检查 Admin API Key / JWT 是否正确"
if status == 403:
return "目标网站权限不足,请检查当前凭证是否有分组管理权限"
if status == 404:
return f"目标网站接口不存在,请检查 API Prefix 和分组接口路径({exc.response.url.path}"
if 500 <= status < 600:
return "目标网站服务异常,请稍后重试"
return f"目标网站返回错误(HTTP {status}"
def _friendly_connection_error(exc: Exception) -> str:
"""将网络/超时异常转换为中文友好提示。"""
logger.warning("website_client connection error: %s", exc)
if isinstance(exc, httpx.TimeoutException):
return "目标网站请求超时,请检查网络连接和 API 地址是否正确"
if isinstance(exc, httpx.ConnectError):
return "无法连接目标网站,请检查 API 地址和网络连通性"
return f"目标网站通信异常:{exc}"
def parse_positive_decimal(value: Any) -> Decimal | None: def parse_positive_decimal(value: Any) -> Decimal | None:
if value is None or value == "": if value is None or value == "":
return None return None
@@ -59,6 +88,19 @@ def _unwrap_data(value: Any) -> Any:
return value return value
def _extract_id(value: Any) -> str:
if isinstance(value, dict):
for key in ("id", "account_id", "accountId", "group_id", "groupId"):
candidate = value.get(key)
if candidate is not None:
return str(candidate)
for key in ("data", "result", "account", "group"):
found = _extract_id(value.get(key))
if found:
return found
return ""
def normalize_groups(value: Any) -> list[dict[str, Any]]: def normalize_groups(value: Any) -> list[dict[str, Any]]:
raw = _unwrap_data(value) raw = _unwrap_data(value)
if isinstance(raw, dict): if isinstance(raw, dict):
@@ -129,24 +171,111 @@ class Sub2ApiWebsiteClient:
return headers return headers
def _request(self, method: str, path: str, body: Any = None) -> Any: def _request(self, method: str, path: str, body: Any = None) -> Any:
resp = self._client.request(method, self._url(path), json=body, headers=self._headers()) try:
resp.raise_for_status() resp = self._client.request(method, self._url(path), json=body, headers=self._headers())
resp.raise_for_status()
except httpx.HTTPStatusError as exc:
raise WebsiteError(_friendly_http_error(exc)) from exc
except httpx.TimeoutException as exc:
raise WebsiteError(_friendly_connection_error(exc)) from exc
except httpx.ConnectError as exc:
raise WebsiteError(_friendly_connection_error(exc)) from exc
if not resp.content: if not resp.content:
return None return None
text = resp.text text = resp.text
if "application/json" not in resp.headers.get("content-type", "") and text.lstrip().startswith("<"): if "application/json" not in resp.headers.get("content-type", "") and text.lstrip().startswith("<"):
raise WebsiteError(f"{method} {path} returned HTML, not JSON") raise WebsiteError(f"{method} {path} 返回了 HTML,请检查接口地址是否正确")
return resp.json() return resp.json()
def get_groups(self, endpoint: str = "/groups") -> list[dict[str, Any]]: def get_groups(self, endpoint: str = "/groups") -> list[dict[str, Any]]:
errors: list[str] = [] """拉取分组列表,尝试 endpoint 和 fallback /groups/all。"""
last_error: Exception | None = None
tried_paths: list[str] = []
for path in [endpoint, "/groups/all"]: for path in [endpoint, "/groups/all"]:
tried_paths.append(path)
try: try:
return normalize_groups(self._request("GET", path)) return normalize_groups(self._request("GET", path))
except WebsiteError as exc:
msg = str(exc)
# 认证/权限类错误:直接抛出,不需要尝试 fallback
if "认证失败" in msg or "权限不足" in msg:
raise
# 404/5xx 等路径相关错误,试试另一个路径
last_error = exc
except Exception as exc: except Exception as exc:
errors.append(f"{path}: {exc}") last_error = exc
raise WebsiteError("; ".join(errors)) logger.info("get_groups fallback %s failed: %s", path, exc)
msg = str(last_error) if last_error else "拉取分组失败"
raise WebsiteError(f"{msg}(尝试接口:{''.join(tried_paths)}")
def update_group_rate(self, endpoint_template: str, group_id: str, rate: Decimal) -> Any: def update_group_rate(self, endpoint_template: str, group_id: str, rate: Decimal) -> Any:
path = endpoint_template.replace("{id}", quote(group_id, safe="")) path = endpoint_template.replace("{id}", quote(group_id, safe=""))
return self._request("PUT", path, {"rate_multiplier": float(rate)}) return self._request("PUT", path, {"rate_multiplier": float(rate)})
def create_group(self, body: dict[str, Any], endpoint: str = "/groups") -> dict[str, Any]:
resp = self._request("POST", endpoint, body)
data = _unwrap_data(resp)
return data if isinstance(data, dict) else {"value": data}
def create_account(self, body: dict[str, Any], endpoint: str = "/accounts") -> dict[str, Any]:
resp = self._request("POST", endpoint, body)
data = _unwrap_data(resp)
return data if isinstance(data, dict) else {"value": data}
@staticmethod
def _unwrap_list(value: dict) -> list | None:
"""递归展开嵌套的列表包装:data.items、data.data、items、accounts 等。"""
if isinstance(value, list):
return value
if not isinstance(value, dict):
return None
# 先看顶层
for key in ("items", "accounts", "records", "list", "data"):
v = value.get(key)
if isinstance(v, list):
return v
# 再看 data.items、data.records、data.list 等嵌套
data_val = value.get("data")
if isinstance(data_val, dict):
for key in ("items", "records", "list", "data", "accounts"):
v = data_val.get(key)
if isinstance(v, list):
return v
return None
def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None:
"""拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。"""
try:
resp = self._request("GET", endpoint)
except Exception:
logger.warning("account list fetch failed for %s", endpoint, exc_info=True)
return None
items = self._unwrap_list(resp)
if items is None:
logger.warning("account list unexpected format for %s", endpoint)
return None
ids: set[str] = set()
for item in items:
item_id = self.extract_id(item)
if item_id:
ids.add(item_id)
return ids
def account_exists(self, account_id: str, endpoint: str = "/accounts") -> bool | None:
"""检查目标账号是否存在。
优先拉取账号列表判断:
- 列表成功取到 → return account_id in idsTrue=存在,False=已删除)
- 列表取不到(None)→ return None(校验失败,不清本地)
返回 True=存在,False=已删除,None=校验失败。
"""
ids = self._get_account_ids(endpoint)
if ids is None:
logger.warning("account_exists cannot verify %s: list fetch failed", account_id)
return None
return account_id in ids
@staticmethod
def extract_id(value: Any) -> str:
return _extract_id(value)
+383
View File
@@ -0,0 +1,383 @@
import json
import sys
from pathlib import Path
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
sys.path.insert(0, str(Path(__file__).resolve().parent))
from app.database import Base
from app.models.upstream import Upstream
from app.models.upstream_key import UpstreamGeneratedKey
from app.models.website import Website
from app.routers import websites as websites_router
from app.schemas.website import ImportAccountsRequest
@pytest.fixture()
def db_session():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
def seed_account_import_rows(db_session):
website = Website(
name="My Sub2API",
site_type="sub2api",
base_url="http://sub2api.local",
api_prefix="/api/v1/admin",
auth_type="api_key",
auth_config_json=json.dumps({"key": "admin-key", "header": "x-api-key"}),
groups_endpoint="/groups",
group_update_endpoint="/groups/{id}",
)
upstream = Upstream(
name="Packy",
base_url="http://packy.local",
api_prefix="/api/v1",
auth_type="login_password",
auth_config_json="{}",
)
db_session.add_all([website, upstream])
db_session.commit()
db_session.refresh(website)
db_session.refresh(upstream)
generated = UpstreamGeneratedKey(
upstream_id=upstream.id,
group_id="vip",
group_name="VIP",
key_id="up-key-id",
key_name="SmartUp-VIP",
key_value="sk-upstream-generated",
masked_key="sk-u...ated",
raw_json="{}",
status="created",
)
db_session.add(generated)
db_session.commit()
db_session.refresh(generated)
return website, generated
def test_import_upstream_key_creates_sub2api_account_management_apikey(monkeypatch, db_session):
website, generated = seed_account_import_rows(db_session)
account_bodies = []
class FakeWebsiteClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def create_account(self, body, endpoint="/accounts"):
account_bodies.append((endpoint, body))
return {"id": 101, "name": body["name"]}
def account_exists(self, account_id):
return True
@staticmethod
def extract_id(data):
return str(data.get("id"))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeWebsiteClient)
response = websites_router.import_upstream_keys_as_accounts(
website.id,
ImportAccountsRequest(
upstream_key_ids=[generated.id],
target_group_map={"vip": "7"},
account_name_prefix="SmartUp",
default_platform="anthropic",
),
db_session,
object(),
)
assert "新建 1" in response.message
assert len(account_bodies) == 1
endpoint, body = account_bodies[0]
assert endpoint == "/accounts"
assert body["type"] == "apikey"
assert body["platform"] == "anthropic"
assert body["credentials"]["api_key"] == "sk-upstream-generated"
assert body["credentials"]["base_url"] == "http://packy.local"
assert body["group_ids"] == [7]
assert body["concurrency"] == 10
assert body["priority"] == 1
assert body["credentials"]["base_url"] == "http://packy.local"
db_session.refresh(generated)
assert generated.status == "imported"
assert generated.imported_website_id == website.id
assert generated.imported_account_id == "101"
def test_import_upstream_key_idempotent_skips_already_imported(monkeypatch, db_session):
"""已导入过的 Key 再次调用不会重复创建。"""
website, generated = seed_account_import_rows(db_session)
generated.imported_website_id = website.id
generated.imported_account_id = "101"
db_session.commit()
create_called = False
class FakeWebsiteClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def create_account(self, body, endpoint="/accounts"):
nonlocal create_called
create_called = True
return {"id": 999, "name": body["name"]}
def account_exists(self, account_id):
return True # 模拟远端账号仍存在
@staticmethod
def extract_id(data):
return str(data.get("id"))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeWebsiteClient)
response = websites_router.import_upstream_keys_as_accounts(
website.id,
ImportAccountsRequest(
upstream_key_ids=[generated.id],
target_group_map={"vip": "7"},
account_name_prefix="SmartUp",
default_platform="anthropic",
),
db_session,
object(),
)
# create_account 不应被调用
assert not create_called, "create_account should NOT be called for already imported key"
# 返回 exists
assert len(response.items) == 1
assert response.items[0].status == "exists"
assert response.items[0].account_id == "101"
assert response.items[0].message == "已导入过,已跳过"
# success 应为 true(没有 failed
assert response.success
def test_sync_clears_stale_import_mark(monkeypatch, db_session):
"""同步接口:远端账号已删除时清除本地导入标记。"""
from app.routers import websites as websites_router
from app.schemas.website import SyncImportStatusRequest, ImportAccountsRequest
website, generated = seed_account_import_rows(db_session)
generated.imported_website_id = website.id
generated.imported_account_id = "101"
db_session.commit()
class FakeClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, *a):
return False
def account_exists(self, account_id):
return False # 远端返回不存在
def _get_account_ids(self, endpoint="/accounts"):
return set()
@staticmethod
def extract_id(data):
return str(data.get("id"))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", lambda **kw: FakeClient(**kw))
response = websites_router.sync_imported_upstream_keys(
website.id, SyncImportStatusRequest(upstream_id=generated.upstream_id),
db_session, object(),
)
assert len(response.items) == 1
assert response.items[0].status == "stale_cleared"
assert response.items[0].account_id == "101"
db_session.refresh(generated)
assert generated.imported_account_id is None
def test_sync_preserves_mark_on_check_failed(monkeypatch, db_session):
"""同步接口:校验失败时不清除本地标记。"""
from app.routers import websites as websites_router
from app.schemas.website import SyncImportStatusRequest
website, generated = seed_account_import_rows(db_session)
generated.imported_website_id = website.id
generated.imported_account_id = "101"
db_session.commit()
class FakeClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, *a):
return False
def account_exists(self, account_id):
return None # 校验失败
def _get_account_ids(self, endpoint="/accounts"):
return set()
@staticmethod
def extract_id(data):
return str(data.get("id"))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", lambda **kw: FakeClient(**kw))
response = websites_router.sync_imported_upstream_keys(
website.id, SyncImportStatusRequest(upstream_id=generated.upstream_id),
db_session, object(),
)
assert len(response.items) == 1
assert response.items[0].status == "check_failed"
db_session.refresh(generated)
assert generated.imported_account_id == "101" # 未被清除
def test_import_rebuilds_when_remote_deleted(monkeypatch, db_session):
"""导入接口:远端账号已删除时自动清标记并重新创建。"""
from app.routers import websites as websites_router
from app.schemas.website import ImportAccountsRequest
website, generated = seed_account_import_rows(db_session)
generated.imported_website_id = website.id
generated.imported_account_id = "101"
db_session.commit()
class FakeClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, *a):
return False
def account_exists(self, account_id):
return False
def create_account(self, body, endpoint="/accounts"):
return {"id": 202, "name": body["name"]}
@staticmethod
def extract_id(data):
return str(data.get("id"))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", lambda **kw: FakeClient(**kw))
response = websites_router.import_upstream_keys_as_accounts(
website.id,
ImportAccountsRequest(upstream_key_ids=[generated.id], target_group_map={"vip": "7"},
account_name_prefix="SmartUp", default_platform="openai"),
db_session, object(),
)
assert len(response.items) == 1
assert response.items[0].status == "created", f"expected created, got {response.items[0].status}"
db_session.refresh(generated)
assert generated.imported_account_id == "202"
def test_import_skips_on_check_failed(monkeypatch, db_session):
"""导入接口:校验失败时保守跳过,不创建也不清除。"""
from app.routers import websites as websites_router
from app.schemas.website import ImportAccountsRequest
website, generated = seed_account_import_rows(db_session)
generated.imported_website_id = website.id
generated.imported_account_id = "101"
db_session.commit()
class FakeClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, *a):
return False
def account_exists(self, account_id):
return None
def create_account(self, body, endpoint="/accounts"):
raise RuntimeError("should not be called")
@staticmethod
def extract_id(data):
return str(data.get("id"))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", lambda **kw: FakeClient(**kw))
response = websites_router.import_upstream_keys_as_accounts(
website.id,
ImportAccountsRequest(upstream_key_ids=[generated.id], target_group_map={"vip": "7"},
account_name_prefix="SmartUp", default_platform="openai"),
db_session, object(),
)
assert len(response.items) == 1
assert response.items[0].status == "check_failed"
db_session.refresh(generated)
assert generated.imported_account_id == "101" # 未被清除
assert not response.success
def test_import_upstream_key_with_custom_concurrency_and_priority(monkeypatch, db_session):
website, generated = seed_account_import_rows(db_session)
account_bodies = []
class FakeWebsiteClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def create_account(self, body, endpoint="/accounts"):
account_bodies.append((endpoint, body))
return {"id": 202, "name": body["name"]}
def account_exists(self, account_id):
return True
@staticmethod
def extract_id(data):
return str(data.get("id"))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeWebsiteClient)
websites_router.import_upstream_keys_as_accounts(
website.id,
ImportAccountsRequest(
upstream_key_ids=[generated.id],
target_group_map={"vip": "7"},
account_name_prefix="SmartUp",
default_platform="openai",
concurrency=20,
priority=5,
),
db_session,
object(),
)
assert len(account_bodies) == 1
_, body = account_bodies[0]
assert body["concurrency"] == 20
assert body["priority"] == 5
assert body["credentials"]["base_url"] == "http://packy.local"
+469
View File
@@ -0,0 +1,469 @@
"""Tests for upstream key uniquification and sync cleanup."""
import json
from datetime import datetime, timezone
import pytest
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.upstream import Upstream
from app.models.website import Website # noqa: F401 — registers table for FK refs
from app.models.upstream_key import UpstreamGeneratedKey
@pytest.fixture()
def db_session():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
def test_duplicate_cleanup_keeps_latest_only():
"""同一 upstream_id + group_id + key_name 的多条记录只保留最新一条。
使用独立 engine + 全 raw SQL,模拟迁移前的数据库状态。
"""
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
with engine.begin() as conn:
conn.execute(text("""
CREATE TABLE upstreams (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(255) NOT NULL,
base_url VARCHAR(512) NOT NULL,
api_prefix VARCHAR(64) DEFAULT '',
auth_type VARCHAR(32),
auth_config_json TEXT DEFAULT '{}',
groups_endpoint VARCHAR(256),
rate_endpoint VARCHAR(256),
enabled BOOLEAN DEFAULT 1,
check_interval_seconds INTEGER DEFAULT 600,
timeout_seconds INTEGER DEFAULT 30,
last_status VARCHAR(32) DEFAULT 'unknown',
last_checked_at DATETIME,
last_error TEXT,
consecutive_failures INTEGER DEFAULT 0,
balance FLOAT,
balance_updated_at DATETIME,
balance_endpoint VARCHAR(256) DEFAULT '',
balance_response_path VARCHAR(256) DEFAULT '',
balance_divisor FLOAT DEFAULT 1.0,
updated_at DATETIME,
created_at DATETIME
)
"""))
conn.execute(text("""
CREATE TABLE upstream_generated_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
upstream_id INTEGER NOT NULL,
group_id VARCHAR(255) NOT NULL,
group_name VARCHAR(255) DEFAULT '',
key_id VARCHAR(255),
key_name VARCHAR(255) NOT NULL,
key_value TEXT NOT NULL,
masked_key VARCHAR(255) DEFAULT '',
raw_json TEXT DEFAULT '{}',
status VARCHAR(32) DEFAULT 'created',
error TEXT,
imported_website_id INTEGER,
imported_account_id VARCHAR(255),
imported_at DATETIME,
created_at DATETIME,
updated_at DATETIME
)
"""))
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
now = datetime.now(timezone.utc)
db.execute(text("""
INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at)
VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua)
"""), {"n": "Test", "b": "http://local", "p": "/api/v1", "a": "bearer",
"j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now})
db.commit()
uid = db.execute(text("SELECT id FROM upstreams LIMIT 1")).scalar()
# 插入 3 条重复记录
for kv, ca in [
("old-key", datetime(2025, 1, 1, tzinfo=timezone.utc)),
("middle-key", datetime(2025, 6, 1, tzinfo=timezone.utc)),
("newest-key", datetime(2025, 12, 1, tzinfo=timezone.utc)),
]:
db.execute(text("""
INSERT INTO upstream_generated_keys
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at)
VALUES (:uid, :gid, :gn, :kn, :kv, :mk, :rj, :st, :ca, :ca)
"""), {
"uid": uid, "gid": "vip", "gn": "VIP",
"kn": "SmartUp-Test-VIP", "kv": kv,
"mk": "", "rj": "{}", "st": "created", "ca": ca,
})
db.commit()
# 清理:同一组合只保留最新一条(id 最大)
db.execute(text("""
DELETE FROM upstream_generated_keys
WHERE id NOT IN (
SELECT MAX(id) FROM upstream_generated_keys
GROUP BY upstream_id, group_id, key_name
)
"""))
db.commit()
remaining = db.execute(text("SELECT key_value FROM upstream_generated_keys")).fetchall()
assert len(remaining) == 1, f"expected 1 after cleanup, got {len(remaining)}"
assert remaining[0][0] == "newest-key"
finally:
db.close()
def test_migration_backfills_managed_prefix_and_deduplicates():
"""迁移逻辑应回填历史 SmartUp 记录的 managed_prefix 并清理重复。
使用独立 engine(不创建唯一约束),模拟迁移前状态。
"""
from sqlalchemy import text as _text
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
with engine.begin() as conn:
conn.execute(_text("""
CREATE TABLE upstreams (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(255) NOT NULL, base_url VARCHAR(512) NOT NULL,
api_prefix VARCHAR(64) DEFAULT '', auth_type VARCHAR(32),
auth_config_json TEXT DEFAULT '{}', groups_endpoint VARCHAR(256),
rate_endpoint VARCHAR(256), enabled BOOLEAN DEFAULT 1,
check_interval_seconds INTEGER DEFAULT 600, timeout_seconds INTEGER DEFAULT 30,
last_status VARCHAR(32) DEFAULT 'unknown', last_checked_at DATETIME,
last_error TEXT, consecutive_failures INTEGER DEFAULT 0,
balance FLOAT, balance_updated_at DATETIME,
balance_endpoint VARCHAR(256) DEFAULT '', balance_response_path VARCHAR(256) DEFAULT '',
balance_divisor FLOAT DEFAULT 1.0, updated_at DATETIME, created_at DATETIME
)
"""))
conn.execute(_text("""
CREATE TABLE upstream_generated_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL,
group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255),
key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL,
masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}',
managed_prefix VARCHAR(64), status VARCHAR(32) DEFAULT 'created',
error TEXT, imported_website_id INTEGER, imported_account_id VARCHAR(255),
imported_at DATETIME, created_at DATETIME, updated_at DATETIME
)
"""))
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
now = datetime.now(timezone.utc)
db.execute(_text("""
INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at)
VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua)
"""), {"n": "Old", "b": "http://local", "p": "/api/v1", "a": "bearer",
"j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now})
db.commit()
uid = db.execute(_text("SELECT id FROM upstreams LIMIT 1")).scalar()
# 插入两条重复记录(无 managed_prefixkey_name 以 SmartUp 开头)
for kv in ("sk-old", "sk-new"):
db.execute(_text("""
INSERT INTO upstream_generated_keys
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at)
VALUES (:uid, :gid, :gn, :kn, :kv, '', '{}', 'created', :ca, :ca)
"""), {"uid": uid, "gid": "vip", "gn": "VIP",
"kn": "SmartUp-Old-vip", "kv": kv, "ca": now})
db.commit()
# 执行迁移逻辑(与 database.py 中的 SQL 一致)
conn = db.connection()
conn.execute(_text(
"UPDATE upstream_generated_keys SET managed_prefix = 'SmartUp' "
"WHERE managed_prefix IS NULL AND key_name LIKE 'SmartUp-%'"
))
to_delete = conn.execute(_text("""
SELECT id FROM upstream_generated_keys
WHERE managed_prefix IS NOT NULL
AND id NOT IN (
SELECT MAX(id) FROM upstream_generated_keys
WHERE managed_prefix IS NOT NULL
GROUP BY upstream_id, group_id, managed_prefix
)
""")).fetchall()
for (row_id,) in to_delete:
conn.execute(_text("DELETE FROM upstream_generated_keys WHERE id = :id"), {"id": row_id})
db.commit()
remaining = db.execute(_text("SELECT key_value, managed_prefix FROM upstream_generated_keys")).fetchall()
assert len(remaining) == 1, f"expected 1 after migration, got {len(remaining)}"
assert remaining[0][0] == "sk-new" # 保留最新一条
assert remaining[0][1] == "SmartUp" # 已回填
finally:
db.close()
def test_ensure_group_key_reuses_old_record(db_session, monkeypatch):
"""_ensure_group_key 应复用 managed_prefix IS NULL 的旧记录,不新建。"""
from app.routers.upstreams import _ensure_group_key
from app.models.upstream_key import UpstreamGeneratedKey
from app.services.upstream_client import UpstreamClient
from app.schemas.upstream import GenerateKeysByGroupsRequest
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
# 插入一条旧记录(无 managed_prefix
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-vip", key_value="sk-old",
managed_prefix=None, key_id="remote-999",
))
db_session.commit()
# 构造 mock client
class MockClient:
def find_smartup_group_key(self, gid, name, prefix):
return None
def create_api_key(self, name, group_id, **kw):
return {"id": "new-remote", "key": "sk-new-value", "masked_key": "sk-****-lue"}
group = {"id": "vip", "name": "VIP", "rate_multiplier": 1}
body = GenerateKeysByGroupsRequest(
group_ids=["vip"], name_prefix="SmartUp",
quota=0, endpoint="/keys",
)
result = _ensure_group_key(db_session, MockClient(), upstream, group, "SmartUp", body)
assert result.status == "created"
rows = db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == "vip",
).all()
assert len(rows) == 1, f"expected 1 record, got {len(rows)}"
assert rows[0].managed_prefix == "SmartUp"
assert rows[0].key_value == "sk-new-value"
def test_sync_removes_remote_key_when_list_empty(db_session, monkeypatch):
"""同步函数在远端返回空列表时应删除本地 key_id 对应的记录。"""
from app.services import scheduler as sched_mod
from app.models.upstream_key import UpstreamGeneratedKey
from app.services.upstream_client import UpstreamClient
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-vip", key_value="sk-vip",
managed_prefix="SmartUp", key_id="remote-key-id",
))
db_session.commit()
# mock list_api_keys 返回空列表(查询成功但无 Key)
monkeypatch.setattr(UpstreamClient, "list_api_keys", lambda self, **kw: [])
monkeypatch.setattr(UpstreamClient, "login", lambda self: None)
monkeypatch.setattr(UpstreamClient, "close", lambda self: None)
monkeypatch.setattr(UpstreamClient, "__enter__", lambda self: self)
monkeypatch.setattr(UpstreamClient, "__exit__", lambda self, *a: None)
# 让 _sync_upstream_keys 使用 db_session 的 bind 引擎
monkeypatch.setattr(sched_mod, "SessionLocal",
lambda: db_session)
# 阻止 finally 中的 db.close() 影响测试会话
original_close = db_session.close
monkeypatch.setattr(db_session, "close", lambda: None)
snapshot = {
"upstream_id": upstream.id,
"groups": {"vip": {"group_id": "vip", "rate": "1"}},
"captured_at": datetime.now(timezone.utc).isoformat(),
}
captured_at = datetime.now(timezone.utc)
sched_mod._sync_upstream_keys(upstream.id, snapshot, captured_at)
monkeypatch.setattr(db_session, "close", original_close)
remaining = db_session.query(UpstreamGeneratedKey).all()
assert len(remaining) == 0, f"expected 0 after sync with empty remote, got {len(remaining)}"
def test_migration_function_integration(monkeypatch):
"""直接调用 _migrate_upstream_generated_keys() 验证列新增和索引创建。"""
from app.database import _migrate_upstream_generated_keys, engine as real_engine
from sqlalchemy import text as _text
# 使用独立 engine,避免影响真实数据库
test_engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
monkeypatch.setattr("app.database.engine", test_engine)
# 建表(不含 managed_prefix 列,模拟旧版 schema
with test_engine.begin() as conn:
conn.execute(_text("""
CREATE TABLE upstream_generated_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL,
group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255),
key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL,
masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}',
status VARCHAR(32) DEFAULT 'created', error TEXT,
imported_website_id INTEGER, imported_account_id VARCHAR(255),
imported_at DATETIME, created_at DATETIME
)
"""))
conn.execute(_text("""
INSERT INTO upstream_generated_keys
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at)
VALUES (1, 'vip', 'VIP', 'SmartUp-Old-vip', 'sk-val', '', '{}', 'created', datetime('now'))
"""))
# 调用迁移函数入口
_migrate_upstream_generated_keys()
# 验证 managed_prefix 列已存在且被填充
inspector = __import__('sqlalchemy', fromlist=['']).inspect(test_engine)
cols = {c["name"] for c in inspector.get_columns("upstream_generated_keys")}
assert "managed_prefix" in cols, "managed_prefix column should exist after migration"
with test_engine.connect() as conn:
row = conn.execute(_text("SELECT managed_prefix, key_value FROM upstream_generated_keys LIMIT 1")).fetchone()
assert row[0] == "SmartUp", f"expected SmartUp, got {row[0]}"
assert row[1] == "sk-val"
# 验证唯一索引已创建
indexes = inspector.get_indexes("upstream_generated_keys")
index_names = {ix["name"] for ix in indexes}
assert "uq_upstream_group_managed" in index_names, "partial unique index should exist"
monkeypatch.undo()
def test_create_twice_only_one_record(db_session):
"""同一上游同一分组连续调用两次 ensure,本地只保留一条记录。"""
from app.models.upstream_key import UpstreamGeneratedKey
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
# 模拟第一次创建
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-first",
status="created",
))
db_session.commit()
# 模拟第二次调用 upsert(用同一个 key_name 且 status=exists
existing = db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == "vip",
UpstreamGeneratedKey.key_name == "SmartUp-Test-VIP",
).first()
if existing:
existing.status = "exists"
existing.updated_at = datetime.now(timezone.utc)
else:
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-second",
status="exists",
))
db_session.commit()
rows = db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == "vip",
).all()
assert len(rows) == 1
assert rows[0].status == "exists"
assert rows[0].key_value == "sk-first" # 更新的是原记录,不是新建
def test_sync_removes_gone_group(db_session):
"""分组不在最新快照中时,本地对应 Key 记录应被删除。"""
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
db_session.add_all([
UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-vip",
),
UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="free", group_name="Free",
key_name="SmartUp-Test-Free", key_value="sk-free",
),
])
db_session.commit()
# 快照中只有 vip,没有 free
active_group_ids = {"vip"}
for row in db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id).all():
if row.group_id not in active_group_ids:
db_session.delete(row)
db_session.commit()
remaining = db_session.query(UpstreamGeneratedKey).all()
assert len(remaining) == 1
assert remaining[0].group_id == "vip"
def test_sync_removes_deleted_remote_key(db_session):
"""远端 Key 被删除后,本地对应记录应被删除。"""
from app.models.upstream_key import UpstreamGeneratedKey
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-vip",
key_id="remote-123",
))
db_session.commit()
# 模拟远端返回的活跃 key_ids 中没有 remote-123
remote_key_ids = {"remote-456", "remote-789"}
for row in db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id).all():
if row.key_id and row.key_id not in remote_key_ids:
db_session.delete(row)
db_session.commit()
remaining = db_session.query(UpstreamGeneratedKey).all()
assert len(remaining) == 0
+165 -1
View File
@@ -1,4 +1,15 @@
from app.services.website_client import normalize_groups import httpx
import pytest
from app.services.website_client import (
WebsiteError,
_friendly_connection_error,
_friendly_http_error,
normalize_groups,
)
# ——— normalize_groups ———
def test_normalize_groups_unwraps_sub2api_paginated_response(): def test_normalize_groups_unwraps_sub2api_paginated_response():
@@ -68,3 +79,156 @@ def test_normalize_groups_keeps_plain_dict_mapping_compatibility():
assert [group["id"] for group in groups] == ["free", "paid"] assert [group["id"] for group in groups] == ["free", "paid"]
assert groups[0]["rate_multiplier"] == "1" assert groups[0]["rate_multiplier"] == "1"
# ——— _get_account_ids / account_exists ———
def test_get_account_ids_flat_list():
from app.services.website_client import Sub2ApiWebsiteClient
ids = Sub2ApiWebsiteClient._unwrap_list([
{"id": 1, "name": "a"}, {"id": 2, "name": "b"},
])
assert ids == [{"id": 1, "name": "a"}, {"id": 2, "name": "b"}]
def test_get_account_ids_top_level_items():
from app.services.website_client import Sub2ApiWebsiteClient
ids = Sub2ApiWebsiteClient._unwrap_list({"items": [
{"id": "k1"}, {"id": "k2"},
]})
assert ids == [{"id": "k1"}, {"id": "k2"}]
def test_get_account_ids_nested_data_items():
from app.services.website_client import Sub2ApiWebsiteClient
ids = Sub2ApiWebsiteClient._unwrap_list({"data": {"items": [
{"id": "a1", "name": "Alpha"},
{"id": "a2", "name": "Beta"},
]}})
assert ids == [{"id": "a1", "name": "Alpha"}, {"id": "a2", "name": "Beta"}]
def test_get_account_ids_nested_data_empty():
from app.services.website_client import Sub2ApiWebsiteClient
ids = Sub2ApiWebsiteClient._unwrap_list({"data": {"items": []}})
assert ids == []
def test_get_account_ids_unexpected_format():
from app.services.website_client import Sub2ApiWebsiteClient
ids = Sub2ApiWebsiteClient._unwrap_list({"error": "not found"})
assert ids is None
# ——— 友好错误提示 ———
def _make_response(status_code: int, path: str = "/groups") -> httpx.Response:
"""创建模拟的 httpx.Response 用于错误测试。"""
req = httpx.Request("GET", f"http://target.local/api/v1{path}")
resp = httpx.Response(status_code, request=req)
return resp
def test_friendly_http_401():
resp = _make_response(401)
exc = httpx.HTTPStatusError("401", request=resp.request, response=resp)
msg = _friendly_http_error(exc)
assert "认证失败" in msg
assert "API Key" in msg
assert "http://" not in msg
def test_friendly_http_403():
resp = _make_response(403)
exc = httpx.HTTPStatusError("403", request=resp.request, response=resp)
msg = _friendly_http_error(exc)
assert "权限不足" in msg
def test_friendly_http_404():
resp = _make_response(404, path="/wrong-path")
exc = httpx.HTTPStatusError("404", request=resp.request, response=resp)
msg = _friendly_http_error(exc)
assert "接口不存在" in msg
assert "/wrong-path" in msg
# 不包含完整 URL / MDN 链接
assert "http://" not in msg
assert "MDN" not in msg
def test_friendly_http_500():
resp = _make_response(502)
exc = httpx.HTTPStatusError("502", request=resp.request, response=resp)
msg = _friendly_http_error(exc)
assert "服务异常" in msg
def test_friendly_connect_error():
exc = httpx.ConnectError("Connection refused")
msg = _friendly_connection_error(exc)
assert "无法连接" in msg
def test_friendly_timeout_error():
exc = httpx.TimeoutException("Timed out")
msg = _friendly_connection_error(exc)
assert "请求超时" in msg
# ——— get_groups fallback(通过 mock httpx client 触发真实 _request 错误转换) ———
def _mock_httpx_request(status_code: int, path: str = "/groups"):
"""返回一个 mock 的 httpx.Client.request,直接抛 HTTPStatusError。"""
def request(self, method, url, **kwargs):
resp = _make_response(status_code, path=path)
raise httpx.HTTPStatusError(f"{status_code} {path}", request=resp.request, response=resp)
return request
def test_get_groups_401_returns_friendly_auth_error(monkeypatch):
from app.services.website_client import Sub2ApiWebsiteClient
monkeypatch.setattr(httpx.Client, "request", _mock_httpx_request(401))
client = Sub2ApiWebsiteClient("http://target.local", "api/v1", "api_key", {"key": "bad"})
with pytest.raises(WebsiteError) as excinfo:
client.get_groups("/groups")
msg = str(excinfo.value)
assert "认证失败" in msg
assert "http://" not in msg
assert "MDN" not in msg
def test_get_groups_404_fallback_succeeds(monkeypatch):
from app.services.website_client import Sub2ApiWebsiteClient
call_count = 0
def request_fallback(self, method, url, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
resp = _make_response(404)
raise httpx.HTTPStatusError("404", request=resp.request, response=resp)
req = httpx.Request("GET", url)
return httpx.Response(200, json={"data": [{"id": "default", "name": "Default", "rate_multiplier": "1"}]}, request=req)
monkeypatch.setattr(httpx.Client, "request", request_fallback)
client = Sub2ApiWebsiteClient("http://target.local", "api/v1", "api_key", {"key": "ok"})
groups = client.get_groups("/groups")
assert len(groups) == 1
assert groups[0]["id"] == "default"
assert call_count == 2
def test_get_groups_all_404_no_raw_url(monkeypatch):
from app.services.website_client import Sub2ApiWebsiteClient
monkeypatch.setattr(httpx.Client, "request", _mock_httpx_request(404))
client = Sub2ApiWebsiteClient("http://target.local", "api/v1", "api_key", {"key": "ok"})
with pytest.raises(WebsiteError) as excinfo:
client.get_groups("/groups")
msg = str(excinfo.value)
assert "接口不存在" in msg
assert "http://" not in msg
assert "MDN" not in msg
+96 -4
View File
@@ -1,7 +1,28 @@
import axios from 'axios' import axios from 'axios'
import axiosRetry from 'axios-retry' import axiosRetry from 'axios-retry'
import router from '@/router' import router from '@/router'
import { authStorageKeys } from '@/authStorage'
/** 标记是否正在处理 401,防多个并发 */
let isHandlingUnauthorized = false
/** 统一的 401 处理:清登录态 + 提示 + 跳转 /login */
async function handleUnauthorized() {
if (isHandlingUnauthorized) return
isHandlingUnauthorized = true
try {
const { useAuthStore } = await import('@/stores/auth')
useAuthStore().clear()
const { ElMessage } = await import('element-plus')
ElMessage.warning('登录已过期,请重新登录')
if (router.currentRoute.value.path !== '/login') {
await router.replace('/login')
}
} finally {
isHandlingUnauthorized = false
}
}
export const api = axios.create({ export const api = axios.create({
baseURL: '/', baseURL: '/',
@@ -27,10 +48,13 @@ axiosRetry(api, {
api.interceptors.response.use( api.interceptors.response.use(
(r) => r, (r) => r,
(err) => { (err) => {
// 跳过登录接口的 401(密码错误等正常登录失败场景)
const requestPath = new URL(err.config?.url || '', window.location.origin).pathname
if (err.response?.status === 401 && requestPath === '/api/auth/login') {
return Promise.reject(err)
}
if (err.response?.status === 401) { if (err.response?.status === 401) {
localStorage.removeItem(authStorageKeys.token) void handleUnauthorized()
localStorage.removeItem(authStorageKeys.email)
router.push('/login')
} }
return Promise.reject(err) return Promise.reject(err)
} }
@@ -84,6 +108,34 @@ export interface UpstreamForm {
balance_divisor: number balance_divisor: number
} }
export interface GeneratedUpstreamKey {
id: number | null
upstream_id: number
group_id: string
group_name: string
key_id: string | null
key_name: string
key_value: string | null
masked_key: string
status: string
error: string | null
imported_website_id: number | null
imported_account_id: string | null
imported_at: string | null
created_at: string | null
}
export interface GenerateKeysByGroupsForm {
group_ids: string[]
name_prefix: string
quota: number
expires_in_days?: number | null
rate_limit_5h: number
rate_limit_1d: number
rate_limit_7d: number
endpoint: string
}
export const upstreamsApi = { export const upstreamsApi = {
list: () => api.get<UpstreamData[]>('/api/upstreams'), list: () => api.get<UpstreamData[]>('/api/upstreams'),
create: (data: UpstreamForm) => api.post<UpstreamData>('/api/upstreams', data), create: (data: UpstreamForm) => api.post<UpstreamData>('/api/upstreams', data),
@@ -91,6 +143,9 @@ export const upstreamsApi = {
delete: (id: number) => api.delete(`/api/upstreams/${id}`), delete: (id: number) => api.delete(`/api/upstreams/${id}`),
test: (id: number) => api.post<{ success: boolean; message: string; detail?: string }>(`/api/upstreams/${id}/test`), test: (id: number) => api.post<{ success: boolean; message: string; detail?: string }>(`/api/upstreams/${id}/test`),
checkNow: (id: number) => api.post<{ success: boolean; message: string }>(`/api/upstreams/${id}/check-now`), checkNow: (id: number) => api.post<{ success: boolean; message: string }>(`/api/upstreams/${id}/check-now`),
generatedKeys: (id: number) => api.get<GeneratedUpstreamKey[]>(`/api/upstreams/${id}/generated-keys`),
generateKeysByGroups: (id: number, data: GenerateKeysByGroupsForm) =>
api.post<{ success: boolean; message: string; items: GeneratedUpstreamKey[] }>(`/api/upstreams/${id}/keys/generate-by-groups`, data),
latestSnapshot: (id: number) => api.get(`/api/upstreams/${id}/snapshots/latest`), latestSnapshot: (id: number) => api.get(`/api/upstreams/${id}/snapshots/latest`),
listSnapshots: (id: number, limit = 20, offset = 0) => listSnapshots: (id: number, limit = 20, offset = 0) =>
api.get<any[]>(`/api/upstreams/${id}/snapshots`, { params: { limit, offset } }), api.get<any[]>(`/api/upstreams/${id}/snapshots`, { params: { limit, offset } }),
@@ -185,6 +240,30 @@ export interface WebsiteSyncLog {
created_at: string created_at: string
} }
export interface ImportGroupItem {
source_group_id: string
source_group_name: string
target_group_id: string | null
target_group_name: string
status: string
message: string
raw: Record<string, any>
}
export interface ImportAccountItem {
upstream_key_id: number
source_group_id: string
source_group_name: string
target_group_id: string | null
account_id: string | null
account_name: string
platform: string
upstream_base_url: string
status: string
message: string
raw: Record<string, any>
}
export const websitesApi = { export const websitesApi = {
list: () => api.get<WebsiteData[]>('/api/websites'), list: () => api.get<WebsiteData[]>('/api/websites'),
create: (data: WebsiteForm) => api.post<WebsiteData>('/api/websites', data), create: (data: WebsiteForm) => api.post<WebsiteData>('/api/websites', data),
@@ -192,6 +271,19 @@ export const websitesApi = {
delete: (id: number) => api.delete(`/api/websites/${id}`), delete: (id: number) => api.delete(`/api/websites/${id}`),
test: (id: number) => api.post<{ success: boolean; message: string; detail?: string }>(`/api/websites/${id}/test`), test: (id: number) => api.post<{ success: boolean; message: string; detail?: string }>(`/api/websites/${id}/test`),
groups: (id: number) => api.get<WebsiteGroup[]>(`/api/websites/${id}/groups`), groups: (id: number) => api.get<WebsiteGroup[]>(`/api/websites/${id}/groups`),
importGroupsFromUpstream: (id: number, upstreamId: number, data: { group_ids: string[]; name_prefix: string }) =>
api.post<{ success: boolean; message: string; items: ImportGroupItem[] }>(`/api/websites/${id}/groups/import-from-upstream/${upstreamId}`, data),
syncImportedUpstreamKeys: (id: number, data: { upstream_id: number }) =>
api.post<{ success: boolean; message: string; items: ImportAccountItem[] }>(`/api/websites/${id}/accounts/sync-imported-upstream-keys`, data),
importAccountsFromUpstreamKeys: (id: number, data: {
upstream_key_ids: number[]
target_group_map: Record<string, string>
account_name_prefix: string
default_platform: string
platform_mode?: string
concurrency?: number
priority?: number
}) => api.post<{ success: boolean; message: string; items: ImportAccountItem[] }>(`/api/websites/${id}/accounts/import-upstream-keys`, data),
listBindings: () => api.get<GroupBindingData[]>('/api/group-bindings'), listBindings: () => api.get<GroupBindingData[]>('/api/group-bindings'),
createBinding: (data: GroupBindingForm) => api.post<GroupBindingData>('/api/group-bindings', data), createBinding: (data: GroupBindingForm) => api.post<GroupBindingData>('/api/group-bindings', data),
updateBinding: (id: number, data: Partial<GroupBindingForm>) => api.put<GroupBindingData>(`/api/group-bindings/${id}`, data), updateBinding: (id: number, data: Partial<GroupBindingForm>) => api.put<GroupBindingData>(`/api/group-bindings/${id}`, data),
+154 -2
View File
@@ -127,6 +127,9 @@
</el-button> </el-button>
<el-button size="small" text @click="testUpstream(row)" :loading="row._testing">测试</el-button> <el-button size="small" text @click="testUpstream(row)" :loading="row._testing">测试</el-button>
<el-button size="small" text @click="checkNow(row)" :loading="row._checking">检测</el-button> <el-button size="small" text @click="checkNow(row)" :loading="row._checking">检测</el-button>
<el-button size="small" text @click="openKeyGenerate(row)" title="确保每个分组有一个 SmartUp Key">
<el-icon><Key /></el-icon>
</el-button>
<el-button size="small" text @click="openDetail(row)"> <el-button size="small" text @click="openDetail(row)">
<el-icon><List /></el-icon> <el-icon><List /></el-icon>
详情 详情
@@ -374,6 +377,26 @@
<span>{{ detailUpstream.last_error }}</span> <span>{{ detailUpstream.last_error }}</span>
</div> </div>
<div class="section-title">
<el-icon><Key /></el-icon>
已创建 Key
<span class="section-sub">最近 {{ generatedKeys.length }} </span>
</div>
<el-table :data="generatedKeys" v-loading="keysLoading" size="small" style="width: 100%" class="generated-key-table">
<el-table-column prop="group_name" label="分组" min-width="120" />
<el-table-column prop="key_name" label="名称" min-width="180" />
<el-table-column label="Key" min-width="140">
<template #default="{ row }"><span class="mono">{{ row.masked_key || '—' }}</span></template>
</el-table-column>
<el-table-column label="状态" width="96">
<template #default="{ row }">
<el-tag size="small" :type="row.status === 'import_failed' || row.status === 'failed' ? 'danger' : row.status === 'imported' ? 'success' : 'info'">
{{ keyStatusLabel(row.status) }}
</el-tag>
</template>
</el-table-column>
</el-table>
<div class="section-title"> <div class="section-title">
<el-icon><Clock /></el-icon> <el-icon><Clock /></el-icon>
检测历史 检测历史
@@ -440,6 +463,57 @@
</div> </div>
</el-drawer> </el-drawer>
<el-dialog v-model="keyDialogVisible" title="按分组创建 Key" width="620px" destroy-on-close>
<el-form label-position="top">
<el-form-item label="上游">
<el-input :model-value="keyTarget?.name || ''" disabled />
</el-form-item>
<el-form-item label="选择分组">
<el-select v-model="keyForm.group_ids" multiple filterable style="width:100%" placeholder="不选则创建全部分组">
<el-option v-for="group in keyGroupOptions" :key="group.group_id" :label="`${group.group_name || group.group_id} (${group.rate || '—'})`" :value="group.group_id" />
</el-select>
</el-form-item>
<el-row :gutter="12">
<el-col :span="12">
<el-form-item label="名称前缀"><el-input v-model="keyForm.name_prefix" /></el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="创建接口"><el-input v-model="keyForm.endpoint" /></el-form-item>
</el-col>
</el-row>
<el-row :gutter="12">
<el-col :span="12">
<el-form-item label="配额 USD0 不限)"><el-input-number v-model="keyForm.quota" :min="0" :precision="2" style="width:100%" /></el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="有效天数"><el-input-number v-model="keyExpiresDays" :min="1" :disabled="!useKeyExpiry" style="width:100%" /></el-form-item>
</el-col>
</el-row>
<el-checkbox v-model="useKeyExpiry">设置过期时间</el-checkbox>
</el-form>
<div v-if="keyResults.length" class="result-panel">
<div class="result-title">操作结果</div>
<el-table :data="keyResults" size="small">
<el-table-column prop="group_name" label="分组" min-width="120" />
<el-table-column prop="key_name" label="名称" min-width="180" />
<el-table-column label="Key" min-width="160">
<template #default="{ row }"><span class="mono">{{ row.key_value || row.masked_key || '—' }}</span></template>
</el-table-column>
<el-table-column label="状态" width="120">
<template #default="{ row }">
<el-tag v-if="row.status === 'created'" size="small" type="success">新创建</el-tag>
<el-tag v-else-if="row.status === 'exists'" size="small" type="info">已存在</el-tag>
<el-tag v-else size="small" type="danger">失败</el-tag>
</template>
</el-table-column>
</el-table>
</div>
<template #footer>
<el-button @click="keyDialogVisible = false">关闭</el-button>
<el-button type="primary" :loading="generatingKeys" :disabled="generatingKeys" @click="generateKeys">确保 Key 存在</el-button>
</template>
</el-dialog>
<AuthCaptureDialog <AuthCaptureDialog
v-model="authCaptureVisible" v-model="authCaptureVisible"
:initial-url="authCaptureInitialUrl" :initial-url="authCaptureInitialUrl"
@@ -453,8 +527,8 @@ import { ref, computed, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import type { FormInstance } from 'element-plus' import type { FormInstance } from 'element-plus'
import dayjs from 'dayjs' import dayjs from 'dayjs'
import { Refresh, Plus, Edit, List, Delete, Warning, Clock, ArrowRight, Pointer } from '@element-plus/icons-vue' import { Refresh, Plus, Edit, List, Delete, Warning, Clock, ArrowRight, Pointer, Key } from '@element-plus/icons-vue'
import { upstreamsApi, type UpstreamData } from '@/api' import { upstreamsApi, type GeneratedUpstreamKey, type UpstreamData } from '@/api'
import AuthCaptureDialog from '@/components/AuthCaptureDialog.vue' import AuthCaptureDialog from '@/components/AuthCaptureDialog.vue'
const list = ref<(UpstreamData & { _testing?: boolean; _checking?: boolean })[]>([]) const list = ref<(UpstreamData & { _testing?: boolean; _checking?: boolean })[]>([])
@@ -551,6 +625,7 @@ function handlePlatformChange(val: string) {
form.value.auth_config.username_field = 'email' form.value.auth_config.username_field = 'email'
form.value.balance_endpoint = '/auth/me' form.value.balance_endpoint = '/auth/me'
form.value.balance_response_path = 'data.balance' form.value.balance_response_path = 'data.balance'
form.value.balance_divisor = 1.0
} else if (val === 'new-api') { } else if (val === 'new-api') {
form.value.api_prefix = '' form.value.api_prefix = ''
form.value.groups_endpoint = '/api/group/' form.value.groups_endpoint = '/api/group/'
@@ -572,17 +647,37 @@ function handlePlatformChange(val: string) {
} else { } else {
form.value.balance_endpoint = '' form.value.balance_endpoint = ''
form.value.balance_response_path = '' form.value.balance_response_path = ''
form.value.balance_divisor = 1.0
} }
} }
const detailVisible = ref(false) const detailVisible = ref(false)
const detailUpstream = ref<UpstreamData | null>(null) const detailUpstream = ref<UpstreamData | null>(null)
const snapshots = ref<any[]>([]) const snapshots = ref<any[]>([])
const generatedKeys = ref<GeneratedUpstreamKey[]>([])
const snapshotLoading = ref(false) const snapshotLoading = ref(false)
const keysLoading = ref(false)
const expandedId = ref<number | null>(null) const expandedId = ref<number | null>(null)
const snapshotOffset = ref(0) const snapshotOffset = ref(0)
const snapshotLimit = 20 const snapshotLimit = 20
const keyDialogVisible = ref(false)
const keyTarget = ref<UpstreamData | null>(null)
const keyGroupOptions = ref<any[]>([])
const generatingKeys = ref(false)
const keyResults = ref<GeneratedUpstreamKey[]>([])
const useKeyExpiry = ref(false)
const keyExpiresDays = ref(30)
const keyForm = ref({
group_ids: [] as string[],
name_prefix: 'SmartUp',
quota: 0,
rate_limit_5h: 0,
rate_limit_1d: 0,
rate_limit_7d: 0,
endpoint: '/keys',
})
const metrics = computed(() => ({ const metrics = computed(() => ({
total: list.value.length, total: list.value.length,
healthy: list.value.filter((item) => item.last_status === 'healthy').length, healthy: list.value.filter((item) => item.last_status === 'healthy').length,
@@ -639,6 +734,8 @@ function shrinkError(value: string) {
return value.length > 40 ? `${value.slice(0, 40)}` : value return value.length > 40 ? `${value.slice(0, 40)}` : value
} }
const keyStatusLabel = (s: string) => ({ created: '已创建', imported: '已导入', import_failed: '导入失败', failed: '失败' }[s] || s)
async function loadList() { async function loadList() {
tableLoading.value = true tableLoading.value = true
try { try {
@@ -733,13 +830,26 @@ async function checkNow(row: any) {
function openDetail(row: UpstreamData) { function openDetail(row: UpstreamData) {
detailUpstream.value = row detailUpstream.value = row
snapshots.value = [] snapshots.value = []
generatedKeys.value = []
snapshotOffset.value = 0 snapshotOffset.value = 0
expandedId.value = null expandedId.value = null
detailVisible.value = true detailVisible.value = true
} }
async function loadGeneratedKeys() {
if (!detailUpstream.value) return
keysLoading.value = true
try {
const res = await upstreamsApi.generatedKeys(detailUpstream.value.id)
generatedKeys.value = res.data
} finally {
keysLoading.value = false
}
}
async function loadSnapshots() { async function loadSnapshots() {
if (!detailUpstream.value) return if (!detailUpstream.value) return
loadGeneratedKeys()
snapshotLoading.value = true snapshotLoading.value = true
try { try {
const res = await upstreamsApi.listSnapshots(detailUpstream.value.id, snapshotLimit, snapshotOffset.value) const res = await upstreamsApi.listSnapshots(detailUpstream.value.id, snapshotLimit, snapshotOffset.value)
@@ -754,6 +864,48 @@ async function loadSnapshots() {
} }
} }
async function openKeyGenerate(row: UpstreamData) {
keyTarget.value = row
keyResults.value = []
keyForm.value = {
group_ids: [],
name_prefix: 'SmartUp',
quota: 0,
rate_limit_5h: 0,
rate_limit_1d: 0,
rate_limit_7d: 0,
endpoint: '/keys',
}
useKeyExpiry.value = false
keyExpiresDays.value = 30
try {
const res = await upstreamsApi.latestSnapshot(row.id)
keyGroupOptions.value = Object.values(res.data.snapshot?.groups || {})
} catch {
keyGroupOptions.value = []
ElMessage.warning('未找到快照,将由后端实时拉取分组')
}
keyDialogVisible.value = true
}
async function generateKeys() {
if (!keyTarget.value) return
generatingKeys.value = true
try {
const res = await upstreamsApi.generateKeysByGroups(keyTarget.value.id, {
...keyForm.value,
expires_in_days: useKeyExpiry.value ? keyExpiresDays.value : null,
})
keyResults.value = res.data.items
ElMessage[res.data.success ? 'success' : 'warning'](res.data.message)
if (detailUpstream.value?.id === keyTarget.value.id) await loadGeneratedKeys()
} catch (e: any) {
ElMessage.error(e.response?.data?.detail || '创建 Key 失败')
} finally {
generatingKeys.value = false
}
}
function toggleExpand(snap: any) { function toggleExpand(snap: any) {
expandedId.value = expandedId.value === snap.id ? null : snap.id expandedId.value = expandedId.value === snap.id ? null : snap.id
} }
+615 -23
View File
@@ -14,7 +14,11 @@
<div class="panel"> <div class="panel">
<div class="panel-head"> <div class="panel-head">
<div class="panel-title">网站</div> <div class="panel-title">网站</div>
<el-button size="small" text @click="loadAll">刷新</el-button> <div class="panel-actions">
<el-button size="small" text :disabled="websites.length === 0" @click="openImportGroups(selectedWebsite || websites[0])">导入上游分组</el-button>
<el-button size="small" text :disabled="websites.length === 0" @click="openImportAccounts(selectedWebsite || websites[0])">导入为账号管理账号</el-button>
<el-button size="small" text @click="loadAll">刷新</el-button>
</div>
</div> </div>
<el-table :data="websites" v-loading="websiteLoading" row-key="id" style="width:100%"> <el-table :data="websites" v-loading="websiteLoading" row-key="id" style="width:100%">
<el-table-column label="名称" min-width="180"> <el-table-column label="名称" min-width="180">
@@ -44,34 +48,43 @@
<span v-else class="muted"></span> <span v-else class="muted"></span>
</template> </template>
</el-table-column> </el-table-column>
<el-table-column label="操作" width="174" align="right"> <el-table-column label="操作" width="240" align="right">
<template #default="{ row }"> <template #default="{ row }">
<div class="action-row"> <div class="action-row">
<el-tooltip content="编辑网站配置" placement="top" :show-after="300">
<el-button size="small" text class="btn-edit" @click="openWebsiteEdit(row)">
<el-icon class="btn-edit-icon"><Edit /></el-icon><span>编辑</span>
</el-button>
</el-tooltip>
<el-tooltip content="查看分组" placement="top" :show-after="300"> <el-tooltip content="查看分组" placement="top" :show-after="300">
<el-button size="small" circle text @click="selectWebsite(row)"> <el-button size="small" circle text @click="selectWebsite(row)">
<el-icon><Grid /></el-icon> <el-icon><Grid /></el-icon>
</el-button> </el-button>
</el-tooltip> </el-tooltip>
<el-tooltip content="编辑" placement="top" :show-after="300"> <el-dropdown trigger="click" @command="(cmd: string) => handleMoreAction(cmd, row)">
<el-button size="small" circle text @click="openWebsiteEdit(row)"> <el-button size="small" text class="btn-more" :loading="row._testing">
<el-icon><Edit /></el-icon> 更多<el-icon v-if="!row._testing" class="el-icon--right"><ArrowDown /></el-icon>
</el-button> </el-button>
</el-tooltip> <template #dropdown>
<el-tooltip content="连接测试" placement="top" :show-after="300"> <el-dropdown-menu>
<el-button size="small" circle text :loading="row._testing" @click="testWebsite(row)"> <el-dropdown-item command="test" :disabled="row._testing">
<el-icon v-if="!row._testing"><Connection /></el-icon> <el-icon><Connection /></el-icon>连接测试
</el-button> </el-dropdown-item>
</el-tooltip> <el-dropdown-item command="binding">
<el-tooltip content="新增绑定" placement="top" :show-after="300"> <el-icon><Link /></el-icon>新增绑定
<el-button size="small" circle text @click="openBindingCreate(row)"> </el-dropdown-item>
<el-icon><Link /></el-icon> <el-dropdown-item command="importGroups">
</el-button> <el-icon><Upload /></el-icon>导入上游分组
</el-tooltip> </el-dropdown-item>
<el-tooltip content="删除" placement="top" :show-after="300"> <el-dropdown-item command="importAccounts">
<el-button size="small" circle text type="danger" @click="deleteWebsite(row)"> <el-icon><Key /></el-icon>导入为账号管理账号
<el-icon><Delete /></el-icon> </el-dropdown-item>
</el-button> <el-dropdown-item divided command="delete" class="btn-more-delete">
</el-tooltip> <el-icon><Delete /></el-icon>删除
</el-dropdown-item>
</el-dropdown-menu>
</template>
</el-dropdown>
</div> </div>
</template> </template>
</el-table-column> </el-table-column>
@@ -229,6 +242,188 @@
<el-button type="primary" :loading="savingBinding" @click="saveBinding">保存</el-button> <el-button type="primary" :loading="savingBinding" @click="saveBinding">保存</el-button>
</template> </template>
</el-drawer> </el-drawer>
<el-dialog v-model="importGroupsDialog" title="导入上游分组" width="680px" destroy-on-close>
<el-form label-position="top">
<el-form-item label="目标网站">
<el-select v-model="importGroupsForm.website_id" style="width:100%">
<el-option v-for="site in websites" :key="site.id" :label="site.name" :value="site.id" />
</el-select>
</el-form-item>
<el-form-item label="来源上游">
<el-select v-model="importGroupsForm.upstream_id" filterable style="width:100%" @change="importGroupsForm.group_ids = []">
<el-option v-for="upstream in upstreams" :key="upstream.id" :label="upstream.name" :value="upstream.id" />
</el-select>
</el-form-item>
<el-form-item label="上游分组">
<el-select v-model="importGroupsForm.group_ids" multiple filterable style="width:100%" placeholder="不选则导入全部分组">
<el-option
v-for="group in importSourceGroups"
:key="sourceGroupId(group)"
:label="`${sourceGroupName(group)} (${group.rate || group.rate_multiplier || '—'})`"
:value="sourceGroupId(group)"
/>
</el-select>
</el-form-item>
<el-form-item label="分组名前缀">
<el-input v-model="importGroupsForm.name_prefix" placeholder="可留空,参数和倍率保持上游一致" />
</el-form-item>
</el-form>
<div v-if="importGroupResults.length" class="result-panel">
<div class="result-title">导入结果</div>
<el-table :data="importGroupResults" size="small">
<el-table-column prop="source_group_name" label="上游分组" min-width="140" />
<el-table-column prop="target_group_name" label="我的分组" min-width="160" />
<el-table-column prop="target_group_id" label="目标 ID" width="100" />
<el-table-column label="状态" width="90">
<template #default="{ row }">
<el-tag v-if="row.status === 'created'" size="small" type="success">已创建</el-tag>
<el-tag v-else-if="row.status === 'exists'" size="small" type="info">已存在</el-tag>
<el-tag v-else size="small" type="danger">失败</el-tag>
</template>
</el-table-column>
<el-table-column prop="message" label="结果" min-width="160" />
</el-table>
</div>
<template #footer>
<el-button @click="importGroupsDialog = false">关闭</el-button>
<el-button type="primary" :loading="importingGroups" @click="submitImportGroups">导入分组</el-button>
</template>
</el-dialog>
<el-dialog v-model="importAccountsDialog" title="导入为账号管理账号" width="760px" destroy-on-close>
<el-form label-position="top">
<el-alert
class="dialog-note"
type="info"
show-icon
:closable="false"
title="这里会把已生成的上游 Key 创建成 Sub2API 账号管理里的 apikey 账号,不会创建系统登录用户。"
/>
<el-row :gutter="12">
<el-col :span="12">
<el-form-item label="目标网站">
<el-select v-model="importAccountsForm.website_id" style="width:100%" @change="onImportAccountWebsiteChange">
<el-option v-for="site in websites" :key="site.id" :label="site.name" :value="site.id" />
</el-select>
</el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="账号平台">
<el-select v-model="importAccountsForm.platform_mode" style="width:100%" @change="onPlatformModeChange">
<el-option label="自动识别(按 Key/分组名判断)" value="auto" />
<el-option label="手动选择" value="manual" />
</el-select>
</el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="固定平台(手动模式)">
<el-select v-model="importAccountsForm.default_platform" style="width:100%" :disabled="importAccountsForm.platform_mode === 'auto'">
<el-option label="OpenAI 兼容" value="openai" />
<el-option label="Anthropic" value="anthropic" />
<el-option label="Gemini" value="gemini" />
<el-option label="Antigravity" value="antigravity" />
</el-select>
</el-form-item>
</el-col>
<el-col :span="24" style="margin-bottom:6px">
<el-button size="small" text :loading="syncingImportStatus" @click="syncImportStatus">
<el-icon><Refresh /></el-icon>刷新导入状态
</el-button>
<span v-if="importSyncStatus" style="font-size:12px;color:var(--text-muted);margin-left:8px">
已校验 {{ importSyncStatus.total }} 清除 {{ importSyncStatus.cleared }} 个失效标记
</span>
</el-col>
</el-row>
<el-row :gutter="12">
<el-col :span="12">
<el-form-item label="来源上游">
<el-select v-model="importAccountsForm.upstream_id" filterable style="width:100%" @change="onImportAccountUpstreamChange">
<el-option v-for="upstream in upstreams" :key="upstream.id" :label="upstream.name" :value="upstream.id" />
</el-select>
</el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="账号名前缀">
<el-input v-model="importAccountsForm.account_name_prefix" />
</el-form-item>
</el-col>
</el-row>
<el-row :gutter="12">
<el-col :span="12">
<el-form-item label="并发/容量">
<el-input-number v-model="importAccountsForm.concurrency" :min="1" style="width:100%" />
</el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="优先级">
<el-input-number v-model="importAccountsForm.priority" :min="0" style="width:100%" />
</el-form-item>
</el-col>
</el-row>
<el-form-item label="已生成的上游 Key">
<div style="display:flex;gap:8px;margin-bottom:6px">
<el-button size="small" text @click="selectAllImportableKeys">全选可导入 Key</el-button>
<el-button size="small" text @click="clearImportAccountSelection">清空</el-button>
</div>
<el-select
v-model="importAccountsForm.upstream_key_ids"
multiple
filterable
style="width:100%"
placeholder="选择要导入为账号管理账号的 Key"
:loading="generatedKeyLoading"
@change="autoFillAccountTargetGroups()"
>
<el-option
v-for="item in importableGeneratedKeys"
:key="item.id!"
:label="`${item.group_name || item.group_id} / ${detectPlatform(item)} / ${item.key_name} / ${item.masked_key}`"
:value="item.id!"
/>
</el-select>
</el-form-item>
<div v-if="selectedAccountGroups.length" class="mapping-panel">
<div class="result-title">目标分组映射</div>
<div v-for="group in selectedAccountGroups" :key="group.group_id" class="mapping-row">
<span class="mapping-label">{{ group.group_name || group.group_id }}</span>
<el-select v-model="importAccountsForm.target_group_map[group.group_id]" clearable filterable placeholder="可不选" style="width:280px">
<el-option v-for="target in importTargetGroups" :key="target.id" :label="`${target.name} (${target.rate_multiplier ?? '—'})`" :value="target.id" />
</el-select>
</div>
</div>
</el-form>
<div v-if="importAccountResults.length" class="result-panel">
<div class="result-title">创建结果</div>
<el-table :data="importAccountResults" size="small">
<el-table-column prop="source_group_name" label="来源分组" min-width="130" />
<el-table-column prop="account_name" label="账号管理账号" min-width="180" />
<el-table-column label="识别平台" width="120">
<template #default="{ row }">
<span>{{ platformLabel(row.platform) }}</span>
</template>
</el-table-column>
<el-table-column label="请求地址" min-width="180">
<template #default="{ row }">
<span class="mono" style="font-size:12px">{{ row.upstream_base_url || '—' }}</span>
</template>
</el-table-column>
<el-table-column prop="account_id" label="账号 ID" width="110" />
<el-table-column label="状态" width="90">
<template #default="{ row }">
<el-tag v-if="row.status === 'created'" size="small" type="success">成功</el-tag>
<el-tag v-else-if="row.status === 'exists'" size="small" type="info">已存在</el-tag>
<el-tag v-else size="small" type="danger">失败</el-tag>
</template>
</el-table-column>
<el-table-column prop="message" label="结果" min-width="160" />
</el-table>
</div>
<template #footer>
<el-button @click="importAccountsDialog = false">关闭</el-button>
<el-button type="primary" :loading="importingAccounts" :disabled="importingAccounts" @click="submitImportAccounts">创建账号管理账号</el-button>
</template>
</el-dialog>
</div> </div>
</template> </template>
@@ -237,13 +432,16 @@ import { computed, onMounted, ref } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import type { FormInstance } from 'element-plus' import type { FormInstance } from 'element-plus'
import dayjs from 'dayjs' import dayjs from 'dayjs'
import { Delete, Edit, Plus, Grid, Connection, Link } from '@element-plus/icons-vue' import { ArrowDown, Delete, Edit, Plus, Grid, Connection, Link, Upload, Key, Refresh } from '@element-plus/icons-vue'
import { import {
upstreamsApi, upstreamsApi,
websitesApi, websitesApi,
type BindingSourceGroup, type BindingSourceGroup,
type GeneratedUpstreamKey,
type GroupBindingData, type GroupBindingData,
type GroupBindingForm, type GroupBindingForm,
type ImportAccountItem,
type ImportGroupItem,
type UpstreamData, type UpstreamData,
type WebsiteData, type WebsiteData,
type WebsiteForm, type WebsiteForm,
@@ -259,16 +457,25 @@ const bindingWebsiteGroups = ref<WebsiteGroup[]>([])
const bindings = ref<(GroupBindingData & { _syncing?: boolean })[]>([]) const bindings = ref<(GroupBindingData & { _syncing?: boolean })[]>([])
const logs = ref<WebsiteSyncLog[]>([]) const logs = ref<WebsiteSyncLog[]>([])
const snapshotsByUpstream = ref<Record<number, any[]>>({}) const snapshotsByUpstream = ref<Record<number, any[]>>({})
const importTargetGroups = ref<WebsiteGroup[]>([])
const importGeneratedKeys = ref<GeneratedUpstreamKey[]>([])
const websiteLoading = ref(false) const websiteLoading = ref(false)
const groupsLoading = ref(false) const groupsLoading = ref(false)
const bindingLoading = ref(false) const bindingLoading = ref(false)
const logLoading = ref(false) const logLoading = ref(false)
const importingGroups = ref(false)
const importingAccounts = ref(false)
const generatedKeyLoading = ref(false)
const importSyncStatus = ref<{ total: number; cleared: number; failed: number } | null>(null)
const syncingImportStatus = ref(false)
const statusLabel = (s: string) => ({ healthy: '健康', unhealthy: '异常', unknown: '未知' }[s] || s) const statusLabel = (s: string) => ({ healthy: '健康', unhealthy: '异常', unknown: '未知' }[s] || s)
const algorithmLabel = (s: string) => ({ max_plus_percent: '最高倍率', average_plus_percent: '平均倍率', min_plus_percent: '最低倍率' }[s] || s) const algorithmLabel = (s: string) => ({ max_plus_percent: '最高倍率', average_plus_percent: '平均倍率', min_plus_percent: '最低倍率' }[s] || s)
const toUTC = (t: string) => /[Z+\-]\d*$/.test(t.trim()) ? t : t + 'Z' const toUTC = (t: string) => /[Z+\-]\d*$/.test(t.trim()) ? t : t + 'Z'
const fmtTime = (t: string) => dayjs(toUTC(t)).format('MM-DD HH:mm:ss') const fmtTime = (t: string) => dayjs(toUTC(t)).format('MM-DD HH:mm:ss')
const sourceGroupId = (group: any) => String(group?.group_id || group?.id || group?.name || '')
const sourceGroupName = (group: any) => String(group?.group_name || group?.name || sourceGroupId(group))
function defaultWebsiteForm(): WebsiteForm { function defaultWebsiteForm(): WebsiteForm {
return { return {
@@ -318,6 +525,29 @@ const bindingRules = {
target_group_id: [{ required: true, message: '请选择目标分组', trigger: 'change' }], target_group_id: [{ required: true, message: '请选择目标分组', trigger: 'change' }],
} }
const importGroupsDialog = ref(false)
const importGroupsForm = ref({
website_id: 0,
upstream_id: 0,
group_ids: [] as string[],
name_prefix: '',
})
const importGroupResults = ref<ImportGroupItem[]>([])
const importAccountsDialog = ref(false)
const importAccountsForm = ref({
website_id: 0,
upstream_id: 0,
upstream_key_ids: [] as number[],
target_group_map: {} as Record<string, string>,
account_name_prefix: 'SmartUp',
default_platform: 'openai',
platform_mode: 'auto',
concurrency: 10,
priority: 1,
})
const importAccountResults = ref<ImportAccountItem[]>([])
const upstreamGroupOptions = computed(() => { const upstreamGroupOptions = computed(() => {
const rows: Array<{ key: string; label: string; source: BindingSourceGroup }> = [] const rows: Array<{ key: string; label: string; source: BindingSourceGroup }> = []
for (const upstream of upstreams.value) { for (const upstream of upstreams.value) {
@@ -339,6 +569,31 @@ const upstreamGroupOptions = computed(() => {
return rows return rows
}) })
const importSourceGroups = computed(() => snapshotsByUpstream.value[importGroupsForm.value.upstream_id] || [])
function isImportableGeneratedKey(item: GeneratedUpstreamKey) {
return item.id !== null
&& item.status !== 'failed'
&& !(item.imported_website_id === importAccountsForm.value.website_id && item.imported_account_id)
}
const importableGeneratedKeys = computed(() =>
importGeneratedKeys.value.filter(isImportableGeneratedKey),
)
const selectedAccountGroups = computed(() => {
const selected = new Set(importAccountsForm.value.upstream_key_ids)
const rows = importableGeneratedKeys.value.filter((item) => item.id !== null && selected.has(item.id))
const seen = new Set<string>()
const groups: Array<{ group_id: string; group_name: string }> = []
for (const row of rows) {
if (seen.has(row.group_id)) continue
seen.add(row.group_id)
groups.push({ group_id: row.group_id, group_name: row.group_name })
}
return groups
})
async function loadWebsites() { async function loadWebsites() {
websiteLoading.value = true websiteLoading.value = true
try { try {
@@ -391,6 +646,19 @@ async function loadBindingWebsiteGroups(websiteId: number) {
} }
} }
async function loadImportTargetGroups(websiteId: number) {
if (!websiteId) {
importTargetGroups.value = []
return
}
try {
const res = await websitesApi.groups(websiteId)
importTargetGroups.value = res.data
} catch {
importTargetGroups.value = []
}
}
async function loadBindings() { async function loadBindings() {
bindingLoading.value = true bindingLoading.value = true
try { try {
@@ -411,6 +679,50 @@ async function loadLogs() {
} }
} }
async function syncImportStatus() {
const websiteId = importAccountsForm.value.website_id
const upstreamId = importAccountsForm.value.upstream_id
if (!websiteId || !upstreamId) return
syncingImportStatus.value = true
try {
const res = await websitesApi.syncImportedUpstreamKeys(websiteId, { upstream_id: upstreamId })
// 校验请求完成时表单未切换
if (importAccountsForm.value.website_id !== websiteId || importAccountsForm.value.upstream_id !== upstreamId) return
const items = res.data.items
importSyncStatus.value = {
total: items.length,
cleared: items.filter(i => i.status === 'stale_cleared').length,
failed: items.filter(i => i.status === 'check_failed').length,
}
if (importSyncStatus.value.cleared > 0) {
ElMessage.success(`已清除 ${importSyncStatus.value.cleared} 个失效导入标记`)
}
if (importAccountsForm.value.upstream_id === upstreamId) {
await loadImportGeneratedKeys(upstreamId)
}
} catch (e: any) {
ElMessage.error(e.response?.data?.detail || '同步导入状态失败')
} finally {
syncingImportStatus.value = false
}
}
async function loadImportGeneratedKeys(upstreamId: number) {
importGeneratedKeys.value = []
if (!upstreamId) return
generatedKeyLoading.value = true
const frozenId = upstreamId
try {
const res = await upstreamsApi.generatedKeys(frozenId)
if (importAccountsForm.value.upstream_id !== frozenId) return // 已切换到其他上游
importGeneratedKeys.value = res.data
} catch (e: any) {
ElMessage.error(e.response?.data?.detail || '加载上游 Key 失败')
} finally {
generatedKeyLoading.value = false
}
}
async function loadAll() { async function loadAll() {
await Promise.all([loadWebsites(), loadUpstreamGroups(), loadBindings(), loadLogs()]) await Promise.all([loadWebsites(), loadUpstreamGroups(), loadBindings(), loadLogs()])
if (selectedWebsite.value) await loadWebsiteGroups() if (selectedWebsite.value) await loadWebsiteGroups()
@@ -487,6 +799,17 @@ async function testWebsite(row: WebsiteData & { _testing?: boolean }) {
} }
} }
/** 处理「更多」下拉菜单中的操作 */
function handleMoreAction(cmd: string, row: WebsiteData & { _testing?: boolean }) {
switch (cmd) {
case 'test': testWebsite(row); break
case 'binding': openBindingCreate(row); break
case 'importGroups': openImportGroups(row); break
case 'importAccounts': openImportAccounts(row); break
case 'delete': deleteWebsite(row); break
}
}
async function deleteWebsite(row: WebsiteData) { async function deleteWebsite(row: WebsiteData) {
try { try {
await ElMessageBox.confirm(`确认删除网站 "${row.name}"`, '删除确认', { type: 'warning' }) await ElMessageBox.confirm(`确认删除网站 "${row.name}"`, '删除确认', { type: 'warning' })
@@ -599,6 +922,197 @@ async function deleteBinding(row: GroupBindingData) {
} catch {} } catch {}
} }
async function openImportGroups(site?: WebsiteData | null) {
if (upstreams.value.length === 0) await loadUpstreamGroups()
const target = site || selectedWebsite.value || websites.value[0]
importGroupsForm.value = {
website_id: target?.id || 0,
upstream_id: upstreams.value[0]?.id || 0,
group_ids: [],
name_prefix: '',
}
importGroupResults.value = []
importGroupsDialog.value = true
}
async function submitImportGroups() {
if (!importGroupsForm.value.website_id || !importGroupsForm.value.upstream_id) {
ElMessage.error('请选择目标网站和来源上游')
return
}
importingGroups.value = true
try {
const res = await websitesApi.importGroupsFromUpstream(
importGroupsForm.value.website_id,
importGroupsForm.value.upstream_id,
{
group_ids: importGroupsForm.value.group_ids,
name_prefix: importGroupsForm.value.name_prefix,
},
)
importGroupResults.value = res.data.items
ElMessage[res.data.success ? 'success' : 'warning'](res.data.message)
if (selectedWebsite.value?.id === importGroupsForm.value.website_id) await loadWebsiteGroups()
await loadImportTargetGroups(importGroupsForm.value.website_id)
} catch (e: any) {
ElMessage.error(e.response?.data?.detail || '导入上游分组失败')
} finally {
importingGroups.value = false
}
}
async function openImportAccounts(site?: WebsiteData | null) {
if (upstreams.value.length === 0) await loadUpstreamGroups()
const target = site || selectedWebsite.value || websites.value[0]
importAccountsForm.value = {
website_id: target?.id || 0,
upstream_id: upstreams.value[0]?.id || 0,
upstream_key_ids: [],
target_group_map: {},
account_name_prefix: 'SmartUp',
default_platform: 'openai',
platform_mode: 'auto',
concurrency: 10,
priority: 1,
}
importAccountResults.value = []
importSyncStatus.value = null
await Promise.all([
loadImportTargetGroups(importAccountsForm.value.website_id),
loadImportGeneratedKeys(importAccountsForm.value.upstream_id),
])
// 打开弹窗后自动同步导入状态(校验远端账号是否仍存在)
await syncImportStatus()
await loadImportGeneratedKeys(importAccountsForm.value.upstream_id)
importAccountsDialog.value = true
}
async function onImportAccountWebsiteChange(value: number) {
importAccountsForm.value.target_group_map = {}
await loadImportTargetGroups(value)
await syncImportStatus()
await loadImportGeneratedKeys(importAccountsForm.value.upstream_id)
}
async function onImportAccountUpstreamChange(value: number) {
importAccountsForm.value.upstream_key_ids = []
importAccountsForm.value.target_group_map = {}
importAccountResults.value = []
await loadImportGeneratedKeys(value)
await syncImportStatus()
await loadImportGeneratedKeys(value)
}
function onPlatformModeChange(value: string) {
if (value === 'auto') {
importAccountsForm.value.default_platform = 'openai'
}
}
function detectPlatform(item: { group_name?: string; group_id?: string; key_name?: string }) {
const text = `${item.group_name || ''} ${item.group_id || ''} ${item.key_name || ''}`.toLowerCase()
if (text.includes('claude') || text.includes('anthropic')) return 'Anthropic'
if (text.includes('gemini')) return 'Gemini'
if (text.includes('antigravity')) return 'Antigravity'
return 'OpenAI 兼容'
}
function platformLabel(platform: string) {
const map: Record<string, string> = {
openai: 'OpenAI 兼容',
anthropic: 'Anthropic',
gemini: 'Gemini',
antigravity: 'Antigravity',
}
return map[platform] || platform || '—'
}
function normalizeGroupName(name: string) {
return String(name || '')
.toLowerCase()
.replace(/^smartup[-_\s]*/i, '')
.replace(/^ai\d+pro/i, '')
.replace(/[|]/g, ' ')
.replace(/\s+/g, '')
.trim()
}
function findTargetGroupForSource(sourceName: string, sourceId: string) {
const sourceNorm = normalizeGroupName(sourceName || sourceId)
if (!sourceNorm) return ''
const exact = importTargetGroups.value.find(g =>
normalizeGroupName(g.name) === sourceNorm
)
if (exact) return exact.id
const fuzzy = importTargetGroups.value.find(g => {
const targetNorm = normalizeGroupName(g.name)
return targetNorm.includes(sourceNorm) || sourceNorm.includes(targetNorm)
})
return fuzzy?.id || ''
}
function autoFillAccountTargetGroups() {
const selected = new Set(importAccountsForm.value.upstream_key_ids)
const keys = importableGeneratedKeys.value.filter(item => item.id !== null && selected.has(item.id))
const nextMap = { ...importAccountsForm.value.target_group_map }
for (const item of keys) {
if (!item.id || !selected.has(item.id)) continue
if (nextMap[item.group_id]) continue
const targetId = findTargetGroupForSource(item.group_name || item.group_id, item.group_id)
if (targetId) nextMap[item.group_id] = targetId
}
importAccountsForm.value.target_group_map = nextMap
}
function selectAllImportableKeys() {
const keys = importableGeneratedKeys.value
importAccountsForm.value.upstream_key_ids = keys.map(item => item.id!)
autoFillAccountTargetGroups()
const matched = Object.keys(importAccountsForm.value.target_group_map).length
ElMessage.success(`已选择 ${keys.length} 个 Key,自动匹配 ${matched} 个目标分组`)
}
function clearImportAccountSelection() {
importAccountsForm.value.upstream_key_ids = []
importAccountsForm.value.target_group_map = {}
}
async function submitImportAccounts() {
if (!importAccountsForm.value.website_id || !importAccountsForm.value.upstream_id) {
ElMessage.error('请选择目标网站和来源上游')
return
}
if (importAccountsForm.value.upstream_key_ids.length === 0) {
ElMessage.error('请选择要导入的上游 Key')
return
}
importingAccounts.value = true
try {
const res = await websitesApi.importAccountsFromUpstreamKeys(importAccountsForm.value.website_id, {
upstream_key_ids: importAccountsForm.value.upstream_key_ids,
target_group_map: importAccountsForm.value.target_group_map,
account_name_prefix: importAccountsForm.value.account_name_prefix,
default_platform: importAccountsForm.value.default_platform,
platform_mode: importAccountsForm.value.platform_mode,
concurrency: importAccountsForm.value.concurrency,
priority: importAccountsForm.value.priority,
})
importAccountResults.value = res.data.items
ElMessage[res.data.success ? 'success' : 'warning'](res.data.message)
await loadImportGeneratedKeys(importAccountsForm.value.upstream_id)
} catch (e: any) {
ElMessage.error(e.response?.data?.detail || '创建账号管理账号失败')
} finally {
importingAccounts.value = false
}
}
onMounted(loadAll) onMounted(loadAll)
</script> </script>
@@ -624,6 +1138,13 @@ onMounted(loadAll)
} }
.panel-title { font-size: 14px; font-weight: 600; color: var(--text-primary); } .panel-title { font-size: 14px; font-weight: 600; color: var(--text-primary); }
.panel-sub { font-size: 12px; color: var(--text-muted); margin-top: 2px; } .panel-sub { font-size: 12px; color: var(--text-muted); margin-top: 2px; }
.panel-actions {
display: flex;
align-items: center;
flex-wrap: wrap;
justify-content: flex-end;
gap: 4px;
}
.content-grid { .content-grid {
display: grid; display: grid;
grid-template-columns: 1fr; grid-template-columns: 1fr;
@@ -639,7 +1160,7 @@ onMounted(loadAll)
align-items: center; align-items: center;
justify-content: flex-end; justify-content: flex-end;
flex-wrap: nowrap; flex-wrap: nowrap;
gap: 2px; gap: 6px;
min-width: 0; min-width: 0;
} }
.action-row .el-button.is-circle { .action-row .el-button.is-circle {
@@ -647,6 +1168,36 @@ onMounted(loadAll)
height: 26px; height: 26px;
margin-left: 0; margin-left: 0;
} }
.action-row .btn-edit {
color: var(--text-primary);
font-weight: 500;
gap: 3px;
padding: 0 6px;
white-space: nowrap;
flex-shrink: 0;
}
.action-row .btn-edit-icon {
font-size: 13px;
}
.action-row .btn-edit:hover {
color: var(--el-color-primary);
background: var(--el-color-primary-light-9);
border-radius: 4px;
}
.action-row .btn-more {
color: var(--text-muted);
font-size: 12px;
padding: 0 4px;
}
.action-row .btn-more:hover {
color: var(--el-color-primary);
}
.action-row .btn-more .el-icon--right {
margin-left: 1px;
}
.btn-more-delete {
color: var(--el-color-danger);
}
.binding-actions { .binding-actions {
display: flex; display: flex;
@@ -698,6 +1249,41 @@ onMounted(loadAll)
color: var(--text-secondary); color: var(--text-secondary);
font-size: 13px; font-size: 13px;
} }
.dialog-note { margin-bottom: 12px; }
.result-panel {
margin-top: 14px;
border-top: 1px solid var(--border-color);
padding-top: 12px;
}
.result-title {
margin-bottom: 8px;
font-size: 13px;
font-weight: 600;
color: var(--text-primary);
}
.mapping-panel {
border: 1px solid var(--border-color);
border-radius: 8px;
padding: 12px;
}
.mapping-row {
min-height: 38px;
display: flex;
align-items: center;
justify-content: space-between;
gap: 12px;
}
.mapping-row + .mapping-row {
border-top: 1px solid var(--border-color);
}
.mapping-label {
min-width: 0;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
color: var(--text-secondary);
font-size: 13px;
}
@media (min-width: 1024px) { @media (min-width: 1024px) {
.content-grid { grid-template-columns: minmax(0, 1fr) minmax(0, 1fr); } .content-grid { grid-template-columns: minmax(0, 1fr) minmax(0, 1fr); }
@@ -714,5 +1300,11 @@ onMounted(loadAll)
flex-direction: column; flex-direction: column;
} }
.binding-actions { width: 100%; justify-content: flex-end; } .binding-actions { width: 100%; justify-content: flex-end; }
.mapping-row {
align-items: stretch;
flex-direction: column;
padding: 8px 0;
}
.mapping-row .el-select { width: 100% !important; }
} }
</style> </style>