feat: support real browser auth import

This commit is contained in:
liumangmang
2026-06-02 13:51:29 +08:00
parent f4d16a4c01
commit 84148f4a69
22 changed files with 1651 additions and 111 deletions
+20 -9
View File
@@ -137,9 +137,12 @@ def _cookie_matches_hostname(cookie_domain: str, hostname: str) -> bool:
"""判断 cookie domain 是否适用于给定 hostname。
支持带点前缀的 domain(如 `.saki.lat` 匹配 `api.saki.lat`)。
注意:hostname 为空时,调用方应跳过 cookie 收集而不是调用此函数。
"""
if not cookie_domain or not hostname:
return True # 无 domain 限制时视为全域
if not cookie_domain:
return True # 无 domain 限制的 cookie 对当前域有效
if not hostname:
return False # 无法确定当前域,保守拒绝
# 去掉前缀点
domain = cookie_domain.lstrip(".")
return hostname == domain or hostname.endswith("." + domain)
@@ -153,14 +156,22 @@ def _build_cookie_bundle(
返回 (cookie_string, cookie_names_list)。
cookie_string 格式:name1=value1; name2=value2; ...
过滤掉空值 cookie。
过滤掉空值 cookie。若 page_url 为空或无法解析 hostname,返回空结果
(不收集全域 cookie 以防误写入无关域凭证)。
"""
if not page_url:
logger.debug("_build_cookie_bundle: no page_url, skipping cookie collection")
return "", []
hostname = ""
if page_url:
try:
hostname = urlparse(page_url).hostname or ""
except Exception:
pass
try:
hostname = urlparse(page_url).hostname or ""
except Exception:
pass
if not hostname:
logger.debug("_build_cookie_bundle: cannot parse hostname from %s, skipping", page_url[:80])
return "", []
parts: list[str] = []
names: list[str] = []
@@ -170,7 +181,7 @@ def _build_cookie_bundle(
domain = c.get("domain", "")
if not name or not value:
continue
if hostname and not _cookie_matches_hostname(domain, hostname):
if not _cookie_matches_hostname(domain, hostname):
continue
parts.append(f"{name}={value}")
names.append(name)
@@ -0,0 +1,159 @@
"""One-time credential import sessions for real-browser auth capture."""
from __future__ import annotations
import hashlib
import secrets
import time
from dataclasses import dataclass, field
from typing import Any
from app.services.auth_capture_service import _curate_candidates, _find_new_api_user
IMPORT_SESSION_TTL_SECONDS = 600
class ImportSessionError(RuntimeError):
pass
@dataclass
class BrowserImportSession:
id: str
secret_hash: str
target_url: str
created_by: str
expires_at: float
payload: dict[str, Any] | None = None
consumed: bool = False
created_at: float = field(default_factory=time.time)
def _hash_secret(secret: str) -> str:
return hashlib.sha256(secret.encode("utf-8")).hexdigest()
def _normalize_storage(value: Any) -> dict[str, str]:
if not isinstance(value, dict):
return {}
result: dict[str, str] = {}
for key, item in value.items():
if item is None:
continue
result[str(key)] = item if isinstance(item, str) else str(item)
return result
def _normalize_cookies(value: Any) -> list[dict[str, Any]]:
if not isinstance(value, list):
return []
result: list[dict[str, Any]] = []
for item in value:
if not isinstance(item, dict):
continue
name = str(item.get("name") or "").strip()
cookie_value = str(item.get("value") or "")
if not name or not cookie_value:
continue
result.append({
"name": name,
"value": cookie_value,
"domain": str(item.get("domain") or ""),
"path": str(item.get("path") or "/"),
"httpOnly": bool(item.get("httpOnly", False)),
"secure": bool(item.get("secure", False)),
})
return result
def _normalize_headers(value: Any) -> list[dict[str, str]]:
if not isinstance(value, list):
return []
result: list[dict[str, str]] = []
for item in value:
if not isinstance(item, dict):
continue
header_value = str(item.get("value") or "").strip()
if not header_value:
continue
result.append({
"type": str(item.get("type") or "authorization"),
"value": header_value,
"url": str(item.get("url") or ""),
})
return result
def build_import_result(payload: dict[str, Any]) -> dict[str, Any]:
"""Convert extension-submitted payload into auth-capture result shape."""
page_url = str(payload.get("page_url") or payload.get("url") or "")
cookies = _normalize_cookies(payload.get("cookies"))
storage = _normalize_storage(payload.get("local_storage") or payload.get("storage"))
session_storage = _normalize_storage(payload.get("session_storage"))
auth_headers = _normalize_headers(payload.get("auth_headers"))
new_api_user = _find_new_api_user(storage, session_storage)
candidates = _curate_candidates(
cookies=cookies,
local_storage=storage,
session_storage=session_storage,
auth_headers=auth_headers,
new_api_user=new_api_user,
page_url=page_url,
)
return {
"cookies": cookies,
"storage": storage,
"session_storage": session_storage,
"auth_headers": auth_headers,
"candidates": candidates,
}
class BrowserImportService:
def __init__(self) -> None:
self._sessions: dict[str, BrowserImportSession] = {}
def create(self, target_url: str, created_by: str) -> tuple[BrowserImportSession, str]:
self.cleanup()
session_id = secrets.token_urlsafe(18)
secret = secrets.token_urlsafe(24)
session = BrowserImportSession(
id=session_id,
secret_hash=_hash_secret(secret),
target_url=target_url,
created_by=created_by,
expires_at=time.time() + IMPORT_SESSION_TTL_SECONDS,
)
self._sessions[session_id] = session
return session, secret
def get(self, session_id: str, created_by: str | None = None) -> BrowserImportSession:
self.cleanup()
session = self._sessions.get(session_id)
if not session:
raise ImportSessionError("import session not found")
if created_by is not None and session.created_by != created_by:
raise ImportSessionError("import session not found")
if session.expires_at <= time.time():
self._sessions.pop(session_id, None)
raise ImportSessionError("import session expired")
return session
def submit(self, session_id: str, secret: str, payload: dict[str, Any]) -> BrowserImportSession:
session = self.get(session_id)
if session.consumed:
raise ImportSessionError("import session already consumed")
if not secrets.compare_digest(session.secret_hash, _hash_secret(secret)):
raise ImportSessionError("invalid import secret")
session.payload = payload
session.consumed = True
return session
def cleanup(self) -> None:
now = time.time()
expired = [sid for sid, session in self._sessions.items() if session.expires_at <= now]
for sid in expired:
self._sessions.pop(sid, None)
browser_imports = BrowserImportService()
@@ -68,6 +68,34 @@ class BrowserSessionService:
self._last_event_at: dict[str, float] = {}
self._evict_task: Optional[asyncio.Task[None]] = None
def _browser_launch_kwargs(self, width: int, height: int) -> dict[str, Any]:
return {
"headless": get_settings().browser_headless,
"viewport": {"width": width, "height": height},
"color_scheme": "dark",
"locale": "zh-CN",
"timezone_id": get_settings().tz,
"ignore_default_args": ["--enable-automation"],
"args": [
"--no-sandbox",
"--disable-dev-shm-usage",
"--disable-blink-features=AutomationControlled",
"--window-size=%d,%d" % (width, height),
],
}
async def _install_browser_init_scripts(self, context: Any) -> None:
await context.add_init_script("""
(() => {
try {
Object.defineProperty(navigator, 'webdriver', { get: () => undefined });
Object.defineProperty(navigator, 'languages', { get: () => ['zh-CN', 'zh', 'en-US', 'en'] });
Object.defineProperty(navigator, 'plugins', { get: () => [1, 2, 3, 4, 5] });
window.chrome = window.chrome || { runtime: {} };
} catch (_) {}
})();
""")
async def create(
self,
custom_page_id: int,
@@ -113,11 +141,9 @@ class BrowserSessionService:
context = await self._playwright.chromium.launch_persistent_context(
str(self._profile_dir(profile_key)),
headless=get_settings().browser_headless,
viewport={"width": width, "height": height},
color_scheme="dark",
args=["--no-sandbox", "--disable-dev-shm-usage"],
**self._browser_launch_kwargs(width, height),
)
await self._install_browser_init_scripts(context)
await self._restore_session_state(context, profile_key)
# Grant clipboard access for the page origin
try:
@@ -773,11 +799,9 @@ class BrowserSessionService:
profile_key = f"auth-capture-{session_id[:12]}"
context = await self._playwright.chromium.launch_persistent_context(
str(self._profile_dir(profile_key)),
headless=get_settings().browser_headless,
viewport={"width": width, "height": height},
color_scheme="dark",
args=["--no-sandbox", "--disable-dev-shm-usage"],
**self._browser_launch_kwargs(width, height),
)
await self._install_browser_init_scripts(context)
# Grant clipboard access for the page origin
try:
parsed = urlparse(url)
+156 -1
View File
@@ -2,6 +2,7 @@
from __future__ import annotations
import json
import time
from typing import Any, Optional
from urllib.parse import urljoin
import httpx
@@ -13,6 +14,9 @@ class UpstreamError(RuntimeError):
pass
NEW_API_DEFAULT_QUOTA_PER_UNIT = 500000
def _find_token(value: Any) -> str:
if isinstance(value, str) and value.count(".") >= 2:
return value
@@ -74,6 +78,8 @@ def mask_secret(value: Any) -> str:
def _unwrap_data(value: Any) -> Any:
if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value):
return value.get("data")
if isinstance(value, dict) and "data" in value and "success" in value:
return value.get("data")
return value
@@ -105,6 +111,20 @@ def _extract_key_value(value: Any) -> str:
return ""
def _is_success_response(value: Any) -> bool:
if not isinstance(value, dict) or "success" not in value:
return True
return value.get("success") is True
def _response_message(value: Any, fallback: str = "") -> str:
if isinstance(value, dict):
msg = value.get("message") or value.get("detail")
if msg:
return str(msg)
return fallback
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
def _normalize(lst: list) -> list[dict[str, Any]]:
out = []
@@ -288,6 +308,17 @@ class UpstreamClient:
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
def _is_new_api_user_mode(self) -> bool:
login_path = str(self.auth_config.get("login_path") or "")
return (
self.api_prefix == ""
and (
bool(self.auth_config.get("new_api_user"))
or login_path == "/api/user/login"
or self.auth_type == "cookie"
)
)
def _headers(self, auth: bool = True) -> dict[str, str]:
headers: dict[str, str] = {
"Accept": "application/json",
@@ -348,6 +379,119 @@ class UpstreamClient:
raise UpstreamError(f"{method} {path} returned HTML, not JSON")
return resp.json()
def _ensure_api_success(self, payload: Any, action: str) -> None:
if not _is_success_response(payload):
raise UpstreamError(_response_message(payload, f"{action} failed"))
def _new_api_quota_per_unit(self) -> int:
try:
payload = self._request("GET", "/api/status", auth=False)
data = _unwrap_data(payload)
if isinstance(data, dict):
value = data.get("quota_per_unit")
quota_per_unit = int(float(value))
if quota_per_unit > 0:
return quota_per_unit
except Exception:
pass
return NEW_API_DEFAULT_QUOTA_PER_UNIT
@staticmethod
def _normalize_key_record(record: dict[str, Any]) -> dict[str, Any]:
out = dict(record)
if not out.get("name") and out.get("key_name"):
out["name"] = out.get("key_name")
if not out.get("group_id") and out.get("group"):
out["group_id"] = str(out.get("group"))
if not out.get("group_name") and out.get("group"):
out["group_name"] = str(out.get("group"))
return out
def _list_new_api_tokens(
self,
search: str = "",
group_id: str | int | None = None,
) -> list[dict[str, Any]]:
if search:
path = "/api/token/search"
params = {"keyword": search, "token": "", "p": 1, "size": 100}
else:
path = "/api/token/"
params = {"p": 1, "size": 100}
resp = self._client.request(
"GET",
self._url(path),
params=params,
headers=self._headers(),
cookies=self._cookies,
)
self._cookies.update(dict(resp.cookies))
resp.raise_for_status()
data = resp.json()
self._ensure_api_success(data, "list New-API tokens")
nested = _unwrap_data(data)
items: list[dict[str, Any]] | None = None
if isinstance(nested, list):
items = [i for i in nested if isinstance(i, dict)]
elif isinstance(nested, dict):
for key in ("items", "tokens", "list", "records"):
value = nested.get(key)
if isinstance(value, list):
items = [i for i in value if isinstance(i, dict)]
break
if items is None:
raise UpstreamError("unexpected New-API token list response")
normalized = [self._normalize_key_record(i) for i in items]
if group_id is not None:
gid = str(group_id)
normalized = [i for i in normalized if str(i.get("group_id") or i.get("group") or "") == gid]
return normalized
def _get_new_api_token_key(self, token_id: str | int) -> str:
payload = self._request("POST", f"/api/token/{token_id}/key")
self._ensure_api_success(payload, "get New-API token key")
key_value = _extract_key_value(_unwrap_data(payload))
if not key_value:
raise UpstreamError("New-API token key response did not include key")
return key_value
def _create_new_api_token(
self,
name: str,
group_id: str | int,
quota: float = 0,
expires_in_days: int | None = None,
) -> dict[str, Any]:
unlimited = quota <= 0
body: dict[str, Any] = {
"name": name,
"remain_quota": 0 if unlimited else int(round(quota * self._new_api_quota_per_unit())),
"unlimited_quota": unlimited,
"expired_time": int(time.time()) + expires_in_days * 86400 if expires_in_days else -1,
"model_limits_enabled": False,
"model_limits": "",
"allow_ips": "",
"group": str(group_id),
"cross_group_retry": False,
}
payload = self._request("POST", "/api/token/", body)
self._ensure_api_success(payload, "create New-API token")
matches = self._list_new_api_tokens(search=name, group_id=group_id)
token = next((i for i in matches if str(i.get("name") or "").strip() == name.strip()), None)
if not token:
raise UpstreamError("New-API token was created but could not be found by name")
token_id = token.get("id")
if token_id is None:
raise UpstreamError("New-API token list response did not include id")
key_value = self._get_new_api_token_key(token_id)
return {
"id": str(token_id),
"key": key_value,
"masked_key": mask_secret(key_value),
"raw": self._normalize_key_record(token),
}
def login(self) -> None:
if self.auth_type != "login_password":
return
@@ -412,6 +556,9 @@ class UpstreamClient:
endpoint: str = "/keys",
) -> list[dict[str, Any]]:
"""查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。"""
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
return self._list_new_api_tokens(search=search, group_id=group_id)
params: dict[str, Any] = {}
if search:
params["search"] = search
@@ -446,7 +593,7 @@ class UpstreamClient:
for key in ("items", "keys", "list", "records"):
val = data.get(key)
if isinstance(val, list):
return val
return [self._normalize_key_record(i) for i in val if isinstance(i, dict)]
raise UpstreamError(f"unexpected keys response type: {type(data).__name__}")
def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None:
@@ -486,6 +633,14 @@ class UpstreamClient:
rate_limit_7d: float = 0,
endpoint: str = "/keys",
) -> dict[str, Any]:
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
return self._create_new_api_token(
name,
group_id,
quota=quota,
expires_in_days=expires_in_days,
)
body: dict[str, Any] = {
"name": name,
"group_id": int(group_id) if str(group_id).isdigit() else group_id,