feat: support real browser auth import
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urljoin
|
||||
import httpx
|
||||
@@ -13,6 +14,9 @@ class UpstreamError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
NEW_API_DEFAULT_QUOTA_PER_UNIT = 500000
|
||||
|
||||
|
||||
def _find_token(value: Any) -> str:
|
||||
if isinstance(value, str) and value.count(".") >= 2:
|
||||
return value
|
||||
@@ -74,6 +78,8 @@ def mask_secret(value: Any) -> str:
|
||||
def _unwrap_data(value: Any) -> Any:
|
||||
if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value):
|
||||
return value.get("data")
|
||||
if isinstance(value, dict) and "data" in value and "success" in value:
|
||||
return value.get("data")
|
||||
return value
|
||||
|
||||
|
||||
@@ -105,6 +111,20 @@ def _extract_key_value(value: Any) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _is_success_response(value: Any) -> bool:
|
||||
if not isinstance(value, dict) or "success" not in value:
|
||||
return True
|
||||
return value.get("success") is True
|
||||
|
||||
|
||||
def _response_message(value: Any, fallback: str = "") -> str:
|
||||
if isinstance(value, dict):
|
||||
msg = value.get("message") or value.get("detail")
|
||||
if msg:
|
||||
return str(msg)
|
||||
return fallback
|
||||
|
||||
|
||||
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
|
||||
def _normalize(lst: list) -> list[dict[str, Any]]:
|
||||
out = []
|
||||
@@ -288,6 +308,17 @@ class UpstreamClient:
|
||||
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
|
||||
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
|
||||
|
||||
def _is_new_api_user_mode(self) -> bool:
|
||||
login_path = str(self.auth_config.get("login_path") or "")
|
||||
return (
|
||||
self.api_prefix == ""
|
||||
and (
|
||||
bool(self.auth_config.get("new_api_user"))
|
||||
or login_path == "/api/user/login"
|
||||
or self.auth_type == "cookie"
|
||||
)
|
||||
)
|
||||
|
||||
def _headers(self, auth: bool = True) -> dict[str, str]:
|
||||
headers: dict[str, str] = {
|
||||
"Accept": "application/json",
|
||||
@@ -348,6 +379,119 @@ class UpstreamClient:
|
||||
raise UpstreamError(f"{method} {path} returned HTML, not JSON")
|
||||
return resp.json()
|
||||
|
||||
def _ensure_api_success(self, payload: Any, action: str) -> None:
|
||||
if not _is_success_response(payload):
|
||||
raise UpstreamError(_response_message(payload, f"{action} failed"))
|
||||
|
||||
def _new_api_quota_per_unit(self) -> int:
|
||||
try:
|
||||
payload = self._request("GET", "/api/status", auth=False)
|
||||
data = _unwrap_data(payload)
|
||||
if isinstance(data, dict):
|
||||
value = data.get("quota_per_unit")
|
||||
quota_per_unit = int(float(value))
|
||||
if quota_per_unit > 0:
|
||||
return quota_per_unit
|
||||
except Exception:
|
||||
pass
|
||||
return NEW_API_DEFAULT_QUOTA_PER_UNIT
|
||||
|
||||
@staticmethod
|
||||
def _normalize_key_record(record: dict[str, Any]) -> dict[str, Any]:
|
||||
out = dict(record)
|
||||
if not out.get("name") and out.get("key_name"):
|
||||
out["name"] = out.get("key_name")
|
||||
if not out.get("group_id") and out.get("group"):
|
||||
out["group_id"] = str(out.get("group"))
|
||||
if not out.get("group_name") and out.get("group"):
|
||||
out["group_name"] = str(out.get("group"))
|
||||
return out
|
||||
|
||||
def _list_new_api_tokens(
|
||||
self,
|
||||
search: str = "",
|
||||
group_id: str | int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if search:
|
||||
path = "/api/token/search"
|
||||
params = {"keyword": search, "token": "", "p": 1, "size": 100}
|
||||
else:
|
||||
path = "/api/token/"
|
||||
params = {"p": 1, "size": 100}
|
||||
resp = self._client.request(
|
||||
"GET",
|
||||
self._url(path),
|
||||
params=params,
|
||||
headers=self._headers(),
|
||||
cookies=self._cookies,
|
||||
)
|
||||
self._cookies.update(dict(resp.cookies))
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._ensure_api_success(data, "list New-API tokens")
|
||||
nested = _unwrap_data(data)
|
||||
items: list[dict[str, Any]] | None = None
|
||||
if isinstance(nested, list):
|
||||
items = [i for i in nested if isinstance(i, dict)]
|
||||
elif isinstance(nested, dict):
|
||||
for key in ("items", "tokens", "list", "records"):
|
||||
value = nested.get(key)
|
||||
if isinstance(value, list):
|
||||
items = [i for i in value if isinstance(i, dict)]
|
||||
break
|
||||
if items is None:
|
||||
raise UpstreamError("unexpected New-API token list response")
|
||||
normalized = [self._normalize_key_record(i) for i in items]
|
||||
if group_id is not None:
|
||||
gid = str(group_id)
|
||||
normalized = [i for i in normalized if str(i.get("group_id") or i.get("group") or "") == gid]
|
||||
return normalized
|
||||
|
||||
def _get_new_api_token_key(self, token_id: str | int) -> str:
|
||||
payload = self._request("POST", f"/api/token/{token_id}/key")
|
||||
self._ensure_api_success(payload, "get New-API token key")
|
||||
key_value = _extract_key_value(_unwrap_data(payload))
|
||||
if not key_value:
|
||||
raise UpstreamError("New-API token key response did not include key")
|
||||
return key_value
|
||||
|
||||
def _create_new_api_token(
|
||||
self,
|
||||
name: str,
|
||||
group_id: str | int,
|
||||
quota: float = 0,
|
||||
expires_in_days: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
unlimited = quota <= 0
|
||||
body: dict[str, Any] = {
|
||||
"name": name,
|
||||
"remain_quota": 0 if unlimited else int(round(quota * self._new_api_quota_per_unit())),
|
||||
"unlimited_quota": unlimited,
|
||||
"expired_time": int(time.time()) + expires_in_days * 86400 if expires_in_days else -1,
|
||||
"model_limits_enabled": False,
|
||||
"model_limits": "",
|
||||
"allow_ips": "",
|
||||
"group": str(group_id),
|
||||
"cross_group_retry": False,
|
||||
}
|
||||
payload = self._request("POST", "/api/token/", body)
|
||||
self._ensure_api_success(payload, "create New-API token")
|
||||
|
||||
matches = self._list_new_api_tokens(search=name, group_id=group_id)
|
||||
token = next((i for i in matches if str(i.get("name") or "").strip() == name.strip()), None)
|
||||
if not token:
|
||||
raise UpstreamError("New-API token was created but could not be found by name")
|
||||
token_id = token.get("id")
|
||||
if token_id is None:
|
||||
raise UpstreamError("New-API token list response did not include id")
|
||||
key_value = self._get_new_api_token_key(token_id)
|
||||
return {
|
||||
"id": str(token_id),
|
||||
"key": key_value,
|
||||
"masked_key": mask_secret(key_value),
|
||||
"raw": self._normalize_key_record(token),
|
||||
}
|
||||
|
||||
def login(self) -> None:
|
||||
if self.auth_type != "login_password":
|
||||
return
|
||||
@@ -412,6 +556,9 @@ class UpstreamClient:
|
||||
endpoint: str = "/keys",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。"""
|
||||
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
|
||||
return self._list_new_api_tokens(search=search, group_id=group_id)
|
||||
|
||||
params: dict[str, Any] = {}
|
||||
if search:
|
||||
params["search"] = search
|
||||
@@ -446,7 +593,7 @@ class UpstreamClient:
|
||||
for key in ("items", "keys", "list", "records"):
|
||||
val = data.get(key)
|
||||
if isinstance(val, list):
|
||||
return val
|
||||
return [self._normalize_key_record(i) for i in val if isinstance(i, dict)]
|
||||
raise UpstreamError(f"unexpected keys response type: {type(data).__name__}")
|
||||
|
||||
def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None:
|
||||
@@ -486,6 +633,14 @@ class UpstreamClient:
|
||||
rate_limit_7d: float = 0,
|
||||
endpoint: str = "/keys",
|
||||
) -> dict[str, Any]:
|
||||
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
|
||||
return self._create_new_api_token(
|
||||
name,
|
||||
group_id,
|
||||
quota=quota,
|
||||
expires_in_days=expires_in_days,
|
||||
)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"name": name,
|
||||
"group_id": int(group_id) if str(group_id).isdigit() else group_id,
|
||||
|
||||
Reference in New Issue
Block a user