feat: support real browser auth import

This commit is contained in:
liumangmang
2026-06-02 13:51:29 +08:00
parent f4d16a4c01
commit 84148f4a69
22 changed files with 1651 additions and 111 deletions
+93
View File
@@ -10,6 +10,11 @@ from sqlalchemy.orm import Session
from app.database import get_db
from app.services.auth_capture_service import extract_all
from app.services.browser_import_service import (
ImportSessionError,
browser_imports,
build_import_result,
)
from app.services.browser_session_service import (
BrowserDependencyError,
BrowserSessionError,
@@ -43,6 +48,32 @@ class CaptureExtractResponse(BaseModel):
candidates: list[dict] = []
class ImportSessionCreate(BaseModel):
target_url: str = Field(..., description="Target page URL opened in the user's real browser")
class ImportSessionCreateResponse(BaseModel):
session_id: str
secret: str
expires_in_seconds: int
class ImportSessionStatusResponse(BaseModel):
session_id: str
ready: bool = False
expires_at: float
result: Optional[CaptureExtractResponse] = None
class ImportSessionSubmit(BaseModel):
secret: str
page_url: str = ""
cookies: list[dict] = []
local_storage: dict[str, Any] = {}
session_storage: dict[str, Any] = {}
auth_headers: list[dict] = []
def _sanitize_candidate(candidate: dict[str, Any]) -> dict[str, Any]:
return {
key: value
@@ -134,3 +165,65 @@ async def close_capture_session(
await browser_sessions.close(session_id)
except Exception as exc:
raise _browser_error(exc)
@router.post("/import-sessions", response_model=ImportSessionCreateResponse, status_code=201)
async def create_import_session(
body: ImportSessionCreate,
user=Depends(get_current_user),
):
"""Create a one-time real-browser import session.
The returned secret is intended for the local browser extension only.
"""
target_url = body.target_url.strip()
if not target_url.startswith(("http://", "https://")):
raise HTTPException(400, "Only http/https URLs are allowed")
session, secret = browser_imports.create(target_url=target_url, created_by=user.email)
return ImportSessionCreateResponse(
session_id=session.id,
secret=secret,
expires_in_seconds=600,
)
@router.get("/import-sessions/{session_id}", response_model=ImportSessionStatusResponse)
async def get_import_session(
session_id: str,
include_raw: bool = Query(default=False),
user=Depends(get_current_user),
):
try:
session = browser_imports.get(session_id, created_by=user.email)
except ImportSessionError as exc:
raise HTTPException(404, str(exc))
result = None
if session.payload is not None:
full_result = build_import_result(session.payload)
if not include_raw:
candidates = [_sanitize_candidate(candidate) for candidate in full_result.get("candidates", [])]
result = CaptureExtractResponse(candidates=candidates)
else:
result = CaptureExtractResponse(**full_result)
return ImportSessionStatusResponse(
session_id=session.id,
ready=session.payload is not None,
expires_at=session.expires_at,
result=result,
)
@router.post("/import-sessions/{session_id}/submit", status_code=204)
async def submit_import_session(
session_id: str,
body: ImportSessionSubmit,
):
try:
browser_imports.submit(
session_id=session_id,
secret=body.secret,
payload=body.model_dump(exclude={"secret"}),
)
except ImportSessionError as exc:
raise HTTPException(400, str(exc))
+44 -14
View File
@@ -221,25 +221,45 @@ class RefreshAuthResponse(BaseModel):
warning: Optional[str] = None
def _pick_best_candidate(candidates: list[dict], preferred_auth_type: str) -> Optional[dict]:
def _norm_path(value: Any) -> str:
return str(value or "").strip().rstrip("/")
def _detect_upstream_platform(upstream: Upstream, auth_config: dict) -> str:
api_prefix = _norm_path(upstream.api_prefix)
groups_endpoint = _norm_path(upstream.groups_endpoint)
rate_endpoint = _norm_path(upstream.rate_endpoint)
login_path = _norm_path(auth_config.get("login_path"))
if groups_endpoint == "/api/user/self/groups" or login_path == "/api/user/login":
return "new-api-user"
if api_prefix == "/api/v1" or groups_endpoint in {"/groups/available", "/groups/rates"} or login_path == "/auth/login":
return "sub2api"
return "unknown"
def _first_candidate(candidates: list[dict], *types: str) -> Optional[dict]:
for c in candidates:
if c.get("type") in types:
return c
return None
def _pick_best_candidate(candidates: list[dict], preferred_auth_type: str, platform: str = "unknown") -> Optional[dict]:
if not candidates:
return None
# cookie_bundle > cookie > bearer_token > api_key
# preferred_auth_type="cookie" 时优先匹配 bundle,其次单 cookie
if platform == "sub2api":
return _first_candidate(candidates, "bearer_token", "api_key")
if platform == "new-api-user":
return _first_candidate(candidates, "cookie_bundle", "cookie", "bearer_token", "api_key")
if preferred_auth_type == "cookie":
for c in candidates:
if c["type"] == "cookie_bundle":
return c
for c in candidates:
if c["type"] == "cookie":
return c
return _first_candidate(candidates, "cookie_bundle", "cookie")
elif preferred_auth_type in ("bearer", "api_key"):
type_map = {"bearer": "bearer_token", "api_key": "api_key"}
preferred = type_map.get(preferred_auth_type)
if preferred:
for c in candidates:
if c["type"] == preferred:
return c
return _first_candidate(candidates, preferred)
# fallback:排序后取第一个
return candidates[0]
@@ -268,11 +288,17 @@ async def refresh_auth(pid: int, db: Session = Depends(get_db), _=Depends(get_cu
return RefreshAuthResponse(success=False, message=f"提取失败: {exc}")
candidates = result.get("candidates", [])
candidate = _pick_best_candidate(candidates, upstream.auth_type)
existing_config = _json.loads(upstream.auth_config_json or "{}")
platform = _detect_upstream_platform(upstream, existing_config)
candidate = _pick_best_candidate(candidates, upstream.auth_type, platform)
if not candidate:
if platform == "sub2api" and _first_candidate(candidates, "cookie_bundle", "cookie"):
return RefreshAuthResponse(
success=False,
message="Sub2API 需要 Bearer Token;当前只提取到 Cookie。请在远程浏览器完成登录后刷新页面或触发一次接口请求,再重新提取。",
)
return RefreshAuthResponse(success=False, message="未提取到有效凭证,请确认已在远程浏览器中登录")
existing_config = _json.loads(upstream.auth_config_json or "{}")
ctype = candidate["type"]
if ctype in ("cookie_bundle", "cookie"):
@@ -281,6 +307,10 @@ async def refresh_auth(pid: int, db: Session = Depends(get_db), _=Depends(get_cu
existing_config["cookie_string"] = candidate.get("value", "")
if candidate.get("new_api_user"):
existing_config["new_api_user"] = candidate["new_api_user"]
if platform == "new-api-user":
upstream.api_prefix = ""
upstream.groups_endpoint = "/api/user/self/groups"
upstream.rate_endpoint = "/api/user/self/groups"
elif ctype == "bearer_token":
upstream.auth_type = "bearer"
raw = candidate.get("value", "")
+22 -2
View File
@@ -315,6 +315,26 @@ def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_c
_generate_key_lock = __import__("threading").Lock()
def _is_sub2api_upstream(upstream: Upstream) -> bool:
return upstream.api_prefix.strip("/") == "api/v1"
def _is_new_api_user_upstream(upstream: Upstream) -> bool:
auth_config = json.loads(upstream.auth_config_json or "{}")
return (
upstream.api_prefix.strip("/") == ""
and (
upstream.groups_endpoint == "/api/user/self/groups"
or auth_config.get("login_path") == "/api/user/login"
or bool(auth_config.get("new_api_user"))
)
)
def _supports_key_generation(upstream: Upstream) -> bool:
return _is_sub2api_upstream(upstream) or _is_new_api_user_upstream(upstream)
def _ensure_group_key(
db: Session,
client: UpstreamClient,
@@ -465,8 +485,8 @@ def generate_keys_by_groups(
u = db.query(Upstream).filter(Upstream.id == uid).first()
if not u:
raise HTTPException(404, "upstream not found")
if u.api_prefix.strip("/") != "api/v1":
raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1")
if not _supports_key_generation(u):
raise HTTPException(400, "仅支持 Sub2API 或 New-API 普通账号上游生成 Key")
# 生成前先对账,清理远端已删除的旧 Key
try:
+20 -9
View File
@@ -137,9 +137,12 @@ def _cookie_matches_hostname(cookie_domain: str, hostname: str) -> bool:
"""判断 cookie domain 是否适用于给定 hostname。
支持带点前缀的 domain(如 `.saki.lat` 匹配 `api.saki.lat`)。
注意:hostname 为空时,调用方应跳过 cookie 收集而不是调用此函数。
"""
if not cookie_domain or not hostname:
return True # 无 domain 限制时视为全域
if not cookie_domain:
return True # 无 domain 限制的 cookie 对当前域有效
if not hostname:
return False # 无法确定当前域,保守拒绝
# 去掉前缀点
domain = cookie_domain.lstrip(".")
return hostname == domain or hostname.endswith("." + domain)
@@ -153,14 +156,22 @@ def _build_cookie_bundle(
返回 (cookie_string, cookie_names_list)。
cookie_string 格式:name1=value1; name2=value2; ...
过滤掉空值 cookie。
过滤掉空值 cookie。若 page_url 为空或无法解析 hostname,返回空结果
(不收集全域 cookie 以防误写入无关域凭证)。
"""
if not page_url:
logger.debug("_build_cookie_bundle: no page_url, skipping cookie collection")
return "", []
hostname = ""
if page_url:
try:
hostname = urlparse(page_url).hostname or ""
except Exception:
pass
try:
hostname = urlparse(page_url).hostname or ""
except Exception:
pass
if not hostname:
logger.debug("_build_cookie_bundle: cannot parse hostname from %s, skipping", page_url[:80])
return "", []
parts: list[str] = []
names: list[str] = []
@@ -170,7 +181,7 @@ def _build_cookie_bundle(
domain = c.get("domain", "")
if not name or not value:
continue
if hostname and not _cookie_matches_hostname(domain, hostname):
if not _cookie_matches_hostname(domain, hostname):
continue
parts.append(f"{name}={value}")
names.append(name)
@@ -0,0 +1,159 @@
"""One-time credential import sessions for real-browser auth capture."""
from __future__ import annotations
import hashlib
import secrets
import time
from dataclasses import dataclass, field
from typing import Any
from app.services.auth_capture_service import _curate_candidates, _find_new_api_user
IMPORT_SESSION_TTL_SECONDS = 600
class ImportSessionError(RuntimeError):
pass
@dataclass
class BrowserImportSession:
id: str
secret_hash: str
target_url: str
created_by: str
expires_at: float
payload: dict[str, Any] | None = None
consumed: bool = False
created_at: float = field(default_factory=time.time)
def _hash_secret(secret: str) -> str:
return hashlib.sha256(secret.encode("utf-8")).hexdigest()
def _normalize_storage(value: Any) -> dict[str, str]:
if not isinstance(value, dict):
return {}
result: dict[str, str] = {}
for key, item in value.items():
if item is None:
continue
result[str(key)] = item if isinstance(item, str) else str(item)
return result
def _normalize_cookies(value: Any) -> list[dict[str, Any]]:
if not isinstance(value, list):
return []
result: list[dict[str, Any]] = []
for item in value:
if not isinstance(item, dict):
continue
name = str(item.get("name") or "").strip()
cookie_value = str(item.get("value") or "")
if not name or not cookie_value:
continue
result.append({
"name": name,
"value": cookie_value,
"domain": str(item.get("domain") or ""),
"path": str(item.get("path") or "/"),
"httpOnly": bool(item.get("httpOnly", False)),
"secure": bool(item.get("secure", False)),
})
return result
def _normalize_headers(value: Any) -> list[dict[str, str]]:
if not isinstance(value, list):
return []
result: list[dict[str, str]] = []
for item in value:
if not isinstance(item, dict):
continue
header_value = str(item.get("value") or "").strip()
if not header_value:
continue
result.append({
"type": str(item.get("type") or "authorization"),
"value": header_value,
"url": str(item.get("url") or ""),
})
return result
def build_import_result(payload: dict[str, Any]) -> dict[str, Any]:
"""Convert extension-submitted payload into auth-capture result shape."""
page_url = str(payload.get("page_url") or payload.get("url") or "")
cookies = _normalize_cookies(payload.get("cookies"))
storage = _normalize_storage(payload.get("local_storage") or payload.get("storage"))
session_storage = _normalize_storage(payload.get("session_storage"))
auth_headers = _normalize_headers(payload.get("auth_headers"))
new_api_user = _find_new_api_user(storage, session_storage)
candidates = _curate_candidates(
cookies=cookies,
local_storage=storage,
session_storage=session_storage,
auth_headers=auth_headers,
new_api_user=new_api_user,
page_url=page_url,
)
return {
"cookies": cookies,
"storage": storage,
"session_storage": session_storage,
"auth_headers": auth_headers,
"candidates": candidates,
}
class BrowserImportService:
def __init__(self) -> None:
self._sessions: dict[str, BrowserImportSession] = {}
def create(self, target_url: str, created_by: str) -> tuple[BrowserImportSession, str]:
self.cleanup()
session_id = secrets.token_urlsafe(18)
secret = secrets.token_urlsafe(24)
session = BrowserImportSession(
id=session_id,
secret_hash=_hash_secret(secret),
target_url=target_url,
created_by=created_by,
expires_at=time.time() + IMPORT_SESSION_TTL_SECONDS,
)
self._sessions[session_id] = session
return session, secret
def get(self, session_id: str, created_by: str | None = None) -> BrowserImportSession:
self.cleanup()
session = self._sessions.get(session_id)
if not session:
raise ImportSessionError("import session not found")
if created_by is not None and session.created_by != created_by:
raise ImportSessionError("import session not found")
if session.expires_at <= time.time():
self._sessions.pop(session_id, None)
raise ImportSessionError("import session expired")
return session
def submit(self, session_id: str, secret: str, payload: dict[str, Any]) -> BrowserImportSession:
session = self.get(session_id)
if session.consumed:
raise ImportSessionError("import session already consumed")
if not secrets.compare_digest(session.secret_hash, _hash_secret(secret)):
raise ImportSessionError("invalid import secret")
session.payload = payload
session.consumed = True
return session
def cleanup(self) -> None:
now = time.time()
expired = [sid for sid, session in self._sessions.items() if session.expires_at <= now]
for sid in expired:
self._sessions.pop(sid, None)
browser_imports = BrowserImportService()
@@ -68,6 +68,34 @@ class BrowserSessionService:
self._last_event_at: dict[str, float] = {}
self._evict_task: Optional[asyncio.Task[None]] = None
def _browser_launch_kwargs(self, width: int, height: int) -> dict[str, Any]:
return {
"headless": get_settings().browser_headless,
"viewport": {"width": width, "height": height},
"color_scheme": "dark",
"locale": "zh-CN",
"timezone_id": get_settings().tz,
"ignore_default_args": ["--enable-automation"],
"args": [
"--no-sandbox",
"--disable-dev-shm-usage",
"--disable-blink-features=AutomationControlled",
"--window-size=%d,%d" % (width, height),
],
}
async def _install_browser_init_scripts(self, context: Any) -> None:
await context.add_init_script("""
(() => {
try {
Object.defineProperty(navigator, 'webdriver', { get: () => undefined });
Object.defineProperty(navigator, 'languages', { get: () => ['zh-CN', 'zh', 'en-US', 'en'] });
Object.defineProperty(navigator, 'plugins', { get: () => [1, 2, 3, 4, 5] });
window.chrome = window.chrome || { runtime: {} };
} catch (_) {}
})();
""")
async def create(
self,
custom_page_id: int,
@@ -113,11 +141,9 @@ class BrowserSessionService:
context = await self._playwright.chromium.launch_persistent_context(
str(self._profile_dir(profile_key)),
headless=get_settings().browser_headless,
viewport={"width": width, "height": height},
color_scheme="dark",
args=["--no-sandbox", "--disable-dev-shm-usage"],
**self._browser_launch_kwargs(width, height),
)
await self._install_browser_init_scripts(context)
await self._restore_session_state(context, profile_key)
# Grant clipboard access for the page origin
try:
@@ -773,11 +799,9 @@ class BrowserSessionService:
profile_key = f"auth-capture-{session_id[:12]}"
context = await self._playwright.chromium.launch_persistent_context(
str(self._profile_dir(profile_key)),
headless=get_settings().browser_headless,
viewport={"width": width, "height": height},
color_scheme="dark",
args=["--no-sandbox", "--disable-dev-shm-usage"],
**self._browser_launch_kwargs(width, height),
)
await self._install_browser_init_scripts(context)
# Grant clipboard access for the page origin
try:
parsed = urlparse(url)
+156 -1
View File
@@ -2,6 +2,7 @@
from __future__ import annotations
import json
import time
from typing import Any, Optional
from urllib.parse import urljoin
import httpx
@@ -13,6 +14,9 @@ class UpstreamError(RuntimeError):
pass
NEW_API_DEFAULT_QUOTA_PER_UNIT = 500000
def _find_token(value: Any) -> str:
if isinstance(value, str) and value.count(".") >= 2:
return value
@@ -74,6 +78,8 @@ def mask_secret(value: Any) -> str:
def _unwrap_data(value: Any) -> Any:
if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value):
return value.get("data")
if isinstance(value, dict) and "data" in value and "success" in value:
return value.get("data")
return value
@@ -105,6 +111,20 @@ def _extract_key_value(value: Any) -> str:
return ""
def _is_success_response(value: Any) -> bool:
if not isinstance(value, dict) or "success" not in value:
return True
return value.get("success") is True
def _response_message(value: Any, fallback: str = "") -> str:
if isinstance(value, dict):
msg = value.get("message") or value.get("detail")
if msg:
return str(msg)
return fallback
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
def _normalize(lst: list) -> list[dict[str, Any]]:
out = []
@@ -288,6 +308,17 @@ class UpstreamClient:
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
def _is_new_api_user_mode(self) -> bool:
login_path = str(self.auth_config.get("login_path") or "")
return (
self.api_prefix == ""
and (
bool(self.auth_config.get("new_api_user"))
or login_path == "/api/user/login"
or self.auth_type == "cookie"
)
)
def _headers(self, auth: bool = True) -> dict[str, str]:
headers: dict[str, str] = {
"Accept": "application/json",
@@ -348,6 +379,119 @@ class UpstreamClient:
raise UpstreamError(f"{method} {path} returned HTML, not JSON")
return resp.json()
def _ensure_api_success(self, payload: Any, action: str) -> None:
if not _is_success_response(payload):
raise UpstreamError(_response_message(payload, f"{action} failed"))
def _new_api_quota_per_unit(self) -> int:
try:
payload = self._request("GET", "/api/status", auth=False)
data = _unwrap_data(payload)
if isinstance(data, dict):
value = data.get("quota_per_unit")
quota_per_unit = int(float(value))
if quota_per_unit > 0:
return quota_per_unit
except Exception:
pass
return NEW_API_DEFAULT_QUOTA_PER_UNIT
@staticmethod
def _normalize_key_record(record: dict[str, Any]) -> dict[str, Any]:
out = dict(record)
if not out.get("name") and out.get("key_name"):
out["name"] = out.get("key_name")
if not out.get("group_id") and out.get("group"):
out["group_id"] = str(out.get("group"))
if not out.get("group_name") and out.get("group"):
out["group_name"] = str(out.get("group"))
return out
def _list_new_api_tokens(
self,
search: str = "",
group_id: str | int | None = None,
) -> list[dict[str, Any]]:
if search:
path = "/api/token/search"
params = {"keyword": search, "token": "", "p": 1, "size": 100}
else:
path = "/api/token/"
params = {"p": 1, "size": 100}
resp = self._client.request(
"GET",
self._url(path),
params=params,
headers=self._headers(),
cookies=self._cookies,
)
self._cookies.update(dict(resp.cookies))
resp.raise_for_status()
data = resp.json()
self._ensure_api_success(data, "list New-API tokens")
nested = _unwrap_data(data)
items: list[dict[str, Any]] | None = None
if isinstance(nested, list):
items = [i for i in nested if isinstance(i, dict)]
elif isinstance(nested, dict):
for key in ("items", "tokens", "list", "records"):
value = nested.get(key)
if isinstance(value, list):
items = [i for i in value if isinstance(i, dict)]
break
if items is None:
raise UpstreamError("unexpected New-API token list response")
normalized = [self._normalize_key_record(i) for i in items]
if group_id is not None:
gid = str(group_id)
normalized = [i for i in normalized if str(i.get("group_id") or i.get("group") or "") == gid]
return normalized
def _get_new_api_token_key(self, token_id: str | int) -> str:
payload = self._request("POST", f"/api/token/{token_id}/key")
self._ensure_api_success(payload, "get New-API token key")
key_value = _extract_key_value(_unwrap_data(payload))
if not key_value:
raise UpstreamError("New-API token key response did not include key")
return key_value
def _create_new_api_token(
self,
name: str,
group_id: str | int,
quota: float = 0,
expires_in_days: int | None = None,
) -> dict[str, Any]:
unlimited = quota <= 0
body: dict[str, Any] = {
"name": name,
"remain_quota": 0 if unlimited else int(round(quota * self._new_api_quota_per_unit())),
"unlimited_quota": unlimited,
"expired_time": int(time.time()) + expires_in_days * 86400 if expires_in_days else -1,
"model_limits_enabled": False,
"model_limits": "",
"allow_ips": "",
"group": str(group_id),
"cross_group_retry": False,
}
payload = self._request("POST", "/api/token/", body)
self._ensure_api_success(payload, "create New-API token")
matches = self._list_new_api_tokens(search=name, group_id=group_id)
token = next((i for i in matches if str(i.get("name") or "").strip() == name.strip()), None)
if not token:
raise UpstreamError("New-API token was created but could not be found by name")
token_id = token.get("id")
if token_id is None:
raise UpstreamError("New-API token list response did not include id")
key_value = self._get_new_api_token_key(token_id)
return {
"id": str(token_id),
"key": key_value,
"masked_key": mask_secret(key_value),
"raw": self._normalize_key_record(token),
}
def login(self) -> None:
if self.auth_type != "login_password":
return
@@ -412,6 +556,9 @@ class UpstreamClient:
endpoint: str = "/keys",
) -> list[dict[str, Any]]:
"""查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。"""
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
return self._list_new_api_tokens(search=search, group_id=group_id)
params: dict[str, Any] = {}
if search:
params["search"] = search
@@ -446,7 +593,7 @@ class UpstreamClient:
for key in ("items", "keys", "list", "records"):
val = data.get(key)
if isinstance(val, list):
return val
return [self._normalize_key_record(i) for i in val if isinstance(i, dict)]
raise UpstreamError(f"unexpected keys response type: {type(data).__name__}")
def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None:
@@ -486,6 +633,14 @@ class UpstreamClient:
rate_limit_7d: float = 0,
endpoint: str = "/keys",
) -> dict[str, Any]:
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
return self._create_new_api_token(
name,
group_id,
quota=quota,
expires_in_days=expires_in_days,
)
body: dict[str, Any] = {
"name": name,
"group_id": int(group_id) if str(group_id).isdigit() else group_id,
+62 -2
View File
@@ -36,11 +36,17 @@ def test_dot_prefix_exact_match():
assert _cookie_matches_hostname(".saki.lat", "saki.lat")
def test_no_domain_matches_all():
"""domain 视为不限制"""
def test_no_domain_cookie_matches_any_hostname():
"""cookie domain(无限制)应对任意 hostname 返回 True"""
assert _cookie_matches_hostname("", "anything.example.com")
def test_empty_hostname_rejects_all():
"""hostname 为空时,所有有 domain 的 cookie 都应被保守拒绝。"""
assert not _cookie_matches_hostname(".saki.lat", "")
assert not _cookie_matches_hostname("saki.lat", "")
def test_different_domain_no_match():
assert not _cookie_matches_hostname(".example.com", "saki.lat")
@@ -210,3 +216,57 @@ def test_new_api_user_propagated_to_bundle():
)
bundle = next(c for c in candidates if c["type"] == "cookie_bundle")
assert bundle.get("new_api_user") == "42"
def test_browser_import_payload_builds_cookie_bundle_with_new_api_user():
from app.services.browser_import_service import build_import_result
result = build_import_result({
"page_url": "https://meow.example.com/panel",
"cookies": [
{"name": "cf_clearance", "value": "cf", "domain": ".example.com", "httpOnly": True},
{"name": "session", "value": "sess", "domain": ".example.com", "httpOnly": True},
],
"local_storage": {"uid": "7"},
"session_storage": {},
"auth_headers": [],
})
bundle = next(c for c in result["candidates"] if c["type"] == "cookie_bundle")
assert "cf_clearance=cf" in bundle["value"]
assert "session=sess" in bundle["value"]
assert bundle["new_api_user"] == "7"
def test_browser_import_payload_includes_auth_headers():
from app.services.browser_import_service import build_import_result
result = build_import_result({
"page_url": "https://sub2api.example.com/dashboard",
"cookies": [],
"local_storage": {},
"session_storage": {},
"auth_headers": [
{"type": "authorization", "value": "Bearer abc.def.ghi", "url": "https://sub2api.example.com/api/v1/groups"}
],
})
assert result["candidates"][0]["type"] == "bearer_token"
assert result["candidates"][0]["value"] == "Bearer abc.def.ghi"
def test_browser_import_session_secret_and_one_time_submit():
from app.services.browser_import_service import BrowserImportService, ImportSessionError
service = BrowserImportService()
session, secret = service.create("https://example.com/login", "admin@example.com")
with pytest.raises(ImportSessionError):
service.submit(session.id, "wrong", {"page_url": "https://example.com/"})
submitted = service.submit(session.id, secret, {"page_url": "https://example.com/"})
assert submitted.consumed is True
assert submitted.payload == {"page_url": "https://example.com/"}
with pytest.raises(ImportSessionError):
service.submit(session.id, secret, {"page_url": "https://example.com/again"})
+60 -59
View File
@@ -2,7 +2,6 @@ import sys
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
@@ -10,10 +9,9 @@ from sqlalchemy.pool import StaticPool
sys.path.insert(0, str(Path(__file__).resolve().parent))
from app import database as database_module
from app.database import Base, get_db
from app.main import app
from app.database import Base
from app.models.custom_page import CustomPage
from app.utils.auth import get_current_user
from app.routers import custom_pages
@pytest.fixture()
@@ -33,48 +31,41 @@ def db_session():
Base.metadata.drop_all(bind=engine)
@pytest.fixture()
def client(db_session):
def override_get_db():
yield db_session
def test_create_page_auto_enables_autofill_when_credentials_are_saved(db_session):
response = custom_pages.create_page(
custom_pages.CustomPageCreate(
name="Login page",
url="https://example.test/login",
access_mode="remote_browser",
login_username="alice",
login_password="secret",
),
db_session,
object(),
)
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = lambda: object()
try:
yield TestClient(app)
finally:
app.dependency_overrides.clear()
assert response.login_autofill_enabled is True
assert response.login_password_configured is True
def test_create_page_auto_enables_autofill_when_credentials_are_saved(client):
response = client.post("/api/custom-pages", json={
"name": "Login page",
"url": "https://example.test/login",
"access_mode": "remote_browser",
"login_username": "alice",
"login_password": "secret",
})
def test_create_page_respects_explicit_autofill_disable(db_session):
response = custom_pages.create_page(
custom_pages.CustomPageCreate(
name="Login page",
url="https://example.test/login",
access_mode="remote_browser",
login_username="alice",
login_password="secret",
login_autofill_enabled=False,
),
db_session,
object(),
)
assert response.status_code == 201
assert response.json()["login_autofill_enabled"] is True
assert response.json()["login_password_configured"] is True
assert response.login_autofill_enabled is False
def test_create_page_respects_explicit_autofill_disable(client):
response = client.post("/api/custom-pages", json={
"name": "Login page",
"url": "https://example.test/login",
"access_mode": "remote_browser",
"login_username": "alice",
"login_password": "secret",
"login_autofill_enabled": False,
})
assert response.status_code == 201
assert response.json()["login_autofill_enabled"] is False
def test_update_page_auto_enables_autofill_when_new_password_is_saved(client, db_session):
def test_update_page_auto_enables_autofill_when_new_password_is_saved(db_session):
page = CustomPage(
name="Login page",
url="https://example.test/login",
@@ -87,16 +78,20 @@ def test_update_page_auto_enables_autofill_when_new_password_is_saved(client, db
db_session.commit()
db_session.refresh(page)
response = client.put(f"/api/custom-pages/{page.id}", json={
"login_username": "alice@example.test",
"login_password": "new-secret",
})
response = custom_pages.update_page(
page.id,
custom_pages.CustomPageUpdate(
login_username="alice@example.test",
login_password="new-secret",
),
db_session,
object(),
)
assert response.status_code == 200
assert response.json()["login_autofill_enabled"] is True
assert response.login_autofill_enabled is True
def test_update_page_keeps_autofill_disabled_when_existing_password_is_kept(client, db_session):
def test_update_page_keeps_autofill_disabled_when_existing_password_is_kept(db_session):
page = CustomPage(
name="Login page",
url="https://example.test/login",
@@ -109,15 +104,17 @@ def test_update_page_keeps_autofill_disabled_when_existing_password_is_kept(clie
db_session.commit()
db_session.refresh(page)
response = client.put(f"/api/custom-pages/{page.id}", json={
"login_username": "alice@example.test",
})
response = custom_pages.update_page(
page.id,
custom_pages.CustomPageUpdate(login_username="alice@example.test"),
db_session,
object(),
)
assert response.status_code == 200
assert response.json()["login_autofill_enabled"] is False
assert response.login_autofill_enabled is False
def test_update_page_respects_explicit_autofill_disable(client, db_session):
def test_update_page_respects_explicit_autofill_disable(db_session):
page = CustomPage(
name="Login page",
url="https://example.test/login",
@@ -130,13 +127,17 @@ def test_update_page_respects_explicit_autofill_disable(client, db_session):
db_session.commit()
db_session.refresh(page)
response = client.put(f"/api/custom-pages/{page.id}", json={
"login_username": "alice@example.test",
"login_autofill_enabled": False,
})
response = custom_pages.update_page(
page.id,
custom_pages.CustomPageUpdate(
login_username="alice@example.test",
login_autofill_enabled=False,
),
db_session,
object(),
)
assert response.status_code == 200
assert response.json()["login_autofill_enabled"] is False
assert response.login_autofill_enabled is False
def test_custom_page_migration_backfills_autofill_once(monkeypatch):
+177
View File
@@ -0,0 +1,177 @@
import asyncio
import json
import sys
from pathlib import Path
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
sys.path.insert(0, str(Path(__file__).resolve().parent))
from app.database import Base
from app.models.custom_page import CustomPage
from app.models.upstream import Upstream
from app.routers import custom_pages
@pytest.fixture()
def db_session():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
def _linked_page(db, upstream: Upstream) -> CustomPage:
db.add(upstream)
db.commit()
db.refresh(upstream)
page = CustomPage(
name="Login",
url="https://meow.example/login",
access_mode="remote_browser",
linked_upstream_id=upstream.id,
)
db.add(page)
db.commit()
db.refresh(page)
return page
def _install_refresh_fakes(monkeypatch, candidates: list[dict], calls: list[dict]):
monkeypatch.setattr(custom_pages.browser_sessions, "find_by_page_id", lambda _pid: object())
async def fake_extract_all(_session):
return {"candidates": candidates}
monkeypatch.setattr(custom_pages, "extract_all", fake_extract_all)
class FakeUpstreamClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, *_args):
return False
def get_available_groups(self, endpoint):
calls.append({"endpoint": endpoint, **self.kwargs})
return [{"id": "default", "name": "Default"}]
import app.services.upstream_client as upstream_client_module
monkeypatch.setattr(upstream_client_module, "UpstreamClient", FakeUpstreamClient)
def test_refresh_auth_sub2api_prefers_bearer_over_cookie_bundle(db_session, monkeypatch):
upstream = Upstream(
name="Meow",
base_url="https://api.saki.lat",
api_prefix="/api/v1",
auth_type="login_password",
auth_config_json=json.dumps({
"email": "alice@example.test",
"password": "secret",
"login_path": "/auth/login",
}),
groups_endpoint="/groups/available",
rate_endpoint="/groups/rates",
)
page = _linked_page(db_session, upstream)
calls: list[dict] = []
_install_refresh_fakes(monkeypatch, [
{"type": "cookie_bundle", "value": "cf_clearance=cf; session=s", "cookie_count": 2},
{"type": "bearer_token", "value": "jwt.header.payload", "source": "localStorage.auth_token"},
], calls)
response = asyncio.run(custom_pages.refresh_auth(page.id, db_session, object()))
db_session.refresh(upstream)
cfg = json.loads(upstream.auth_config_json)
assert response.success is True
assert upstream.auth_type == "bearer"
assert cfg["token"] == "jwt.header.payload"
assert "cookie_string" not in cfg
assert upstream.api_prefix == "/api/v1"
assert upstream.groups_endpoint == "/groups/available"
assert calls[0]["auth_type"] == "bearer"
assert calls[0]["endpoint"] == "/groups/available"
def test_refresh_auth_sub2api_rejects_cookie_only_capture(db_session, monkeypatch):
upstream = Upstream(
name="Meow",
base_url="https://api.saki.lat",
api_prefix="/api/v1",
auth_type="login_password",
auth_config_json=json.dumps({"login_path": "/auth/login"}),
groups_endpoint="/groups/available",
rate_endpoint="/groups/rates",
)
page = _linked_page(db_session, upstream)
_install_refresh_fakes(monkeypatch, [
{"type": "cookie_bundle", "value": "cf_clearance=cf; session=s", "cookie_count": 2},
], [])
response = asyncio.run(custom_pages.refresh_auth(page.id, db_session, object()))
db_session.refresh(upstream)
cfg = json.loads(upstream.auth_config_json)
assert response.success is False
assert "Sub2API 需要 Bearer Token" in response.message
assert upstream.auth_type == "login_password"
assert "cookie_string" not in cfg
def test_refresh_auth_new_api_user_uses_cookie_bundle_and_resets_user_endpoints(db_session, monkeypatch):
upstream = Upstream(
name="New API User",
base_url="https://newapi.example",
api_prefix="/api/v1",
auth_type="login_password",
auth_config_json=json.dumps({
"email": "alice",
"password": "secret",
"login_path": "/api/user/login",
}),
groups_endpoint="/groups/available",
rate_endpoint="/groups/rates",
)
page = _linked_page(db_session, upstream)
calls: list[dict] = []
_install_refresh_fakes(monkeypatch, [
{
"type": "cookie_bundle",
"value": "cf_clearance=cf; session=s",
"cookie_count": 2,
"new_api_user": "42",
},
{"type": "bearer_token", "value": "jwt.header.payload"},
], calls)
response = asyncio.run(custom_pages.refresh_auth(page.id, db_session, object()))
db_session.refresh(upstream)
cfg = json.loads(upstream.auth_config_json)
assert response.success is True
assert upstream.auth_type == "cookie"
assert cfg["cookie_string"] == "cf_clearance=cf; session=s"
assert cfg["new_api_user"] == "42"
assert upstream.api_prefix == ""
assert upstream.groups_endpoint == "/api/user/self/groups"
assert upstream.rate_endpoint == "/api/user/self/groups"
assert calls[0]["auth_type"] == "cookie"
assert calls[0]["endpoint"] == "/api/user/self/groups"
+103
View File
@@ -590,3 +590,106 @@ def test_sync_removes_deleted_remote_key(db_session):
remaining = db_session.query(UpstreamGeneratedKey).all()
assert len(remaining) == 0
def test_new_api_create_token_fetches_plaintext_key(monkeypatch):
"""New-API 创建 token 后需按 id 再取一次明文 key。"""
from app.services.upstream_client import UpstreamClient
client = UpstreamClient(
base_url="http://newapi.local",
api_prefix="",
auth_type="cookie",
auth_config={"cookie_string": "session=abc", "new_api_user": "7"},
)
created_bodies = []
def fake_request(method, path, body=None, auth=True):
if method == "GET" and path == "/api/status":
return {"success": True, "data": {"quota_per_unit": 500000}}
if method == "POST" and path == "/api/token/":
created_bodies.append(body)
return {"success": True, "message": ""}
if method == "POST" and path == "/api/token/123/key":
return {"success": True, "data": {"key": "new-api-plain-key"}}
raise AssertionError(f"unexpected request {method} {path}")
monkeypatch.setattr(client, "_request", fake_request)
monkeypatch.setattr(
client,
"_list_new_api_tokens",
lambda search="", group_id=None: [{"id": 123, "name": search, "group": group_id, "key": "new-****-key"}],
)
result = client.create_api_key(
"SmartUp-1-VIP-vip",
"vip",
quota=2,
expires_in_days=3,
endpoint="/api/token",
)
assert result["id"] == "123"
assert result["key"] == "new-api-plain-key"
assert created_bodies[0]["group"] == "vip"
assert created_bodies[0]["remain_quota"] == 1000000
assert created_bodies[0]["unlimited_quota"] is False
assert created_bodies[0]["expired_time"] > 0
def test_generate_keys_allows_new_api_user_upstream(db_session, monkeypatch):
"""New-API 普通账号上游应允许按分组生成 token。"""
from app.routers import upstreams as upstreams_router
from app.schemas.upstream import GenerateKeysByGroupsRequest
upstream = Upstream(
name="NewAPI",
base_url="http://newapi.local",
api_prefix="",
auth_type="cookie",
auth_config_json=json.dumps({"cookie_string": "session=abc", "new_api_user": "7"}),
groups_endpoint="/api/user/self/groups",
rate_endpoint="/api/user/self/groups",
)
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
monkeypatch.setattr(upstreams_router.website_sync, "reconcile_upstream_keys_full", lambda db, uid: True)
class FakeClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, *args):
return None
def login(self):
return None
def get_available_groups(self, endpoint):
assert endpoint == "/api/user/self/groups"
return [{"id": "vip", "name": "VIP"}]
def find_smartup_group_key(self, gid, expected_name, prefix):
return None
def create_api_key(self, name, group_id, **kwargs):
assert kwargs["endpoint"] == "/api/token"
return {"id": "123", "key": "new-api-plain-key", "masked_key": "new-****-key", "raw": {"id": 123}}
monkeypatch.setattr(upstreams_router, "UpstreamClient", FakeClient)
response = upstreams_router.generate_keys_by_groups(
upstream.id,
GenerateKeysByGroupsRequest(group_ids=["vip"], endpoint="/api/token"),
db_session,
object(),
)
assert response.success is True
assert response.items[0].status == "created"
assert response.items[0].key_value == "new-api-plain-key"