feat: support real browser auth import

This commit is contained in:
liumangmang
2026-06-02 13:51:29 +08:00
parent f4d16a4c01
commit 84148f4a69
22 changed files with 1651 additions and 111 deletions
+156 -1
View File
@@ -2,6 +2,7 @@
from __future__ import annotations
import json
import time
from typing import Any, Optional
from urllib.parse import urljoin
import httpx
@@ -13,6 +14,9 @@ class UpstreamError(RuntimeError):
pass
NEW_API_DEFAULT_QUOTA_PER_UNIT = 500000
def _find_token(value: Any) -> str:
if isinstance(value, str) and value.count(".") >= 2:
return value
@@ -74,6 +78,8 @@ def mask_secret(value: Any) -> str:
def _unwrap_data(value: Any) -> Any:
if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value):
return value.get("data")
if isinstance(value, dict) and "data" in value and "success" in value:
return value.get("data")
return value
@@ -105,6 +111,20 @@ def _extract_key_value(value: Any) -> str:
return ""
def _is_success_response(value: Any) -> bool:
if not isinstance(value, dict) or "success" not in value:
return True
return value.get("success") is True
def _response_message(value: Any, fallback: str = "") -> str:
if isinstance(value, dict):
msg = value.get("message") or value.get("detail")
if msg:
return str(msg)
return fallback
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
def _normalize(lst: list) -> list[dict[str, Any]]:
out = []
@@ -288,6 +308,17 @@ class UpstreamClient:
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
def _is_new_api_user_mode(self) -> bool:
login_path = str(self.auth_config.get("login_path") or "")
return (
self.api_prefix == ""
and (
bool(self.auth_config.get("new_api_user"))
or login_path == "/api/user/login"
or self.auth_type == "cookie"
)
)
def _headers(self, auth: bool = True) -> dict[str, str]:
headers: dict[str, str] = {
"Accept": "application/json",
@@ -348,6 +379,119 @@ class UpstreamClient:
raise UpstreamError(f"{method} {path} returned HTML, not JSON")
return resp.json()
def _ensure_api_success(self, payload: Any, action: str) -> None:
if not _is_success_response(payload):
raise UpstreamError(_response_message(payload, f"{action} failed"))
def _new_api_quota_per_unit(self) -> int:
try:
payload = self._request("GET", "/api/status", auth=False)
data = _unwrap_data(payload)
if isinstance(data, dict):
value = data.get("quota_per_unit")
quota_per_unit = int(float(value))
if quota_per_unit > 0:
return quota_per_unit
except Exception:
pass
return NEW_API_DEFAULT_QUOTA_PER_UNIT
@staticmethod
def _normalize_key_record(record: dict[str, Any]) -> dict[str, Any]:
out = dict(record)
if not out.get("name") and out.get("key_name"):
out["name"] = out.get("key_name")
if not out.get("group_id") and out.get("group"):
out["group_id"] = str(out.get("group"))
if not out.get("group_name") and out.get("group"):
out["group_name"] = str(out.get("group"))
return out
def _list_new_api_tokens(
self,
search: str = "",
group_id: str | int | None = None,
) -> list[dict[str, Any]]:
if search:
path = "/api/token/search"
params = {"keyword": search, "token": "", "p": 1, "size": 100}
else:
path = "/api/token/"
params = {"p": 1, "size": 100}
resp = self._client.request(
"GET",
self._url(path),
params=params,
headers=self._headers(),
cookies=self._cookies,
)
self._cookies.update(dict(resp.cookies))
resp.raise_for_status()
data = resp.json()
self._ensure_api_success(data, "list New-API tokens")
nested = _unwrap_data(data)
items: list[dict[str, Any]] | None = None
if isinstance(nested, list):
items = [i for i in nested if isinstance(i, dict)]
elif isinstance(nested, dict):
for key in ("items", "tokens", "list", "records"):
value = nested.get(key)
if isinstance(value, list):
items = [i for i in value if isinstance(i, dict)]
break
if items is None:
raise UpstreamError("unexpected New-API token list response")
normalized = [self._normalize_key_record(i) for i in items]
if group_id is not None:
gid = str(group_id)
normalized = [i for i in normalized if str(i.get("group_id") or i.get("group") or "") == gid]
return normalized
def _get_new_api_token_key(self, token_id: str | int) -> str:
payload = self._request("POST", f"/api/token/{token_id}/key")
self._ensure_api_success(payload, "get New-API token key")
key_value = _extract_key_value(_unwrap_data(payload))
if not key_value:
raise UpstreamError("New-API token key response did not include key")
return key_value
def _create_new_api_token(
self,
name: str,
group_id: str | int,
quota: float = 0,
expires_in_days: int | None = None,
) -> dict[str, Any]:
unlimited = quota <= 0
body: dict[str, Any] = {
"name": name,
"remain_quota": 0 if unlimited else int(round(quota * self._new_api_quota_per_unit())),
"unlimited_quota": unlimited,
"expired_time": int(time.time()) + expires_in_days * 86400 if expires_in_days else -1,
"model_limits_enabled": False,
"model_limits": "",
"allow_ips": "",
"group": str(group_id),
"cross_group_retry": False,
}
payload = self._request("POST", "/api/token/", body)
self._ensure_api_success(payload, "create New-API token")
matches = self._list_new_api_tokens(search=name, group_id=group_id)
token = next((i for i in matches if str(i.get("name") or "").strip() == name.strip()), None)
if not token:
raise UpstreamError("New-API token was created but could not be found by name")
token_id = token.get("id")
if token_id is None:
raise UpstreamError("New-API token list response did not include id")
key_value = self._get_new_api_token_key(token_id)
return {
"id": str(token_id),
"key": key_value,
"masked_key": mask_secret(key_value),
"raw": self._normalize_key_record(token),
}
def login(self) -> None:
if self.auth_type != "login_password":
return
@@ -412,6 +556,9 @@ class UpstreamClient:
endpoint: str = "/keys",
) -> list[dict[str, Any]]:
"""查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。"""
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
return self._list_new_api_tokens(search=search, group_id=group_id)
params: dict[str, Any] = {}
if search:
params["search"] = search
@@ -446,7 +593,7 @@ class UpstreamClient:
for key in ("items", "keys", "list", "records"):
val = data.get(key)
if isinstance(val, list):
return val
return [self._normalize_key_record(i) for i in val if isinstance(i, dict)]
raise UpstreamError(f"unexpected keys response type: {type(data).__name__}")
def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None:
@@ -486,6 +633,14 @@ class UpstreamClient:
rate_limit_7d: float = 0,
endpoint: str = "/keys",
) -> dict[str, Any]:
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
return self._create_new_api_token(
name,
group_id,
quota=quota,
expires_in_days=expires_in_days,
)
body: dict[str, Any] = {
"name": name,
"group_id": int(group_id) if str(group_id).isdigit() else group_id,