feat: 上游 Key 唯一化、分组导入跳过、账号导入平台识别&远端校验&base_url 注入
- 上游 Key 命名改为 {prefix}-{upstream.id}-{safe_group_name}-{group_id}
- 唯一约束 (upstream_id, group_id, managed_prefix) 加 managed_prefix 列
- 上游检测成功时同步 Key 状态,远端已删/分组已删自动清理
- 重复分组导入跳过,目标网站已存在同名分组返回 exists
- 账号导入平台自动识别(auto/manual 模式)
- 全选可导入 Key 按钮 + 目标分组自动匹配
- 导入幂等:已导入过的 Key 校验远端账号,不存在则重建
- 新增同步接口 POST /sync-imported-upstream-keys
- account_exists() 通过拉取账号列表判断,避免 404 误判
- credentials.base_url 注入来源上游地址,避免 401
- 前端导入弹窗自动同步+刷新按钮+并发/优先级设置
- 新增 12 个测试覆盖同步、幂等、远端删除、校验失败路径
This commit is contained in:
+58
-1
@@ -26,10 +26,11 @@ def get_db():
|
||||
def init_db():
|
||||
"""Create all tables."""
|
||||
# import models so SQLAlchemy registers them
|
||||
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token # noqa: F401
|
||||
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key # noqa: F401
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_migrate_custom_pages()
|
||||
_migrate_upstreams()
|
||||
_migrate_upstream_generated_keys()
|
||||
|
||||
|
||||
def _migrate_custom_pages():
|
||||
@@ -87,3 +88,59 @@ def _migrate_upstreams():
|
||||
if "balance_divisor" not in columns:
|
||||
conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_divisor FLOAT NOT NULL DEFAULT 1.0"))
|
||||
|
||||
|
||||
def _migrate_upstream_generated_keys():
|
||||
"""Apply SQLite-safe migrations to the generated upstream keys table."""
|
||||
inspector = inspect(engine)
|
||||
if "upstream_generated_keys" not in inspector.get_table_names():
|
||||
return
|
||||
columns = {col["name"] for col in inspector.get_columns("upstream_generated_keys")}
|
||||
with engine.begin() as conn:
|
||||
if "imported_website_id" not in columns:
|
||||
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_website_id INTEGER"))
|
||||
if "imported_account_id" not in columns:
|
||||
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_account_id VARCHAR(255)"))
|
||||
if "imported_at" not in columns:
|
||||
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_at DATETIME"))
|
||||
if "updated_at" not in columns:
|
||||
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN updated_at DATETIME"))
|
||||
conn.execute(text("UPDATE upstream_generated_keys SET updated_at = created_at WHERE updated_at IS NULL"))
|
||||
if "managed_prefix" not in columns:
|
||||
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN managed_prefix VARCHAR(64)"))
|
||||
|
||||
# ——— 历史数据迁移:回填 managed_prefix + 清理重复 ———
|
||||
with engine.begin() as conn:
|
||||
# 1. 回填:key_name 以 SmartUp- 开头的旧记录设置 managed_prefix = 'SmartUp'
|
||||
conn.execute(text(
|
||||
"UPDATE upstream_generated_keys SET managed_prefix = 'SmartUp' "
|
||||
"WHERE managed_prefix IS NULL AND key_name LIKE 'SmartUp-%'"
|
||||
))
|
||||
# 2. 清理:同一 (upstream_id, group_id, managed_prefix) 只保留最新一条
|
||||
# SQLite 不支持子查询直接 DELETE,用两步
|
||||
to_delete = conn.execute(text("""
|
||||
SELECT id FROM upstream_generated_keys
|
||||
WHERE managed_prefix IS NOT NULL
|
||||
AND id NOT IN (
|
||||
SELECT MAX(id) FROM upstream_generated_keys
|
||||
WHERE managed_prefix IS NOT NULL
|
||||
GROUP BY upstream_id, group_id, managed_prefix
|
||||
)
|
||||
""")).fetchall()
|
||||
for (row_id,) in to_delete:
|
||||
conn.execute(text("DELETE FROM upstream_generated_keys WHERE id = :id"), {"id": row_id})
|
||||
|
||||
# ——— 创建唯一索引 ———
|
||||
try:
|
||||
with engine.begin() as conn:
|
||||
conn.execute(
|
||||
text("CREATE UNIQUE INDEX IF NOT EXISTS uq_upstream_group_key "
|
||||
"ON upstream_generated_keys(upstream_id, group_id, key_name)")
|
||||
)
|
||||
conn.execute(
|
||||
text("CREATE UNIQUE INDEX IF NOT EXISTS uq_upstream_group_managed "
|
||||
"ON upstream_generated_keys(upstream_id, group_id, managed_prefix) "
|
||||
"WHERE managed_prefix IS NOT NULL")
|
||||
)
|
||||
except Exception:
|
||||
logger = __import__("logging").getLogger(__name__)
|
||||
logger.warning("could not create unique indexes on upstream_generated_keys (non-fatal)")
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class UpstreamGeneratedKey(Base):
|
||||
__tablename__ = "upstream_generated_keys"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
upstream_id: Mapped[int] = mapped_column(Integer, ForeignKey("upstreams.id", ondelete="CASCADE"), index=True)
|
||||
group_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
group_name: Mapped[str] = mapped_column(String(255), default="")
|
||||
key_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
key_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
key_value: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
masked_key: Mapped[str] = mapped_column(String(255), default="")
|
||||
raw_json: Mapped[str] = mapped_column(Text, default="{}")
|
||||
managed_prefix: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
status: Mapped[str] = mapped_column(String(32), default="created")
|
||||
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
imported_website_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("websites.id", ondelete="SET NULL"), nullable=True, index=True)
|
||||
imported_account_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
imported_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("upstream_id", "group_id", "key_name", name="uq_upstream_group_key"),
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
@@ -14,11 +15,15 @@ from sqlalchemy.orm import Session
|
||||
from app.database import get_db
|
||||
from app.models.admin_user import AdminUser
|
||||
from app.models.upstream import Upstream
|
||||
from app.models.upstream_key import UpstreamGeneratedKey
|
||||
from app.models.snapshot import UpstreamRateSnapshot
|
||||
from app.schemas.upstream import (
|
||||
GenerateKeysByGroupsRequest,
|
||||
GenerateKeysByGroupsResponse,
|
||||
GeneratedUpstreamKeyResponse,
|
||||
UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult
|
||||
)
|
||||
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot
|
||||
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret
|
||||
from app.services.snapshot_service import diff_snapshots
|
||||
from app.services import scheduler as sched_svc
|
||||
from app.services import webhook_service
|
||||
@@ -31,6 +36,38 @@ MASK = "***"
|
||||
SECRET_KEYS = {"password", "token", "key", "secret"}
|
||||
|
||||
|
||||
def _group_id(group: dict) -> str:
|
||||
for key in ("id", "group_id", "groupId"):
|
||||
value = group.get(key)
|
||||
if value is not None:
|
||||
return str(value)
|
||||
return str(group.get("name") or group.get("group_name") or "")
|
||||
|
||||
|
||||
def _group_name(group: dict, gid: str) -> str:
|
||||
return str(group.get("name") or group.get("group_name") or gid)
|
||||
|
||||
|
||||
def _key_response(row: UpstreamGeneratedKey, include_value: bool = False) -> GeneratedUpstreamKeyResponse:
|
||||
return GeneratedUpstreamKeyResponse(
|
||||
id=row.id,
|
||||
upstream_id=row.upstream_id,
|
||||
group_id=row.group_id,
|
||||
group_name=row.group_name,
|
||||
key_id=row.key_id,
|
||||
key_name=row.key_name,
|
||||
key_value=row.key_value if include_value else None,
|
||||
masked_key=row.masked_key,
|
||||
status=row.status,
|
||||
error=row.error,
|
||||
imported_website_id=row.imported_website_id,
|
||||
imported_account_id=row.imported_account_id,
|
||||
imported_at=row.imported_at,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _mask_auth_config(auth_type: str, cfg: dict) -> dict:
|
||||
masked = {}
|
||||
for k, v in cfg.items():
|
||||
@@ -73,6 +110,198 @@ def list_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()]
|
||||
|
||||
|
||||
@router.get("/{uid}/generated-keys", response_model=List[GeneratedUpstreamKeyResponse])
|
||||
def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
if not db.query(Upstream.id).filter(Upstream.id == uid).first():
|
||||
raise HTTPException(404, "upstream not found")
|
||||
rows = (
|
||||
db.query(UpstreamGeneratedKey)
|
||||
.filter(UpstreamGeneratedKey.upstream_id == uid)
|
||||
.order_by(UpstreamGeneratedKey.id.desc())
|
||||
.limit(200)
|
||||
.all()
|
||||
)
|
||||
return [_key_response(row) for row in rows]
|
||||
|
||||
|
||||
_generate_key_lock = __import__("threading").Lock()
|
||||
|
||||
|
||||
def _ensure_group_key(
|
||||
db: Session,
|
||||
client: UpstreamClient,
|
||||
upstream: Upstream,
|
||||
group: dict[str, Any],
|
||||
prefix: str,
|
||||
body: GenerateKeysByGroupsRequest,
|
||||
) -> GeneratedUpstreamKeyResponse:
|
||||
"""确保一个上游分组有一个 SmartUp 前缀 Key:存在则 upsert,不存在则创建。"""
|
||||
gid = _group_id(group)
|
||||
gname = _group_name(group, gid)
|
||||
# 使用稳定的 upstream_id + group_id 而非可变名称,避免因改名产生重复
|
||||
# 可读 Key 名:{prefix}-{upstream.id}-{安全的分组名}-{group_id}
|
||||
safe_group_name = re.sub(r"[^a-zA-Z0-9\u4e00-\u9fff_-]", "", gname)[:30] if gname else gid
|
||||
stable_name = f"{prefix}-{upstream.id}-{safe_group_name}-{gid}"
|
||||
|
||||
with _generate_key_lock:
|
||||
try:
|
||||
# 1. 先查本地是否已有该分组的托管 Key(兼容迁移前无 managed_prefix 的记录)
|
||||
row = (
|
||||
db.query(UpstreamGeneratedKey)
|
||||
.filter(
|
||||
UpstreamGeneratedKey.upstream_id == upstream.id,
|
||||
UpstreamGeneratedKey.group_id == gid,
|
||||
(UpstreamGeneratedKey.managed_prefix == prefix)
|
||||
| ((UpstreamGeneratedKey.managed_prefix.is_(None))
|
||||
& UpstreamGeneratedKey.key_name.like(f"{prefix}-%")),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if row and row.key_id:
|
||||
# 本地已有记录,检查远端是否仍存在
|
||||
try:
|
||||
existing = client.find_smartup_group_key(gid, stable_name, prefix)
|
||||
except Exception:
|
||||
existing = None
|
||||
if existing:
|
||||
row.status = "exists"
|
||||
row.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
return _key_response(row, include_value=False)
|
||||
# 远端不存在,需要重新创建
|
||||
row.status = "replaced"
|
||||
|
||||
# 2. 查远端是否有同名 Key(防止并发时另一个请求已创建)
|
||||
existing = client.find_smartup_group_key(gid, stable_name, prefix)
|
||||
if existing:
|
||||
key_id = str(existing.get("id") or "")
|
||||
masked = existing.get("masked_key") or existing.get("key") or ""
|
||||
if row:
|
||||
row.key_id = key_id or row.key_id
|
||||
row.masked_key = masked or row.masked_key
|
||||
row.status = "exists"
|
||||
row.updated_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
row = UpstreamGeneratedKey(
|
||||
upstream_id=upstream.id,
|
||||
group_id=gid,
|
||||
group_name=gname,
|
||||
key_id=key_id or None,
|
||||
key_name=stable_name,
|
||||
key_value="",
|
||||
masked_key=masked,
|
||||
raw_json=json.dumps(existing, ensure_ascii=False),
|
||||
managed_prefix=prefix,
|
||||
status="exists",
|
||||
)
|
||||
db.add(row)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
return _key_response(row, include_value=False)
|
||||
|
||||
# 3. 远端不存在,创建新 Key
|
||||
created = client.create_api_key(
|
||||
stable_name,
|
||||
gid,
|
||||
quota=body.quota,
|
||||
expires_in_days=body.expires_in_days,
|
||||
rate_limit_5h=body.rate_limit_5h,
|
||||
rate_limit_1d=body.rate_limit_1d,
|
||||
rate_limit_7d=body.rate_limit_7d,
|
||||
endpoint=body.endpoint,
|
||||
)
|
||||
if row:
|
||||
# 复用旧行
|
||||
row.key_id = created.get("id") or None
|
||||
row.key_name = stable_name
|
||||
row.key_value = created["key"]
|
||||
row.masked_key = created.get("masked_key") or mask_secret(created["key"])
|
||||
row.raw_json = json.dumps(created.get("raw") or {}, ensure_ascii=False)
|
||||
row.managed_prefix = prefix
|
||||
row.status = "created"
|
||||
row.error = None
|
||||
else:
|
||||
row = UpstreamGeneratedKey(
|
||||
upstream_id=upstream.id,
|
||||
group_id=gid,
|
||||
group_name=gname,
|
||||
key_id=created.get("id") or None,
|
||||
key_name=stable_name,
|
||||
key_value=created["key"],
|
||||
masked_key=created.get("masked_key") or mask_secret(created["key"]),
|
||||
raw_json=json.dumps(created.get("raw") or {}, ensure_ascii=False),
|
||||
managed_prefix=prefix,
|
||||
status="created",
|
||||
)
|
||||
db.add(row)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
return _key_response(row, include_value=True)
|
||||
except Exception as exc:
|
||||
logger.exception("ensure group key failed for upstream=%s group=%s", upstream.id, gid)
|
||||
return GeneratedUpstreamKeyResponse(
|
||||
upstream_id=upstream.id,
|
||||
group_id=gid,
|
||||
group_name=gname,
|
||||
key_name=stable_name,
|
||||
status="failed",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{uid}/keys/generate-by-groups", response_model=GenerateKeysByGroupsResponse)
|
||||
def generate_keys_by_groups(
|
||||
uid: int,
|
||||
body: GenerateKeysByGroupsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
_=Depends(get_current_user),
|
||||
):
|
||||
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
||||
if not u:
|
||||
raise HTTPException(404, "upstream not found")
|
||||
if u.api_prefix.strip("/") != "api/v1":
|
||||
raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1)")
|
||||
|
||||
auth_config = json.loads(u.auth_config_json or "{}")
|
||||
selected = set(body.group_ids)
|
||||
prefix = body.name_prefix
|
||||
results: list[GeneratedUpstreamKeyResponse] = []
|
||||
with UpstreamClient(
|
||||
base_url=u.base_url,
|
||||
api_prefix=u.api_prefix,
|
||||
auth_type=u.auth_type,
|
||||
auth_config=auth_config,
|
||||
timeout=float(u.timeout_seconds),
|
||||
) as client:
|
||||
try:
|
||||
client.login()
|
||||
groups = client.get_available_groups(u.groups_endpoint)
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, str(exc))
|
||||
|
||||
for group in groups:
|
||||
gid = _group_id(group)
|
||||
if not gid or (selected and gid not in selected):
|
||||
continue
|
||||
result = _ensure_group_key(db, client, u, group, prefix, body)
|
||||
results.append(result)
|
||||
|
||||
created = len([item for item in results if item.status == "created"])
|
||||
existed = len([item for item in results if item.status == "exists"])
|
||||
total = len(results)
|
||||
msg_parts = []
|
||||
if created:
|
||||
msg_parts.append(f"新创建 {created}")
|
||||
if existed:
|
||||
msg_parts.append(f"已存在 {existed}")
|
||||
msg = "、".join(msg_parts) + f" / 共 {total} 个分组" if msg_parts else f"共处理 {total} 个分组"
|
||||
return GenerateKeysByGroupsResponse(
|
||||
success=total > 0 and all(item.status != "failed" for item in results),
|
||||
message=msg,
|
||||
items=results,
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=UpstreamResponse, status_code=201)
|
||||
def create_upstream(
|
||||
body: UpstreamCreate,
|
||||
@@ -255,6 +484,10 @@ def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_use
|
||||
webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes)
|
||||
website_sync.sync_affected_bindings(db, u.id, changes)
|
||||
|
||||
# 同步 SmartUp Key 状态(使用实际快照入库时间,与定时任务一致)
|
||||
from app.services.scheduler import _sync_upstream_keys as _synck
|
||||
_synck(uid, snapshot, new_row.captured_at)
|
||||
|
||||
msg = f"检测成功,{len(groups)} 个分组"
|
||||
if changes:
|
||||
msg += f",发现 {len(changes)} 处倍率变化"
|
||||
|
||||
@@ -9,11 +9,21 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.snapshot import UpstreamRateSnapshot
|
||||
from app.models.upstream import Upstream
|
||||
from app.models.upstream_key import UpstreamGeneratedKey
|
||||
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
|
||||
from app.schemas.website import (
|
||||
BindingCreate,
|
||||
BindingResponse,
|
||||
BindingUpdate,
|
||||
ImportAccountItem,
|
||||
ImportAccountsRequest,
|
||||
ImportAccountsResponse,
|
||||
ImportGroupItem,
|
||||
ImportGroupsRequest,
|
||||
ImportGroupsResponse,
|
||||
SyncImportStatusRequest,
|
||||
TestResult,
|
||||
WebsiteCreate,
|
||||
WebsiteGroupResponse,
|
||||
@@ -118,6 +128,47 @@ def _client(row: Website) -> Sub2ApiWebsiteClient:
|
||||
)
|
||||
|
||||
|
||||
def _latest_upstream_groups(db: Session, upstream_id: int) -> list[dict]:
|
||||
row = (
|
||||
db.query(UpstreamRateSnapshot)
|
||||
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
|
||||
.order_by(UpstreamRateSnapshot.captured_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(404, "no upstream snapshot found; run upstream check first")
|
||||
snapshot = json.loads(row.snapshot_json or "{}")
|
||||
groups = snapshot.get("groups") or {}
|
||||
if not isinstance(groups, dict):
|
||||
return []
|
||||
return [item for item in groups.values() if isinstance(item, dict)]
|
||||
|
||||
|
||||
def _source_group_id(group: dict) -> str:
|
||||
return str(group.get("group_id") or group.get("id") or group.get("name") or "")
|
||||
|
||||
|
||||
def _source_group_name(group: dict, gid: str) -> str:
|
||||
return str(group.get("group_name") or group.get("name") or gid)
|
||||
|
||||
|
||||
def _source_group_rate(group: dict) -> float:
|
||||
raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1
|
||||
try:
|
||||
return float(raw)
|
||||
except (TypeError, ValueError):
|
||||
return 1.0
|
||||
|
||||
|
||||
def _numeric_group_id(value: str | None) -> int | None:
|
||||
if value is None or value == "":
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/api/websites", response_model=List[WebsiteResponse])
|
||||
def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()]
|
||||
@@ -215,6 +266,375 @@ def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_c
|
||||
raise HTTPException(502, str(exc))
|
||||
|
||||
|
||||
@router.post("/api/websites/{wid}/groups/import-from-upstream/{upstream_id}", response_model=ImportGroupsResponse)
|
||||
def import_groups_from_upstream(
|
||||
wid: int,
|
||||
upstream_id: int,
|
||||
body: ImportGroupsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
_=Depends(get_current_user),
|
||||
):
|
||||
website = db.query(Website).filter(Website.id == wid).first()
|
||||
if not website:
|
||||
raise HTTPException(404, "website not found")
|
||||
if website.site_type != "sub2api":
|
||||
raise HTTPException(400, "目前只支持 sub2api")
|
||||
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
|
||||
if not upstream:
|
||||
raise HTTPException(404, "upstream not found")
|
||||
|
||||
selected = set(body.group_ids)
|
||||
groups = _latest_upstream_groups(db, upstream_id)
|
||||
# 拉取目标网站已有分组,同名则跳过
|
||||
try:
|
||||
existing_names = set()
|
||||
with _client(website) as c:
|
||||
for eg in c.get_groups(website.groups_endpoint):
|
||||
gname = eg.get("name") or eg.get("group_name") or ""
|
||||
if gname:
|
||||
existing_names.add(gname)
|
||||
except Exception:
|
||||
existing_names = set()
|
||||
|
||||
items: list[ImportGroupItem] = []
|
||||
with _client(website) as c:
|
||||
for group in groups:
|
||||
source_gid = _source_group_id(group)
|
||||
if not source_gid or (selected and source_gid not in selected):
|
||||
continue
|
||||
source_name = _source_group_name(group, source_gid)
|
||||
target_name = f"{body.name_prefix}{source_name}" if body.name_prefix else source_name
|
||||
|
||||
# 检查是否已存在同名分组
|
||||
if target_name in existing_names:
|
||||
items.append(ImportGroupItem(
|
||||
source_group_id=source_gid,
|
||||
source_group_name=source_name,
|
||||
target_group_name=target_name,
|
||||
status="exists",
|
||||
message="目标分组已存在,已跳过",
|
||||
))
|
||||
continue
|
||||
|
||||
create_body = {
|
||||
"name": target_name,
|
||||
"description": group.get("description") or f"Imported from {upstream.name} / {source_name}",
|
||||
"platform": group.get("platform") or "openai",
|
||||
"rate_multiplier": _source_group_rate(group),
|
||||
}
|
||||
if group.get("rpm_limit") is not None:
|
||||
create_body["rpm_limit"] = group.get("rpm_limit")
|
||||
try:
|
||||
created = c.create_group(create_body)
|
||||
target_id = c.extract_id(created)
|
||||
items.append(ImportGroupItem(
|
||||
source_group_id=source_gid,
|
||||
source_group_name=source_name,
|
||||
target_group_id=target_id or None,
|
||||
target_group_name=str(created.get("name") or target_name),
|
||||
status="created",
|
||||
message="已创建",
|
||||
raw=created,
|
||||
))
|
||||
except Exception as exc:
|
||||
msg = str(exc)
|
||||
# 捕获 409 等已存在错误
|
||||
if "已存在" in msg or "already exists" in msg.lower() or "409" in msg or "Conflict" in msg:
|
||||
items.append(ImportGroupItem(
|
||||
source_group_id=source_gid,
|
||||
source_group_name=source_name,
|
||||
target_group_name=target_name,
|
||||
status="exists",
|
||||
message="目标分组已存在(接口返回冲突)",
|
||||
))
|
||||
else:
|
||||
logger.exception("import website group failed website=%s upstream=%s group=%s", wid, upstream_id, source_gid)
|
||||
items.append(ImportGroupItem(
|
||||
source_group_id=source_gid,
|
||||
source_group_name=source_name,
|
||||
target_group_name=target_name,
|
||||
status="failed",
|
||||
message=msg,
|
||||
))
|
||||
created_count = len([item for item in items if item.status == "created"])
|
||||
exists_count = len([item for item in items if item.status == "exists"])
|
||||
failed_count = len([item for item in items if item.status == "failed"])
|
||||
msg_parts = []
|
||||
if created_count:
|
||||
msg_parts.append(f"新建 {created_count}")
|
||||
if exists_count:
|
||||
msg_parts.append(f"已存在 {exists_count}")
|
||||
if failed_count:
|
||||
msg_parts.append(f"失败 {failed_count}")
|
||||
return ImportGroupsResponse(
|
||||
success=failed_count == 0,
|
||||
message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共处理 {len(items)} 个分组",
|
||||
items=items,
|
||||
)
|
||||
|
||||
|
||||
def _detect_platform(text: str, fallback: str = "openai") -> str:
|
||||
"""根据 Key 名或分组名关键词判断平台类型。"""
|
||||
lower = text.lower()
|
||||
if "claude" in lower or "anthropic" in lower:
|
||||
return "anthropic"
|
||||
if "gemini" in lower:
|
||||
return "gemini"
|
||||
if "antigravity" in lower:
|
||||
return "antigravity"
|
||||
return fallback
|
||||
|
||||
|
||||
@router.post("/api/websites/{wid}/accounts/sync-imported-upstream-keys", response_model=ImportAccountsResponse)
|
||||
def sync_imported_upstream_keys(
|
||||
wid: int,
|
||||
body: SyncImportStatusRequest,
|
||||
db: Session = Depends(get_db),
|
||||
_=Depends(get_current_user),
|
||||
):
|
||||
"""校验已导入的上游 Key 在目标 Sub2API 账号管理中是否仍存在。"""
|
||||
website = db.query(Website).filter(Website.id == wid).first()
|
||||
if not website:
|
||||
raise HTTPException(404, "website not found")
|
||||
|
||||
rows = (
|
||||
db.query(UpstreamGeneratedKey)
|
||||
.filter(
|
||||
UpstreamGeneratedKey.upstream_id == body.upstream_id,
|
||||
UpstreamGeneratedKey.imported_website_id == wid,
|
||||
UpstreamGeneratedKey.imported_account_id.isnot(None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
items: list[ImportAccountItem] = []
|
||||
with _client(website) as c:
|
||||
for row in rows:
|
||||
platform = _detect_platform(f"{row.group_name} {row.group_id} {row.key_name}", "openai")
|
||||
if not row.imported_account_id:
|
||||
continue
|
||||
old_account_id = row.imported_account_id
|
||||
exists = c.account_exists(row.imported_account_id)
|
||||
if exists is False:
|
||||
row.imported_website_id = None
|
||||
row.imported_account_id = None
|
||||
row.imported_at = None
|
||||
row.status = "created"
|
||||
items.append(ImportAccountItem(
|
||||
upstream_key_id=row.id, source_group_id=row.group_id,
|
||||
source_group_name=row.group_name, account_id=old_account_id,
|
||||
platform=platform, status="stale_cleared",
|
||||
message="目标账号已删除,已清除导入标记",
|
||||
))
|
||||
elif exists is True:
|
||||
items.append(ImportAccountItem(
|
||||
upstream_key_id=row.id, source_group_id=row.group_id,
|
||||
source_group_name=row.group_name, account_id=old_account_id,
|
||||
platform=platform, status="exists", message="目标账号仍存在",
|
||||
))
|
||||
else:
|
||||
items.append(ImportAccountItem(
|
||||
upstream_key_id=row.id, source_group_id=row.group_id,
|
||||
source_group_name=row.group_name, account_id=old_account_id,
|
||||
platform=platform, status="check_failed",
|
||||
message="无法校验目标账号存在性(目标网站认证/网络问题)",
|
||||
))
|
||||
db.commit()
|
||||
cleared_count = len([i for i in items if i.status == "stale_cleared"])
|
||||
check_failed_count = len([i for i in items if i.status == "check_failed"])
|
||||
msg_parts = []
|
||||
if cleared_count:
|
||||
msg_parts.append(f"清除 {cleared_count}")
|
||||
if check_failed_count:
|
||||
msg_parts.append(f"校验失败 {check_failed_count}")
|
||||
return ImportAccountsResponse(
|
||||
success=check_failed_count == 0,
|
||||
message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共校验 {len(items)} 个,无变化",
|
||||
items=items,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/websites/{wid}/accounts/import-upstream-keys", response_model=ImportAccountsResponse)
|
||||
def import_upstream_keys_as_accounts(
|
||||
wid: int,
|
||||
body: ImportAccountsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
_=Depends(get_current_user),
|
||||
):
|
||||
website = db.query(Website).filter(Website.id == wid).first()
|
||||
if not website:
|
||||
raise HTTPException(404, "website not found")
|
||||
if website.site_type != "sub2api":
|
||||
raise HTTPException(400, "目前只支持 sub2api")
|
||||
if not body.upstream_key_ids:
|
||||
raise HTTPException(400, "请选择要导入的 Key")
|
||||
rows = (
|
||||
db.query(UpstreamGeneratedKey)
|
||||
.filter(UpstreamGeneratedKey.id.in_(body.upstream_key_ids))
|
||||
.order_by(UpstreamGeneratedKey.id)
|
||||
.all()
|
||||
)
|
||||
found_ids = {row.id for row in rows}
|
||||
missing_ids = [kid for kid in body.upstream_key_ids if kid not in found_ids]
|
||||
items: list[ImportAccountItem] = [
|
||||
ImportAccountItem(
|
||||
upstream_key_id=kid,
|
||||
source_group_id="",
|
||||
source_group_name="",
|
||||
platform=body.default_platform,
|
||||
status="failed",
|
||||
message="key not found",
|
||||
)
|
||||
for kid in missing_ids
|
||||
]
|
||||
# 查出来源上游的 Base URL
|
||||
upstream_base_url = ""
|
||||
if body.upstream_key_ids:
|
||||
first_row = rows[0] if rows else None
|
||||
if first_row:
|
||||
from app.models.upstream import Upstream as _Up
|
||||
_u = db.query(_Up).filter(_Up.id == first_row.upstream_id).first()
|
||||
if _u:
|
||||
upstream_base_url = _u.base_url
|
||||
|
||||
with _client(website) as c:
|
||||
for row in rows:
|
||||
# 先确定平台(失败项也需要记录)
|
||||
if body.platform_mode == "auto":
|
||||
platform = _detect_platform(
|
||||
f"{row.group_name} {row.group_id} {row.key_name}",
|
||||
body.default_platform,
|
||||
)
|
||||
else:
|
||||
platform = body.default_platform
|
||||
|
||||
# 幂等校验:已导入过则检查远端账号是否仍存在
|
||||
if row.imported_website_id == wid and row.imported_account_id:
|
||||
old_account_id = row.imported_account_id
|
||||
exists = c.account_exists(row.imported_account_id)
|
||||
if exists is True:
|
||||
items.append(ImportAccountItem(
|
||||
upstream_key_id=row.id,
|
||||
source_group_id=row.group_id,
|
||||
source_group_name=row.group_name,
|
||||
target_group_id=body.target_group_map.get(row.group_id),
|
||||
account_id=old_account_id,
|
||||
account_name=f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}",
|
||||
platform=platform,
|
||||
upstream_base_url=upstream_base_url,
|
||||
status="exists",
|
||||
message="已导入过,已跳过",
|
||||
))
|
||||
continue
|
||||
elif exists is False:
|
||||
# 远端已删除,清空标记后继续创建
|
||||
row.imported_website_id = None
|
||||
row.imported_account_id = None
|
||||
row.imported_at = None
|
||||
row.status = "created"
|
||||
# 继续往下走(不 continue)
|
||||
else:
|
||||
# 校验失败,保守跳过
|
||||
items.append(ImportAccountItem(
|
||||
upstream_key_id=row.id,
|
||||
source_group_id=row.group_id,
|
||||
source_group_name=row.group_name,
|
||||
target_group_id=body.target_group_map.get(row.group_id),
|
||||
account_id=old_account_id,
|
||||
platform=platform,
|
||||
status="check_failed",
|
||||
message="无法校验目标账号状态,已保守跳过",
|
||||
))
|
||||
continue
|
||||
|
||||
if not row.key_value:
|
||||
items.append(ImportAccountItem(
|
||||
upstream_key_id=row.id,
|
||||
source_group_id=row.group_id,
|
||||
source_group_name=row.group_name,
|
||||
platform=platform,
|
||||
status="failed",
|
||||
message="该 Key 无明文值,无法导入(远端已存在 Key 不会保留明文,请重新创建或手动填入)",
|
||||
))
|
||||
continue
|
||||
|
||||
target_group_id = body.target_group_map.get(row.group_id)
|
||||
group_ids = []
|
||||
numeric_target = _numeric_group_id(target_group_id)
|
||||
if numeric_target is not None:
|
||||
group_ids.append(numeric_target)
|
||||
account_name = f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}"
|
||||
account_body = {
|
||||
"name": account_name,
|
||||
"platform": platform,
|
||||
"type": "apikey",
|
||||
"credentials": {
|
||||
"api_key": row.key_value,
|
||||
"base_url": upstream_base_url,
|
||||
},
|
||||
"group_ids": group_ids,
|
||||
"rate_multiplier": 1,
|
||||
"concurrency": body.concurrency,
|
||||
"priority": body.priority,
|
||||
"notes": f"Imported by SmartUp from upstream key #{row.id}",
|
||||
}
|
||||
try:
|
||||
created = c.create_account(account_body)
|
||||
account_id = c.extract_id(created)
|
||||
row.imported_website_id = wid
|
||||
row.imported_account_id = account_id or None
|
||||
row.imported_at = datetime.now(timezone.utc)
|
||||
row.status = "imported"
|
||||
row.error = None
|
||||
db.commit()
|
||||
items.append(ImportAccountItem(
|
||||
upstream_key_id=row.id,
|
||||
source_group_id=row.group_id,
|
||||
source_group_name=row.group_name,
|
||||
target_group_id=target_group_id,
|
||||
account_id=account_id or None,
|
||||
account_name=str(created.get("name") or account_name),
|
||||
platform=platform,
|
||||
upstream_base_url=upstream_base_url,
|
||||
status="created",
|
||||
message="已创建账号",
|
||||
raw=created,
|
||||
))
|
||||
except Exception as exc:
|
||||
logger.exception("import upstream key as account failed website=%s key=%s", wid, row.id)
|
||||
row.status = "import_failed"
|
||||
row.error = str(exc)
|
||||
db.commit()
|
||||
items.append(ImportAccountItem(
|
||||
upstream_key_id=row.id,
|
||||
source_group_id=row.group_id,
|
||||
source_group_name=row.group_name,
|
||||
target_group_id=target_group_id,
|
||||
account_name=account_name,
|
||||
platform=platform,
|
||||
upstream_base_url=upstream_base_url,
|
||||
status="failed",
|
||||
message=str(exc),
|
||||
))
|
||||
created_count = len([item for item in items if item.status == "created"])
|
||||
exists_count = len([item for item in items if item.status == "exists"])
|
||||
failed_count = len([item for item in items if item.status == "failed"])
|
||||
check_failed_count = len([item for item in items if item.status == "check_failed"])
|
||||
msg_parts = []
|
||||
if created_count:
|
||||
msg_parts.append(f"新建 {created_count}")
|
||||
if exists_count:
|
||||
msg_parts.append(f"已存在 {exists_count}")
|
||||
if check_failed_count:
|
||||
msg_parts.append(f"校验失败 {check_failed_count}")
|
||||
if failed_count:
|
||||
msg_parts.append(f"失败 {failed_count}")
|
||||
return ImportAccountsResponse(
|
||||
success=failed_count == 0 and check_failed_count == 0,
|
||||
message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共处理 {len(items)} 个",
|
||||
items=items,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/group-bindings", response_model=List[BindingResponse])
|
||||
def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AuthConfigBearer(BaseModel):
|
||||
@@ -89,3 +89,38 @@ class TestResult(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
detail: Optional[str] = None
|
||||
|
||||
|
||||
class GenerateKeysByGroupsRequest(BaseModel):
|
||||
group_ids: list[str] = Field(default_factory=list)
|
||||
name_prefix: str = "SmartUp"
|
||||
quota: float = Field(default=0, ge=0)
|
||||
expires_in_days: Optional[int] = Field(default=None, ge=1)
|
||||
rate_limit_5h: float = Field(default=0, ge=0)
|
||||
rate_limit_1d: float = Field(default=0, ge=0)
|
||||
rate_limit_7d: float = Field(default=0, ge=0)
|
||||
endpoint: str = "/keys"
|
||||
|
||||
|
||||
class GeneratedUpstreamKeyResponse(BaseModel):
|
||||
id: Optional[int] = None
|
||||
upstream_id: int
|
||||
group_id: str
|
||||
group_name: str = ""
|
||||
key_id: Optional[str] = None
|
||||
key_name: str
|
||||
key_value: Optional[str] = None
|
||||
masked_key: str = ""
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
imported_website_id: Optional[int] = None
|
||||
imported_account_id: Optional[str] = None
|
||||
imported_at: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class GenerateKeysByGroupsResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
items: list[GeneratedUpstreamKeyResponse]
|
||||
|
||||
@@ -122,3 +122,58 @@ class WebsiteSyncLogResponse(BaseModel):
|
||||
status: str
|
||||
message: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ImportGroupsRequest(BaseModel):
|
||||
group_ids: list[str] = Field(default_factory=list)
|
||||
name_prefix: str = ""
|
||||
|
||||
|
||||
class ImportGroupItem(BaseModel):
|
||||
source_group_id: str
|
||||
source_group_name: str
|
||||
target_group_id: Optional[str] = None
|
||||
target_group_name: str = ""
|
||||
status: str
|
||||
message: str = ""
|
||||
raw: dict[str, Any] = {}
|
||||
|
||||
|
||||
class ImportGroupsResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
items: list[ImportGroupItem]
|
||||
|
||||
|
||||
class SyncImportStatusRequest(BaseModel):
|
||||
upstream_id: int = Field(default=0)
|
||||
|
||||
|
||||
class ImportAccountsRequest(BaseModel):
|
||||
upstream_key_ids: list[int] = Field(default_factory=list)
|
||||
target_group_map: dict[str, str] = Field(default_factory=dict)
|
||||
account_name_prefix: str = "SmartUp"
|
||||
default_platform: str = "openai"
|
||||
platform_mode: str = "auto" # "auto" | "manual"
|
||||
concurrency: int = Field(default=10, ge=1)
|
||||
priority: int = Field(default=1, ge=0)
|
||||
|
||||
|
||||
class ImportAccountItem(BaseModel):
|
||||
upstream_key_id: int
|
||||
source_group_id: str
|
||||
source_group_name: str
|
||||
target_group_id: Optional[str] = None
|
||||
account_id: Optional[str] = None
|
||||
account_name: str = ""
|
||||
platform: str = ""
|
||||
upstream_base_url: str = ""
|
||||
status: str
|
||||
message: str = ""
|
||||
raw: dict[str, Any] = {}
|
||||
|
||||
|
||||
class ImportAccountsResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
items: list[ImportAccountItem]
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import SessionLocal
|
||||
from app.models.upstream import Upstream
|
||||
from app.models.upstream_key import UpstreamGeneratedKey
|
||||
from app.models.snapshot import UpstreamRateSnapshot
|
||||
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot
|
||||
from app.services.snapshot_service import diff_snapshots, prune_snapshots
|
||||
@@ -130,7 +131,20 @@ def _check_upstream(upstream_id: int) -> None:
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# ── Phase 2: notifications (independent sessions) ──────────────
|
||||
# ── Phase 2: key sync (independent session) ───────────────────
|
||||
if snapshot:
|
||||
captured_at = snapshot.get("captured_at")
|
||||
if isinstance(captured_at, str):
|
||||
from datetime import datetime as dt
|
||||
try:
|
||||
captured_at = dt.fromisoformat(captured_at)
|
||||
except Exception:
|
||||
captured_at = datetime.now(timezone.utc)
|
||||
elif captured_at is None:
|
||||
captured_at = datetime.now(timezone.utc)
|
||||
_sync_upstream_keys(upstream_id, snapshot, captured_at)
|
||||
|
||||
# ── Phase 3: notifications (independent sessions) ──────────────
|
||||
if was_unhealthy:
|
||||
_notify_status(upstream_id, upstream.name, upstream.base_url, "upstream_recovered")
|
||||
|
||||
@@ -170,6 +184,63 @@ def _notify_rate_changed(
|
||||
db.close()
|
||||
|
||||
|
||||
def _sync_upstream_keys(upstream_id: int, snapshot: dict[str, Any], captured_at: datetime) -> None:
|
||||
"""上游检测成功后同步 SmartUp Key 状态(远端删除/分组删除)。"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
active_group_ids = set(snapshot.get("groups", {}).keys())
|
||||
key_rows = (
|
||||
db.query(UpstreamGeneratedKey)
|
||||
.filter(
|
||||
UpstreamGeneratedKey.upstream_id == upstream_id,
|
||||
UpstreamGeneratedKey.key_name.like("SmartUp-%"),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
auth_config = json.loads(
|
||||
db.query(Upstream).filter(Upstream.id == upstream_id).first().auth_config_json or "{}"
|
||||
)
|
||||
# 用 UpstreamClient 查询远端活跃 Key ID 集合
|
||||
remote_key_ids: set[str] | None = None # None=查询失败,set()=查询成功但为空
|
||||
try:
|
||||
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
|
||||
if upstream:
|
||||
with UpstreamClient(
|
||||
base_url=upstream.base_url,
|
||||
api_prefix=upstream.api_prefix,
|
||||
auth_type=upstream.auth_type,
|
||||
auth_config=auth_config,
|
||||
timeout=float(upstream.timeout_seconds),
|
||||
) as client:
|
||||
client.login()
|
||||
remote_keys = client.list_api_keys(search="SmartUp", status="active")
|
||||
remote_key_ids = {
|
||||
str(k["id"]) for k in remote_keys if k.get("id")
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("sync upstream keys list failed for %s: %s", upstream_id, exc)
|
||||
|
||||
for row in key_rows:
|
||||
# 1. 分组已不在当前快照中 → 删除本地记录
|
||||
if row.group_id not in active_group_ids:
|
||||
db.delete(row)
|
||||
logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id)
|
||||
continue
|
||||
# 2. 远端查询成功但 key_id 不在列表中 → 删除本地记录
|
||||
if row.key_id and remote_key_ids is not None and row.key_id not in remote_key_ids:
|
||||
db.delete(row)
|
||||
logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id)
|
||||
continue
|
||||
# 3. 更新同步时间戳(仅当查询成功且 Key 仍在远端时)
|
||||
if remote_key_ids is not None and row.key_id in remote_key_ids:
|
||||
row.updated_at = captured_at
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.exception("key sync failed for upstream %s", upstream_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _sync_website_bindings(upstream_id: int, changes: list[dict[str, Any]]) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
@@ -62,6 +62,49 @@ def _find_user_id(value: Any) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def mask_secret(value: Any) -> str:
|
||||
text = str(value or "")
|
||||
if not text:
|
||||
return ""
|
||||
if len(text) <= 8:
|
||||
return text[:2] + "****" + text[-2:] if len(text) > 4 else "****"
|
||||
return text[:4] + "**********" + text[-4:]
|
||||
|
||||
|
||||
def _unwrap_data(value: Any) -> Any:
|
||||
if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value):
|
||||
return value.get("data")
|
||||
return value
|
||||
|
||||
|
||||
def _extract_id(value: Any) -> str:
|
||||
if isinstance(value, dict):
|
||||
for key in ("id", "key_id", "keyId"):
|
||||
candidate = value.get(key)
|
||||
if candidate is not None:
|
||||
return str(candidate)
|
||||
for key in ("data", "result", "key", "api_key"):
|
||||
found = _extract_id(value.get(key))
|
||||
if found:
|
||||
return found
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_key_value(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
for key in ("key", "api_key", "apiKey", "token", "value"):
|
||||
candidate = value.get(key)
|
||||
if isinstance(candidate, str) and candidate:
|
||||
return candidate
|
||||
for key in ("data", "result", "api_key", "key"):
|
||||
found = _extract_key_value(value.get(key))
|
||||
if found:
|
||||
return found
|
||||
return ""
|
||||
|
||||
|
||||
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
|
||||
def _normalize(lst: list) -> list[dict[str, Any]]:
|
||||
out = []
|
||||
@@ -360,3 +403,107 @@ class UpstreamClient:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def list_api_keys(
|
||||
self,
|
||||
search: str = "",
|
||||
group_id: str | int | None = None,
|
||||
status: str = "active",
|
||||
endpoint: str = "/keys",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。"""
|
||||
params: dict[str, Any] = {}
|
||||
if search:
|
||||
params["search"] = search
|
||||
if group_id is not None:
|
||||
params["group_id"] = int(group_id) if str(group_id).isdigit() else group_id
|
||||
if status:
|
||||
params["status"] = status
|
||||
url = self._url(endpoint)
|
||||
resp = self._client.request(
|
||||
"GET",
|
||||
url,
|
||||
params=params if params else None,
|
||||
headers=self._headers(),
|
||||
cookies=self._cookies,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
# 尝试展开常见的包装结构
|
||||
for top_key in ("data", "result", "response"):
|
||||
val = data.get(top_key)
|
||||
if isinstance(val, list):
|
||||
return val
|
||||
if isinstance(val, dict):
|
||||
for inner_key in ("items", "keys", "list", "records", "data"):
|
||||
inner = val.get(inner_key)
|
||||
if isinstance(inner, list):
|
||||
return inner
|
||||
# 顶层本身就是 list-like wrapper
|
||||
for key in ("items", "keys", "list", "records"):
|
||||
val = data.get(key)
|
||||
if isinstance(val, list):
|
||||
return val
|
||||
raise UpstreamError(f"unexpected keys response type: {type(data).__name__}")
|
||||
|
||||
def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None:
|
||||
"""删除远端上游上的一个 Key。"""
|
||||
self._request("DELETE", f"{endpoint}/{key_id}")
|
||||
|
||||
def find_smartup_group_key(
|
||||
self,
|
||||
group_id: str | int,
|
||||
expected_name: str,
|
||||
prefix: str = "SmartUp",
|
||||
) -> dict[str, Any] | None:
|
||||
"""查找同一上游分组下是否已存在 SmartUp 前缀的 Key。
|
||||
|
||||
匹配规则:key_name 等于 expected_name,且以 prefix 开头。
|
||||
返回匹配到的第一个 Key,或 None。
|
||||
"""
|
||||
gid = int(group_id) if str(group_id).isdigit() else group_id
|
||||
keys = self.list_api_keys(search=prefix, group_id=gid, status="active")
|
||||
for k in keys:
|
||||
name = k.get("name") or k.get("key_name") or ""
|
||||
if name == expected_name:
|
||||
return k
|
||||
# 部分后端返回的 name 可能带空格或 trimming
|
||||
if name.strip() == expected_name.strip():
|
||||
return k
|
||||
return None
|
||||
|
||||
def create_api_key(
|
||||
self,
|
||||
name: str,
|
||||
group_id: str | int,
|
||||
quota: float = 0,
|
||||
expires_in_days: int | None = None,
|
||||
rate_limit_5h: float = 0,
|
||||
rate_limit_1d: float = 0,
|
||||
rate_limit_7d: float = 0,
|
||||
endpoint: str = "/keys",
|
||||
) -> dict[str, Any]:
|
||||
body: dict[str, Any] = {
|
||||
"name": name,
|
||||
"group_id": int(group_id) if str(group_id).isdigit() else group_id,
|
||||
"quota": quota,
|
||||
"rate_limit_5h": rate_limit_5h,
|
||||
"rate_limit_1d": rate_limit_1d,
|
||||
"rate_limit_7d": rate_limit_7d,
|
||||
}
|
||||
if expires_in_days:
|
||||
body["expires_in_days"] = expires_in_days
|
||||
resp = self._request("POST", endpoint, body)
|
||||
data = _unwrap_data(resp)
|
||||
key_value = _extract_key_value(data)
|
||||
if not key_value:
|
||||
raise UpstreamError("key create response did not include key")
|
||||
return {
|
||||
"id": _extract_id(data),
|
||||
"key": key_value,
|
||||
"masked_key": mask_secret(key_value),
|
||||
"raw": data if isinstance(data, dict) else {"value": data},
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
@@ -8,11 +9,39 @@ import httpx
|
||||
|
||||
from app.utils.number import decimal_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebsiteError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _friendly_http_error(exc: httpx.HTTPStatusError) -> str:
|
||||
"""将常见 HTTP 错误转换为中文友好提示,原始信息保留在日志中。"""
|
||||
status = exc.response.status_code
|
||||
url = exc.request.url if exc.request else "?"
|
||||
logger.warning("website_client HTTP %s from %s: %s", status, url, exc)
|
||||
if status == 401:
|
||||
return "目标网站认证失败,请检查 Admin API Key / JWT 是否正确"
|
||||
if status == 403:
|
||||
return "目标网站权限不足,请检查当前凭证是否有分组管理权限"
|
||||
if status == 404:
|
||||
return f"目标网站接口不存在,请检查 API Prefix 和分组接口路径({exc.response.url.path})"
|
||||
if 500 <= status < 600:
|
||||
return "目标网站服务异常,请稍后重试"
|
||||
return f"目标网站返回错误(HTTP {status})"
|
||||
|
||||
|
||||
def _friendly_connection_error(exc: Exception) -> str:
|
||||
"""将网络/超时异常转换为中文友好提示。"""
|
||||
logger.warning("website_client connection error: %s", exc)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
return "目标网站请求超时,请检查网络连接和 API 地址是否正确"
|
||||
if isinstance(exc, httpx.ConnectError):
|
||||
return "无法连接目标网站,请检查 API 地址和网络连通性"
|
||||
return f"目标网站通信异常:{exc}"
|
||||
|
||||
|
||||
def parse_positive_decimal(value: Any) -> Decimal | None:
|
||||
if value is None or value == "":
|
||||
return None
|
||||
@@ -59,6 +88,19 @@ def _unwrap_data(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _extract_id(value: Any) -> str:
|
||||
if isinstance(value, dict):
|
||||
for key in ("id", "account_id", "accountId", "group_id", "groupId"):
|
||||
candidate = value.get(key)
|
||||
if candidate is not None:
|
||||
return str(candidate)
|
||||
for key in ("data", "result", "account", "group"):
|
||||
found = _extract_id(value.get(key))
|
||||
if found:
|
||||
return found
|
||||
return ""
|
||||
|
||||
|
||||
def normalize_groups(value: Any) -> list[dict[str, Any]]:
|
||||
raw = _unwrap_data(value)
|
||||
if isinstance(raw, dict):
|
||||
@@ -129,24 +171,111 @@ class Sub2ApiWebsiteClient:
|
||||
return headers
|
||||
|
||||
def _request(self, method: str, path: str, body: Any = None) -> Any:
|
||||
resp = self._client.request(method, self._url(path), json=body, headers=self._headers())
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
resp = self._client.request(method, self._url(path), json=body, headers=self._headers())
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise WebsiteError(_friendly_http_error(exc)) from exc
|
||||
except httpx.TimeoutException as exc:
|
||||
raise WebsiteError(_friendly_connection_error(exc)) from exc
|
||||
except httpx.ConnectError as exc:
|
||||
raise WebsiteError(_friendly_connection_error(exc)) from exc
|
||||
if not resp.content:
|
||||
return None
|
||||
text = resp.text
|
||||
if "application/json" not in resp.headers.get("content-type", "") and text.lstrip().startswith("<"):
|
||||
raise WebsiteError(f"{method} {path} returned HTML, not JSON")
|
||||
raise WebsiteError(f"{method} {path} 返回了 HTML,请检查接口地址是否正确")
|
||||
return resp.json()
|
||||
|
||||
def get_groups(self, endpoint: str = "/groups") -> list[dict[str, Any]]:
|
||||
errors: list[str] = []
|
||||
"""拉取分组列表,尝试 endpoint 和 fallback /groups/all。"""
|
||||
last_error: Exception | None = None
|
||||
tried_paths: list[str] = []
|
||||
for path in [endpoint, "/groups/all"]:
|
||||
tried_paths.append(path)
|
||||
try:
|
||||
return normalize_groups(self._request("GET", path))
|
||||
except WebsiteError as exc:
|
||||
msg = str(exc)
|
||||
# 认证/权限类错误:直接抛出,不需要尝试 fallback
|
||||
if "认证失败" in msg or "权限不足" in msg:
|
||||
raise
|
||||
# 404/5xx 等路径相关错误,试试另一个路径
|
||||
last_error = exc
|
||||
except Exception as exc:
|
||||
errors.append(f"{path}: {exc}")
|
||||
raise WebsiteError("; ".join(errors))
|
||||
last_error = exc
|
||||
logger.info("get_groups fallback %s failed: %s", path, exc)
|
||||
|
||||
msg = str(last_error) if last_error else "拉取分组失败"
|
||||
raise WebsiteError(f"{msg}(尝试接口:{'、'.join(tried_paths)})")
|
||||
|
||||
def update_group_rate(self, endpoint_template: str, group_id: str, rate: Decimal) -> Any:
|
||||
path = endpoint_template.replace("{id}", quote(group_id, safe=""))
|
||||
return self._request("PUT", path, {"rate_multiplier": float(rate)})
|
||||
|
||||
def create_group(self, body: dict[str, Any], endpoint: str = "/groups") -> dict[str, Any]:
|
||||
resp = self._request("POST", endpoint, body)
|
||||
data = _unwrap_data(resp)
|
||||
return data if isinstance(data, dict) else {"value": data}
|
||||
|
||||
def create_account(self, body: dict[str, Any], endpoint: str = "/accounts") -> dict[str, Any]:
|
||||
resp = self._request("POST", endpoint, body)
|
||||
data = _unwrap_data(resp)
|
||||
return data if isinstance(data, dict) else {"value": data}
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_list(value: dict) -> list | None:
|
||||
"""递归展开嵌套的列表包装:data.items、data.data、items、accounts 等。"""
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
# 先看顶层
|
||||
for key in ("items", "accounts", "records", "list", "data"):
|
||||
v = value.get(key)
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
# 再看 data.items、data.records、data.list 等嵌套
|
||||
data_val = value.get("data")
|
||||
if isinstance(data_val, dict):
|
||||
for key in ("items", "records", "list", "data", "accounts"):
|
||||
v = data_val.get(key)
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
return None
|
||||
|
||||
def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None:
|
||||
"""拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。"""
|
||||
try:
|
||||
resp = self._request("GET", endpoint)
|
||||
except Exception:
|
||||
logger.warning("account list fetch failed for %s", endpoint, exc_info=True)
|
||||
return None
|
||||
items = self._unwrap_list(resp)
|
||||
if items is None:
|
||||
logger.warning("account list unexpected format for %s", endpoint)
|
||||
return None
|
||||
ids: set[str] = set()
|
||||
for item in items:
|
||||
item_id = self.extract_id(item)
|
||||
if item_id:
|
||||
ids.add(item_id)
|
||||
return ids
|
||||
|
||||
def account_exists(self, account_id: str, endpoint: str = "/accounts") -> bool | None:
|
||||
"""检查目标账号是否存在。
|
||||
|
||||
优先拉取账号列表判断:
|
||||
- 列表成功取到 → return account_id in ids(True=存在,False=已删除)
|
||||
- 列表取不到(None)→ return None(校验失败,不清本地)
|
||||
返回 True=存在,False=已删除,None=校验失败。
|
||||
"""
|
||||
ids = self._get_account_ids(endpoint)
|
||||
if ids is None:
|
||||
logger.warning("account_exists cannot verify %s: list fetch failed", account_id)
|
||||
return None
|
||||
return account_id in ids
|
||||
|
||||
@staticmethod
|
||||
def extract_id(value: Any) -> str:
|
||||
return _extract_id(value)
|
||||
|
||||
Reference in New Issue
Block a user