feat: support real browser auth import
This commit is contained in:
@@ -10,6 +10,11 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.services.auth_capture_service import extract_all
|
||||
from app.services.browser_import_service import (
|
||||
ImportSessionError,
|
||||
browser_imports,
|
||||
build_import_result,
|
||||
)
|
||||
from app.services.browser_session_service import (
|
||||
BrowserDependencyError,
|
||||
BrowserSessionError,
|
||||
@@ -43,6 +48,32 @@ class CaptureExtractResponse(BaseModel):
|
||||
candidates: list[dict] = []
|
||||
|
||||
|
||||
class ImportSessionCreate(BaseModel):
|
||||
target_url: str = Field(..., description="Target page URL opened in the user's real browser")
|
||||
|
||||
|
||||
class ImportSessionCreateResponse(BaseModel):
|
||||
session_id: str
|
||||
secret: str
|
||||
expires_in_seconds: int
|
||||
|
||||
|
||||
class ImportSessionStatusResponse(BaseModel):
|
||||
session_id: str
|
||||
ready: bool = False
|
||||
expires_at: float
|
||||
result: Optional[CaptureExtractResponse] = None
|
||||
|
||||
|
||||
class ImportSessionSubmit(BaseModel):
|
||||
secret: str
|
||||
page_url: str = ""
|
||||
cookies: list[dict] = []
|
||||
local_storage: dict[str, Any] = {}
|
||||
session_storage: dict[str, Any] = {}
|
||||
auth_headers: list[dict] = []
|
||||
|
||||
|
||||
def _sanitize_candidate(candidate: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
key: value
|
||||
@@ -134,3 +165,65 @@ async def close_capture_session(
|
||||
await browser_sessions.close(session_id)
|
||||
except Exception as exc:
|
||||
raise _browser_error(exc)
|
||||
|
||||
|
||||
@router.post("/import-sessions", response_model=ImportSessionCreateResponse, status_code=201)
|
||||
async def create_import_session(
|
||||
body: ImportSessionCreate,
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Create a one-time real-browser import session.
|
||||
|
||||
The returned secret is intended for the local browser extension only.
|
||||
"""
|
||||
target_url = body.target_url.strip()
|
||||
if not target_url.startswith(("http://", "https://")):
|
||||
raise HTTPException(400, "Only http/https URLs are allowed")
|
||||
session, secret = browser_imports.create(target_url=target_url, created_by=user.email)
|
||||
return ImportSessionCreateResponse(
|
||||
session_id=session.id,
|
||||
secret=secret,
|
||||
expires_in_seconds=600,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/import-sessions/{session_id}", response_model=ImportSessionStatusResponse)
|
||||
async def get_import_session(
|
||||
session_id: str,
|
||||
include_raw: bool = Query(default=False),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
session = browser_imports.get(session_id, created_by=user.email)
|
||||
except ImportSessionError as exc:
|
||||
raise HTTPException(404, str(exc))
|
||||
|
||||
result = None
|
||||
if session.payload is not None:
|
||||
full_result = build_import_result(session.payload)
|
||||
if not include_raw:
|
||||
candidates = [_sanitize_candidate(candidate) for candidate in full_result.get("candidates", [])]
|
||||
result = CaptureExtractResponse(candidates=candidates)
|
||||
else:
|
||||
result = CaptureExtractResponse(**full_result)
|
||||
return ImportSessionStatusResponse(
|
||||
session_id=session.id,
|
||||
ready=session.payload is not None,
|
||||
expires_at=session.expires_at,
|
||||
result=result,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import-sessions/{session_id}/submit", status_code=204)
|
||||
async def submit_import_session(
|
||||
session_id: str,
|
||||
body: ImportSessionSubmit,
|
||||
):
|
||||
try:
|
||||
browser_imports.submit(
|
||||
session_id=session_id,
|
||||
secret=body.secret,
|
||||
payload=body.model_dump(exclude={"secret"}),
|
||||
)
|
||||
except ImportSessionError as exc:
|
||||
raise HTTPException(400, str(exc))
|
||||
|
||||
@@ -221,25 +221,45 @@ class RefreshAuthResponse(BaseModel):
|
||||
warning: Optional[str] = None
|
||||
|
||||
|
||||
def _pick_best_candidate(candidates: list[dict], preferred_auth_type: str) -> Optional[dict]:
|
||||
def _norm_path(value: Any) -> str:
|
||||
return str(value or "").strip().rstrip("/")
|
||||
|
||||
|
||||
def _detect_upstream_platform(upstream: Upstream, auth_config: dict) -> str:
|
||||
api_prefix = _norm_path(upstream.api_prefix)
|
||||
groups_endpoint = _norm_path(upstream.groups_endpoint)
|
||||
rate_endpoint = _norm_path(upstream.rate_endpoint)
|
||||
login_path = _norm_path(auth_config.get("login_path"))
|
||||
|
||||
if groups_endpoint == "/api/user/self/groups" or login_path == "/api/user/login":
|
||||
return "new-api-user"
|
||||
if api_prefix == "/api/v1" or groups_endpoint in {"/groups/available", "/groups/rates"} or login_path == "/auth/login":
|
||||
return "sub2api"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _first_candidate(candidates: list[dict], *types: str) -> Optional[dict]:
|
||||
for c in candidates:
|
||||
if c.get("type") in types:
|
||||
return c
|
||||
return None
|
||||
|
||||
|
||||
def _pick_best_candidate(candidates: list[dict], preferred_auth_type: str, platform: str = "unknown") -> Optional[dict]:
|
||||
if not candidates:
|
||||
return None
|
||||
# cookie_bundle > cookie > bearer_token > api_key
|
||||
# preferred_auth_type="cookie" 时优先匹配 bundle,其次单 cookie
|
||||
|
||||
if platform == "sub2api":
|
||||
return _first_candidate(candidates, "bearer_token", "api_key")
|
||||
if platform == "new-api-user":
|
||||
return _first_candidate(candidates, "cookie_bundle", "cookie", "bearer_token", "api_key")
|
||||
if preferred_auth_type == "cookie":
|
||||
for c in candidates:
|
||||
if c["type"] == "cookie_bundle":
|
||||
return c
|
||||
for c in candidates:
|
||||
if c["type"] == "cookie":
|
||||
return c
|
||||
return _first_candidate(candidates, "cookie_bundle", "cookie")
|
||||
elif preferred_auth_type in ("bearer", "api_key"):
|
||||
type_map = {"bearer": "bearer_token", "api_key": "api_key"}
|
||||
preferred = type_map.get(preferred_auth_type)
|
||||
if preferred:
|
||||
for c in candidates:
|
||||
if c["type"] == preferred:
|
||||
return c
|
||||
return _first_candidate(candidates, preferred)
|
||||
# fallback:排序后取第一个
|
||||
return candidates[0]
|
||||
|
||||
@@ -268,11 +288,17 @@ async def refresh_auth(pid: int, db: Session = Depends(get_db), _=Depends(get_cu
|
||||
return RefreshAuthResponse(success=False, message=f"提取失败: {exc}")
|
||||
|
||||
candidates = result.get("candidates", [])
|
||||
candidate = _pick_best_candidate(candidates, upstream.auth_type)
|
||||
existing_config = _json.loads(upstream.auth_config_json or "{}")
|
||||
platform = _detect_upstream_platform(upstream, existing_config)
|
||||
candidate = _pick_best_candidate(candidates, upstream.auth_type, platform)
|
||||
if not candidate:
|
||||
if platform == "sub2api" and _first_candidate(candidates, "cookie_bundle", "cookie"):
|
||||
return RefreshAuthResponse(
|
||||
success=False,
|
||||
message="Sub2API 需要 Bearer Token;当前只提取到 Cookie。请在远程浏览器完成登录后刷新页面或触发一次接口请求,再重新提取。",
|
||||
)
|
||||
return RefreshAuthResponse(success=False, message="未提取到有效凭证,请确认已在远程浏览器中登录")
|
||||
|
||||
existing_config = _json.loads(upstream.auth_config_json or "{}")
|
||||
ctype = candidate["type"]
|
||||
|
||||
if ctype in ("cookie_bundle", "cookie"):
|
||||
@@ -281,6 +307,10 @@ async def refresh_auth(pid: int, db: Session = Depends(get_db), _=Depends(get_cu
|
||||
existing_config["cookie_string"] = candidate.get("value", "")
|
||||
if candidate.get("new_api_user"):
|
||||
existing_config["new_api_user"] = candidate["new_api_user"]
|
||||
if platform == "new-api-user":
|
||||
upstream.api_prefix = ""
|
||||
upstream.groups_endpoint = "/api/user/self/groups"
|
||||
upstream.rate_endpoint = "/api/user/self/groups"
|
||||
elif ctype == "bearer_token":
|
||||
upstream.auth_type = "bearer"
|
||||
raw = candidate.get("value", "")
|
||||
|
||||
@@ -315,6 +315,26 @@ def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_c
|
||||
_generate_key_lock = __import__("threading").Lock()
|
||||
|
||||
|
||||
def _is_sub2api_upstream(upstream: Upstream) -> bool:
|
||||
return upstream.api_prefix.strip("/") == "api/v1"
|
||||
|
||||
|
||||
def _is_new_api_user_upstream(upstream: Upstream) -> bool:
|
||||
auth_config = json.loads(upstream.auth_config_json or "{}")
|
||||
return (
|
||||
upstream.api_prefix.strip("/") == ""
|
||||
and (
|
||||
upstream.groups_endpoint == "/api/user/self/groups"
|
||||
or auth_config.get("login_path") == "/api/user/login"
|
||||
or bool(auth_config.get("new_api_user"))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _supports_key_generation(upstream: Upstream) -> bool:
|
||||
return _is_sub2api_upstream(upstream) or _is_new_api_user_upstream(upstream)
|
||||
|
||||
|
||||
def _ensure_group_key(
|
||||
db: Session,
|
||||
client: UpstreamClient,
|
||||
@@ -465,8 +485,8 @@ def generate_keys_by_groups(
|
||||
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
||||
if not u:
|
||||
raise HTTPException(404, "upstream not found")
|
||||
if u.api_prefix.strip("/") != "api/v1":
|
||||
raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1)")
|
||||
if not _supports_key_generation(u):
|
||||
raise HTTPException(400, "仅支持 Sub2API 或 New-API 普通账号上游生成 Key")
|
||||
|
||||
# 生成前先对账,清理远端已删除的旧 Key
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user