feat: support real browser auth import
This commit is contained in:
@@ -10,6 +10,11 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.services.auth_capture_service import extract_all
|
||||
from app.services.browser_import_service import (
|
||||
ImportSessionError,
|
||||
browser_imports,
|
||||
build_import_result,
|
||||
)
|
||||
from app.services.browser_session_service import (
|
||||
BrowserDependencyError,
|
||||
BrowserSessionError,
|
||||
@@ -43,6 +48,32 @@ class CaptureExtractResponse(BaseModel):
|
||||
candidates: list[dict] = []
|
||||
|
||||
|
||||
class ImportSessionCreate(BaseModel):
|
||||
target_url: str = Field(..., description="Target page URL opened in the user's real browser")
|
||||
|
||||
|
||||
class ImportSessionCreateResponse(BaseModel):
|
||||
session_id: str
|
||||
secret: str
|
||||
expires_in_seconds: int
|
||||
|
||||
|
||||
class ImportSessionStatusResponse(BaseModel):
|
||||
session_id: str
|
||||
ready: bool = False
|
||||
expires_at: float
|
||||
result: Optional[CaptureExtractResponse] = None
|
||||
|
||||
|
||||
class ImportSessionSubmit(BaseModel):
|
||||
secret: str
|
||||
page_url: str = ""
|
||||
cookies: list[dict] = []
|
||||
local_storage: dict[str, Any] = {}
|
||||
session_storage: dict[str, Any] = {}
|
||||
auth_headers: list[dict] = []
|
||||
|
||||
|
||||
def _sanitize_candidate(candidate: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
key: value
|
||||
@@ -134,3 +165,65 @@ async def close_capture_session(
|
||||
await browser_sessions.close(session_id)
|
||||
except Exception as exc:
|
||||
raise _browser_error(exc)
|
||||
|
||||
|
||||
@router.post("/import-sessions", response_model=ImportSessionCreateResponse, status_code=201)
|
||||
async def create_import_session(
|
||||
body: ImportSessionCreate,
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Create a one-time real-browser import session.
|
||||
|
||||
The returned secret is intended for the local browser extension only.
|
||||
"""
|
||||
target_url = body.target_url.strip()
|
||||
if not target_url.startswith(("http://", "https://")):
|
||||
raise HTTPException(400, "Only http/https URLs are allowed")
|
||||
session, secret = browser_imports.create(target_url=target_url, created_by=user.email)
|
||||
return ImportSessionCreateResponse(
|
||||
session_id=session.id,
|
||||
secret=secret,
|
||||
expires_in_seconds=600,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/import-sessions/{session_id}", response_model=ImportSessionStatusResponse)
|
||||
async def get_import_session(
|
||||
session_id: str,
|
||||
include_raw: bool = Query(default=False),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
session = browser_imports.get(session_id, created_by=user.email)
|
||||
except ImportSessionError as exc:
|
||||
raise HTTPException(404, str(exc))
|
||||
|
||||
result = None
|
||||
if session.payload is not None:
|
||||
full_result = build_import_result(session.payload)
|
||||
if not include_raw:
|
||||
candidates = [_sanitize_candidate(candidate) for candidate in full_result.get("candidates", [])]
|
||||
result = CaptureExtractResponse(candidates=candidates)
|
||||
else:
|
||||
result = CaptureExtractResponse(**full_result)
|
||||
return ImportSessionStatusResponse(
|
||||
session_id=session.id,
|
||||
ready=session.payload is not None,
|
||||
expires_at=session.expires_at,
|
||||
result=result,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import-sessions/{session_id}/submit", status_code=204)
|
||||
async def submit_import_session(
|
||||
session_id: str,
|
||||
body: ImportSessionSubmit,
|
||||
):
|
||||
try:
|
||||
browser_imports.submit(
|
||||
session_id=session_id,
|
||||
secret=body.secret,
|
||||
payload=body.model_dump(exclude={"secret"}),
|
||||
)
|
||||
except ImportSessionError as exc:
|
||||
raise HTTPException(400, str(exc))
|
||||
|
||||
@@ -221,25 +221,45 @@ class RefreshAuthResponse(BaseModel):
|
||||
warning: Optional[str] = None
|
||||
|
||||
|
||||
def _pick_best_candidate(candidates: list[dict], preferred_auth_type: str) -> Optional[dict]:
|
||||
def _norm_path(value: Any) -> str:
|
||||
return str(value or "").strip().rstrip("/")
|
||||
|
||||
|
||||
def _detect_upstream_platform(upstream: Upstream, auth_config: dict) -> str:
|
||||
api_prefix = _norm_path(upstream.api_prefix)
|
||||
groups_endpoint = _norm_path(upstream.groups_endpoint)
|
||||
rate_endpoint = _norm_path(upstream.rate_endpoint)
|
||||
login_path = _norm_path(auth_config.get("login_path"))
|
||||
|
||||
if groups_endpoint == "/api/user/self/groups" or login_path == "/api/user/login":
|
||||
return "new-api-user"
|
||||
if api_prefix == "/api/v1" or groups_endpoint in {"/groups/available", "/groups/rates"} or login_path == "/auth/login":
|
||||
return "sub2api"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _first_candidate(candidates: list[dict], *types: str) -> Optional[dict]:
|
||||
for c in candidates:
|
||||
if c.get("type") in types:
|
||||
return c
|
||||
return None
|
||||
|
||||
|
||||
def _pick_best_candidate(candidates: list[dict], preferred_auth_type: str, platform: str = "unknown") -> Optional[dict]:
|
||||
if not candidates:
|
||||
return None
|
||||
# cookie_bundle > cookie > bearer_token > api_key
|
||||
# preferred_auth_type="cookie" 时优先匹配 bundle,其次单 cookie
|
||||
|
||||
if platform == "sub2api":
|
||||
return _first_candidate(candidates, "bearer_token", "api_key")
|
||||
if platform == "new-api-user":
|
||||
return _first_candidate(candidates, "cookie_bundle", "cookie", "bearer_token", "api_key")
|
||||
if preferred_auth_type == "cookie":
|
||||
for c in candidates:
|
||||
if c["type"] == "cookie_bundle":
|
||||
return c
|
||||
for c in candidates:
|
||||
if c["type"] == "cookie":
|
||||
return c
|
||||
return _first_candidate(candidates, "cookie_bundle", "cookie")
|
||||
elif preferred_auth_type in ("bearer", "api_key"):
|
||||
type_map = {"bearer": "bearer_token", "api_key": "api_key"}
|
||||
preferred = type_map.get(preferred_auth_type)
|
||||
if preferred:
|
||||
for c in candidates:
|
||||
if c["type"] == preferred:
|
||||
return c
|
||||
return _first_candidate(candidates, preferred)
|
||||
# fallback:排序后取第一个
|
||||
return candidates[0]
|
||||
|
||||
@@ -268,11 +288,17 @@ async def refresh_auth(pid: int, db: Session = Depends(get_db), _=Depends(get_cu
|
||||
return RefreshAuthResponse(success=False, message=f"提取失败: {exc}")
|
||||
|
||||
candidates = result.get("candidates", [])
|
||||
candidate = _pick_best_candidate(candidates, upstream.auth_type)
|
||||
existing_config = _json.loads(upstream.auth_config_json or "{}")
|
||||
platform = _detect_upstream_platform(upstream, existing_config)
|
||||
candidate = _pick_best_candidate(candidates, upstream.auth_type, platform)
|
||||
if not candidate:
|
||||
if platform == "sub2api" and _first_candidate(candidates, "cookie_bundle", "cookie"):
|
||||
return RefreshAuthResponse(
|
||||
success=False,
|
||||
message="Sub2API 需要 Bearer Token;当前只提取到 Cookie。请在远程浏览器完成登录后刷新页面或触发一次接口请求,再重新提取。",
|
||||
)
|
||||
return RefreshAuthResponse(success=False, message="未提取到有效凭证,请确认已在远程浏览器中登录")
|
||||
|
||||
existing_config = _json.loads(upstream.auth_config_json or "{}")
|
||||
ctype = candidate["type"]
|
||||
|
||||
if ctype in ("cookie_bundle", "cookie"):
|
||||
@@ -281,6 +307,10 @@ async def refresh_auth(pid: int, db: Session = Depends(get_db), _=Depends(get_cu
|
||||
existing_config["cookie_string"] = candidate.get("value", "")
|
||||
if candidate.get("new_api_user"):
|
||||
existing_config["new_api_user"] = candidate["new_api_user"]
|
||||
if platform == "new-api-user":
|
||||
upstream.api_prefix = ""
|
||||
upstream.groups_endpoint = "/api/user/self/groups"
|
||||
upstream.rate_endpoint = "/api/user/self/groups"
|
||||
elif ctype == "bearer_token":
|
||||
upstream.auth_type = "bearer"
|
||||
raw = candidate.get("value", "")
|
||||
|
||||
@@ -315,6 +315,26 @@ def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_c
|
||||
_generate_key_lock = __import__("threading").Lock()
|
||||
|
||||
|
||||
def _is_sub2api_upstream(upstream: Upstream) -> bool:
|
||||
return upstream.api_prefix.strip("/") == "api/v1"
|
||||
|
||||
|
||||
def _is_new_api_user_upstream(upstream: Upstream) -> bool:
|
||||
auth_config = json.loads(upstream.auth_config_json or "{}")
|
||||
return (
|
||||
upstream.api_prefix.strip("/") == ""
|
||||
and (
|
||||
upstream.groups_endpoint == "/api/user/self/groups"
|
||||
or auth_config.get("login_path") == "/api/user/login"
|
||||
or bool(auth_config.get("new_api_user"))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _supports_key_generation(upstream: Upstream) -> bool:
|
||||
return _is_sub2api_upstream(upstream) or _is_new_api_user_upstream(upstream)
|
||||
|
||||
|
||||
def _ensure_group_key(
|
||||
db: Session,
|
||||
client: UpstreamClient,
|
||||
@@ -465,8 +485,8 @@ def generate_keys_by_groups(
|
||||
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
||||
if not u:
|
||||
raise HTTPException(404, "upstream not found")
|
||||
if u.api_prefix.strip("/") != "api/v1":
|
||||
raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1)")
|
||||
if not _supports_key_generation(u):
|
||||
raise HTTPException(400, "仅支持 Sub2API 或 New-API 普通账号上游生成 Key")
|
||||
|
||||
# 生成前先对账,清理远端已删除的旧 Key
|
||||
try:
|
||||
|
||||
@@ -137,9 +137,12 @@ def _cookie_matches_hostname(cookie_domain: str, hostname: str) -> bool:
|
||||
"""判断 cookie domain 是否适用于给定 hostname。
|
||||
|
||||
支持带点前缀的 domain(如 `.saki.lat` 匹配 `api.saki.lat`)。
|
||||
注意:hostname 为空时,调用方应跳过 cookie 收集而不是调用此函数。
|
||||
"""
|
||||
if not cookie_domain or not hostname:
|
||||
return True # 无 domain 限制时视为全域
|
||||
if not cookie_domain:
|
||||
return True # 无 domain 限制的 cookie 对当前域有效
|
||||
if not hostname:
|
||||
return False # 无法确定当前域,保守拒绝
|
||||
# 去掉前缀点
|
||||
domain = cookie_domain.lstrip(".")
|
||||
return hostname == domain or hostname.endswith("." + domain)
|
||||
@@ -153,14 +156,22 @@ def _build_cookie_bundle(
|
||||
|
||||
返回 (cookie_string, cookie_names_list)。
|
||||
cookie_string 格式:name1=value1; name2=value2; ...
|
||||
过滤掉空值 cookie。
|
||||
过滤掉空值 cookie。若 page_url 为空或无法解析 hostname,返回空结果
|
||||
(不收集全域 cookie 以防误写入无关域凭证)。
|
||||
"""
|
||||
if not page_url:
|
||||
logger.debug("_build_cookie_bundle: no page_url, skipping cookie collection")
|
||||
return "", []
|
||||
|
||||
hostname = ""
|
||||
if page_url:
|
||||
try:
|
||||
hostname = urlparse(page_url).hostname or ""
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
hostname = urlparse(page_url).hostname or ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not hostname:
|
||||
logger.debug("_build_cookie_bundle: cannot parse hostname from %s, skipping", page_url[:80])
|
||||
return "", []
|
||||
|
||||
parts: list[str] = []
|
||||
names: list[str] = []
|
||||
@@ -170,7 +181,7 @@ def _build_cookie_bundle(
|
||||
domain = c.get("domain", "")
|
||||
if not name or not value:
|
||||
continue
|
||||
if hostname and not _cookie_matches_hostname(domain, hostname):
|
||||
if not _cookie_matches_hostname(domain, hostname):
|
||||
continue
|
||||
parts.append(f"{name}={value}")
|
||||
names.append(name)
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
"""One-time credential import sessions for real-browser auth capture."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.services.auth_capture_service import _curate_candidates, _find_new_api_user
|
||||
|
||||
|
||||
IMPORT_SESSION_TTL_SECONDS = 600
|
||||
|
||||
|
||||
class ImportSessionError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowserImportSession:
|
||||
id: str
|
||||
secret_hash: str
|
||||
target_url: str
|
||||
created_by: str
|
||||
expires_at: float
|
||||
payload: dict[str, Any] | None = None
|
||||
consumed: bool = False
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
def _hash_secret(secret: str) -> str:
|
||||
return hashlib.sha256(secret.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _normalize_storage(value: Any) -> dict[str, str]:
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
result: dict[str, str] = {}
|
||||
for key, item in value.items():
|
||||
if item is None:
|
||||
continue
|
||||
result[str(key)] = item if isinstance(item, str) else str(item)
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_cookies(value: Any) -> list[dict[str, Any]]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
result: list[dict[str, Any]] = []
|
||||
for item in value:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
name = str(item.get("name") or "").strip()
|
||||
cookie_value = str(item.get("value") or "")
|
||||
if not name or not cookie_value:
|
||||
continue
|
||||
result.append({
|
||||
"name": name,
|
||||
"value": cookie_value,
|
||||
"domain": str(item.get("domain") or ""),
|
||||
"path": str(item.get("path") or "/"),
|
||||
"httpOnly": bool(item.get("httpOnly", False)),
|
||||
"secure": bool(item.get("secure", False)),
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_headers(value: Any) -> list[dict[str, str]]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
result: list[dict[str, str]] = []
|
||||
for item in value:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
header_value = str(item.get("value") or "").strip()
|
||||
if not header_value:
|
||||
continue
|
||||
result.append({
|
||||
"type": str(item.get("type") or "authorization"),
|
||||
"value": header_value,
|
||||
"url": str(item.get("url") or ""),
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def build_import_result(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert extension-submitted payload into auth-capture result shape."""
|
||||
page_url = str(payload.get("page_url") or payload.get("url") or "")
|
||||
cookies = _normalize_cookies(payload.get("cookies"))
|
||||
storage = _normalize_storage(payload.get("local_storage") or payload.get("storage"))
|
||||
session_storage = _normalize_storage(payload.get("session_storage"))
|
||||
auth_headers = _normalize_headers(payload.get("auth_headers"))
|
||||
new_api_user = _find_new_api_user(storage, session_storage)
|
||||
candidates = _curate_candidates(
|
||||
cookies=cookies,
|
||||
local_storage=storage,
|
||||
session_storage=session_storage,
|
||||
auth_headers=auth_headers,
|
||||
new_api_user=new_api_user,
|
||||
page_url=page_url,
|
||||
)
|
||||
return {
|
||||
"cookies": cookies,
|
||||
"storage": storage,
|
||||
"session_storage": session_storage,
|
||||
"auth_headers": auth_headers,
|
||||
"candidates": candidates,
|
||||
}
|
||||
|
||||
|
||||
class BrowserImportService:
|
||||
def __init__(self) -> None:
|
||||
self._sessions: dict[str, BrowserImportSession] = {}
|
||||
|
||||
def create(self, target_url: str, created_by: str) -> tuple[BrowserImportSession, str]:
|
||||
self.cleanup()
|
||||
session_id = secrets.token_urlsafe(18)
|
||||
secret = secrets.token_urlsafe(24)
|
||||
session = BrowserImportSession(
|
||||
id=session_id,
|
||||
secret_hash=_hash_secret(secret),
|
||||
target_url=target_url,
|
||||
created_by=created_by,
|
||||
expires_at=time.time() + IMPORT_SESSION_TTL_SECONDS,
|
||||
)
|
||||
self._sessions[session_id] = session
|
||||
return session, secret
|
||||
|
||||
def get(self, session_id: str, created_by: str | None = None) -> BrowserImportSession:
|
||||
self.cleanup()
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
raise ImportSessionError("import session not found")
|
||||
if created_by is not None and session.created_by != created_by:
|
||||
raise ImportSessionError("import session not found")
|
||||
if session.expires_at <= time.time():
|
||||
self._sessions.pop(session_id, None)
|
||||
raise ImportSessionError("import session expired")
|
||||
return session
|
||||
|
||||
def submit(self, session_id: str, secret: str, payload: dict[str, Any]) -> BrowserImportSession:
|
||||
session = self.get(session_id)
|
||||
if session.consumed:
|
||||
raise ImportSessionError("import session already consumed")
|
||||
if not secrets.compare_digest(session.secret_hash, _hash_secret(secret)):
|
||||
raise ImportSessionError("invalid import secret")
|
||||
session.payload = payload
|
||||
session.consumed = True
|
||||
return session
|
||||
|
||||
def cleanup(self) -> None:
|
||||
now = time.time()
|
||||
expired = [sid for sid, session in self._sessions.items() if session.expires_at <= now]
|
||||
for sid in expired:
|
||||
self._sessions.pop(sid, None)
|
||||
|
||||
|
||||
browser_imports = BrowserImportService()
|
||||
@@ -68,6 +68,34 @@ class BrowserSessionService:
|
||||
self._last_event_at: dict[str, float] = {}
|
||||
self._evict_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
def _browser_launch_kwargs(self, width: int, height: int) -> dict[str, Any]:
|
||||
return {
|
||||
"headless": get_settings().browser_headless,
|
||||
"viewport": {"width": width, "height": height},
|
||||
"color_scheme": "dark",
|
||||
"locale": "zh-CN",
|
||||
"timezone_id": get_settings().tz,
|
||||
"ignore_default_args": ["--enable-automation"],
|
||||
"args": [
|
||||
"--no-sandbox",
|
||||
"--disable-dev-shm-usage",
|
||||
"--disable-blink-features=AutomationControlled",
|
||||
"--window-size=%d,%d" % (width, height),
|
||||
],
|
||||
}
|
||||
|
||||
async def _install_browser_init_scripts(self, context: Any) -> None:
|
||||
await context.add_init_script("""
|
||||
(() => {
|
||||
try {
|
||||
Object.defineProperty(navigator, 'webdriver', { get: () => undefined });
|
||||
Object.defineProperty(navigator, 'languages', { get: () => ['zh-CN', 'zh', 'en-US', 'en'] });
|
||||
Object.defineProperty(navigator, 'plugins', { get: () => [1, 2, 3, 4, 5] });
|
||||
window.chrome = window.chrome || { runtime: {} };
|
||||
} catch (_) {}
|
||||
})();
|
||||
""")
|
||||
|
||||
async def create(
|
||||
self,
|
||||
custom_page_id: int,
|
||||
@@ -113,11 +141,9 @@ class BrowserSessionService:
|
||||
|
||||
context = await self._playwright.chromium.launch_persistent_context(
|
||||
str(self._profile_dir(profile_key)),
|
||||
headless=get_settings().browser_headless,
|
||||
viewport={"width": width, "height": height},
|
||||
color_scheme="dark",
|
||||
args=["--no-sandbox", "--disable-dev-shm-usage"],
|
||||
**self._browser_launch_kwargs(width, height),
|
||||
)
|
||||
await self._install_browser_init_scripts(context)
|
||||
await self._restore_session_state(context, profile_key)
|
||||
# Grant clipboard access for the page origin
|
||||
try:
|
||||
@@ -773,11 +799,9 @@ class BrowserSessionService:
|
||||
profile_key = f"auth-capture-{session_id[:12]}"
|
||||
context = await self._playwright.chromium.launch_persistent_context(
|
||||
str(self._profile_dir(profile_key)),
|
||||
headless=get_settings().browser_headless,
|
||||
viewport={"width": width, "height": height},
|
||||
color_scheme="dark",
|
||||
args=["--no-sandbox", "--disable-dev-shm-usage"],
|
||||
**self._browser_launch_kwargs(width, height),
|
||||
)
|
||||
await self._install_browser_init_scripts(context)
|
||||
# Grant clipboard access for the page origin
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urljoin
|
||||
import httpx
|
||||
@@ -13,6 +14,9 @@ class UpstreamError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
NEW_API_DEFAULT_QUOTA_PER_UNIT = 500000
|
||||
|
||||
|
||||
def _find_token(value: Any) -> str:
|
||||
if isinstance(value, str) and value.count(".") >= 2:
|
||||
return value
|
||||
@@ -74,6 +78,8 @@ def mask_secret(value: Any) -> str:
|
||||
def _unwrap_data(value: Any) -> Any:
|
||||
if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value):
|
||||
return value.get("data")
|
||||
if isinstance(value, dict) and "data" in value and "success" in value:
|
||||
return value.get("data")
|
||||
return value
|
||||
|
||||
|
||||
@@ -105,6 +111,20 @@ def _extract_key_value(value: Any) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _is_success_response(value: Any) -> bool:
|
||||
if not isinstance(value, dict) or "success" not in value:
|
||||
return True
|
||||
return value.get("success") is True
|
||||
|
||||
|
||||
def _response_message(value: Any, fallback: str = "") -> str:
|
||||
if isinstance(value, dict):
|
||||
msg = value.get("message") or value.get("detail")
|
||||
if msg:
|
||||
return str(msg)
|
||||
return fallback
|
||||
|
||||
|
||||
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
|
||||
def _normalize(lst: list) -> list[dict[str, Any]]:
|
||||
out = []
|
||||
@@ -288,6 +308,17 @@ class UpstreamClient:
|
||||
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
|
||||
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
|
||||
|
||||
def _is_new_api_user_mode(self) -> bool:
|
||||
login_path = str(self.auth_config.get("login_path") or "")
|
||||
return (
|
||||
self.api_prefix == ""
|
||||
and (
|
||||
bool(self.auth_config.get("new_api_user"))
|
||||
or login_path == "/api/user/login"
|
||||
or self.auth_type == "cookie"
|
||||
)
|
||||
)
|
||||
|
||||
def _headers(self, auth: bool = True) -> dict[str, str]:
|
||||
headers: dict[str, str] = {
|
||||
"Accept": "application/json",
|
||||
@@ -348,6 +379,119 @@ class UpstreamClient:
|
||||
raise UpstreamError(f"{method} {path} returned HTML, not JSON")
|
||||
return resp.json()
|
||||
|
||||
def _ensure_api_success(self, payload: Any, action: str) -> None:
|
||||
if not _is_success_response(payload):
|
||||
raise UpstreamError(_response_message(payload, f"{action} failed"))
|
||||
|
||||
def _new_api_quota_per_unit(self) -> int:
|
||||
try:
|
||||
payload = self._request("GET", "/api/status", auth=False)
|
||||
data = _unwrap_data(payload)
|
||||
if isinstance(data, dict):
|
||||
value = data.get("quota_per_unit")
|
||||
quota_per_unit = int(float(value))
|
||||
if quota_per_unit > 0:
|
||||
return quota_per_unit
|
||||
except Exception:
|
||||
pass
|
||||
return NEW_API_DEFAULT_QUOTA_PER_UNIT
|
||||
|
||||
@staticmethod
|
||||
def _normalize_key_record(record: dict[str, Any]) -> dict[str, Any]:
|
||||
out = dict(record)
|
||||
if not out.get("name") and out.get("key_name"):
|
||||
out["name"] = out.get("key_name")
|
||||
if not out.get("group_id") and out.get("group"):
|
||||
out["group_id"] = str(out.get("group"))
|
||||
if not out.get("group_name") and out.get("group"):
|
||||
out["group_name"] = str(out.get("group"))
|
||||
return out
|
||||
|
||||
def _list_new_api_tokens(
|
||||
self,
|
||||
search: str = "",
|
||||
group_id: str | int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if search:
|
||||
path = "/api/token/search"
|
||||
params = {"keyword": search, "token": "", "p": 1, "size": 100}
|
||||
else:
|
||||
path = "/api/token/"
|
||||
params = {"p": 1, "size": 100}
|
||||
resp = self._client.request(
|
||||
"GET",
|
||||
self._url(path),
|
||||
params=params,
|
||||
headers=self._headers(),
|
||||
cookies=self._cookies,
|
||||
)
|
||||
self._cookies.update(dict(resp.cookies))
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._ensure_api_success(data, "list New-API tokens")
|
||||
nested = _unwrap_data(data)
|
||||
items: list[dict[str, Any]] | None = None
|
||||
if isinstance(nested, list):
|
||||
items = [i for i in nested if isinstance(i, dict)]
|
||||
elif isinstance(nested, dict):
|
||||
for key in ("items", "tokens", "list", "records"):
|
||||
value = nested.get(key)
|
||||
if isinstance(value, list):
|
||||
items = [i for i in value if isinstance(i, dict)]
|
||||
break
|
||||
if items is None:
|
||||
raise UpstreamError("unexpected New-API token list response")
|
||||
normalized = [self._normalize_key_record(i) for i in items]
|
||||
if group_id is not None:
|
||||
gid = str(group_id)
|
||||
normalized = [i for i in normalized if str(i.get("group_id") or i.get("group") or "") == gid]
|
||||
return normalized
|
||||
|
||||
def _get_new_api_token_key(self, token_id: str | int) -> str:
|
||||
payload = self._request("POST", f"/api/token/{token_id}/key")
|
||||
self._ensure_api_success(payload, "get New-API token key")
|
||||
key_value = _extract_key_value(_unwrap_data(payload))
|
||||
if not key_value:
|
||||
raise UpstreamError("New-API token key response did not include key")
|
||||
return key_value
|
||||
|
||||
def _create_new_api_token(
|
||||
self,
|
||||
name: str,
|
||||
group_id: str | int,
|
||||
quota: float = 0,
|
||||
expires_in_days: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
unlimited = quota <= 0
|
||||
body: dict[str, Any] = {
|
||||
"name": name,
|
||||
"remain_quota": 0 if unlimited else int(round(quota * self._new_api_quota_per_unit())),
|
||||
"unlimited_quota": unlimited,
|
||||
"expired_time": int(time.time()) + expires_in_days * 86400 if expires_in_days else -1,
|
||||
"model_limits_enabled": False,
|
||||
"model_limits": "",
|
||||
"allow_ips": "",
|
||||
"group": str(group_id),
|
||||
"cross_group_retry": False,
|
||||
}
|
||||
payload = self._request("POST", "/api/token/", body)
|
||||
self._ensure_api_success(payload, "create New-API token")
|
||||
|
||||
matches = self._list_new_api_tokens(search=name, group_id=group_id)
|
||||
token = next((i for i in matches if str(i.get("name") or "").strip() == name.strip()), None)
|
||||
if not token:
|
||||
raise UpstreamError("New-API token was created but could not be found by name")
|
||||
token_id = token.get("id")
|
||||
if token_id is None:
|
||||
raise UpstreamError("New-API token list response did not include id")
|
||||
key_value = self._get_new_api_token_key(token_id)
|
||||
return {
|
||||
"id": str(token_id),
|
||||
"key": key_value,
|
||||
"masked_key": mask_secret(key_value),
|
||||
"raw": self._normalize_key_record(token),
|
||||
}
|
||||
|
||||
def login(self) -> None:
|
||||
if self.auth_type != "login_password":
|
||||
return
|
||||
@@ -412,6 +556,9 @@ class UpstreamClient:
|
||||
endpoint: str = "/keys",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。"""
|
||||
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
|
||||
return self._list_new_api_tokens(search=search, group_id=group_id)
|
||||
|
||||
params: dict[str, Any] = {}
|
||||
if search:
|
||||
params["search"] = search
|
||||
@@ -446,7 +593,7 @@ class UpstreamClient:
|
||||
for key in ("items", "keys", "list", "records"):
|
||||
val = data.get(key)
|
||||
if isinstance(val, list):
|
||||
return val
|
||||
return [self._normalize_key_record(i) for i in val if isinstance(i, dict)]
|
||||
raise UpstreamError(f"unexpected keys response type: {type(data).__name__}")
|
||||
|
||||
def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None:
|
||||
@@ -486,6 +633,14 @@ class UpstreamClient:
|
||||
rate_limit_7d: float = 0,
|
||||
endpoint: str = "/keys",
|
||||
) -> dict[str, Any]:
|
||||
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
|
||||
return self._create_new_api_token(
|
||||
name,
|
||||
group_id,
|
||||
quota=quota,
|
||||
expires_in_days=expires_in_days,
|
||||
)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"name": name,
|
||||
"group_id": int(group_id) if str(group_id).isdigit() else group_id,
|
||||
|
||||
Reference in New Issue
Block a user