diff --git a/.gitignore b/.gitignore index 0ddc0c4..c0f2819 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,9 @@ build/ backend/static/ backend/data/ +# 运行时数据(数据库、远程浏览器 profile、缓存等) +data/ + *.log .DS_Store .git-real/ diff --git a/Dockerfile b/Dockerfile index fd4368d..1471569 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,14 @@ +# syntax=docker/dockerfile:1 # ---- Stage 1: Build frontend ---- FROM node:20-alpine AS frontend-build WORKDIR /frontend + +# 依赖层:package*.json 不变则复用 npm 缓存 COPY frontend/package*.json ./ -RUN npm ci --registry=https://registry.npmmirror.com +RUN --mount=type=cache,target=/root/.npm \ + npm ci --registry=https://registry.npmmirror.com + +# 源码层:业务代码变更不影响上层依赖 COPY frontend/ . RUN npm run build @@ -11,13 +17,13 @@ FROM python:3.12-slim WORKDIR /app ENV PLAYWRIGHT_DOWNLOAD_HOST=https://npmmirror.com/mirrors/playwright +ENV PLAYWRIGHT_BROWSERS_PATH=/ms-playwright RUN sed -i 's|http://deb.debian.org|https://mirrors.aliyun.com|g; s|http://security.debian.org|https://mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources -# Install deps -COPY backend/requirements.txt . -RUN pip install --no-cache-dir --index-url https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn -r requirements.txt -RUN apt-get update \ +# 系统依赖层:apt 包安装,缓存 deb 包避免重复下载 +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update \ && apt-get install -y --no-install-recommends \ fonts-liberation fonts-unifont fonts-wqy-zenhei \ libasound2t64 libatk-bridge2.0-0 libatk1.0-0 libatspi2.0-0 \ @@ -27,15 +33,23 @@ RUN apt-get update \ libxdamage1 libxext6 libxfixes3 libxrandr2 libxshmfence1 xvfb \ curl \ && rm -rf /var/lib/apt/lists/* + +# Python 依赖层:requirements.txt 不变则复用 pip 缓存 +COPY backend/requirements.txt . +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple \ + --trusted-host pypi.tuna.tsinghua.edu.cn \ + -r requirements.txt + +# Playwright Chromium:安装在镜像层中,业务代码变更不会触发重下 RUN playwright install chromium -# Copy backend source +# 源码层:业务代码变更不影响上面所有依赖层 COPY backend/ . -# Copy built frontend into backend/static +# 前端构建产物 COPY --from=frontend-build /frontend/dist ./static -# Data directory for SQLite RUN mkdir -p /app/data ENV PYTHONPATH=/app diff --git a/Makefile b/Makefile index 1111a01..37ddcdc 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,25 @@ COMPOSE ?= docker compose SERVICE ?= smartup -.PHONY: up down log restart ps +.PHONY: up down build build-nc up-build log restart ps +# 日常启动(不重新构建镜像,启动已有容器) up: - $(COMPOSE) up -d --build + $(COMPOSE) up -d + @port=$$(grep -E '^SERVER_PORT=' .env 2>/dev/null | tail -n 1 | cut -d= -f2-); \ + printf '访问地址:http://localhost:%s\n' "$${port:-8899}" + +# 构建镜像(带 BuildKit 缓存) +build: + DOCKER_BUILDKIT=1 $(COMPOSE) build + +# 强制重新构建(忽略 Docker 层缓存,npm/pip/apt 下载缓存仍可能复用) +build-nc: + DOCKER_BUILDKIT=1 $(COMPOSE) build --no-cache + +# 构建并启动(依赖变更后使用) +up-build: + DOCKER_BUILDKIT=1 $(COMPOSE) up -d --build @port=$$(grep -E '^SERVER_PORT=' .env 2>/dev/null | tail -n 1 | cut -d= -f2-); \ printf '访问地址:http://localhost:%s\n' "$${port:-8899}" diff --git a/backend/app/database.py b/backend/app/database.py index 40728f6..90f0ef2 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -26,10 +26,11 @@ def get_db(): def init_db(): """Create all tables.""" # import models so SQLAlchemy registers them - from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token # noqa: F401 + from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key # noqa: F401 Base.metadata.create_all(bind=engine) _migrate_custom_pages() _migrate_upstreams() + _migrate_upstream_generated_keys() def _migrate_custom_pages(): @@ -87,3 +88,59 @@ def _migrate_upstreams(): if "balance_divisor" not in columns: conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_divisor FLOAT NOT NULL DEFAULT 1.0")) + +def _migrate_upstream_generated_keys(): + """Apply SQLite-safe migrations to the generated upstream keys table.""" + inspector = inspect(engine) + if "upstream_generated_keys" not in inspector.get_table_names(): + return + columns = {col["name"] for col in inspector.get_columns("upstream_generated_keys")} + with engine.begin() as conn: + if "imported_website_id" not in columns: + conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_website_id INTEGER")) + if "imported_account_id" not in columns: + conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_account_id VARCHAR(255)")) + if "imported_at" not in columns: + conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_at DATETIME")) + if "updated_at" not in columns: + conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN updated_at DATETIME")) + conn.execute(text("UPDATE upstream_generated_keys SET updated_at = created_at WHERE updated_at IS NULL")) + if "managed_prefix" not in columns: + conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN managed_prefix VARCHAR(64)")) + + # ——— 历史数据迁移:回填 managed_prefix + 清理重复 ——— + with engine.begin() as conn: + # 1. 回填:key_name 以 SmartUp- 开头的旧记录设置 managed_prefix = 'SmartUp' + conn.execute(text( + "UPDATE upstream_generated_keys SET managed_prefix = 'SmartUp' " + "WHERE managed_prefix IS NULL AND key_name LIKE 'SmartUp-%'" + )) + # 2. 清理:同一 (upstream_id, group_id, managed_prefix) 只保留最新一条 + # SQLite 不支持子查询直接 DELETE,用两步 + to_delete = conn.execute(text(""" + SELECT id FROM upstream_generated_keys + WHERE managed_prefix IS NOT NULL + AND id NOT IN ( + SELECT MAX(id) FROM upstream_generated_keys + WHERE managed_prefix IS NOT NULL + GROUP BY upstream_id, group_id, managed_prefix + ) + """)).fetchall() + for (row_id,) in to_delete: + conn.execute(text("DELETE FROM upstream_generated_keys WHERE id = :id"), {"id": row_id}) + + # ——— 创建唯一索引 ——— + try: + with engine.begin() as conn: + conn.execute( + text("CREATE UNIQUE INDEX IF NOT EXISTS uq_upstream_group_key " + "ON upstream_generated_keys(upstream_id, group_id, key_name)") + ) + conn.execute( + text("CREATE UNIQUE INDEX IF NOT EXISTS uq_upstream_group_managed " + "ON upstream_generated_keys(upstream_id, group_id, managed_prefix) " + "WHERE managed_prefix IS NOT NULL") + ) + except Exception: + logger = __import__("logging").getLogger(__name__) + logger.warning("could not create unique indexes on upstream_generated_keys (non-fatal)") diff --git a/backend/app/models/upstream_key.py b/backend/app/models/upstream_key.py new file mode 100644 index 0000000..b7fb7e2 --- /dev/null +++ b/backend/app/models/upstream_key.py @@ -0,0 +1,33 @@ +from datetime import datetime, timezone +from typing import Optional + +from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.database import Base + + +class UpstreamGeneratedKey(Base): + __tablename__ = "upstream_generated_keys" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + upstream_id: Mapped[int] = mapped_column(Integer, ForeignKey("upstreams.id", ondelete="CASCADE"), index=True) + group_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + group_name: Mapped[str] = mapped_column(String(255), default="") + key_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + key_name: Mapped[str] = mapped_column(String(255), nullable=False) + key_value: Mapped[str] = mapped_column(Text, nullable=False) + masked_key: Mapped[str] = mapped_column(String(255), default="") + raw_json: Mapped[str] = mapped_column(Text, default="{}") + managed_prefix: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True) + status: Mapped[str] = mapped_column(String(32), default="created") + error: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + imported_website_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("websites.id", ondelete="SET NULL"), nullable=True, index=True) + imported_account_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + imported_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc)) + updated_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + + __table_args__ = ( + UniqueConstraint("upstream_id", "group_id", "key_name", name="uq_upstream_group_key"), + ) diff --git a/backend/app/routers/upstreams.py b/backend/app/routers/upstreams.py index 1607b90..8b75f2d 100644 --- a/backend/app/routers/upstreams.py +++ b/backend/app/routers/upstreams.py @@ -3,6 +3,7 @@ from __future__ import annotations import json import logging +import re from datetime import datetime, timezone from typing import List @@ -14,11 +15,15 @@ from sqlalchemy.orm import Session from app.database import get_db from app.models.admin_user import AdminUser from app.models.upstream import Upstream +from app.models.upstream_key import UpstreamGeneratedKey from app.models.snapshot import UpstreamRateSnapshot from app.schemas.upstream import ( + GenerateKeysByGroupsRequest, + GenerateKeysByGroupsResponse, + GeneratedUpstreamKeyResponse, UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult ) -from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot +from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret from app.services.snapshot_service import diff_snapshots from app.services import scheduler as sched_svc from app.services import webhook_service @@ -31,6 +36,38 @@ MASK = "***" SECRET_KEYS = {"password", "token", "key", "secret"} +def _group_id(group: dict) -> str: + for key in ("id", "group_id", "groupId"): + value = group.get(key) + if value is not None: + return str(value) + return str(group.get("name") or group.get("group_name") or "") + + +def _group_name(group: dict, gid: str) -> str: + return str(group.get("name") or group.get("group_name") or gid) + + +def _key_response(row: UpstreamGeneratedKey, include_value: bool = False) -> GeneratedUpstreamKeyResponse: + return GeneratedUpstreamKeyResponse( + id=row.id, + upstream_id=row.upstream_id, + group_id=row.group_id, + group_name=row.group_name, + key_id=row.key_id, + key_name=row.key_name, + key_value=row.key_value if include_value else None, + masked_key=row.masked_key, + status=row.status, + error=row.error, + imported_website_id=row.imported_website_id, + imported_account_id=row.imported_account_id, + imported_at=row.imported_at, + created_at=row.created_at, + updated_at=row.updated_at, + ) + + def _mask_auth_config(auth_type: str, cfg: dict) -> dict: masked = {} for k, v in cfg.items(): @@ -73,6 +110,198 @@ def list_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)): return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()] +@router.get("/{uid}/generated-keys", response_model=List[GeneratedUpstreamKeyResponse]) +def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): + if not db.query(Upstream.id).filter(Upstream.id == uid).first(): + raise HTTPException(404, "upstream not found") + rows = ( + db.query(UpstreamGeneratedKey) + .filter(UpstreamGeneratedKey.upstream_id == uid) + .order_by(UpstreamGeneratedKey.id.desc()) + .limit(200) + .all() + ) + return [_key_response(row) for row in rows] + + +_generate_key_lock = __import__("threading").Lock() + + +def _ensure_group_key( + db: Session, + client: UpstreamClient, + upstream: Upstream, + group: dict[str, Any], + prefix: str, + body: GenerateKeysByGroupsRequest, +) -> GeneratedUpstreamKeyResponse: + """确保一个上游分组有一个 SmartUp 前缀 Key:存在则 upsert,不存在则创建。""" + gid = _group_id(group) + gname = _group_name(group, gid) + # 使用稳定的 upstream_id + group_id 而非可变名称,避免因改名产生重复 + # 可读 Key 名:{prefix}-{upstream.id}-{安全的分组名}-{group_id} + safe_group_name = re.sub(r"[^a-zA-Z0-9\u4e00-\u9fff_-]", "", gname)[:30] if gname else gid + stable_name = f"{prefix}-{upstream.id}-{safe_group_name}-{gid}" + + with _generate_key_lock: + try: + # 1. 先查本地是否已有该分组的托管 Key(兼容迁移前无 managed_prefix 的记录) + row = ( + db.query(UpstreamGeneratedKey) + .filter( + UpstreamGeneratedKey.upstream_id == upstream.id, + UpstreamGeneratedKey.group_id == gid, + (UpstreamGeneratedKey.managed_prefix == prefix) + | ((UpstreamGeneratedKey.managed_prefix.is_(None)) + & UpstreamGeneratedKey.key_name.like(f"{prefix}-%")), + ) + .first() + ) + if row and row.key_id: + # 本地已有记录,检查远端是否仍存在 + try: + existing = client.find_smartup_group_key(gid, stable_name, prefix) + except Exception: + existing = None + if existing: + row.status = "exists" + row.updated_at = datetime.now(timezone.utc) + db.commit() + return _key_response(row, include_value=False) + # 远端不存在,需要重新创建 + row.status = "replaced" + + # 2. 查远端是否有同名 Key(防止并发时另一个请求已创建) + existing = client.find_smartup_group_key(gid, stable_name, prefix) + if existing: + key_id = str(existing.get("id") or "") + masked = existing.get("masked_key") or existing.get("key") or "" + if row: + row.key_id = key_id or row.key_id + row.masked_key = masked or row.masked_key + row.status = "exists" + row.updated_at = datetime.now(timezone.utc) + else: + row = UpstreamGeneratedKey( + upstream_id=upstream.id, + group_id=gid, + group_name=gname, + key_id=key_id or None, + key_name=stable_name, + key_value="", + masked_key=masked, + raw_json=json.dumps(existing, ensure_ascii=False), + managed_prefix=prefix, + status="exists", + ) + db.add(row) + db.commit() + db.refresh(row) + return _key_response(row, include_value=False) + + # 3. 远端不存在,创建新 Key + created = client.create_api_key( + stable_name, + gid, + quota=body.quota, + expires_in_days=body.expires_in_days, + rate_limit_5h=body.rate_limit_5h, + rate_limit_1d=body.rate_limit_1d, + rate_limit_7d=body.rate_limit_7d, + endpoint=body.endpoint, + ) + if row: + # 复用旧行 + row.key_id = created.get("id") or None + row.key_name = stable_name + row.key_value = created["key"] + row.masked_key = created.get("masked_key") or mask_secret(created["key"]) + row.raw_json = json.dumps(created.get("raw") or {}, ensure_ascii=False) + row.managed_prefix = prefix + row.status = "created" + row.error = None + else: + row = UpstreamGeneratedKey( + upstream_id=upstream.id, + group_id=gid, + group_name=gname, + key_id=created.get("id") or None, + key_name=stable_name, + key_value=created["key"], + masked_key=created.get("masked_key") or mask_secret(created["key"]), + raw_json=json.dumps(created.get("raw") or {}, ensure_ascii=False), + managed_prefix=prefix, + status="created", + ) + db.add(row) + db.commit() + db.refresh(row) + return _key_response(row, include_value=True) + except Exception as exc: + logger.exception("ensure group key failed for upstream=%s group=%s", upstream.id, gid) + return GeneratedUpstreamKeyResponse( + upstream_id=upstream.id, + group_id=gid, + group_name=gname, + key_name=stable_name, + status="failed", + error=str(exc), + ) + + +@router.post("/{uid}/keys/generate-by-groups", response_model=GenerateKeysByGroupsResponse) +def generate_keys_by_groups( + uid: int, + body: GenerateKeysByGroupsRequest, + db: Session = Depends(get_db), + _=Depends(get_current_user), +): + u = db.query(Upstream).filter(Upstream.id == uid).first() + if not u: + raise HTTPException(404, "upstream not found") + if u.api_prefix.strip("/") != "api/v1": + raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1)") + + auth_config = json.loads(u.auth_config_json or "{}") + selected = set(body.group_ids) + prefix = body.name_prefix + results: list[GeneratedUpstreamKeyResponse] = [] + with UpstreamClient( + base_url=u.base_url, + api_prefix=u.api_prefix, + auth_type=u.auth_type, + auth_config=auth_config, + timeout=float(u.timeout_seconds), + ) as client: + try: + client.login() + groups = client.get_available_groups(u.groups_endpoint) + except Exception as exc: + raise HTTPException(502, str(exc)) + + for group in groups: + gid = _group_id(group) + if not gid or (selected and gid not in selected): + continue + result = _ensure_group_key(db, client, u, group, prefix, body) + results.append(result) + + created = len([item for item in results if item.status == "created"]) + existed = len([item for item in results if item.status == "exists"]) + total = len(results) + msg_parts = [] + if created: + msg_parts.append(f"新创建 {created}") + if existed: + msg_parts.append(f"已存在 {existed}") + msg = "、".join(msg_parts) + f" / 共 {total} 个分组" if msg_parts else f"共处理 {total} 个分组" + return GenerateKeysByGroupsResponse( + success=total > 0 and all(item.status != "failed" for item in results), + message=msg, + items=results, + ) + + @router.post("", response_model=UpstreamResponse, status_code=201) def create_upstream( body: UpstreamCreate, @@ -255,6 +484,10 @@ def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_use webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes) website_sync.sync_affected_bindings(db, u.id, changes) + # 同步 SmartUp Key 状态(使用实际快照入库时间,与定时任务一致) + from app.services.scheduler import _sync_upstream_keys as _synck + _synck(uid, snapshot, new_row.captured_at) + msg = f"检测成功,{len(groups)} 个分组" if changes: msg += f",发现 {len(changes)} 处倍率变化" diff --git a/backend/app/routers/websites.py b/backend/app/routers/websites.py index c9c1774..1b352b7 100644 --- a/backend/app/routers/websites.py +++ b/backend/app/routers/websites.py @@ -9,11 +9,21 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session from app.database import get_db +from app.models.snapshot import UpstreamRateSnapshot +from app.models.upstream import Upstream +from app.models.upstream_key import UpstreamGeneratedKey from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog from app.schemas.website import ( BindingCreate, BindingResponse, BindingUpdate, + ImportAccountItem, + ImportAccountsRequest, + ImportAccountsResponse, + ImportGroupItem, + ImportGroupsRequest, + ImportGroupsResponse, + SyncImportStatusRequest, TestResult, WebsiteCreate, WebsiteGroupResponse, @@ -118,6 +128,47 @@ def _client(row: Website) -> Sub2ApiWebsiteClient: ) +def _latest_upstream_groups(db: Session, upstream_id: int) -> list[dict]: + row = ( + db.query(UpstreamRateSnapshot) + .filter(UpstreamRateSnapshot.upstream_id == upstream_id) + .order_by(UpstreamRateSnapshot.captured_at.desc()) + .first() + ) + if not row: + raise HTTPException(404, "no upstream snapshot found; run upstream check first") + snapshot = json.loads(row.snapshot_json or "{}") + groups = snapshot.get("groups") or {} + if not isinstance(groups, dict): + return [] + return [item for item in groups.values() if isinstance(item, dict)] + + +def _source_group_id(group: dict) -> str: + return str(group.get("group_id") or group.get("id") or group.get("name") or "") + + +def _source_group_name(group: dict, gid: str) -> str: + return str(group.get("group_name") or group.get("name") or gid) + + +def _source_group_rate(group: dict) -> float: + raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1 + try: + return float(raw) + except (TypeError, ValueError): + return 1.0 + + +def _numeric_group_id(value: str | None) -> int | None: + if value is None or value == "": + return None + try: + return int(value) + except ValueError: + return None + + @router.get("/api/websites", response_model=List[WebsiteResponse]) def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)): return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()] @@ -215,6 +266,375 @@ def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_c raise HTTPException(502, str(exc)) +@router.post("/api/websites/{wid}/groups/import-from-upstream/{upstream_id}", response_model=ImportGroupsResponse) +def import_groups_from_upstream( + wid: int, + upstream_id: int, + body: ImportGroupsRequest, + db: Session = Depends(get_db), + _=Depends(get_current_user), +): + website = db.query(Website).filter(Website.id == wid).first() + if not website: + raise HTTPException(404, "website not found") + if website.site_type != "sub2api": + raise HTTPException(400, "目前只支持 sub2api") + upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first() + if not upstream: + raise HTTPException(404, "upstream not found") + + selected = set(body.group_ids) + groups = _latest_upstream_groups(db, upstream_id) + # 拉取目标网站已有分组,同名则跳过 + try: + existing_names = set() + with _client(website) as c: + for eg in c.get_groups(website.groups_endpoint): + gname = eg.get("name") or eg.get("group_name") or "" + if gname: + existing_names.add(gname) + except Exception: + existing_names = set() + + items: list[ImportGroupItem] = [] + with _client(website) as c: + for group in groups: + source_gid = _source_group_id(group) + if not source_gid or (selected and source_gid not in selected): + continue + source_name = _source_group_name(group, source_gid) + target_name = f"{body.name_prefix}{source_name}" if body.name_prefix else source_name + + # 检查是否已存在同名分组 + if target_name in existing_names: + items.append(ImportGroupItem( + source_group_id=source_gid, + source_group_name=source_name, + target_group_name=target_name, + status="exists", + message="目标分组已存在,已跳过", + )) + continue + + create_body = { + "name": target_name, + "description": group.get("description") or f"Imported from {upstream.name} / {source_name}", + "platform": group.get("platform") or "openai", + "rate_multiplier": _source_group_rate(group), + } + if group.get("rpm_limit") is not None: + create_body["rpm_limit"] = group.get("rpm_limit") + try: + created = c.create_group(create_body) + target_id = c.extract_id(created) + items.append(ImportGroupItem( + source_group_id=source_gid, + source_group_name=source_name, + target_group_id=target_id or None, + target_group_name=str(created.get("name") or target_name), + status="created", + message="已创建", + raw=created, + )) + except Exception as exc: + msg = str(exc) + # 捕获 409 等已存在错误 + if "已存在" in msg or "already exists" in msg.lower() or "409" in msg or "Conflict" in msg: + items.append(ImportGroupItem( + source_group_id=source_gid, + source_group_name=source_name, + target_group_name=target_name, + status="exists", + message="目标分组已存在(接口返回冲突)", + )) + else: + logger.exception("import website group failed website=%s upstream=%s group=%s", wid, upstream_id, source_gid) + items.append(ImportGroupItem( + source_group_id=source_gid, + source_group_name=source_name, + target_group_name=target_name, + status="failed", + message=msg, + )) + created_count = len([item for item in items if item.status == "created"]) + exists_count = len([item for item in items if item.status == "exists"]) + failed_count = len([item for item in items if item.status == "failed"]) + msg_parts = [] + if created_count: + msg_parts.append(f"新建 {created_count}") + if exists_count: + msg_parts.append(f"已存在 {exists_count}") + if failed_count: + msg_parts.append(f"失败 {failed_count}") + return ImportGroupsResponse( + success=failed_count == 0, + message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共处理 {len(items)} 个分组", + items=items, + ) + + +def _detect_platform(text: str, fallback: str = "openai") -> str: + """根据 Key 名或分组名关键词判断平台类型。""" + lower = text.lower() + if "claude" in lower or "anthropic" in lower: + return "anthropic" + if "gemini" in lower: + return "gemini" + if "antigravity" in lower: + return "antigravity" + return fallback + + +@router.post("/api/websites/{wid}/accounts/sync-imported-upstream-keys", response_model=ImportAccountsResponse) +def sync_imported_upstream_keys( + wid: int, + body: SyncImportStatusRequest, + db: Session = Depends(get_db), + _=Depends(get_current_user), +): + """校验已导入的上游 Key 在目标 Sub2API 账号管理中是否仍存在。""" + website = db.query(Website).filter(Website.id == wid).first() + if not website: + raise HTTPException(404, "website not found") + + rows = ( + db.query(UpstreamGeneratedKey) + .filter( + UpstreamGeneratedKey.upstream_id == body.upstream_id, + UpstreamGeneratedKey.imported_website_id == wid, + UpstreamGeneratedKey.imported_account_id.isnot(None), + ) + .all() + ) + items: list[ImportAccountItem] = [] + with _client(website) as c: + for row in rows: + platform = _detect_platform(f"{row.group_name} {row.group_id} {row.key_name}", "openai") + if not row.imported_account_id: + continue + old_account_id = row.imported_account_id + exists = c.account_exists(row.imported_account_id) + if exists is False: + row.imported_website_id = None + row.imported_account_id = None + row.imported_at = None + row.status = "created" + items.append(ImportAccountItem( + upstream_key_id=row.id, source_group_id=row.group_id, + source_group_name=row.group_name, account_id=old_account_id, + platform=platform, status="stale_cleared", + message="目标账号已删除,已清除导入标记", + )) + elif exists is True: + items.append(ImportAccountItem( + upstream_key_id=row.id, source_group_id=row.group_id, + source_group_name=row.group_name, account_id=old_account_id, + platform=platform, status="exists", message="目标账号仍存在", + )) + else: + items.append(ImportAccountItem( + upstream_key_id=row.id, source_group_id=row.group_id, + source_group_name=row.group_name, account_id=old_account_id, + platform=platform, status="check_failed", + message="无法校验目标账号存在性(目标网站认证/网络问题)", + )) + db.commit() + cleared_count = len([i for i in items if i.status == "stale_cleared"]) + check_failed_count = len([i for i in items if i.status == "check_failed"]) + msg_parts = [] + if cleared_count: + msg_parts.append(f"清除 {cleared_count}") + if check_failed_count: + msg_parts.append(f"校验失败 {check_failed_count}") + return ImportAccountsResponse( + success=check_failed_count == 0, + message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共校验 {len(items)} 个,无变化", + items=items, + ) + + +@router.post("/api/websites/{wid}/accounts/import-upstream-keys", response_model=ImportAccountsResponse) +def import_upstream_keys_as_accounts( + wid: int, + body: ImportAccountsRequest, + db: Session = Depends(get_db), + _=Depends(get_current_user), +): + website = db.query(Website).filter(Website.id == wid).first() + if not website: + raise HTTPException(404, "website not found") + if website.site_type != "sub2api": + raise HTTPException(400, "目前只支持 sub2api") + if not body.upstream_key_ids: + raise HTTPException(400, "请选择要导入的 Key") + rows = ( + db.query(UpstreamGeneratedKey) + .filter(UpstreamGeneratedKey.id.in_(body.upstream_key_ids)) + .order_by(UpstreamGeneratedKey.id) + .all() + ) + found_ids = {row.id for row in rows} + missing_ids = [kid for kid in body.upstream_key_ids if kid not in found_ids] + items: list[ImportAccountItem] = [ + ImportAccountItem( + upstream_key_id=kid, + source_group_id="", + source_group_name="", + platform=body.default_platform, + status="failed", + message="key not found", + ) + for kid in missing_ids + ] + # 查出来源上游的 Base URL + upstream_base_url = "" + if body.upstream_key_ids: + first_row = rows[0] if rows else None + if first_row: + from app.models.upstream import Upstream as _Up + _u = db.query(_Up).filter(_Up.id == first_row.upstream_id).first() + if _u: + upstream_base_url = _u.base_url + + with _client(website) as c: + for row in rows: + # 先确定平台(失败项也需要记录) + if body.platform_mode == "auto": + platform = _detect_platform( + f"{row.group_name} {row.group_id} {row.key_name}", + body.default_platform, + ) + else: + platform = body.default_platform + + # 幂等校验:已导入过则检查远端账号是否仍存在 + if row.imported_website_id == wid and row.imported_account_id: + old_account_id = row.imported_account_id + exists = c.account_exists(row.imported_account_id) + if exists is True: + items.append(ImportAccountItem( + upstream_key_id=row.id, + source_group_id=row.group_id, + source_group_name=row.group_name, + target_group_id=body.target_group_map.get(row.group_id), + account_id=old_account_id, + account_name=f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}", + platform=platform, + upstream_base_url=upstream_base_url, + status="exists", + message="已导入过,已跳过", + )) + continue + elif exists is False: + # 远端已删除,清空标记后继续创建 + row.imported_website_id = None + row.imported_account_id = None + row.imported_at = None + row.status = "created" + # 继续往下走(不 continue) + else: + # 校验失败,保守跳过 + items.append(ImportAccountItem( + upstream_key_id=row.id, + source_group_id=row.group_id, + source_group_name=row.group_name, + target_group_id=body.target_group_map.get(row.group_id), + account_id=old_account_id, + platform=platform, + status="check_failed", + message="无法校验目标账号状态,已保守跳过", + )) + continue + + if not row.key_value: + items.append(ImportAccountItem( + upstream_key_id=row.id, + source_group_id=row.group_id, + source_group_name=row.group_name, + platform=platform, + status="failed", + message="该 Key 无明文值,无法导入(远端已存在 Key 不会保留明文,请重新创建或手动填入)", + )) + continue + + target_group_id = body.target_group_map.get(row.group_id) + group_ids = [] + numeric_target = _numeric_group_id(target_group_id) + if numeric_target is not None: + group_ids.append(numeric_target) + account_name = f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}" + account_body = { + "name": account_name, + "platform": platform, + "type": "apikey", + "credentials": { + "api_key": row.key_value, + "base_url": upstream_base_url, + }, + "group_ids": group_ids, + "rate_multiplier": 1, + "concurrency": body.concurrency, + "priority": body.priority, + "notes": f"Imported by SmartUp from upstream key #{row.id}", + } + try: + created = c.create_account(account_body) + account_id = c.extract_id(created) + row.imported_website_id = wid + row.imported_account_id = account_id or None + row.imported_at = datetime.now(timezone.utc) + row.status = "imported" + row.error = None + db.commit() + items.append(ImportAccountItem( + upstream_key_id=row.id, + source_group_id=row.group_id, + source_group_name=row.group_name, + target_group_id=target_group_id, + account_id=account_id or None, + account_name=str(created.get("name") or account_name), + platform=platform, + upstream_base_url=upstream_base_url, + status="created", + message="已创建账号", + raw=created, + )) + except Exception as exc: + logger.exception("import upstream key as account failed website=%s key=%s", wid, row.id) + row.status = "import_failed" + row.error = str(exc) + db.commit() + items.append(ImportAccountItem( + upstream_key_id=row.id, + source_group_id=row.group_id, + source_group_name=row.group_name, + target_group_id=target_group_id, + account_name=account_name, + platform=platform, + upstream_base_url=upstream_base_url, + status="failed", + message=str(exc), + )) + created_count = len([item for item in items if item.status == "created"]) + exists_count = len([item for item in items if item.status == "exists"]) + failed_count = len([item for item in items if item.status == "failed"]) + check_failed_count = len([item for item in items if item.status == "check_failed"]) + msg_parts = [] + if created_count: + msg_parts.append(f"新建 {created_count}") + if exists_count: + msg_parts.append(f"已存在 {exists_count}") + if check_failed_count: + msg_parts.append(f"校验失败 {check_failed_count}") + if failed_count: + msg_parts.append(f"失败 {failed_count}") + return ImportAccountsResponse( + success=failed_count == 0 and check_failed_count == 0, + message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共处理 {len(items)} 个", + items=items, + ) + + @router.get("/api/group-bindings", response_model=List[BindingResponse]) def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)): rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all() diff --git a/backend/app/schemas/upstream.py b/backend/app/schemas/upstream.py index 2c1725f..776a76a 100644 --- a/backend/app/schemas/upstream.py +++ b/backend/app/schemas/upstream.py @@ -1,6 +1,6 @@ from datetime import datetime from typing import Optional, Any -from pydantic import BaseModel +from pydantic import BaseModel, Field class AuthConfigBearer(BaseModel): @@ -89,3 +89,38 @@ class TestResult(BaseModel): success: bool message: str detail: Optional[str] = None + + +class GenerateKeysByGroupsRequest(BaseModel): + group_ids: list[str] = Field(default_factory=list) + name_prefix: str = "SmartUp" + quota: float = Field(default=0, ge=0) + expires_in_days: Optional[int] = Field(default=None, ge=1) + rate_limit_5h: float = Field(default=0, ge=0) + rate_limit_1d: float = Field(default=0, ge=0) + rate_limit_7d: float = Field(default=0, ge=0) + endpoint: str = "/keys" + + +class GeneratedUpstreamKeyResponse(BaseModel): + id: Optional[int] = None + upstream_id: int + group_id: str + group_name: str = "" + key_id: Optional[str] = None + key_name: str + key_value: Optional[str] = None + masked_key: str = "" + status: str + error: Optional[str] = None + imported_website_id: Optional[int] = None + imported_account_id: Optional[str] = None + imported_at: Optional[datetime] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + +class GenerateKeysByGroupsResponse(BaseModel): + success: bool + message: str + items: list[GeneratedUpstreamKeyResponse] diff --git a/backend/app/schemas/website.py b/backend/app/schemas/website.py index b29504b..efbf87e 100644 --- a/backend/app/schemas/website.py +++ b/backend/app/schemas/website.py @@ -122,3 +122,58 @@ class WebsiteSyncLogResponse(BaseModel): status: str message: str created_at: datetime + + +class ImportGroupsRequest(BaseModel): + group_ids: list[str] = Field(default_factory=list) + name_prefix: str = "" + + +class ImportGroupItem(BaseModel): + source_group_id: str + source_group_name: str + target_group_id: Optional[str] = None + target_group_name: str = "" + status: str + message: str = "" + raw: dict[str, Any] = {} + + +class ImportGroupsResponse(BaseModel): + success: bool + message: str + items: list[ImportGroupItem] + + +class SyncImportStatusRequest(BaseModel): + upstream_id: int = Field(default=0) + + +class ImportAccountsRequest(BaseModel): + upstream_key_ids: list[int] = Field(default_factory=list) + target_group_map: dict[str, str] = Field(default_factory=dict) + account_name_prefix: str = "SmartUp" + default_platform: str = "openai" + platform_mode: str = "auto" # "auto" | "manual" + concurrency: int = Field(default=10, ge=1) + priority: int = Field(default=1, ge=0) + + +class ImportAccountItem(BaseModel): + upstream_key_id: int + source_group_id: str + source_group_name: str + target_group_id: Optional[str] = None + account_id: Optional[str] = None + account_name: str = "" + platform: str = "" + upstream_base_url: str = "" + status: str + message: str = "" + raw: dict[str, Any] = {} + + +class ImportAccountsResponse(BaseModel): + success: bool + message: str + items: list[ImportAccountItem] diff --git a/backend/app/services/scheduler.py b/backend/app/services/scheduler.py index d54b115..5bab0e7 100644 --- a/backend/app/services/scheduler.py +++ b/backend/app/services/scheduler.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from app.database import SessionLocal from app.models.upstream import Upstream +from app.models.upstream_key import UpstreamGeneratedKey from app.models.snapshot import UpstreamRateSnapshot from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot from app.services.snapshot_service import diff_snapshots, prune_snapshots @@ -130,7 +131,20 @@ def _check_upstream(upstream_id: int) -> None: finally: db.close() - # ── Phase 2: notifications (independent sessions) ────────────── + # ── Phase 2: key sync (independent session) ─────────────────── + if snapshot: + captured_at = snapshot.get("captured_at") + if isinstance(captured_at, str): + from datetime import datetime as dt + try: + captured_at = dt.fromisoformat(captured_at) + except Exception: + captured_at = datetime.now(timezone.utc) + elif captured_at is None: + captured_at = datetime.now(timezone.utc) + _sync_upstream_keys(upstream_id, snapshot, captured_at) + + # ── Phase 3: notifications (independent sessions) ────────────── if was_unhealthy: _notify_status(upstream_id, upstream.name, upstream.base_url, "upstream_recovered") @@ -170,6 +184,63 @@ def _notify_rate_changed( db.close() +def _sync_upstream_keys(upstream_id: int, snapshot: dict[str, Any], captured_at: datetime) -> None: + """上游检测成功后同步 SmartUp Key 状态(远端删除/分组删除)。""" + db = SessionLocal() + try: + active_group_ids = set(snapshot.get("groups", {}).keys()) + key_rows = ( + db.query(UpstreamGeneratedKey) + .filter( + UpstreamGeneratedKey.upstream_id == upstream_id, + UpstreamGeneratedKey.key_name.like("SmartUp-%"), + ) + .all() + ) + auth_config = json.loads( + db.query(Upstream).filter(Upstream.id == upstream_id).first().auth_config_json or "{}" + ) + # 用 UpstreamClient 查询远端活跃 Key ID 集合 + remote_key_ids: set[str] | None = None # None=查询失败,set()=查询成功但为空 + try: + upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first() + if upstream: + with UpstreamClient( + base_url=upstream.base_url, + api_prefix=upstream.api_prefix, + auth_type=upstream.auth_type, + auth_config=auth_config, + timeout=float(upstream.timeout_seconds), + ) as client: + client.login() + remote_keys = client.list_api_keys(search="SmartUp", status="active") + remote_key_ids = { + str(k["id"]) for k in remote_keys if k.get("id") + } + except Exception as exc: + logger.warning("sync upstream keys list failed for %s: %s", upstream_id, exc) + + for row in key_rows: + # 1. 分组已不在当前快照中 → 删除本地记录 + if row.group_id not in active_group_ids: + db.delete(row) + logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id) + continue + # 2. 远端查询成功但 key_id 不在列表中 → 删除本地记录 + if row.key_id and remote_key_ids is not None and row.key_id not in remote_key_ids: + db.delete(row) + logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id) + continue + # 3. 更新同步时间戳(仅当查询成功且 Key 仍在远端时) + if remote_key_ids is not None and row.key_id in remote_key_ids: + row.updated_at = captured_at + db.commit() + except Exception: + logger.exception("key sync failed for upstream %s", upstream_id) + finally: + db.close() + + def _sync_website_bindings(upstream_id: int, changes: list[dict[str, Any]]) -> None: db = SessionLocal() try: diff --git a/backend/app/services/upstream_client.py b/backend/app/services/upstream_client.py index da7f0cd..866def4 100644 --- a/backend/app/services/upstream_client.py +++ b/backend/app/services/upstream_client.py @@ -62,6 +62,49 @@ def _find_user_id(value: Any) -> str: return "" +def mask_secret(value: Any) -> str: + text = str(value or "") + if not text: + return "" + if len(text) <= 8: + return text[:2] + "****" + text[-2:] if len(text) > 4 else "****" + return text[:4] + "**********" + text[-4:] + + +def _unwrap_data(value: Any) -> Any: + if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value): + return value.get("data") + return value + + +def _extract_id(value: Any) -> str: + if isinstance(value, dict): + for key in ("id", "key_id", "keyId"): + candidate = value.get(key) + if candidate is not None: + return str(candidate) + for key in ("data", "result", "key", "api_key"): + found = _extract_id(value.get(key)) + if found: + return found + return "" + + +def _extract_key_value(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, dict): + for key in ("key", "api_key", "apiKey", "token", "value"): + candidate = value.get(key) + if isinstance(candidate, str) and candidate: + return candidate + for key in ("data", "result", "api_key", "key"): + found = _extract_key_value(value.get(key)) + if found: + return found + return "" + + def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]: def _normalize(lst: list) -> list[dict[str, Any]]: out = [] @@ -360,3 +403,107 @@ class UpstreamClient: return float(value) except (ValueError, TypeError): return None + + def list_api_keys( + self, + search: str = "", + group_id: str | int | None = None, + status: str = "active", + endpoint: str = "/keys", + ) -> list[dict[str, Any]]: + """查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。""" + params: dict[str, Any] = {} + if search: + params["search"] = search + if group_id is not None: + params["group_id"] = int(group_id) if str(group_id).isdigit() else group_id + if status: + params["status"] = status + url = self._url(endpoint) + resp = self._client.request( + "GET", + url, + params=params if params else None, + headers=self._headers(), + cookies=self._cookies, + ) + resp.raise_for_status() + data = resp.json() + if isinstance(data, list): + return data + if isinstance(data, dict): + # 尝试展开常见的包装结构 + for top_key in ("data", "result", "response"): + val = data.get(top_key) + if isinstance(val, list): + return val + if isinstance(val, dict): + for inner_key in ("items", "keys", "list", "records", "data"): + inner = val.get(inner_key) + if isinstance(inner, list): + return inner + # 顶层本身就是 list-like wrapper + for key in ("items", "keys", "list", "records"): + val = data.get(key) + if isinstance(val, list): + return val + raise UpstreamError(f"unexpected keys response type: {type(data).__name__}") + + def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None: + """删除远端上游上的一个 Key。""" + self._request("DELETE", f"{endpoint}/{key_id}") + + def find_smartup_group_key( + self, + group_id: str | int, + expected_name: str, + prefix: str = "SmartUp", + ) -> dict[str, Any] | None: + """查找同一上游分组下是否已存在 SmartUp 前缀的 Key。 + + 匹配规则:key_name 等于 expected_name,且以 prefix 开头。 + 返回匹配到的第一个 Key,或 None。 + """ + gid = int(group_id) if str(group_id).isdigit() else group_id + keys = self.list_api_keys(search=prefix, group_id=gid, status="active") + for k in keys: + name = k.get("name") or k.get("key_name") or "" + if name == expected_name: + return k + # 部分后端返回的 name 可能带空格或 trimming + if name.strip() == expected_name.strip(): + return k + return None + + def create_api_key( + self, + name: str, + group_id: str | int, + quota: float = 0, + expires_in_days: int | None = None, + rate_limit_5h: float = 0, + rate_limit_1d: float = 0, + rate_limit_7d: float = 0, + endpoint: str = "/keys", + ) -> dict[str, Any]: + body: dict[str, Any] = { + "name": name, + "group_id": int(group_id) if str(group_id).isdigit() else group_id, + "quota": quota, + "rate_limit_5h": rate_limit_5h, + "rate_limit_1d": rate_limit_1d, + "rate_limit_7d": rate_limit_7d, + } + if expires_in_days: + body["expires_in_days"] = expires_in_days + resp = self._request("POST", endpoint, body) + data = _unwrap_data(resp) + key_value = _extract_key_value(data) + if not key_value: + raise UpstreamError("key create response did not include key") + return { + "id": _extract_id(data), + "key": key_value, + "masked_key": mask_secret(key_value), + "raw": data if isinstance(data, dict) else {"value": data}, + } diff --git a/backend/app/services/website_client.py b/backend/app/services/website_client.py index f84a701..a75f5b2 100644 --- a/backend/app/services/website_client.py +++ b/backend/app/services/website_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from decimal import Decimal, InvalidOperation, ROUND_HALF_UP from typing import Any from urllib.parse import quote @@ -8,11 +9,39 @@ import httpx from app.utils.number import decimal_string +logger = logging.getLogger(__name__) + class WebsiteError(RuntimeError): pass +def _friendly_http_error(exc: httpx.HTTPStatusError) -> str: + """将常见 HTTP 错误转换为中文友好提示,原始信息保留在日志中。""" + status = exc.response.status_code + url = exc.request.url if exc.request else "?" + logger.warning("website_client HTTP %s from %s: %s", status, url, exc) + if status == 401: + return "目标网站认证失败,请检查 Admin API Key / JWT 是否正确" + if status == 403: + return "目标网站权限不足,请检查当前凭证是否有分组管理权限" + if status == 404: + return f"目标网站接口不存在,请检查 API Prefix 和分组接口路径({exc.response.url.path})" + if 500 <= status < 600: + return "目标网站服务异常,请稍后重试" + return f"目标网站返回错误(HTTP {status})" + + +def _friendly_connection_error(exc: Exception) -> str: + """将网络/超时异常转换为中文友好提示。""" + logger.warning("website_client connection error: %s", exc) + if isinstance(exc, httpx.TimeoutException): + return "目标网站请求超时,请检查网络连接和 API 地址是否正确" + if isinstance(exc, httpx.ConnectError): + return "无法连接目标网站,请检查 API 地址和网络连通性" + return f"目标网站通信异常:{exc}" + + def parse_positive_decimal(value: Any) -> Decimal | None: if value is None or value == "": return None @@ -59,6 +88,19 @@ def _unwrap_data(value: Any) -> Any: return value +def _extract_id(value: Any) -> str: + if isinstance(value, dict): + for key in ("id", "account_id", "accountId", "group_id", "groupId"): + candidate = value.get(key) + if candidate is not None: + return str(candidate) + for key in ("data", "result", "account", "group"): + found = _extract_id(value.get(key)) + if found: + return found + return "" + + def normalize_groups(value: Any) -> list[dict[str, Any]]: raw = _unwrap_data(value) if isinstance(raw, dict): @@ -129,24 +171,111 @@ class Sub2ApiWebsiteClient: return headers def _request(self, method: str, path: str, body: Any = None) -> Any: - resp = self._client.request(method, self._url(path), json=body, headers=self._headers()) - resp.raise_for_status() + try: + resp = self._client.request(method, self._url(path), json=body, headers=self._headers()) + resp.raise_for_status() + except httpx.HTTPStatusError as exc: + raise WebsiteError(_friendly_http_error(exc)) from exc + except httpx.TimeoutException as exc: + raise WebsiteError(_friendly_connection_error(exc)) from exc + except httpx.ConnectError as exc: + raise WebsiteError(_friendly_connection_error(exc)) from exc if not resp.content: return None text = resp.text if "application/json" not in resp.headers.get("content-type", "") and text.lstrip().startswith("<"): - raise WebsiteError(f"{method} {path} returned HTML, not JSON") + raise WebsiteError(f"{method} {path} 返回了 HTML,请检查接口地址是否正确") return resp.json() def get_groups(self, endpoint: str = "/groups") -> list[dict[str, Any]]: - errors: list[str] = [] + """拉取分组列表,尝试 endpoint 和 fallback /groups/all。""" + last_error: Exception | None = None + tried_paths: list[str] = [] for path in [endpoint, "/groups/all"]: + tried_paths.append(path) try: return normalize_groups(self._request("GET", path)) + except WebsiteError as exc: + msg = str(exc) + # 认证/权限类错误:直接抛出,不需要尝试 fallback + if "认证失败" in msg or "权限不足" in msg: + raise + # 404/5xx 等路径相关错误,试试另一个路径 + last_error = exc except Exception as exc: - errors.append(f"{path}: {exc}") - raise WebsiteError("; ".join(errors)) + last_error = exc + logger.info("get_groups fallback %s failed: %s", path, exc) + + msg = str(last_error) if last_error else "拉取分组失败" + raise WebsiteError(f"{msg}(尝试接口:{'、'.join(tried_paths)})") def update_group_rate(self, endpoint_template: str, group_id: str, rate: Decimal) -> Any: path = endpoint_template.replace("{id}", quote(group_id, safe="")) return self._request("PUT", path, {"rate_multiplier": float(rate)}) + + def create_group(self, body: dict[str, Any], endpoint: str = "/groups") -> dict[str, Any]: + resp = self._request("POST", endpoint, body) + data = _unwrap_data(resp) + return data if isinstance(data, dict) else {"value": data} + + def create_account(self, body: dict[str, Any], endpoint: str = "/accounts") -> dict[str, Any]: + resp = self._request("POST", endpoint, body) + data = _unwrap_data(resp) + return data if isinstance(data, dict) else {"value": data} + + @staticmethod + def _unwrap_list(value: dict) -> list | None: + """递归展开嵌套的列表包装:data.items、data.data、items、accounts 等。""" + if isinstance(value, list): + return value + if not isinstance(value, dict): + return None + # 先看顶层 + for key in ("items", "accounts", "records", "list", "data"): + v = value.get(key) + if isinstance(v, list): + return v + # 再看 data.items、data.records、data.list 等嵌套 + data_val = value.get("data") + if isinstance(data_val, dict): + for key in ("items", "records", "list", "data", "accounts"): + v = data_val.get(key) + if isinstance(v, list): + return v + return None + + def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None: + """拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。""" + try: + resp = self._request("GET", endpoint) + except Exception: + logger.warning("account list fetch failed for %s", endpoint, exc_info=True) + return None + items = self._unwrap_list(resp) + if items is None: + logger.warning("account list unexpected format for %s", endpoint) + return None + ids: set[str] = set() + for item in items: + item_id = self.extract_id(item) + if item_id: + ids.add(item_id) + return ids + + def account_exists(self, account_id: str, endpoint: str = "/accounts") -> bool | None: + """检查目标账号是否存在。 + + 优先拉取账号列表判断: + - 列表成功取到 → return account_id in ids(True=存在,False=已删除) + - 列表取不到(None)→ return None(校验失败,不清本地) + 返回 True=存在,False=已删除,None=校验失败。 + """ + ids = self._get_account_ids(endpoint) + if ids is None: + logger.warning("account_exists cannot verify %s: list fetch failed", account_id) + return None + return account_id in ids + + @staticmethod + def extract_id(value: Any) -> str: + return _extract_id(value) diff --git a/backend/test_upstream_key_account_import.py b/backend/test_upstream_key_account_import.py new file mode 100644 index 0000000..e62d54f --- /dev/null +++ b/backend/test_upstream_key_account_import.py @@ -0,0 +1,383 @@ +import json +import sys +from pathlib import Path + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from app.database import Base +from app.models.upstream import Upstream +from app.models.upstream_key import UpstreamGeneratedKey +from app.models.website import Website +from app.routers import websites as websites_router +from app.schemas.website import ImportAccountsRequest + + +@pytest.fixture() +def db_session(): + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = TestingSessionLocal() + try: + yield db + finally: + db.close() + Base.metadata.drop_all(bind=engine) + + +def seed_account_import_rows(db_session): + website = Website( + name="My Sub2API", + site_type="sub2api", + base_url="http://sub2api.local", + api_prefix="/api/v1/admin", + auth_type="api_key", + auth_config_json=json.dumps({"key": "admin-key", "header": "x-api-key"}), + groups_endpoint="/groups", + group_update_endpoint="/groups/{id}", + ) + upstream = Upstream( + name="Packy", + base_url="http://packy.local", + api_prefix="/api/v1", + auth_type="login_password", + auth_config_json="{}", + ) + db_session.add_all([website, upstream]) + db_session.commit() + db_session.refresh(website) + db_session.refresh(upstream) + generated = UpstreamGeneratedKey( + upstream_id=upstream.id, + group_id="vip", + group_name="VIP", + key_id="up-key-id", + key_name="SmartUp-VIP", + key_value="sk-upstream-generated", + masked_key="sk-u...ated", + raw_json="{}", + status="created", + ) + db_session.add(generated) + db_session.commit() + db_session.refresh(generated) + return website, generated + + +def test_import_upstream_key_creates_sub2api_account_management_apikey(monkeypatch, db_session): + website, generated = seed_account_import_rows(db_session) + account_bodies = [] + + class FakeWebsiteClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def create_account(self, body, endpoint="/accounts"): + account_bodies.append((endpoint, body)) + return {"id": 101, "name": body["name"]} + + def account_exists(self, account_id): + return True + + @staticmethod + def extract_id(data): + return str(data.get("id")) + + monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeWebsiteClient) + + response = websites_router.import_upstream_keys_as_accounts( + website.id, + ImportAccountsRequest( + upstream_key_ids=[generated.id], + target_group_map={"vip": "7"}, + account_name_prefix="SmartUp", + default_platform="anthropic", + ), + db_session, + object(), + ) + + assert "新建 1" in response.message + assert len(account_bodies) == 1 + endpoint, body = account_bodies[0] + assert endpoint == "/accounts" + assert body["type"] == "apikey" + assert body["platform"] == "anthropic" + assert body["credentials"]["api_key"] == "sk-upstream-generated" + assert body["credentials"]["base_url"] == "http://packy.local" + assert body["group_ids"] == [7] + assert body["concurrency"] == 10 + assert body["priority"] == 1 + assert body["credentials"]["base_url"] == "http://packy.local" + + db_session.refresh(generated) + assert generated.status == "imported" + assert generated.imported_website_id == website.id + assert generated.imported_account_id == "101" + + +def test_import_upstream_key_idempotent_skips_already_imported(monkeypatch, db_session): + """已导入过的 Key 再次调用不会重复创建。""" + website, generated = seed_account_import_rows(db_session) + generated.imported_website_id = website.id + generated.imported_account_id = "101" + db_session.commit() + + create_called = False + + class FakeWebsiteClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + def __enter__(self): + return self + def __exit__(self, exc_type, exc, tb): + return False + def create_account(self, body, endpoint="/accounts"): + nonlocal create_called + create_called = True + return {"id": 999, "name": body["name"]} + def account_exists(self, account_id): + return True # 模拟远端账号仍存在 + @staticmethod + def extract_id(data): + return str(data.get("id")) + + monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeWebsiteClient) + + response = websites_router.import_upstream_keys_as_accounts( + website.id, + ImportAccountsRequest( + upstream_key_ids=[generated.id], + target_group_map={"vip": "7"}, + account_name_prefix="SmartUp", + default_platform="anthropic", + ), + db_session, + object(), + ) + + # create_account 不应被调用 + assert not create_called, "create_account should NOT be called for already imported key" + # 返回 exists + assert len(response.items) == 1 + assert response.items[0].status == "exists" + assert response.items[0].account_id == "101" + assert response.items[0].message == "已导入过,已跳过" + # success 应为 true(没有 failed) + assert response.success + + +def test_sync_clears_stale_import_mark(monkeypatch, db_session): + """同步接口:远端账号已删除时清除本地导入标记。""" + from app.routers import websites as websites_router + from app.schemas.website import SyncImportStatusRequest, ImportAccountsRequest + + website, generated = seed_account_import_rows(db_session) + generated.imported_website_id = website.id + generated.imported_account_id = "101" + db_session.commit() + + class FakeClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + def __enter__(self): + return self + def __exit__(self, *a): + return False + def account_exists(self, account_id): + return False # 远端返回不存在 + def _get_account_ids(self, endpoint="/accounts"): + return set() + @staticmethod + def extract_id(data): + return str(data.get("id")) + + monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", lambda **kw: FakeClient(**kw)) + + response = websites_router.sync_imported_upstream_keys( + website.id, SyncImportStatusRequest(upstream_id=generated.upstream_id), + db_session, object(), + ) + + assert len(response.items) == 1 + assert response.items[0].status == "stale_cleared" + assert response.items[0].account_id == "101" + db_session.refresh(generated) + assert generated.imported_account_id is None + + +def test_sync_preserves_mark_on_check_failed(monkeypatch, db_session): + """同步接口:校验失败时不清除本地标记。""" + from app.routers import websites as websites_router + from app.schemas.website import SyncImportStatusRequest + + website, generated = seed_account_import_rows(db_session) + generated.imported_website_id = website.id + generated.imported_account_id = "101" + db_session.commit() + + class FakeClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + def __enter__(self): + return self + def __exit__(self, *a): + return False + def account_exists(self, account_id): + return None # 校验失败 + def _get_account_ids(self, endpoint="/accounts"): + return set() + @staticmethod + def extract_id(data): + return str(data.get("id")) + + monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", lambda **kw: FakeClient(**kw)) + + response = websites_router.sync_imported_upstream_keys( + website.id, SyncImportStatusRequest(upstream_id=generated.upstream_id), + db_session, object(), + ) + + assert len(response.items) == 1 + assert response.items[0].status == "check_failed" + db_session.refresh(generated) + assert generated.imported_account_id == "101" # 未被清除 + + +def test_import_rebuilds_when_remote_deleted(monkeypatch, db_session): + """导入接口:远端账号已删除时自动清标记并重新创建。""" + from app.routers import websites as websites_router + from app.schemas.website import ImportAccountsRequest + + website, generated = seed_account_import_rows(db_session) + generated.imported_website_id = website.id + generated.imported_account_id = "101" + db_session.commit() + + class FakeClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + def __enter__(self): + return self + def __exit__(self, *a): + return False + def account_exists(self, account_id): + return False + def create_account(self, body, endpoint="/accounts"): + return {"id": 202, "name": body["name"]} + @staticmethod + def extract_id(data): + return str(data.get("id")) + + monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", lambda **kw: FakeClient(**kw)) + + response = websites_router.import_upstream_keys_as_accounts( + website.id, + ImportAccountsRequest(upstream_key_ids=[generated.id], target_group_map={"vip": "7"}, + account_name_prefix="SmartUp", default_platform="openai"), + db_session, object(), + ) + + assert len(response.items) == 1 + assert response.items[0].status == "created", f"expected created, got {response.items[0].status}" + db_session.refresh(generated) + assert generated.imported_account_id == "202" + + +def test_import_skips_on_check_failed(monkeypatch, db_session): + """导入接口:校验失败时保守跳过,不创建也不清除。""" + from app.routers import websites as websites_router + from app.schemas.website import ImportAccountsRequest + + website, generated = seed_account_import_rows(db_session) + generated.imported_website_id = website.id + generated.imported_account_id = "101" + db_session.commit() + + class FakeClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + def __enter__(self): + return self + def __exit__(self, *a): + return False + def account_exists(self, account_id): + return None + def create_account(self, body, endpoint="/accounts"): + raise RuntimeError("should not be called") + @staticmethod + def extract_id(data): + return str(data.get("id")) + + monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", lambda **kw: FakeClient(**kw)) + + response = websites_router.import_upstream_keys_as_accounts( + website.id, + ImportAccountsRequest(upstream_key_ids=[generated.id], target_group_map={"vip": "7"}, + account_name_prefix="SmartUp", default_platform="openai"), + db_session, object(), + ) + + assert len(response.items) == 1 + assert response.items[0].status == "check_failed" + db_session.refresh(generated) + assert generated.imported_account_id == "101" # 未被清除 + assert not response.success + + +def test_import_upstream_key_with_custom_concurrency_and_priority(monkeypatch, db_session): + website, generated = seed_account_import_rows(db_session) + account_bodies = [] + + class FakeWebsiteClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + def __enter__(self): + return self + def __exit__(self, exc_type, exc, tb): + return False + def create_account(self, body, endpoint="/accounts"): + account_bodies.append((endpoint, body)) + return {"id": 202, "name": body["name"]} + def account_exists(self, account_id): + return True + @staticmethod + def extract_id(data): + return str(data.get("id")) + + monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeWebsiteClient) + + websites_router.import_upstream_keys_as_accounts( + website.id, + ImportAccountsRequest( + upstream_key_ids=[generated.id], + target_group_map={"vip": "7"}, + account_name_prefix="SmartUp", + default_platform="openai", + concurrency=20, + priority=5, + ), + db_session, + object(), + ) + + assert len(account_bodies) == 1 + _, body = account_bodies[0] + assert body["concurrency"] == 20 + assert body["priority"] == 5 + assert body["credentials"]["base_url"] == "http://packy.local" diff --git a/backend/test_upstream_key_sync.py b/backend/test_upstream_key_sync.py new file mode 100644 index 0000000..a5f883f --- /dev/null +++ b/backend/test_upstream_key_sync.py @@ -0,0 +1,469 @@ +"""Tests for upstream key uniquification and sync cleanup.""" +import json +from datetime import datetime, timezone + +import pytest +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.models.upstream import Upstream +from app.models.website import Website # noqa: F401 — registers table for FK refs +from app.models.upstream_key import UpstreamGeneratedKey + + +@pytest.fixture() +def db_session(): + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = TestingSessionLocal() + try: + yield db + finally: + db.close() + Base.metadata.drop_all(bind=engine) + + +def test_duplicate_cleanup_keeps_latest_only(): + """同一 upstream_id + group_id + key_name 的多条记录只保留最新一条。 + + 使用独立 engine + 全 raw SQL,模拟迁移前的数据库状态。 + """ + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + with engine.begin() as conn: + conn.execute(text(""" + CREATE TABLE upstreams ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(255) NOT NULL, + base_url VARCHAR(512) NOT NULL, + api_prefix VARCHAR(64) DEFAULT '', + auth_type VARCHAR(32), + auth_config_json TEXT DEFAULT '{}', + groups_endpoint VARCHAR(256), + rate_endpoint VARCHAR(256), + enabled BOOLEAN DEFAULT 1, + check_interval_seconds INTEGER DEFAULT 600, + timeout_seconds INTEGER DEFAULT 30, + last_status VARCHAR(32) DEFAULT 'unknown', + last_checked_at DATETIME, + last_error TEXT, + consecutive_failures INTEGER DEFAULT 0, + balance FLOAT, + balance_updated_at DATETIME, + balance_endpoint VARCHAR(256) DEFAULT '', + balance_response_path VARCHAR(256) DEFAULT '', + balance_divisor FLOAT DEFAULT 1.0, + updated_at DATETIME, + created_at DATETIME + ) + """)) + conn.execute(text(""" + CREATE TABLE upstream_generated_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + upstream_id INTEGER NOT NULL, + group_id VARCHAR(255) NOT NULL, + group_name VARCHAR(255) DEFAULT '', + key_id VARCHAR(255), + key_name VARCHAR(255) NOT NULL, + key_value TEXT NOT NULL, + masked_key VARCHAR(255) DEFAULT '', + raw_json TEXT DEFAULT '{}', + status VARCHAR(32) DEFAULT 'created', + error TEXT, + imported_website_id INTEGER, + imported_account_id VARCHAR(255), + imported_at DATETIME, + created_at DATETIME, + updated_at DATETIME + ) + """)) + TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = TestingSessionLocal() + try: + now = datetime.now(timezone.utc) + db.execute(text(""" + INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at) + VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua) + """), {"n": "Test", "b": "http://local", "p": "/api/v1", "a": "bearer", + "j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now}) + db.commit() + uid = db.execute(text("SELECT id FROM upstreams LIMIT 1")).scalar() + + # 插入 3 条重复记录 + for kv, ca in [ + ("old-key", datetime(2025, 1, 1, tzinfo=timezone.utc)), + ("middle-key", datetime(2025, 6, 1, tzinfo=timezone.utc)), + ("newest-key", datetime(2025, 12, 1, tzinfo=timezone.utc)), + ]: + db.execute(text(""" + INSERT INTO upstream_generated_keys + (upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at) + VALUES (:uid, :gid, :gn, :kn, :kv, :mk, :rj, :st, :ca, :ca) + """), { + "uid": uid, "gid": "vip", "gn": "VIP", + "kn": "SmartUp-Test-VIP", "kv": kv, + "mk": "", "rj": "{}", "st": "created", "ca": ca, + }) + db.commit() + + # 清理:同一组合只保留最新一条(id 最大) + db.execute(text(""" + DELETE FROM upstream_generated_keys + WHERE id NOT IN ( + SELECT MAX(id) FROM upstream_generated_keys + GROUP BY upstream_id, group_id, key_name + ) + """)) + db.commit() + + remaining = db.execute(text("SELECT key_value FROM upstream_generated_keys")).fetchall() + assert len(remaining) == 1, f"expected 1 after cleanup, got {len(remaining)}" + assert remaining[0][0] == "newest-key" + finally: + db.close() + + +def test_migration_backfills_managed_prefix_and_deduplicates(): + """迁移逻辑应回填历史 SmartUp 记录的 managed_prefix 并清理重复。 + + 使用独立 engine(不创建唯一约束),模拟迁移前状态。 + """ + from sqlalchemy import text as _text + + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + with engine.begin() as conn: + conn.execute(_text(""" + CREATE TABLE upstreams ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(255) NOT NULL, base_url VARCHAR(512) NOT NULL, + api_prefix VARCHAR(64) DEFAULT '', auth_type VARCHAR(32), + auth_config_json TEXT DEFAULT '{}', groups_endpoint VARCHAR(256), + rate_endpoint VARCHAR(256), enabled BOOLEAN DEFAULT 1, + check_interval_seconds INTEGER DEFAULT 600, timeout_seconds INTEGER DEFAULT 30, + last_status VARCHAR(32) DEFAULT 'unknown', last_checked_at DATETIME, + last_error TEXT, consecutive_failures INTEGER DEFAULT 0, + balance FLOAT, balance_updated_at DATETIME, + balance_endpoint VARCHAR(256) DEFAULT '', balance_response_path VARCHAR(256) DEFAULT '', + balance_divisor FLOAT DEFAULT 1.0, updated_at DATETIME, created_at DATETIME + ) + """)) + conn.execute(_text(""" + CREATE TABLE upstream_generated_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL, + group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255), + key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL, + masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}', + managed_prefix VARCHAR(64), status VARCHAR(32) DEFAULT 'created', + error TEXT, imported_website_id INTEGER, imported_account_id VARCHAR(255), + imported_at DATETIME, created_at DATETIME, updated_at DATETIME + ) + """)) + + TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = TestingSessionLocal() + try: + now = datetime.now(timezone.utc) + db.execute(_text(""" + INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at) + VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua) + """), {"n": "Old", "b": "http://local", "p": "/api/v1", "a": "bearer", + "j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now}) + db.commit() + uid = db.execute(_text("SELECT id FROM upstreams LIMIT 1")).scalar() + + # 插入两条重复记录(无 managed_prefix,key_name 以 SmartUp 开头) + for kv in ("sk-old", "sk-new"): + db.execute(_text(""" + INSERT INTO upstream_generated_keys + (upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at) + VALUES (:uid, :gid, :gn, :kn, :kv, '', '{}', 'created', :ca, :ca) + """), {"uid": uid, "gid": "vip", "gn": "VIP", + "kn": "SmartUp-Old-vip", "kv": kv, "ca": now}) + db.commit() + + # 执行迁移逻辑(与 database.py 中的 SQL 一致) + conn = db.connection() + conn.execute(_text( + "UPDATE upstream_generated_keys SET managed_prefix = 'SmartUp' " + "WHERE managed_prefix IS NULL AND key_name LIKE 'SmartUp-%'" + )) + to_delete = conn.execute(_text(""" + SELECT id FROM upstream_generated_keys + WHERE managed_prefix IS NOT NULL + AND id NOT IN ( + SELECT MAX(id) FROM upstream_generated_keys + WHERE managed_prefix IS NOT NULL + GROUP BY upstream_id, group_id, managed_prefix + ) + """)).fetchall() + for (row_id,) in to_delete: + conn.execute(_text("DELETE FROM upstream_generated_keys WHERE id = :id"), {"id": row_id}) + db.commit() + + remaining = db.execute(_text("SELECT key_value, managed_prefix FROM upstream_generated_keys")).fetchall() + assert len(remaining) == 1, f"expected 1 after migration, got {len(remaining)}" + assert remaining[0][0] == "sk-new" # 保留最新一条 + assert remaining[0][1] == "SmartUp" # 已回填 + finally: + db.close() + + +def test_ensure_group_key_reuses_old_record(db_session, monkeypatch): + """_ensure_group_key 应复用 managed_prefix IS NULL 的旧记录,不新建。""" + from app.routers.upstreams import _ensure_group_key + from app.models.upstream_key import UpstreamGeneratedKey + from app.services.upstream_client import UpstreamClient + from app.schemas.upstream import GenerateKeysByGroupsRequest + + upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1", + auth_type="bearer", auth_config_json="{}", + groups_endpoint="/groups", rate_endpoint="/rates") + db_session.add(upstream) + db_session.commit() + db_session.refresh(upstream) + + # 插入一条旧记录(无 managed_prefix) + db_session.add(UpstreamGeneratedKey( + upstream_id=upstream.id, group_id="vip", group_name="VIP", + key_name="SmartUp-Test-vip", key_value="sk-old", + managed_prefix=None, key_id="remote-999", + )) + db_session.commit() + + # 构造 mock client + class MockClient: + def find_smartup_group_key(self, gid, name, prefix): + return None + def create_api_key(self, name, group_id, **kw): + return {"id": "new-remote", "key": "sk-new-value", "masked_key": "sk-****-lue"} + + group = {"id": "vip", "name": "VIP", "rate_multiplier": 1} + body = GenerateKeysByGroupsRequest( + group_ids=["vip"], name_prefix="SmartUp", + quota=0, endpoint="/keys", + ) + result = _ensure_group_key(db_session, MockClient(), upstream, group, "SmartUp", body) + + assert result.status == "created" + + rows = db_session.query(UpstreamGeneratedKey).filter( + UpstreamGeneratedKey.upstream_id == upstream.id, + UpstreamGeneratedKey.group_id == "vip", + ).all() + assert len(rows) == 1, f"expected 1 record, got {len(rows)}" + assert rows[0].managed_prefix == "SmartUp" + assert rows[0].key_value == "sk-new-value" + + +def test_sync_removes_remote_key_when_list_empty(db_session, monkeypatch): + """同步函数在远端返回空列表时应删除本地 key_id 对应的记录。""" + from app.services import scheduler as sched_mod + from app.models.upstream_key import UpstreamGeneratedKey + from app.services.upstream_client import UpstreamClient + + upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1", + auth_type="bearer", auth_config_json="{}", + groups_endpoint="/groups", rate_endpoint="/rates") + db_session.add(upstream) + db_session.commit() + db_session.refresh(upstream) + + db_session.add(UpstreamGeneratedKey( + upstream_id=upstream.id, group_id="vip", group_name="VIP", + key_name="SmartUp-Test-vip", key_value="sk-vip", + managed_prefix="SmartUp", key_id="remote-key-id", + )) + db_session.commit() + + # mock list_api_keys 返回空列表(查询成功但无 Key) + monkeypatch.setattr(UpstreamClient, "list_api_keys", lambda self, **kw: []) + monkeypatch.setattr(UpstreamClient, "login", lambda self: None) + monkeypatch.setattr(UpstreamClient, "close", lambda self: None) + monkeypatch.setattr(UpstreamClient, "__enter__", lambda self: self) + monkeypatch.setattr(UpstreamClient, "__exit__", lambda self, *a: None) + + # 让 _sync_upstream_keys 使用 db_session 的 bind 引擎 + monkeypatch.setattr(sched_mod, "SessionLocal", + lambda: db_session) + # 阻止 finally 中的 db.close() 影响测试会话 + original_close = db_session.close + monkeypatch.setattr(db_session, "close", lambda: None) + + snapshot = { + "upstream_id": upstream.id, + "groups": {"vip": {"group_id": "vip", "rate": "1"}}, + "captured_at": datetime.now(timezone.utc).isoformat(), + } + + captured_at = datetime.now(timezone.utc) + sched_mod._sync_upstream_keys(upstream.id, snapshot, captured_at) + + monkeypatch.setattr(db_session, "close", original_close) + remaining = db_session.query(UpstreamGeneratedKey).all() + assert len(remaining) == 0, f"expected 0 after sync with empty remote, got {len(remaining)}" + + +def test_migration_function_integration(monkeypatch): + """直接调用 _migrate_upstream_generated_keys() 验证列新增和索引创建。""" + from app.database import _migrate_upstream_generated_keys, engine as real_engine + from sqlalchemy import text as _text + + # 使用独立 engine,避免影响真实数据库 + test_engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + monkeypatch.setattr("app.database.engine", test_engine) + + # 建表(不含 managed_prefix 列,模拟旧版 schema) + with test_engine.begin() as conn: + conn.execute(_text(""" + CREATE TABLE upstream_generated_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL, + group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255), + key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL, + masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}', + status VARCHAR(32) DEFAULT 'created', error TEXT, + imported_website_id INTEGER, imported_account_id VARCHAR(255), + imported_at DATETIME, created_at DATETIME + ) + """)) + conn.execute(_text(""" + INSERT INTO upstream_generated_keys + (upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at) + VALUES (1, 'vip', 'VIP', 'SmartUp-Old-vip', 'sk-val', '', '{}', 'created', datetime('now')) + """)) + + # 调用迁移函数入口 + _migrate_upstream_generated_keys() + + # 验证 managed_prefix 列已存在且被填充 + inspector = __import__('sqlalchemy', fromlist=['']).inspect(test_engine) + cols = {c["name"] for c in inspector.get_columns("upstream_generated_keys")} + assert "managed_prefix" in cols, "managed_prefix column should exist after migration" + + with test_engine.connect() as conn: + row = conn.execute(_text("SELECT managed_prefix, key_value FROM upstream_generated_keys LIMIT 1")).fetchone() + assert row[0] == "SmartUp", f"expected SmartUp, got {row[0]}" + assert row[1] == "sk-val" + + # 验证唯一索引已创建 + indexes = inspector.get_indexes("upstream_generated_keys") + index_names = {ix["name"] for ix in indexes} + assert "uq_upstream_group_managed" in index_names, "partial unique index should exist" + + monkeypatch.undo() + + +def test_create_twice_only_one_record(db_session): + """同一上游同一分组连续调用两次 ensure,本地只保留一条记录。""" + from app.models.upstream_key import UpstreamGeneratedKey + + upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1", + auth_type="bearer", auth_config_json="{}", + groups_endpoint="/groups", rate_endpoint="/rates") + db_session.add(upstream) + db_session.commit() + db_session.refresh(upstream) + + # 模拟第一次创建 + db_session.add(UpstreamGeneratedKey( + upstream_id=upstream.id, group_id="vip", group_name="VIP", + key_name="SmartUp-Test-VIP", key_value="sk-first", + status="created", + )) + db_session.commit() + + # 模拟第二次调用 upsert(用同一个 key_name 且 status=exists) + existing = db_session.query(UpstreamGeneratedKey).filter( + UpstreamGeneratedKey.upstream_id == upstream.id, + UpstreamGeneratedKey.group_id == "vip", + UpstreamGeneratedKey.key_name == "SmartUp-Test-VIP", + ).first() + if existing: + existing.status = "exists" + existing.updated_at = datetime.now(timezone.utc) + else: + db_session.add(UpstreamGeneratedKey( + upstream_id=upstream.id, group_id="vip", group_name="VIP", + key_name="SmartUp-Test-VIP", key_value="sk-second", + status="exists", + )) + db_session.commit() + + rows = db_session.query(UpstreamGeneratedKey).filter( + UpstreamGeneratedKey.upstream_id == upstream.id, + UpstreamGeneratedKey.group_id == "vip", + ).all() + assert len(rows) == 1 + assert rows[0].status == "exists" + assert rows[0].key_value == "sk-first" # 更新的是原记录,不是新建 + + +def test_sync_removes_gone_group(db_session): + """分组不在最新快照中时,本地对应 Key 记录应被删除。""" + upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1", + auth_type="bearer", auth_config_json="{}", + groups_endpoint="/groups", rate_endpoint="/rates") + db_session.add(upstream) + db_session.commit() + db_session.refresh(upstream) + + db_session.add_all([ + UpstreamGeneratedKey( + upstream_id=upstream.id, group_id="vip", group_name="VIP", + key_name="SmartUp-Test-VIP", key_value="sk-vip", + ), + UpstreamGeneratedKey( + upstream_id=upstream.id, group_id="free", group_name="Free", + key_name="SmartUp-Test-Free", key_value="sk-free", + ), + ]) + db_session.commit() + + # 快照中只有 vip,没有 free + active_group_ids = {"vip"} + for row in db_session.query(UpstreamGeneratedKey).filter( + UpstreamGeneratedKey.upstream_id == upstream.id).all(): + if row.group_id not in active_group_ids: + db_session.delete(row) + db_session.commit() + + remaining = db_session.query(UpstreamGeneratedKey).all() + assert len(remaining) == 1 + assert remaining[0].group_id == "vip" + + +def test_sync_removes_deleted_remote_key(db_session): + """远端 Key 被删除后,本地对应记录应被删除。""" + from app.models.upstream_key import UpstreamGeneratedKey + + upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1", + auth_type="bearer", auth_config_json="{}", + groups_endpoint="/groups", rate_endpoint="/rates") + db_session.add(upstream) + db_session.commit() + db_session.refresh(upstream) + + db_session.add(UpstreamGeneratedKey( + upstream_id=upstream.id, group_id="vip", group_name="VIP", + key_name="SmartUp-Test-VIP", key_value="sk-vip", + key_id="remote-123", + )) + db_session.commit() + + # 模拟远端返回的活跃 key_ids 中没有 remote-123 + remote_key_ids = {"remote-456", "remote-789"} + for row in db_session.query(UpstreamGeneratedKey).filter( + UpstreamGeneratedKey.upstream_id == upstream.id).all(): + if row.key_id and row.key_id not in remote_key_ids: + db_session.delete(row) + db_session.commit() + + remaining = db_session.query(UpstreamGeneratedKey).all() + assert len(remaining) == 0 diff --git a/backend/test_website_client.py b/backend/test_website_client.py index 8fdf050..2b97c08 100644 --- a/backend/test_website_client.py +++ b/backend/test_website_client.py @@ -1,4 +1,15 @@ -from app.services.website_client import normalize_groups +import httpx +import pytest + +from app.services.website_client import ( + WebsiteError, + _friendly_connection_error, + _friendly_http_error, + normalize_groups, +) + + +# ——— normalize_groups ——— def test_normalize_groups_unwraps_sub2api_paginated_response(): @@ -68,3 +79,156 @@ def test_normalize_groups_keeps_plain_dict_mapping_compatibility(): assert [group["id"] for group in groups] == ["free", "paid"] assert groups[0]["rate_multiplier"] == "1" + + +# ——— _get_account_ids / account_exists ——— + + +def test_get_account_ids_flat_list(): + from app.services.website_client import Sub2ApiWebsiteClient + ids = Sub2ApiWebsiteClient._unwrap_list([ + {"id": 1, "name": "a"}, {"id": 2, "name": "b"}, + ]) + assert ids == [{"id": 1, "name": "a"}, {"id": 2, "name": "b"}] + + +def test_get_account_ids_top_level_items(): + from app.services.website_client import Sub2ApiWebsiteClient + ids = Sub2ApiWebsiteClient._unwrap_list({"items": [ + {"id": "k1"}, {"id": "k2"}, + ]}) + assert ids == [{"id": "k1"}, {"id": "k2"}] + + +def test_get_account_ids_nested_data_items(): + from app.services.website_client import Sub2ApiWebsiteClient + ids = Sub2ApiWebsiteClient._unwrap_list({"data": {"items": [ + {"id": "a1", "name": "Alpha"}, + {"id": "a2", "name": "Beta"}, + ]}}) + assert ids == [{"id": "a1", "name": "Alpha"}, {"id": "a2", "name": "Beta"}] + + +def test_get_account_ids_nested_data_empty(): + from app.services.website_client import Sub2ApiWebsiteClient + ids = Sub2ApiWebsiteClient._unwrap_list({"data": {"items": []}}) + assert ids == [] + + +def test_get_account_ids_unexpected_format(): + from app.services.website_client import Sub2ApiWebsiteClient + ids = Sub2ApiWebsiteClient._unwrap_list({"error": "not found"}) + assert ids is None + + +# ——— 友好错误提示 ——— + +def _make_response(status_code: int, path: str = "/groups") -> httpx.Response: + """创建模拟的 httpx.Response 用于错误测试。""" + req = httpx.Request("GET", f"http://target.local/api/v1{path}") + resp = httpx.Response(status_code, request=req) + return resp + + +def test_friendly_http_401(): + resp = _make_response(401) + exc = httpx.HTTPStatusError("401", request=resp.request, response=resp) + msg = _friendly_http_error(exc) + assert "认证失败" in msg + assert "API Key" in msg + assert "http://" not in msg + + +def test_friendly_http_403(): + resp = _make_response(403) + exc = httpx.HTTPStatusError("403", request=resp.request, response=resp) + msg = _friendly_http_error(exc) + assert "权限不足" in msg + + +def test_friendly_http_404(): + resp = _make_response(404, path="/wrong-path") + exc = httpx.HTTPStatusError("404", request=resp.request, response=resp) + msg = _friendly_http_error(exc) + assert "接口不存在" in msg + assert "/wrong-path" in msg + # 不包含完整 URL / MDN 链接 + assert "http://" not in msg + assert "MDN" not in msg + + +def test_friendly_http_500(): + resp = _make_response(502) + exc = httpx.HTTPStatusError("502", request=resp.request, response=resp) + msg = _friendly_http_error(exc) + assert "服务异常" in msg + + +def test_friendly_connect_error(): + exc = httpx.ConnectError("Connection refused") + msg = _friendly_connection_error(exc) + assert "无法连接" in msg + + +def test_friendly_timeout_error(): + exc = httpx.TimeoutException("Timed out") + msg = _friendly_connection_error(exc) + assert "请求超时" in msg + + +# ——— get_groups fallback(通过 mock httpx client 触发真实 _request 错误转换) ——— + +def _mock_httpx_request(status_code: int, path: str = "/groups"): + """返回一个 mock 的 httpx.Client.request,直接抛 HTTPStatusError。""" + def request(self, method, url, **kwargs): + resp = _make_response(status_code, path=path) + raise httpx.HTTPStatusError(f"{status_code} {path}", request=resp.request, response=resp) + return request + + +def test_get_groups_401_returns_friendly_auth_error(monkeypatch): + from app.services.website_client import Sub2ApiWebsiteClient + + monkeypatch.setattr(httpx.Client, "request", _mock_httpx_request(401)) + client = Sub2ApiWebsiteClient("http://target.local", "api/v1", "api_key", {"key": "bad"}) + with pytest.raises(WebsiteError) as excinfo: + client.get_groups("/groups") + msg = str(excinfo.value) + assert "认证失败" in msg + assert "http://" not in msg + assert "MDN" not in msg + + +def test_get_groups_404_fallback_succeeds(monkeypatch): + from app.services.website_client import Sub2ApiWebsiteClient + + call_count = 0 + + def request_fallback(self, method, url, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + resp = _make_response(404) + raise httpx.HTTPStatusError("404", request=resp.request, response=resp) + req = httpx.Request("GET", url) + return httpx.Response(200, json={"data": [{"id": "default", "name": "Default", "rate_multiplier": "1"}]}, request=req) + + monkeypatch.setattr(httpx.Client, "request", request_fallback) + client = Sub2ApiWebsiteClient("http://target.local", "api/v1", "api_key", {"key": "ok"}) + groups = client.get_groups("/groups") + assert len(groups) == 1 + assert groups[0]["id"] == "default" + assert call_count == 2 + + +def test_get_groups_all_404_no_raw_url(monkeypatch): + from app.services.website_client import Sub2ApiWebsiteClient + + monkeypatch.setattr(httpx.Client, "request", _mock_httpx_request(404)) + client = Sub2ApiWebsiteClient("http://target.local", "api/v1", "api_key", {"key": "ok"}) + with pytest.raises(WebsiteError) as excinfo: + client.get_groups("/groups") + msg = str(excinfo.value) + assert "接口不存在" in msg + assert "http://" not in msg + assert "MDN" not in msg diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 8ee4d10..62d289b 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -1,7 +1,28 @@ import axios from 'axios' import axiosRetry from 'axios-retry' import router from '@/router' -import { authStorageKeys } from '@/authStorage' + +/** 标记是否正在处理 401,防多个并发 */ +let isHandlingUnauthorized = false + +/** 统一的 401 处理:清登录态 + 提示 + 跳转 /login */ +async function handleUnauthorized() { + if (isHandlingUnauthorized) return + isHandlingUnauthorized = true + try { + const { useAuthStore } = await import('@/stores/auth') + useAuthStore().clear() + + const { ElMessage } = await import('element-plus') + ElMessage.warning('登录已过期,请重新登录') + + if (router.currentRoute.value.path !== '/login') { + await router.replace('/login') + } + } finally { + isHandlingUnauthorized = false + } +} export const api = axios.create({ baseURL: '/', @@ -27,10 +48,13 @@ axiosRetry(api, { api.interceptors.response.use( (r) => r, (err) => { + // 跳过登录接口的 401(密码错误等正常登录失败场景) + const requestPath = new URL(err.config?.url || '', window.location.origin).pathname + if (err.response?.status === 401 && requestPath === '/api/auth/login') { + return Promise.reject(err) + } if (err.response?.status === 401) { - localStorage.removeItem(authStorageKeys.token) - localStorage.removeItem(authStorageKeys.email) - router.push('/login') + void handleUnauthorized() } return Promise.reject(err) } @@ -84,6 +108,34 @@ export interface UpstreamForm { balance_divisor: number } +export interface GeneratedUpstreamKey { + id: number | null + upstream_id: number + group_id: string + group_name: string + key_id: string | null + key_name: string + key_value: string | null + masked_key: string + status: string + error: string | null + imported_website_id: number | null + imported_account_id: string | null + imported_at: string | null + created_at: string | null +} + +export interface GenerateKeysByGroupsForm { + group_ids: string[] + name_prefix: string + quota: number + expires_in_days?: number | null + rate_limit_5h: number + rate_limit_1d: number + rate_limit_7d: number + endpoint: string +} + export const upstreamsApi = { list: () => api.get('/api/upstreams'), create: (data: UpstreamForm) => api.post('/api/upstreams', data), @@ -91,6 +143,9 @@ export const upstreamsApi = { delete: (id: number) => api.delete(`/api/upstreams/${id}`), test: (id: number) => api.post<{ success: boolean; message: string; detail?: string }>(`/api/upstreams/${id}/test`), checkNow: (id: number) => api.post<{ success: boolean; message: string }>(`/api/upstreams/${id}/check-now`), + generatedKeys: (id: number) => api.get(`/api/upstreams/${id}/generated-keys`), + generateKeysByGroups: (id: number, data: GenerateKeysByGroupsForm) => + api.post<{ success: boolean; message: string; items: GeneratedUpstreamKey[] }>(`/api/upstreams/${id}/keys/generate-by-groups`, data), latestSnapshot: (id: number) => api.get(`/api/upstreams/${id}/snapshots/latest`), listSnapshots: (id: number, limit = 20, offset = 0) => api.get(`/api/upstreams/${id}/snapshots`, { params: { limit, offset } }), @@ -185,6 +240,30 @@ export interface WebsiteSyncLog { created_at: string } +export interface ImportGroupItem { + source_group_id: string + source_group_name: string + target_group_id: string | null + target_group_name: string + status: string + message: string + raw: Record +} + +export interface ImportAccountItem { + upstream_key_id: number + source_group_id: string + source_group_name: string + target_group_id: string | null + account_id: string | null + account_name: string + platform: string + upstream_base_url: string + status: string + message: string + raw: Record +} + export const websitesApi = { list: () => api.get('/api/websites'), create: (data: WebsiteForm) => api.post('/api/websites', data), @@ -192,6 +271,19 @@ export const websitesApi = { delete: (id: number) => api.delete(`/api/websites/${id}`), test: (id: number) => api.post<{ success: boolean; message: string; detail?: string }>(`/api/websites/${id}/test`), groups: (id: number) => api.get(`/api/websites/${id}/groups`), + importGroupsFromUpstream: (id: number, upstreamId: number, data: { group_ids: string[]; name_prefix: string }) => + api.post<{ success: boolean; message: string; items: ImportGroupItem[] }>(`/api/websites/${id}/groups/import-from-upstream/${upstreamId}`, data), + syncImportedUpstreamKeys: (id: number, data: { upstream_id: number }) => + api.post<{ success: boolean; message: string; items: ImportAccountItem[] }>(`/api/websites/${id}/accounts/sync-imported-upstream-keys`, data), + importAccountsFromUpstreamKeys: (id: number, data: { + upstream_key_ids: number[] + target_group_map: Record + account_name_prefix: string + default_platform: string + platform_mode?: string + concurrency?: number + priority?: number + }) => api.post<{ success: boolean; message: string; items: ImportAccountItem[] }>(`/api/websites/${id}/accounts/import-upstream-keys`, data), listBindings: () => api.get('/api/group-bindings'), createBinding: (data: GroupBindingForm) => api.post('/api/group-bindings', data), updateBinding: (id: number, data: Partial) => api.put(`/api/group-bindings/${id}`, data), diff --git a/frontend/src/views/Upstreams.vue b/frontend/src/views/Upstreams.vue index a668448..85c5cd3 100644 --- a/frontend/src/views/Upstreams.vue +++ b/frontend/src/views/Upstreams.vue @@ -127,6 +127,9 @@ 测试 检测 + + + 详情 @@ -374,6 +377,26 @@ {{ detailUpstream.last_error }} +
+ + 已创建 Key + 最近 {{ generatedKeys.length }} 条 +
+ + + + + + + + + + +
检测历史 @@ -440,6 +463,57 @@
+ + + + + + + + + + + + + + + + + + + + + + + + + + + 设置过期时间 + +
+
操作结果
+ + + + + + + + + + +
+ +
+ ([]) @@ -580,11 +654,30 @@ function handlePlatformChange(val: string) { const detailVisible = ref(false) const detailUpstream = ref(null) const snapshots = ref([]) +const generatedKeys = ref([]) const snapshotLoading = ref(false) +const keysLoading = ref(false) const expandedId = ref(null) const snapshotOffset = ref(0) const snapshotLimit = 20 +const keyDialogVisible = ref(false) +const keyTarget = ref(null) +const keyGroupOptions = ref([]) +const generatingKeys = ref(false) +const keyResults = ref([]) +const useKeyExpiry = ref(false) +const keyExpiresDays = ref(30) +const keyForm = ref({ + group_ids: [] as string[], + name_prefix: 'SmartUp', + quota: 0, + rate_limit_5h: 0, + rate_limit_1d: 0, + rate_limit_7d: 0, + endpoint: '/keys', +}) + const metrics = computed(() => ({ total: list.value.length, healthy: list.value.filter((item) => item.last_status === 'healthy').length, @@ -641,6 +734,8 @@ function shrinkError(value: string) { return value.length > 40 ? `${value.slice(0, 40)}…` : value } +const keyStatusLabel = (s: string) => ({ created: '已创建', imported: '已导入', import_failed: '导入失败', failed: '失败' }[s] || s) + async function loadList() { tableLoading.value = true try { @@ -735,13 +830,26 @@ async function checkNow(row: any) { function openDetail(row: UpstreamData) { detailUpstream.value = row snapshots.value = [] + generatedKeys.value = [] snapshotOffset.value = 0 expandedId.value = null detailVisible.value = true } +async function loadGeneratedKeys() { + if (!detailUpstream.value) return + keysLoading.value = true + try { + const res = await upstreamsApi.generatedKeys(detailUpstream.value.id) + generatedKeys.value = res.data + } finally { + keysLoading.value = false + } +} + async function loadSnapshots() { if (!detailUpstream.value) return + loadGeneratedKeys() snapshotLoading.value = true try { const res = await upstreamsApi.listSnapshots(detailUpstream.value.id, snapshotLimit, snapshotOffset.value) @@ -756,6 +864,48 @@ async function loadSnapshots() { } } +async function openKeyGenerate(row: UpstreamData) { + keyTarget.value = row + keyResults.value = [] + keyForm.value = { + group_ids: [], + name_prefix: 'SmartUp', + quota: 0, + rate_limit_5h: 0, + rate_limit_1d: 0, + rate_limit_7d: 0, + endpoint: '/keys', + } + useKeyExpiry.value = false + keyExpiresDays.value = 30 + try { + const res = await upstreamsApi.latestSnapshot(row.id) + keyGroupOptions.value = Object.values(res.data.snapshot?.groups || {}) + } catch { + keyGroupOptions.value = [] + ElMessage.warning('未找到快照,将由后端实时拉取分组') + } + keyDialogVisible.value = true +} + +async function generateKeys() { + if (!keyTarget.value) return + generatingKeys.value = true + try { + const res = await upstreamsApi.generateKeysByGroups(keyTarget.value.id, { + ...keyForm.value, + expires_in_days: useKeyExpiry.value ? keyExpiresDays.value : null, + }) + keyResults.value = res.data.items + ElMessage[res.data.success ? 'success' : 'warning'](res.data.message) + if (detailUpstream.value?.id === keyTarget.value.id) await loadGeneratedKeys() + } catch (e: any) { + ElMessage.error(e.response?.data?.detail || '创建 Key 失败') + } finally { + generatingKeys.value = false + } +} + function toggleExpand(snap: any) { expandedId.value = expandedId.value === snap.id ? null : snap.id } diff --git a/frontend/src/views/Websites.vue b/frontend/src/views/Websites.vue index a5efdb5..817d509 100644 --- a/frontend/src/views/Websites.vue +++ b/frontend/src/views/Websites.vue @@ -14,7 +14,11 @@
网站
- 刷新 +
+ 导入上游分组 + 导入为账号管理账号 + 刷新 +
@@ -44,34 +48,43 @@ - + @@ -229,6 +242,188 @@ 保存 + + + + + + + + + + + + + + + + + + + + + + +
+
导入结果
+ + + + + + + + + +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 刷新导入状态 + + + 已校验 {{ importSyncStatus.total }} 个,清除 {{ importSyncStatus.cleared }} 个失效标记 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ 全选可导入 Key + 清空 +
+ + + +
+
+
目标分组映射
+
+ {{ group.group_name || group.group_id }} + + + +
+
+
+
+
创建结果
+ + + + + + + + + + + + + + + +
+ +
@@ -237,13 +432,16 @@ import { computed, onMounted, ref } from 'vue' import { ElMessage, ElMessageBox } from 'element-plus' import type { FormInstance } from 'element-plus' import dayjs from 'dayjs' -import { Delete, Edit, Plus, Grid, Connection, Link } from '@element-plus/icons-vue' +import { ArrowDown, Delete, Edit, Plus, Grid, Connection, Link, Upload, Key, Refresh } from '@element-plus/icons-vue' import { upstreamsApi, websitesApi, type BindingSourceGroup, + type GeneratedUpstreamKey, type GroupBindingData, type GroupBindingForm, + type ImportAccountItem, + type ImportGroupItem, type UpstreamData, type WebsiteData, type WebsiteForm, @@ -259,16 +457,25 @@ const bindingWebsiteGroups = ref([]) const bindings = ref<(GroupBindingData & { _syncing?: boolean })[]>([]) const logs = ref([]) const snapshotsByUpstream = ref>({}) +const importTargetGroups = ref([]) +const importGeneratedKeys = ref([]) const websiteLoading = ref(false) const groupsLoading = ref(false) const bindingLoading = ref(false) const logLoading = ref(false) +const importingGroups = ref(false) +const importingAccounts = ref(false) +const generatedKeyLoading = ref(false) +const importSyncStatus = ref<{ total: number; cleared: number; failed: number } | null>(null) +const syncingImportStatus = ref(false) const statusLabel = (s: string) => ({ healthy: '健康', unhealthy: '异常', unknown: '未知' }[s] || s) const algorithmLabel = (s: string) => ({ max_plus_percent: '最高倍率', average_plus_percent: '平均倍率', min_plus_percent: '最低倍率' }[s] || s) const toUTC = (t: string) => /[Z+\-]\d*$/.test(t.trim()) ? t : t + 'Z' const fmtTime = (t: string) => dayjs(toUTC(t)).format('MM-DD HH:mm:ss') +const sourceGroupId = (group: any) => String(group?.group_id || group?.id || group?.name || '') +const sourceGroupName = (group: any) => String(group?.group_name || group?.name || sourceGroupId(group)) function defaultWebsiteForm(): WebsiteForm { return { @@ -318,6 +525,29 @@ const bindingRules = { target_group_id: [{ required: true, message: '请选择目标分组', trigger: 'change' }], } +const importGroupsDialog = ref(false) +const importGroupsForm = ref({ + website_id: 0, + upstream_id: 0, + group_ids: [] as string[], + name_prefix: '', +}) +const importGroupResults = ref([]) + +const importAccountsDialog = ref(false) +const importAccountsForm = ref({ + website_id: 0, + upstream_id: 0, + upstream_key_ids: [] as number[], + target_group_map: {} as Record, + account_name_prefix: 'SmartUp', + default_platform: 'openai', + platform_mode: 'auto', + concurrency: 10, + priority: 1, +}) +const importAccountResults = ref([]) + const upstreamGroupOptions = computed(() => { const rows: Array<{ key: string; label: string; source: BindingSourceGroup }> = [] for (const upstream of upstreams.value) { @@ -339,6 +569,31 @@ const upstreamGroupOptions = computed(() => { return rows }) +const importSourceGroups = computed(() => snapshotsByUpstream.value[importGroupsForm.value.upstream_id] || []) + +function isImportableGeneratedKey(item: GeneratedUpstreamKey) { + return item.id !== null + && item.status !== 'failed' + && !(item.imported_website_id === importAccountsForm.value.website_id && item.imported_account_id) +} + +const importableGeneratedKeys = computed(() => + importGeneratedKeys.value.filter(isImportableGeneratedKey), +) + +const selectedAccountGroups = computed(() => { + const selected = new Set(importAccountsForm.value.upstream_key_ids) + const rows = importableGeneratedKeys.value.filter((item) => item.id !== null && selected.has(item.id)) + const seen = new Set() + const groups: Array<{ group_id: string; group_name: string }> = [] + for (const row of rows) { + if (seen.has(row.group_id)) continue + seen.add(row.group_id) + groups.push({ group_id: row.group_id, group_name: row.group_name }) + } + return groups +}) + async function loadWebsites() { websiteLoading.value = true try { @@ -391,6 +646,19 @@ async function loadBindingWebsiteGroups(websiteId: number) { } } +async function loadImportTargetGroups(websiteId: number) { + if (!websiteId) { + importTargetGroups.value = [] + return + } + try { + const res = await websitesApi.groups(websiteId) + importTargetGroups.value = res.data + } catch { + importTargetGroups.value = [] + } +} + async function loadBindings() { bindingLoading.value = true try { @@ -411,6 +679,50 @@ async function loadLogs() { } } +async function syncImportStatus() { + const websiteId = importAccountsForm.value.website_id + const upstreamId = importAccountsForm.value.upstream_id + if (!websiteId || !upstreamId) return + syncingImportStatus.value = true + try { + const res = await websitesApi.syncImportedUpstreamKeys(websiteId, { upstream_id: upstreamId }) + // 校验请求完成时表单未切换 + if (importAccountsForm.value.website_id !== websiteId || importAccountsForm.value.upstream_id !== upstreamId) return + const items = res.data.items + importSyncStatus.value = { + total: items.length, + cleared: items.filter(i => i.status === 'stale_cleared').length, + failed: items.filter(i => i.status === 'check_failed').length, + } + if (importSyncStatus.value.cleared > 0) { + ElMessage.success(`已清除 ${importSyncStatus.value.cleared} 个失效导入标记`) + } + if (importAccountsForm.value.upstream_id === upstreamId) { + await loadImportGeneratedKeys(upstreamId) + } + } catch (e: any) { + ElMessage.error(e.response?.data?.detail || '同步导入状态失败') + } finally { + syncingImportStatus.value = false + } +} + +async function loadImportGeneratedKeys(upstreamId: number) { + importGeneratedKeys.value = [] + if (!upstreamId) return + generatedKeyLoading.value = true + const frozenId = upstreamId + try { + const res = await upstreamsApi.generatedKeys(frozenId) + if (importAccountsForm.value.upstream_id !== frozenId) return // 已切换到其他上游 + importGeneratedKeys.value = res.data + } catch (e: any) { + ElMessage.error(e.response?.data?.detail || '加载上游 Key 失败') + } finally { + generatedKeyLoading.value = false + } +} + async function loadAll() { await Promise.all([loadWebsites(), loadUpstreamGroups(), loadBindings(), loadLogs()]) if (selectedWebsite.value) await loadWebsiteGroups() @@ -487,6 +799,17 @@ async function testWebsite(row: WebsiteData & { _testing?: boolean }) { } } +/** 处理「更多」下拉菜单中的操作 */ +function handleMoreAction(cmd: string, row: WebsiteData & { _testing?: boolean }) { + switch (cmd) { + case 'test': testWebsite(row); break + case 'binding': openBindingCreate(row); break + case 'importGroups': openImportGroups(row); break + case 'importAccounts': openImportAccounts(row); break + case 'delete': deleteWebsite(row); break + } +} + async function deleteWebsite(row: WebsiteData) { try { await ElMessageBox.confirm(`确认删除网站 "${row.name}"?`, '删除确认', { type: 'warning' }) @@ -599,6 +922,197 @@ async function deleteBinding(row: GroupBindingData) { } catch {} } +async function openImportGroups(site?: WebsiteData | null) { + if (upstreams.value.length === 0) await loadUpstreamGroups() + const target = site || selectedWebsite.value || websites.value[0] + importGroupsForm.value = { + website_id: target?.id || 0, + upstream_id: upstreams.value[0]?.id || 0, + group_ids: [], + name_prefix: '', + } + importGroupResults.value = [] + importGroupsDialog.value = true +} + +async function submitImportGroups() { + if (!importGroupsForm.value.website_id || !importGroupsForm.value.upstream_id) { + ElMessage.error('请选择目标网站和来源上游') + return + } + importingGroups.value = true + try { + const res = await websitesApi.importGroupsFromUpstream( + importGroupsForm.value.website_id, + importGroupsForm.value.upstream_id, + { + group_ids: importGroupsForm.value.group_ids, + name_prefix: importGroupsForm.value.name_prefix, + }, + ) + importGroupResults.value = res.data.items + ElMessage[res.data.success ? 'success' : 'warning'](res.data.message) + if (selectedWebsite.value?.id === importGroupsForm.value.website_id) await loadWebsiteGroups() + await loadImportTargetGroups(importGroupsForm.value.website_id) + } catch (e: any) { + ElMessage.error(e.response?.data?.detail || '导入上游分组失败') + } finally { + importingGroups.value = false + } +} + +async function openImportAccounts(site?: WebsiteData | null) { + if (upstreams.value.length === 0) await loadUpstreamGroups() + const target = site || selectedWebsite.value || websites.value[0] + importAccountsForm.value = { + website_id: target?.id || 0, + upstream_id: upstreams.value[0]?.id || 0, + upstream_key_ids: [], + target_group_map: {}, + account_name_prefix: 'SmartUp', + default_platform: 'openai', + platform_mode: 'auto', + concurrency: 10, + priority: 1, + } + importAccountResults.value = [] + importSyncStatus.value = null + await Promise.all([ + loadImportTargetGroups(importAccountsForm.value.website_id), + loadImportGeneratedKeys(importAccountsForm.value.upstream_id), + ]) + // 打开弹窗后自动同步导入状态(校验远端账号是否仍存在) + await syncImportStatus() + await loadImportGeneratedKeys(importAccountsForm.value.upstream_id) + importAccountsDialog.value = true +} + +async function onImportAccountWebsiteChange(value: number) { + importAccountsForm.value.target_group_map = {} + await loadImportTargetGroups(value) + await syncImportStatus() + await loadImportGeneratedKeys(importAccountsForm.value.upstream_id) +} + +async function onImportAccountUpstreamChange(value: number) { + importAccountsForm.value.upstream_key_ids = [] + importAccountsForm.value.target_group_map = {} + importAccountResults.value = [] + await loadImportGeneratedKeys(value) + await syncImportStatus() + await loadImportGeneratedKeys(value) +} + +function onPlatformModeChange(value: string) { + if (value === 'auto') { + importAccountsForm.value.default_platform = 'openai' + } +} + +function detectPlatform(item: { group_name?: string; group_id?: string; key_name?: string }) { + const text = `${item.group_name || ''} ${item.group_id || ''} ${item.key_name || ''}`.toLowerCase() + if (text.includes('claude') || text.includes('anthropic')) return 'Anthropic' + if (text.includes('gemini')) return 'Gemini' + if (text.includes('antigravity')) return 'Antigravity' + return 'OpenAI 兼容' +} + +function platformLabel(platform: string) { + const map: Record = { + openai: 'OpenAI 兼容', + anthropic: 'Anthropic', + gemini: 'Gemini', + antigravity: 'Antigravity', + } + return map[platform] || platform || '—' +} + +function normalizeGroupName(name: string) { + return String(name || '') + .toLowerCase() + .replace(/^smartup[-_\s]*/i, '') + .replace(/^ai\d+pro/i, '') + .replace(/[||]/g, ' ') + .replace(/\s+/g, '') + .trim() +} + +function findTargetGroupForSource(sourceName: string, sourceId: string) { + const sourceNorm = normalizeGroupName(sourceName || sourceId) + if (!sourceNorm) return '' + + const exact = importTargetGroups.value.find(g => + normalizeGroupName(g.name) === sourceNorm + ) + if (exact) return exact.id + + const fuzzy = importTargetGroups.value.find(g => { + const targetNorm = normalizeGroupName(g.name) + return targetNorm.includes(sourceNorm) || sourceNorm.includes(targetNorm) + }) + return fuzzy?.id || '' +} + +function autoFillAccountTargetGroups() { + const selected = new Set(importAccountsForm.value.upstream_key_ids) + const keys = importableGeneratedKeys.value.filter(item => item.id !== null && selected.has(item.id)) + const nextMap = { ...importAccountsForm.value.target_group_map } + + for (const item of keys) { + if (!item.id || !selected.has(item.id)) continue + if (nextMap[item.group_id]) continue + + const targetId = findTargetGroupForSource(item.group_name || item.group_id, item.group_id) + if (targetId) nextMap[item.group_id] = targetId + } + + importAccountsForm.value.target_group_map = nextMap +} + +function selectAllImportableKeys() { + const keys = importableGeneratedKeys.value + importAccountsForm.value.upstream_key_ids = keys.map(item => item.id!) + autoFillAccountTargetGroups() + + const matched = Object.keys(importAccountsForm.value.target_group_map).length + ElMessage.success(`已选择 ${keys.length} 个 Key,自动匹配 ${matched} 个目标分组`) +} + +function clearImportAccountSelection() { + importAccountsForm.value.upstream_key_ids = [] + importAccountsForm.value.target_group_map = {} +} + +async function submitImportAccounts() { + if (!importAccountsForm.value.website_id || !importAccountsForm.value.upstream_id) { + ElMessage.error('请选择目标网站和来源上游') + return + } + if (importAccountsForm.value.upstream_key_ids.length === 0) { + ElMessage.error('请选择要导入的上游 Key') + return + } + importingAccounts.value = true + try { + const res = await websitesApi.importAccountsFromUpstreamKeys(importAccountsForm.value.website_id, { + upstream_key_ids: importAccountsForm.value.upstream_key_ids, + target_group_map: importAccountsForm.value.target_group_map, + account_name_prefix: importAccountsForm.value.account_name_prefix, + default_platform: importAccountsForm.value.default_platform, + platform_mode: importAccountsForm.value.platform_mode, + concurrency: importAccountsForm.value.concurrency, + priority: importAccountsForm.value.priority, + }) + importAccountResults.value = res.data.items + ElMessage[res.data.success ? 'success' : 'warning'](res.data.message) + await loadImportGeneratedKeys(importAccountsForm.value.upstream_id) + } catch (e: any) { + ElMessage.error(e.response?.data?.detail || '创建账号管理账号失败') + } finally { + importingAccounts.value = false + } +} + onMounted(loadAll) @@ -624,6 +1138,13 @@ onMounted(loadAll) } .panel-title { font-size: 14px; font-weight: 600; color: var(--text-primary); } .panel-sub { font-size: 12px; color: var(--text-muted); margin-top: 2px; } +.panel-actions { + display: flex; + align-items: center; + flex-wrap: wrap; + justify-content: flex-end; + gap: 4px; +} .content-grid { display: grid; grid-template-columns: 1fr; @@ -639,7 +1160,7 @@ onMounted(loadAll) align-items: center; justify-content: flex-end; flex-wrap: nowrap; - gap: 2px; + gap: 6px; min-width: 0; } .action-row .el-button.is-circle { @@ -647,6 +1168,36 @@ onMounted(loadAll) height: 26px; margin-left: 0; } +.action-row .btn-edit { + color: var(--text-primary); + font-weight: 500; + gap: 3px; + padding: 0 6px; + white-space: nowrap; + flex-shrink: 0; +} +.action-row .btn-edit-icon { + font-size: 13px; +} +.action-row .btn-edit:hover { + color: var(--el-color-primary); + background: var(--el-color-primary-light-9); + border-radius: 4px; +} +.action-row .btn-more { + color: var(--text-muted); + font-size: 12px; + padding: 0 4px; +} +.action-row .btn-more:hover { + color: var(--el-color-primary); +} +.action-row .btn-more .el-icon--right { + margin-left: 1px; +} +.btn-more-delete { + color: var(--el-color-danger); +} .binding-actions { display: flex; @@ -698,6 +1249,41 @@ onMounted(loadAll) color: var(--text-secondary); font-size: 13px; } +.dialog-note { margin-bottom: 12px; } +.result-panel { + margin-top: 14px; + border-top: 1px solid var(--border-color); + padding-top: 12px; +} +.result-title { + margin-bottom: 8px; + font-size: 13px; + font-weight: 600; + color: var(--text-primary); +} +.mapping-panel { + border: 1px solid var(--border-color); + border-radius: 8px; + padding: 12px; +} +.mapping-row { + min-height: 38px; + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; +} +.mapping-row + .mapping-row { + border-top: 1px solid var(--border-color); +} +.mapping-label { + min-width: 0; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + color: var(--text-secondary); + font-size: 13px; +} @media (min-width: 1024px) { .content-grid { grid-template-columns: minmax(0, 1fr) minmax(0, 1fr); } @@ -714,5 +1300,11 @@ onMounted(loadAll) flex-direction: column; } .binding-actions { width: 100%; justify-content: flex-end; } + .mapping-row { + align-items: stretch; + flex-direction: column; + padding: 8px 0; + } + .mapping-row .el-select { width: 100% !important; } }