feat: support real browser auth import

This commit is contained in:
liumangmang
2026-06-02 13:51:29 +08:00
parent f4d16a4c01
commit 84148f4a69
22 changed files with 1651 additions and 111 deletions
+93
View File
@@ -10,6 +10,11 @@ from sqlalchemy.orm import Session
from app.database import get_db
from app.services.auth_capture_service import extract_all
from app.services.browser_import_service import (
ImportSessionError,
browser_imports,
build_import_result,
)
from app.services.browser_session_service import (
BrowserDependencyError,
BrowserSessionError,
@@ -43,6 +48,32 @@ class CaptureExtractResponse(BaseModel):
candidates: list[dict] = []
class ImportSessionCreate(BaseModel):
target_url: str = Field(..., description="Target page URL opened in the user's real browser")
class ImportSessionCreateResponse(BaseModel):
session_id: str
secret: str
expires_in_seconds: int
class ImportSessionStatusResponse(BaseModel):
session_id: str
ready: bool = False
expires_at: float
result: Optional[CaptureExtractResponse] = None
class ImportSessionSubmit(BaseModel):
secret: str
page_url: str = ""
cookies: list[dict] = []
local_storage: dict[str, Any] = {}
session_storage: dict[str, Any] = {}
auth_headers: list[dict] = []
def _sanitize_candidate(candidate: dict[str, Any]) -> dict[str, Any]:
return {
key: value
@@ -134,3 +165,65 @@ async def close_capture_session(
await browser_sessions.close(session_id)
except Exception as exc:
raise _browser_error(exc)
@router.post("/import-sessions", response_model=ImportSessionCreateResponse, status_code=201)
async def create_import_session(
body: ImportSessionCreate,
user=Depends(get_current_user),
):
"""Create a one-time real-browser import session.
The returned secret is intended for the local browser extension only.
"""
target_url = body.target_url.strip()
if not target_url.startswith(("http://", "https://")):
raise HTTPException(400, "Only http/https URLs are allowed")
session, secret = browser_imports.create(target_url=target_url, created_by=user.email)
return ImportSessionCreateResponse(
session_id=session.id,
secret=secret,
expires_in_seconds=600,
)
@router.get("/import-sessions/{session_id}", response_model=ImportSessionStatusResponse)
async def get_import_session(
session_id: str,
include_raw: bool = Query(default=False),
user=Depends(get_current_user),
):
try:
session = browser_imports.get(session_id, created_by=user.email)
except ImportSessionError as exc:
raise HTTPException(404, str(exc))
result = None
if session.payload is not None:
full_result = build_import_result(session.payload)
if not include_raw:
candidates = [_sanitize_candidate(candidate) for candidate in full_result.get("candidates", [])]
result = CaptureExtractResponse(candidates=candidates)
else:
result = CaptureExtractResponse(**full_result)
return ImportSessionStatusResponse(
session_id=session.id,
ready=session.payload is not None,
expires_at=session.expires_at,
result=result,
)
@router.post("/import-sessions/{session_id}/submit", status_code=204)
async def submit_import_session(
session_id: str,
body: ImportSessionSubmit,
):
try:
browser_imports.submit(
session_id=session_id,
secret=body.secret,
payload=body.model_dump(exclude={"secret"}),
)
except ImportSessionError as exc:
raise HTTPException(400, str(exc))
+44 -14
View File
@@ -221,25 +221,45 @@ class RefreshAuthResponse(BaseModel):
warning: Optional[str] = None
def _pick_best_candidate(candidates: list[dict], preferred_auth_type: str) -> Optional[dict]:
def _norm_path(value: Any) -> str:
return str(value or "").strip().rstrip("/")
def _detect_upstream_platform(upstream: Upstream, auth_config: dict) -> str:
api_prefix = _norm_path(upstream.api_prefix)
groups_endpoint = _norm_path(upstream.groups_endpoint)
rate_endpoint = _norm_path(upstream.rate_endpoint)
login_path = _norm_path(auth_config.get("login_path"))
if groups_endpoint == "/api/user/self/groups" or login_path == "/api/user/login":
return "new-api-user"
if api_prefix == "/api/v1" or groups_endpoint in {"/groups/available", "/groups/rates"} or login_path == "/auth/login":
return "sub2api"
return "unknown"
def _first_candidate(candidates: list[dict], *types: str) -> Optional[dict]:
for c in candidates:
if c.get("type") in types:
return c
return None
def _pick_best_candidate(candidates: list[dict], preferred_auth_type: str, platform: str = "unknown") -> Optional[dict]:
if not candidates:
return None
# cookie_bundle > cookie > bearer_token > api_key
# preferred_auth_type="cookie" 时优先匹配 bundle,其次单 cookie
if platform == "sub2api":
return _first_candidate(candidates, "bearer_token", "api_key")
if platform == "new-api-user":
return _first_candidate(candidates, "cookie_bundle", "cookie", "bearer_token", "api_key")
if preferred_auth_type == "cookie":
for c in candidates:
if c["type"] == "cookie_bundle":
return c
for c in candidates:
if c["type"] == "cookie":
return c
return _first_candidate(candidates, "cookie_bundle", "cookie")
elif preferred_auth_type in ("bearer", "api_key"):
type_map = {"bearer": "bearer_token", "api_key": "api_key"}
preferred = type_map.get(preferred_auth_type)
if preferred:
for c in candidates:
if c["type"] == preferred:
return c
return _first_candidate(candidates, preferred)
# fallback:排序后取第一个
return candidates[0]
@@ -268,11 +288,17 @@ async def refresh_auth(pid: int, db: Session = Depends(get_db), _=Depends(get_cu
return RefreshAuthResponse(success=False, message=f"提取失败: {exc}")
candidates = result.get("candidates", [])
candidate = _pick_best_candidate(candidates, upstream.auth_type)
existing_config = _json.loads(upstream.auth_config_json or "{}")
platform = _detect_upstream_platform(upstream, existing_config)
candidate = _pick_best_candidate(candidates, upstream.auth_type, platform)
if not candidate:
if platform == "sub2api" and _first_candidate(candidates, "cookie_bundle", "cookie"):
return RefreshAuthResponse(
success=False,
message="Sub2API 需要 Bearer Token;当前只提取到 Cookie。请在远程浏览器完成登录后刷新页面或触发一次接口请求,再重新提取。",
)
return RefreshAuthResponse(success=False, message="未提取到有效凭证,请确认已在远程浏览器中登录")
existing_config = _json.loads(upstream.auth_config_json or "{}")
ctype = candidate["type"]
if ctype in ("cookie_bundle", "cookie"):
@@ -281,6 +307,10 @@ async def refresh_auth(pid: int, db: Session = Depends(get_db), _=Depends(get_cu
existing_config["cookie_string"] = candidate.get("value", "")
if candidate.get("new_api_user"):
existing_config["new_api_user"] = candidate["new_api_user"]
if platform == "new-api-user":
upstream.api_prefix = ""
upstream.groups_endpoint = "/api/user/self/groups"
upstream.rate_endpoint = "/api/user/self/groups"
elif ctype == "bearer_token":
upstream.auth_type = "bearer"
raw = candidate.get("value", "")
+22 -2
View File
@@ -315,6 +315,26 @@ def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_c
_generate_key_lock = __import__("threading").Lock()
def _is_sub2api_upstream(upstream: Upstream) -> bool:
return upstream.api_prefix.strip("/") == "api/v1"
def _is_new_api_user_upstream(upstream: Upstream) -> bool:
auth_config = json.loads(upstream.auth_config_json or "{}")
return (
upstream.api_prefix.strip("/") == ""
and (
upstream.groups_endpoint == "/api/user/self/groups"
or auth_config.get("login_path") == "/api/user/login"
or bool(auth_config.get("new_api_user"))
)
)
def _supports_key_generation(upstream: Upstream) -> bool:
return _is_sub2api_upstream(upstream) or _is_new_api_user_upstream(upstream)
def _ensure_group_key(
db: Session,
client: UpstreamClient,
@@ -465,8 +485,8 @@ def generate_keys_by_groups(
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
if u.api_prefix.strip("/") != "api/v1":
raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1")
if not _supports_key_generation(u):
raise HTTPException(400, "仅支持 Sub2API 或 New-API 普通账号上游生成 Key")
# 生成前先对账,清理远端已删除的旧 Key
try:
+20 -9
View File
@@ -137,9 +137,12 @@ def _cookie_matches_hostname(cookie_domain: str, hostname: str) -> bool:
"""判断 cookie domain 是否适用于给定 hostname。
支持带点前缀的 domain(如 `.saki.lat` 匹配 `api.saki.lat`)。
注意:hostname 为空时,调用方应跳过 cookie 收集而不是调用此函数。
"""
if not cookie_domain or not hostname:
return True # 无 domain 限制时视为全域
if not cookie_domain:
return True # 无 domain 限制的 cookie 对当前域有效
if not hostname:
return False # 无法确定当前域,保守拒绝
# 去掉前缀点
domain = cookie_domain.lstrip(".")
return hostname == domain or hostname.endswith("." + domain)
@@ -153,14 +156,22 @@ def _build_cookie_bundle(
返回 (cookie_string, cookie_names_list)。
cookie_string 格式:name1=value1; name2=value2; ...
过滤掉空值 cookie。
过滤掉空值 cookie。若 page_url 为空或无法解析 hostname,返回空结果
(不收集全域 cookie 以防误写入无关域凭证)。
"""
if not page_url:
logger.debug("_build_cookie_bundle: no page_url, skipping cookie collection")
return "", []
hostname = ""
if page_url:
try:
hostname = urlparse(page_url).hostname or ""
except Exception:
pass
try:
hostname = urlparse(page_url).hostname or ""
except Exception:
pass
if not hostname:
logger.debug("_build_cookie_bundle: cannot parse hostname from %s, skipping", page_url[:80])
return "", []
parts: list[str] = []
names: list[str] = []
@@ -170,7 +181,7 @@ def _build_cookie_bundle(
domain = c.get("domain", "")
if not name or not value:
continue
if hostname and not _cookie_matches_hostname(domain, hostname):
if not _cookie_matches_hostname(domain, hostname):
continue
parts.append(f"{name}={value}")
names.append(name)
@@ -0,0 +1,159 @@
"""One-time credential import sessions for real-browser auth capture."""
from __future__ import annotations
import hashlib
import secrets
import time
from dataclasses import dataclass, field
from typing import Any
from app.services.auth_capture_service import _curate_candidates, _find_new_api_user
IMPORT_SESSION_TTL_SECONDS = 600
class ImportSessionError(RuntimeError):
pass
@dataclass
class BrowserImportSession:
id: str
secret_hash: str
target_url: str
created_by: str
expires_at: float
payload: dict[str, Any] | None = None
consumed: bool = False
created_at: float = field(default_factory=time.time)
def _hash_secret(secret: str) -> str:
return hashlib.sha256(secret.encode("utf-8")).hexdigest()
def _normalize_storage(value: Any) -> dict[str, str]:
if not isinstance(value, dict):
return {}
result: dict[str, str] = {}
for key, item in value.items():
if item is None:
continue
result[str(key)] = item if isinstance(item, str) else str(item)
return result
def _normalize_cookies(value: Any) -> list[dict[str, Any]]:
if not isinstance(value, list):
return []
result: list[dict[str, Any]] = []
for item in value:
if not isinstance(item, dict):
continue
name = str(item.get("name") or "").strip()
cookie_value = str(item.get("value") or "")
if not name or not cookie_value:
continue
result.append({
"name": name,
"value": cookie_value,
"domain": str(item.get("domain") or ""),
"path": str(item.get("path") or "/"),
"httpOnly": bool(item.get("httpOnly", False)),
"secure": bool(item.get("secure", False)),
})
return result
def _normalize_headers(value: Any) -> list[dict[str, str]]:
if not isinstance(value, list):
return []
result: list[dict[str, str]] = []
for item in value:
if not isinstance(item, dict):
continue
header_value = str(item.get("value") or "").strip()
if not header_value:
continue
result.append({
"type": str(item.get("type") or "authorization"),
"value": header_value,
"url": str(item.get("url") or ""),
})
return result
def build_import_result(payload: dict[str, Any]) -> dict[str, Any]:
"""Convert extension-submitted payload into auth-capture result shape."""
page_url = str(payload.get("page_url") or payload.get("url") or "")
cookies = _normalize_cookies(payload.get("cookies"))
storage = _normalize_storage(payload.get("local_storage") or payload.get("storage"))
session_storage = _normalize_storage(payload.get("session_storage"))
auth_headers = _normalize_headers(payload.get("auth_headers"))
new_api_user = _find_new_api_user(storage, session_storage)
candidates = _curate_candidates(
cookies=cookies,
local_storage=storage,
session_storage=session_storage,
auth_headers=auth_headers,
new_api_user=new_api_user,
page_url=page_url,
)
return {
"cookies": cookies,
"storage": storage,
"session_storage": session_storage,
"auth_headers": auth_headers,
"candidates": candidates,
}
class BrowserImportService:
def __init__(self) -> None:
self._sessions: dict[str, BrowserImportSession] = {}
def create(self, target_url: str, created_by: str) -> tuple[BrowserImportSession, str]:
self.cleanup()
session_id = secrets.token_urlsafe(18)
secret = secrets.token_urlsafe(24)
session = BrowserImportSession(
id=session_id,
secret_hash=_hash_secret(secret),
target_url=target_url,
created_by=created_by,
expires_at=time.time() + IMPORT_SESSION_TTL_SECONDS,
)
self._sessions[session_id] = session
return session, secret
def get(self, session_id: str, created_by: str | None = None) -> BrowserImportSession:
self.cleanup()
session = self._sessions.get(session_id)
if not session:
raise ImportSessionError("import session not found")
if created_by is not None and session.created_by != created_by:
raise ImportSessionError("import session not found")
if session.expires_at <= time.time():
self._sessions.pop(session_id, None)
raise ImportSessionError("import session expired")
return session
def submit(self, session_id: str, secret: str, payload: dict[str, Any]) -> BrowserImportSession:
session = self.get(session_id)
if session.consumed:
raise ImportSessionError("import session already consumed")
if not secrets.compare_digest(session.secret_hash, _hash_secret(secret)):
raise ImportSessionError("invalid import secret")
session.payload = payload
session.consumed = True
return session
def cleanup(self) -> None:
now = time.time()
expired = [sid for sid, session in self._sessions.items() if session.expires_at <= now]
for sid in expired:
self._sessions.pop(sid, None)
browser_imports = BrowserImportService()
@@ -68,6 +68,34 @@ class BrowserSessionService:
self._last_event_at: dict[str, float] = {}
self._evict_task: Optional[asyncio.Task[None]] = None
def _browser_launch_kwargs(self, width: int, height: int) -> dict[str, Any]:
return {
"headless": get_settings().browser_headless,
"viewport": {"width": width, "height": height},
"color_scheme": "dark",
"locale": "zh-CN",
"timezone_id": get_settings().tz,
"ignore_default_args": ["--enable-automation"],
"args": [
"--no-sandbox",
"--disable-dev-shm-usage",
"--disable-blink-features=AutomationControlled",
"--window-size=%d,%d" % (width, height),
],
}
async def _install_browser_init_scripts(self, context: Any) -> None:
await context.add_init_script("""
(() => {
try {
Object.defineProperty(navigator, 'webdriver', { get: () => undefined });
Object.defineProperty(navigator, 'languages', { get: () => ['zh-CN', 'zh', 'en-US', 'en'] });
Object.defineProperty(navigator, 'plugins', { get: () => [1, 2, 3, 4, 5] });
window.chrome = window.chrome || { runtime: {} };
} catch (_) {}
})();
""")
async def create(
self,
custom_page_id: int,
@@ -113,11 +141,9 @@ class BrowserSessionService:
context = await self._playwright.chromium.launch_persistent_context(
str(self._profile_dir(profile_key)),
headless=get_settings().browser_headless,
viewport={"width": width, "height": height},
color_scheme="dark",
args=["--no-sandbox", "--disable-dev-shm-usage"],
**self._browser_launch_kwargs(width, height),
)
await self._install_browser_init_scripts(context)
await self._restore_session_state(context, profile_key)
# Grant clipboard access for the page origin
try:
@@ -773,11 +799,9 @@ class BrowserSessionService:
profile_key = f"auth-capture-{session_id[:12]}"
context = await self._playwright.chromium.launch_persistent_context(
str(self._profile_dir(profile_key)),
headless=get_settings().browser_headless,
viewport={"width": width, "height": height},
color_scheme="dark",
args=["--no-sandbox", "--disable-dev-shm-usage"],
**self._browser_launch_kwargs(width, height),
)
await self._install_browser_init_scripts(context)
# Grant clipboard access for the page origin
try:
parsed = urlparse(url)
+156 -1
View File
@@ -2,6 +2,7 @@
from __future__ import annotations
import json
import time
from typing import Any, Optional
from urllib.parse import urljoin
import httpx
@@ -13,6 +14,9 @@ class UpstreamError(RuntimeError):
pass
NEW_API_DEFAULT_QUOTA_PER_UNIT = 500000
def _find_token(value: Any) -> str:
if isinstance(value, str) and value.count(".") >= 2:
return value
@@ -74,6 +78,8 @@ def mask_secret(value: Any) -> str:
def _unwrap_data(value: Any) -> Any:
if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value):
return value.get("data")
if isinstance(value, dict) and "data" in value and "success" in value:
return value.get("data")
return value
@@ -105,6 +111,20 @@ def _extract_key_value(value: Any) -> str:
return ""
def _is_success_response(value: Any) -> bool:
if not isinstance(value, dict) or "success" not in value:
return True
return value.get("success") is True
def _response_message(value: Any, fallback: str = "") -> str:
if isinstance(value, dict):
msg = value.get("message") or value.get("detail")
if msg:
return str(msg)
return fallback
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
def _normalize(lst: list) -> list[dict[str, Any]]:
out = []
@@ -288,6 +308,17 @@ class UpstreamClient:
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
def _is_new_api_user_mode(self) -> bool:
login_path = str(self.auth_config.get("login_path") or "")
return (
self.api_prefix == ""
and (
bool(self.auth_config.get("new_api_user"))
or login_path == "/api/user/login"
or self.auth_type == "cookie"
)
)
def _headers(self, auth: bool = True) -> dict[str, str]:
headers: dict[str, str] = {
"Accept": "application/json",
@@ -348,6 +379,119 @@ class UpstreamClient:
raise UpstreamError(f"{method} {path} returned HTML, not JSON")
return resp.json()
def _ensure_api_success(self, payload: Any, action: str) -> None:
if not _is_success_response(payload):
raise UpstreamError(_response_message(payload, f"{action} failed"))
def _new_api_quota_per_unit(self) -> int:
try:
payload = self._request("GET", "/api/status", auth=False)
data = _unwrap_data(payload)
if isinstance(data, dict):
value = data.get("quota_per_unit")
quota_per_unit = int(float(value))
if quota_per_unit > 0:
return quota_per_unit
except Exception:
pass
return NEW_API_DEFAULT_QUOTA_PER_UNIT
@staticmethod
def _normalize_key_record(record: dict[str, Any]) -> dict[str, Any]:
out = dict(record)
if not out.get("name") and out.get("key_name"):
out["name"] = out.get("key_name")
if not out.get("group_id") and out.get("group"):
out["group_id"] = str(out.get("group"))
if not out.get("group_name") and out.get("group"):
out["group_name"] = str(out.get("group"))
return out
def _list_new_api_tokens(
self,
search: str = "",
group_id: str | int | None = None,
) -> list[dict[str, Any]]:
if search:
path = "/api/token/search"
params = {"keyword": search, "token": "", "p": 1, "size": 100}
else:
path = "/api/token/"
params = {"p": 1, "size": 100}
resp = self._client.request(
"GET",
self._url(path),
params=params,
headers=self._headers(),
cookies=self._cookies,
)
self._cookies.update(dict(resp.cookies))
resp.raise_for_status()
data = resp.json()
self._ensure_api_success(data, "list New-API tokens")
nested = _unwrap_data(data)
items: list[dict[str, Any]] | None = None
if isinstance(nested, list):
items = [i for i in nested if isinstance(i, dict)]
elif isinstance(nested, dict):
for key in ("items", "tokens", "list", "records"):
value = nested.get(key)
if isinstance(value, list):
items = [i for i in value if isinstance(i, dict)]
break
if items is None:
raise UpstreamError("unexpected New-API token list response")
normalized = [self._normalize_key_record(i) for i in items]
if group_id is not None:
gid = str(group_id)
normalized = [i for i in normalized if str(i.get("group_id") or i.get("group") or "") == gid]
return normalized
def _get_new_api_token_key(self, token_id: str | int) -> str:
payload = self._request("POST", f"/api/token/{token_id}/key")
self._ensure_api_success(payload, "get New-API token key")
key_value = _extract_key_value(_unwrap_data(payload))
if not key_value:
raise UpstreamError("New-API token key response did not include key")
return key_value
def _create_new_api_token(
self,
name: str,
group_id: str | int,
quota: float = 0,
expires_in_days: int | None = None,
) -> dict[str, Any]:
unlimited = quota <= 0
body: dict[str, Any] = {
"name": name,
"remain_quota": 0 if unlimited else int(round(quota * self._new_api_quota_per_unit())),
"unlimited_quota": unlimited,
"expired_time": int(time.time()) + expires_in_days * 86400 if expires_in_days else -1,
"model_limits_enabled": False,
"model_limits": "",
"allow_ips": "",
"group": str(group_id),
"cross_group_retry": False,
}
payload = self._request("POST", "/api/token/", body)
self._ensure_api_success(payload, "create New-API token")
matches = self._list_new_api_tokens(search=name, group_id=group_id)
token = next((i for i in matches if str(i.get("name") or "").strip() == name.strip()), None)
if not token:
raise UpstreamError("New-API token was created but could not be found by name")
token_id = token.get("id")
if token_id is None:
raise UpstreamError("New-API token list response did not include id")
key_value = self._get_new_api_token_key(token_id)
return {
"id": str(token_id),
"key": key_value,
"masked_key": mask_secret(key_value),
"raw": self._normalize_key_record(token),
}
def login(self) -> None:
if self.auth_type != "login_password":
return
@@ -412,6 +556,9 @@ class UpstreamClient:
endpoint: str = "/keys",
) -> list[dict[str, Any]]:
"""查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。"""
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
return self._list_new_api_tokens(search=search, group_id=group_id)
params: dict[str, Any] = {}
if search:
params["search"] = search
@@ -446,7 +593,7 @@ class UpstreamClient:
for key in ("items", "keys", "list", "records"):
val = data.get(key)
if isinstance(val, list):
return val
return [self._normalize_key_record(i) for i in val if isinstance(i, dict)]
raise UpstreamError(f"unexpected keys response type: {type(data).__name__}")
def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None:
@@ -486,6 +633,14 @@ class UpstreamClient:
rate_limit_7d: float = 0,
endpoint: str = "/keys",
) -> dict[str, Any]:
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
return self._create_new_api_token(
name,
group_id,
quota=quota,
expires_in_days=expires_in_days,
)
body: dict[str, Any] = {
"name": name,
"group_id": int(group_id) if str(group_id).isdigit() else group_id,