Files
SmartUp/backend/app/services/upstream_client.py
T
2026-06-03 17:03:11 +08:00

730 lines
27 KiB
Python

"""Upstream HTTP client — ported from monitor_ai98pro_group_rates.py."""
from __future__ import annotations
import json
import time
from typing import Any, Optional
from urllib.parse import urljoin
import httpx
from app.utils.number import decimal_string
class UpstreamError(RuntimeError):
pass
NEW_API_DEFAULT_QUOTA_PER_UNIT = 500000
def _find_token(value: Any) -> str:
if isinstance(value, str) and value.count(".") >= 2:
return value
if isinstance(value, dict):
for key in ("token", "access_token", "accessToken", "jwt", "auth_token", "authToken"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
for key in ("data", "result", "user", "session"):
tok = _find_token(value.get(key))
if tok:
return tok
return ""
def _clean_auth_header_value(value: Any, field_name: str) -> str:
text = str(value or "").strip()
if not text:
return ""
if text.startswith("Bearer "):
text = text[7:].strip()
# Try to sanitize non-latin-1 characters instead of hard-failing
try:
text.encode("latin-1")
except UnicodeEncodeError:
# Try stripping non-ASCII characters
cleaned = text.encode("ascii", errors="ignore").decode("ascii").strip()
if cleaned:
return cleaned
raise UpstreamError(
f"{field_name} 含有非 HTTP 标头字符(如中文或 emoji),"
f"请重新登录后再试"
) from None
return text
def _find_user_id(value: Any) -> str:
if isinstance(value, dict):
for key in ("id", "user_id", "userId"):
candidate = value.get(key)
if candidate is not None:
return str(candidate)
for key in ("data", "result", "user", "session"):
user_id = _find_user_id(value.get(key))
if user_id:
return user_id
return ""
def mask_secret(value: Any) -> str:
text = str(value or "")
if not text:
return ""
if len(text) <= 8:
return text[:2] + "****" + text[-2:] if len(text) > 4 else "****"
return text[:4] + "**********" + text[-4:]
def _unwrap_data(value: Any) -> Any:
if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value):
return value.get("data")
if isinstance(value, dict) and "data" in value and "success" in value:
return value.get("data")
return value
def _extract_id(value: Any) -> str:
if isinstance(value, dict):
for key in ("id", "key_id", "keyId"):
candidate = value.get(key)
if candidate is not None:
return str(candidate)
for key in ("data", "result", "key", "api_key"):
found = _extract_id(value.get(key))
if found:
return found
return ""
def _extract_key_value(value: Any) -> str:
if isinstance(value, str):
return value
if isinstance(value, dict):
for key in ("key", "api_key", "apiKey", "token", "value"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
for key in ("data", "result", "api_key", "key"):
found = _extract_key_value(value.get(key))
if found:
return found
return ""
def _is_success_response(value: Any) -> bool:
if not isinstance(value, dict) or "success" not in value:
return True
return value.get("success") is True
def _response_message(value: Any, fallback: str = "") -> str:
if isinstance(value, dict):
msg = value.get("message") or value.get("detail")
if msg:
return str(msg)
return fallback
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
def _normalize(lst: list) -> list[dict[str, Any]]:
out = []
for i in lst:
if isinstance(i, dict):
out.append(i)
elif isinstance(i, str):
out.append({"id": i, "name": i})
return out
if isinstance(value, list):
return _normalize(value)
if isinstance(value, dict):
for key in ("data", "items", "groups", "available_groups", "availableGroups"):
nested = value.get(key)
if isinstance(nested, list):
return _normalize(nested)
elif isinstance(nested, dict):
# Handle /api/user/self/groups where data is a dict of group_name -> { desc, ratio }
out = []
for k in nested.keys():
out.append({"id": k, "name": k})
return out
return None
def _group_id(group: dict[str, Any]) -> str:
for key in ("id", "group_id", "groupId"):
v = group.get(key)
if v is not None:
return str(v)
name = str(group.get("name") or group.get("group_name") or "")
platform = str(group.get("platform") or "")
return f"{platform}:{name}"
def _rate_from_group(group: dict[str, Any]) -> str:
for key in (
"user_rate_multiplier", "userRateMultiplier",
"effective_rate_multiplier", "effectiveRateMultiplier",
"rate_multiplier", "rateMultiplier",
):
r = decimal_string(group.get(key))
if r:
return r
return ""
def _extract_rates_map(raw: Any) -> dict[str, str]:
if raw is None:
return {}
# Handle one-api/new-api /api/option response where GroupRatio is in a list of options
if isinstance(raw, dict) and isinstance(raw.get("data"), list):
for item in raw["data"]:
if isinstance(item, dict) and item.get("key") == "GroupRatio":
val = item.get("value")
if isinstance(val, str):
try:
import json
parsed = json.loads(val)
if isinstance(parsed, dict):
result: dict[str, str] = {}
for k, v in parsed.items():
r = decimal_string(v)
if r:
result[str(k)] = r
return result
except Exception:
pass
elif isinstance(val, dict):
# In case it's returned as dict directly
result = {}
for k, v in val.items():
r = decimal_string(v)
if r:
result[str(k)] = r
return result
if isinstance(raw, dict):
candidates = raw
for key in ("data", "rates", "group_rates", "groupRates", "GroupRatio"):
nested = raw.get(key)
if isinstance(nested, dict):
candidates = nested
break
elif isinstance(nested, str) and key == "GroupRatio":
# Handle GroupRatio as a JSON string
try:
import json
parsed = json.loads(nested)
if isinstance(parsed, dict):
candidates = parsed
break
except Exception:
pass
result: dict[str, str] = {}
for k, v in candidates.items():
if isinstance(v, dict):
r = decimal_string(
v.get("rate_multiplier") or v.get("rateMultiplier")
or v.get("user_rate_multiplier") or v.get("userRateMultiplier")
or v.get("ratio")
)
else:
r = decimal_string(v)
if r:
result[str(k)] = r
return result
if isinstance(raw, list):
result = {}
for item in raw:
if not isinstance(item, dict):
continue
gid = _group_id(item)
rate = _rate_from_group(item)
if gid and rate:
result[gid] = rate
return result
return {}
def build_snapshot(upstream_id: int, base_url: str, api_prefix: str,
groups: list[dict[str, Any]], raw_rates: Any) -> dict[str, Any]:
from datetime import datetime, timezone
override_rates = _extract_rates_map(raw_rates)
entries: dict[str, dict[str, Any]] = {}
for g in groups:
gid = _group_id(g)
default_rate = _rate_from_group(g)
effective_rate = override_rates.get(gid, default_rate)
entries[gid] = {
"group_id": gid,
"group_name": g.get("name") or g.get("group_name") or "",
"platform": g.get("platform") or "",
"rate": effective_rate,
"default_rate": default_rate,
"override_rate": override_rates.get(gid, ""),
}
return {
"upstream_id": upstream_id,
"base_url": base_url.rstrip("/"),
"api_prefix": api_prefix,
"captured_at": datetime.now(timezone.utc).astimezone().isoformat(timespec="seconds"),
"groups": entries,
}
class UpstreamClient:
"""Sync HTTP client that handles all auth types."""
def __init__(
self,
base_url: str,
api_prefix: str,
auth_type: str,
auth_config: dict[str, Any],
timeout: float = 30.0,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_prefix = api_prefix.strip("/")
self.auth_type = auth_type
self.auth_config = auth_config
self.timeout = timeout
self._token: str = ""
self._cookies: dict[str, str] = {}
self._new_api_user: str = ""
self._client = httpx.Client(timeout=timeout)
def close(self) -> None:
self._client.close()
def __enter__(self) -> UpstreamClient:
return self
def __exit__(self, *args: Any) -> None:
self.close()
def _url(self, path: str) -> str:
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
def _is_new_api_user_mode(self) -> bool:
login_path = str(self.auth_config.get("login_path") or "")
return (
self.api_prefix == ""
and (
bool(self.auth_config.get("new_api_user"))
or login_path == "/api/user/login"
or self.auth_type == "cookie"
)
)
def _headers(self, auth: bool = True) -> dict[str, str]:
headers: dict[str, str] = {
"Accept": "application/json",
"User-Agent": "SmartUp/1.0",
}
if not auth:
return headers
if self.auth_type == "bearer":
token = _clean_auth_header_value(self.auth_config.get("token", ""), "Bearer token")
if token:
headers["Authorization"] = f"Bearer {token}"
elif self.auth_type == "api_key":
key = _clean_auth_header_value(self.auth_config.get("key", ""), "API key")
header = self.auth_config.get("header", "Authorization")
if key:
headers[header] = key
elif self.auth_type == "cookie":
cookie_str = _clean_auth_header_value(self.auth_config.get("cookie_string", ""), "Cookie")
if cookie_str:
headers["Cookie"] = cookie_str
new_api_user = _clean_auth_header_value(self.auth_config.get("new_api_user", ""), "New-Api-User")
if new_api_user:
headers["New-Api-User"] = new_api_user
elif self.auth_type == "login_password" and self._token:
token = _clean_auth_header_value(self._token, "Login token")
if token:
headers["Authorization"] = f"Bearer {token}"
if self.auth_type == "login_password" and self._new_api_user:
headers["New-Api-User"] = self._new_api_user
return headers
def _request(self, method: str, path: str, body: Any = None, auth: bool = True) -> Any:
if auth and self.auth_type == "cookie" and "user/self" in path and not self.auth_config.get("new_api_user"):
raise UpstreamError("New-API user endpoint requires New-Api-User; re-extract the session cookie after login and save the upstream")
url = self._url(path)
if body is not None:
resp = self._client.request(
method,
url,
json=body,
headers=self._headers(auth),
cookies=self._cookies,
)
else:
resp = self._client.request(
method,
url,
headers=self._headers(auth),
cookies=self._cookies,
)
self._cookies.update(dict(resp.cookies))
resp.raise_for_status()
ct = resp.headers.get("content-type", "")
if not resp.content:
return None
text = resp.text
if "application/json" not in ct and text.lstrip().startswith("<"):
raise UpstreamError(f"{method} {path} returned HTML, not JSON")
return resp.json()
def _ensure_api_success(self, payload: Any, action: str) -> None:
if not _is_success_response(payload):
raise UpstreamError(_response_message(payload, f"{action} failed"))
def _new_api_quota_per_unit(self) -> int:
try:
payload = self._request("GET", "/api/status", auth=False)
data = _unwrap_data(payload)
if isinstance(data, dict):
value = data.get("quota_per_unit")
quota_per_unit = int(float(value))
if quota_per_unit > 0:
return quota_per_unit
except Exception:
pass
return NEW_API_DEFAULT_QUOTA_PER_UNIT
@staticmethod
def _normalize_key_record(record: dict[str, Any]) -> dict[str, Any]:
out = dict(record)
if not out.get("name") and out.get("key_name"):
out["name"] = out.get("key_name")
if not out.get("group_id") and out.get("group"):
out["group_id"] = str(out.get("group"))
if not out.get("group_name") and out.get("group"):
out["group_name"] = str(out.get("group"))
return out
@staticmethod
def _extract_new_api_token_items(payload: Any) -> tuple[list[dict[str, Any]], dict[str, Any]]:
nested = _unwrap_data(payload)
meta: dict[str, Any] = {}
items: list[dict[str, Any]] | None = None
if isinstance(nested, list):
items = [i for i in nested if isinstance(i, dict)]
elif isinstance(nested, dict):
meta = nested
for key in ("items", "tokens", "list", "records"):
value = nested.get(key)
if isinstance(value, list):
items = [i for i in value if isinstance(i, dict)]
break
if items is None:
raise UpstreamError("unexpected New-API token list response")
return items, meta
def _request_new_api_token_list(self, path: str, params: dict[str, Any]) -> tuple[list[dict[str, Any]], dict[str, Any]]:
resp = self._client.request(
"GET",
self._url(path),
params=params,
headers=self._headers(),
cookies=self._cookies,
)
self._cookies.update(dict(resp.cookies))
resp.raise_for_status()
data = resp.json()
self._ensure_api_success(data, "list New-API tokens")
return self._extract_new_api_token_items(data)
def _list_all_new_api_tokens(self, page_size: int = 100, max_pages: int = 20) -> list[dict[str, Any]]:
all_items: list[dict[str, Any]] = []
for page in range(1, max_pages + 1):
items, meta = self._request_new_api_token_list(
"/api/token/",
{"p": page, "size": page_size},
)
all_items.extend(items)
total = meta.get("total") if isinstance(meta, dict) else None
if isinstance(total, int) and len(all_items) >= total:
break
if len(items) < page_size:
break
return [self._normalize_key_record(i) for i in all_items]
@staticmethod
def _matches_new_api_token_search(record: dict[str, Any], search: str) -> bool:
if not search:
return True
needle = search.strip()
if not needle:
return True
name = str(record.get("name") or record.get("key_name") or "")
key = str(record.get("key") or record.get("api_key") or record.get("apiKey") or record.get("token") or "")
return name.startswith(needle) or name == needle or key == needle
def _hydrate_new_api_token_key(self, record: dict[str, Any]) -> dict[str, Any]:
out = dict(record)
key_value = _extract_key_value(out)
if key_value and "*" not in key_value:
out["key"] = key_value
out["masked_key"] = out.get("masked_key") or mask_secret(key_value)
return out
token_id = out.get("id")
if token_id is None:
return out
try:
plaintext = self._get_new_api_token_key(token_id)
except Exception:
return out
out["key"] = plaintext
out["masked_key"] = mask_secret(plaintext)
return out
def _list_new_api_tokens(
self,
search: str = "",
group_id: str | int | None = None,
) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
if search:
search_items, _ = self._request_new_api_token_list(
"/api/token/search",
{"keyword": search, "token": "", "p": 1, "size": 100},
)
normalized.extend(self._normalize_key_record(i) for i in search_items)
all_tokens = self._list_all_new_api_tokens()
seen_ids = {str(i.get("id")) for i in normalized if i.get("id") is not None}
for item in all_tokens:
item_id = item.get("id")
if item_id is not None and str(item_id) in seen_ids:
continue
if self._matches_new_api_token_search(item, search):
normalized.append(item)
if item_id is not None:
seen_ids.add(str(item_id))
if group_id is not None:
gid = str(group_id)
normalized = [i for i in normalized if str(i.get("group_id") or i.get("group") or "") == gid]
return [self._hydrate_new_api_token_key(i) for i in normalized]
def _get_new_api_token_key(self, token_id: str | int) -> str:
payload = self._request("POST", f"/api/token/{token_id}/key")
self._ensure_api_success(payload, "get New-API token key")
key_value = _extract_key_value(_unwrap_data(payload))
if not key_value:
raise UpstreamError("New-API token key response did not include key")
return key_value
def _create_new_api_token(
self,
name: str,
group_id: str | int,
quota: float = 0,
expires_in_days: int | None = None,
) -> dict[str, Any]:
unlimited = quota <= 0
body: dict[str, Any] = {
"name": name,
"remain_quota": 0 if unlimited else int(round(quota * self._new_api_quota_per_unit())),
"unlimited_quota": unlimited,
"expired_time": int(time.time()) + expires_in_days * 86400 if expires_in_days else -1,
"model_limits_enabled": False,
"model_limits": "",
"allow_ips": "",
"group": str(group_id),
"cross_group_retry": False,
}
payload = self._request("POST", "/api/token/", body)
self._ensure_api_success(payload, "create New-API token")
matches = self._list_new_api_tokens(search=name, group_id=group_id)
token = next((i for i in matches if str(i.get("name") or "").strip() == name.strip()), None)
if not token:
raise UpstreamError("New-API token was created but could not be found by name")
token_id = token.get("id")
if token_id is None:
raise UpstreamError("New-API token list response did not include id")
key_value = self._get_new_api_token_key(token_id)
return {
"id": str(token_id),
"key": key_value,
"masked_key": mask_secret(key_value),
"raw": self._normalize_key_record(token),
}
def login(self) -> None:
if self.auth_type != "login_password":
return
email = self.auth_config.get("email", "")
password = self.auth_config.get("password", "")
login_path = self.auth_config.get("login_path", "/auth/login")
username_field = self.auth_config.get("username_field", "email")
if not email or not password:
raise UpstreamError("login_password auth requires email and password in auth_config")
resp = self._request("POST", login_path, {username_field: email, "password": password}, auth=False)
token = _find_token(resp)
if token:
self._token = token
return
if self._cookies:
self._new_api_user = self.auth_config.get("new_api_user", "") or _find_user_id(resp)
return
raise UpstreamError("login succeeded but no token or session cookie found in response")
def get_available_groups(self, endpoint: str) -> list[dict[str, Any]]:
resp = self._request("GET", endpoint)
groups = _unwrap_list(resp)
if groups is None:
raise UpstreamError(f"{endpoint} did not return a list")
return groups
def get_group_rates(self, endpoint: str) -> Any:
return self._request("GET", endpoint)
def get_balance(self, endpoint: str, response_path: str) -> Optional[float]:
"""Call the balance endpoint and extract a numeric value using a dot-separated JSON path.
response_path 示例:
"balance" → resp["balance"]
"data.quota" → resp["data"]["quota"]
"data.total_balance" → resp["data"]["total_balance"]
"""
if not endpoint or not response_path:
return None
resp = self._request("GET", endpoint)
if not isinstance(resp, dict):
return None
parts = response_path.split(".")
value: Any = resp
for part in parts:
if isinstance(value, dict):
value = value.get(part)
else:
return None
if value is None:
return None
try:
return float(value)
except (ValueError, TypeError):
return None
def list_api_keys(
self,
search: str = "",
group_id: str | int | None = None,
status: str = "active",
endpoint: str = "/keys",
) -> list[dict[str, Any]]:
"""查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。"""
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
return self._list_new_api_tokens(search=search, group_id=group_id)
params: dict[str, Any] = {}
if search:
params["search"] = search
if group_id is not None:
params["group_id"] = int(group_id) if str(group_id).isdigit() else group_id
if status:
params["status"] = status
url = self._url(endpoint)
resp = self._client.request(
"GET",
url,
params=params if params else None,
headers=self._headers(),
cookies=self._cookies,
)
resp.raise_for_status()
data = resp.json()
if isinstance(data, list):
return data
if isinstance(data, dict):
# 尝试展开常见的包装结构
for top_key in ("data", "result", "response"):
val = data.get(top_key)
if isinstance(val, list):
return val
if isinstance(val, dict):
for inner_key in ("items", "keys", "list", "records", "data"):
inner = val.get(inner_key)
if isinstance(inner, list):
return inner
# 顶层本身就是 list-like wrapper
for key in ("items", "keys", "list", "records"):
val = data.get(key)
if isinstance(val, list):
return [self._normalize_key_record(i) for i in val if isinstance(i, dict)]
raise UpstreamError(f"unexpected keys response type: {type(data).__name__}")
def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None:
"""删除远端上游上的一个 Key。"""
self._request("DELETE", f"{endpoint}/{key_id}")
def find_smartup_group_key(
self,
group_id: str | int,
expected_name: str,
prefix: str = "SmartUp",
) -> dict[str, Any] | None:
"""查找同一上游分组下是否已存在 SmartUp 前缀的 Key。
匹配规则:key_name 等于 expected_name,且以 prefix 开头。
返回匹配到的第一个 Key,或 None。
"""
gid = int(group_id) if str(group_id).isdigit() else group_id
keys = self.list_api_keys(search=prefix, group_id=gid, status="active")
for k in keys:
name = k.get("name") or k.get("key_name") or ""
if name == expected_name:
return k
# 部分后端返回的 name 可能带空格或 trimming
if name.strip() == expected_name.strip():
return k
return None
def create_api_key(
self,
name: str,
group_id: str | int,
quota: float = 0,
expires_in_days: int | None = None,
rate_limit_5h: float = 0,
rate_limit_1d: float = 0,
rate_limit_7d: float = 0,
endpoint: str = "/keys",
) -> dict[str, Any]:
if endpoint in {"/api/token", "/api/token/"} or (endpoint == "/keys" and self._is_new_api_user_mode()):
return self._create_new_api_token(
name,
group_id,
quota=quota,
expires_in_days=expires_in_days,
)
body: dict[str, Any] = {
"name": name,
"group_id": int(group_id) if str(group_id).isdigit() else group_id,
"quota": quota,
"rate_limit_5h": rate_limit_5h,
"rate_limit_1d": rate_limit_1d,
"rate_limit_7d": rate_limit_7d,
}
if expires_in_days:
body["expires_in_days"] = expires_in_days
resp = self._request("POST", endpoint, body)
data = _unwrap_data(resp)
key_value = _extract_key_value(data)
if not key_value:
raise UpstreamError("key create response did not include key")
return {
"id": _extract_id(data),
"key": key_value,
"masked_key": mask_secret(key_value),
"raw": data if isinstance(data, dict) else {"value": data},
}