Files
SmartUp/backend/app/services/upstream_client.py
T
liumangmang 6044b00685 feat: 上游 Key 唯一化、分组导入跳过、账号导入平台识别&远端校验&base_url 注入
- 上游 Key 命名改为 {prefix}-{upstream.id}-{safe_group_name}-{group_id}
- 唯一约束 (upstream_id, group_id, managed_prefix) 加 managed_prefix 列
- 上游检测成功时同步 Key 状态,远端已删/分组已删自动清理
- 重复分组导入跳过,目标网站已存在同名分组返回 exists
- 账号导入平台自动识别(auto/manual 模式)
- 全选可导入 Key 按钮 + 目标分组自动匹配
- 导入幂等:已导入过的 Key 校验远端账号,不存在则重建
- 新增同步接口 POST /sync-imported-upstream-keys
- account_exists() 通过拉取账号列表判断,避免 404 误判
- credentials.base_url 注入来源上游地址,避免 401
- 前端导入弹窗自动同步+刷新按钮+并发/优先级设置
- 新增 12 个测试覆盖同步、幂等、远端删除、校验失败路径
2026-05-21 01:16:39 +08:00

510 lines
18 KiB
Python

"""Upstream HTTP client — ported from monitor_ai98pro_group_rates.py."""
from __future__ import annotations
import json
from typing import Any, Optional
from urllib.parse import urljoin
import httpx
from app.utils.number import decimal_string
class UpstreamError(RuntimeError):
pass
def _find_token(value: Any) -> str:
if isinstance(value, str) and value.count(".") >= 2:
return value
if isinstance(value, dict):
for key in ("token", "access_token", "accessToken", "jwt", "auth_token", "authToken"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
for key in ("data", "result", "user", "session"):
tok = _find_token(value.get(key))
if tok:
return tok
return ""
def _clean_auth_header_value(value: Any, field_name: str) -> str:
text = str(value or "").strip()
if not text:
return ""
if text.startswith("Bearer "):
text = text[7:].strip()
# Try to sanitize non-latin-1 characters instead of hard-failing
try:
text.encode("latin-1")
except UnicodeEncodeError:
# Try stripping non-ASCII characters
cleaned = text.encode("ascii", errors="ignore").decode("ascii").strip()
if cleaned:
return cleaned
raise UpstreamError(
f"{field_name} 含有非 HTTP 标头字符(如中文或 emoji),"
f"请重新登录后再试"
) from None
return text
def _find_user_id(value: Any) -> str:
if isinstance(value, dict):
for key in ("id", "user_id", "userId"):
candidate = value.get(key)
if candidate is not None:
return str(candidate)
for key in ("data", "result", "user", "session"):
user_id = _find_user_id(value.get(key))
if user_id:
return user_id
return ""
def mask_secret(value: Any) -> str:
text = str(value or "")
if not text:
return ""
if len(text) <= 8:
return text[:2] + "****" + text[-2:] if len(text) > 4 else "****"
return text[:4] + "**********" + text[-4:]
def _unwrap_data(value: Any) -> Any:
if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value):
return value.get("data")
return value
def _extract_id(value: Any) -> str:
if isinstance(value, dict):
for key in ("id", "key_id", "keyId"):
candidate = value.get(key)
if candidate is not None:
return str(candidate)
for key in ("data", "result", "key", "api_key"):
found = _extract_id(value.get(key))
if found:
return found
return ""
def _extract_key_value(value: Any) -> str:
if isinstance(value, str):
return value
if isinstance(value, dict):
for key in ("key", "api_key", "apiKey", "token", "value"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
for key in ("data", "result", "api_key", "key"):
found = _extract_key_value(value.get(key))
if found:
return found
return ""
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
def _normalize(lst: list) -> list[dict[str, Any]]:
out = []
for i in lst:
if isinstance(i, dict):
out.append(i)
elif isinstance(i, str):
out.append({"id": i, "name": i})
return out
if isinstance(value, list):
return _normalize(value)
if isinstance(value, dict):
for key in ("data", "items", "groups", "available_groups", "availableGroups"):
nested = value.get(key)
if isinstance(nested, list):
return _normalize(nested)
elif isinstance(nested, dict):
# Handle /api/user/self/groups where data is a dict of group_name -> { desc, ratio }
out = []
for k in nested.keys():
out.append({"id": k, "name": k})
return out
return None
def _group_id(group: dict[str, Any]) -> str:
for key in ("id", "group_id", "groupId"):
v = group.get(key)
if v is not None:
return str(v)
name = str(group.get("name") or group.get("group_name") or "")
platform = str(group.get("platform") or "")
return f"{platform}:{name}"
def _rate_from_group(group: dict[str, Any]) -> str:
for key in (
"user_rate_multiplier", "userRateMultiplier",
"effective_rate_multiplier", "effectiveRateMultiplier",
"rate_multiplier", "rateMultiplier",
):
r = decimal_string(group.get(key))
if r:
return r
return ""
def _extract_rates_map(raw: Any) -> dict[str, str]:
if raw is None:
return {}
# Handle one-api/new-api /api/option response where GroupRatio is in a list of options
if isinstance(raw, dict) and isinstance(raw.get("data"), list):
for item in raw["data"]:
if isinstance(item, dict) and item.get("key") == "GroupRatio":
val = item.get("value")
if isinstance(val, str):
try:
import json
parsed = json.loads(val)
if isinstance(parsed, dict):
result: dict[str, str] = {}
for k, v in parsed.items():
r = decimal_string(v)
if r:
result[str(k)] = r
return result
except Exception:
pass
elif isinstance(val, dict):
# In case it's returned as dict directly
result = {}
for k, v in val.items():
r = decimal_string(v)
if r:
result[str(k)] = r
return result
if isinstance(raw, dict):
candidates = raw
for key in ("data", "rates", "group_rates", "groupRates", "GroupRatio"):
nested = raw.get(key)
if isinstance(nested, dict):
candidates = nested
break
elif isinstance(nested, str) and key == "GroupRatio":
# Handle GroupRatio as a JSON string
try:
import json
parsed = json.loads(nested)
if isinstance(parsed, dict):
candidates = parsed
break
except Exception:
pass
result: dict[str, str] = {}
for k, v in candidates.items():
if isinstance(v, dict):
r = decimal_string(
v.get("rate_multiplier") or v.get("rateMultiplier")
or v.get("user_rate_multiplier") or v.get("userRateMultiplier")
or v.get("ratio")
)
else:
r = decimal_string(v)
if r:
result[str(k)] = r
return result
if isinstance(raw, list):
result = {}
for item in raw:
if not isinstance(item, dict):
continue
gid = _group_id(item)
rate = _rate_from_group(item)
if gid and rate:
result[gid] = rate
return result
return {}
def build_snapshot(upstream_id: int, base_url: str, api_prefix: str,
groups: list[dict[str, Any]], raw_rates: Any) -> dict[str, Any]:
from datetime import datetime, timezone
override_rates = _extract_rates_map(raw_rates)
entries: dict[str, dict[str, Any]] = {}
for g in groups:
gid = _group_id(g)
default_rate = _rate_from_group(g)
effective_rate = override_rates.get(gid, default_rate)
entries[gid] = {
"group_id": gid,
"group_name": g.get("name") or g.get("group_name") or "",
"platform": g.get("platform") or "",
"rate": effective_rate,
"default_rate": default_rate,
"override_rate": override_rates.get(gid, ""),
}
return {
"upstream_id": upstream_id,
"base_url": base_url.rstrip("/"),
"api_prefix": api_prefix,
"captured_at": datetime.now(timezone.utc).astimezone().isoformat(timespec="seconds"),
"groups": entries,
}
class UpstreamClient:
"""Sync HTTP client that handles all auth types."""
def __init__(
self,
base_url: str,
api_prefix: str,
auth_type: str,
auth_config: dict[str, Any],
timeout: float = 30.0,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_prefix = api_prefix.strip("/")
self.auth_type = auth_type
self.auth_config = auth_config
self.timeout = timeout
self._token: str = ""
self._cookies: dict[str, str] = {}
self._new_api_user: str = ""
self._client = httpx.Client(timeout=timeout)
def close(self) -> None:
self._client.close()
def __enter__(self) -> UpstreamClient:
return self
def __exit__(self, *args: Any) -> None:
self.close()
def _url(self, path: str) -> str:
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
def _headers(self, auth: bool = True) -> dict[str, str]:
headers: dict[str, str] = {
"Accept": "application/json",
"User-Agent": "SmartUp/1.0",
}
if not auth:
return headers
if self.auth_type == "bearer":
token = _clean_auth_header_value(self.auth_config.get("token", ""), "Bearer token")
if token:
headers["Authorization"] = f"Bearer {token}"
elif self.auth_type == "api_key":
key = _clean_auth_header_value(self.auth_config.get("key", ""), "API key")
header = self.auth_config.get("header", "Authorization")
if key:
headers[header] = key
elif self.auth_type == "cookie":
cookie_str = _clean_auth_header_value(self.auth_config.get("cookie_string", ""), "Cookie")
if cookie_str:
headers["Cookie"] = cookie_str
new_api_user = _clean_auth_header_value(self.auth_config.get("new_api_user", ""), "New-Api-User")
if new_api_user:
headers["New-Api-User"] = new_api_user
elif self.auth_type == "login_password" and self._token:
token = _clean_auth_header_value(self._token, "Login token")
if token:
headers["Authorization"] = f"Bearer {token}"
if self.auth_type == "login_password" and self._new_api_user:
headers["New-Api-User"] = self._new_api_user
return headers
def _request(self, method: str, path: str, body: Any = None, auth: bool = True) -> Any:
if auth and self.auth_type == "cookie" and "user/self" in path and not self.auth_config.get("new_api_user"):
raise UpstreamError("New-API user endpoint requires New-Api-User; re-extract the session cookie after login and save the upstream")
url = self._url(path)
if body is not None:
resp = self._client.request(
method,
url,
json=body,
headers=self._headers(auth),
cookies=self._cookies,
)
else:
resp = self._client.request(
method,
url,
headers=self._headers(auth),
cookies=self._cookies,
)
self._cookies.update(dict(resp.cookies))
resp.raise_for_status()
ct = resp.headers.get("content-type", "")
if not resp.content:
return None
text = resp.text
if "application/json" not in ct and text.lstrip().startswith("<"):
raise UpstreamError(f"{method} {path} returned HTML, not JSON")
return resp.json()
def login(self) -> None:
if self.auth_type != "login_password":
return
email = self.auth_config.get("email", "")
password = self.auth_config.get("password", "")
login_path = self.auth_config.get("login_path", "/auth/login")
username_field = self.auth_config.get("username_field", "email")
if not email or not password:
raise UpstreamError("login_password auth requires email and password in auth_config")
resp = self._request("POST", login_path, {username_field: email, "password": password}, auth=False)
token = _find_token(resp)
if token:
self._token = token
return
if self._cookies:
self._new_api_user = self.auth_config.get("new_api_user", "") or _find_user_id(resp)
return
raise UpstreamError("login succeeded but no token or session cookie found in response")
def get_available_groups(self, endpoint: str) -> list[dict[str, Any]]:
resp = self._request("GET", endpoint)
groups = _unwrap_list(resp)
if groups is None:
raise UpstreamError(f"{endpoint} did not return a list")
return groups
def get_group_rates(self, endpoint: str) -> Any:
return self._request("GET", endpoint)
def get_balance(self, endpoint: str, response_path: str) -> Optional[float]:
"""Call the balance endpoint and extract a numeric value using a dot-separated JSON path.
response_path 示例:
"balance" → resp["balance"]
"data.quota" → resp["data"]["quota"]
"data.total_balance" → resp["data"]["total_balance"]
"""
if not endpoint or not response_path:
return None
resp = self._request("GET", endpoint)
if not isinstance(resp, dict):
return None
parts = response_path.split(".")
value: Any = resp
for part in parts:
if isinstance(value, dict):
value = value.get(part)
else:
return None
if value is None:
return None
try:
return float(value)
except (ValueError, TypeError):
return None
def list_api_keys(
self,
search: str = "",
group_id: str | int | None = None,
status: str = "active",
endpoint: str = "/keys",
) -> list[dict[str, Any]]:
"""查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。"""
params: dict[str, Any] = {}
if search:
params["search"] = search
if group_id is not None:
params["group_id"] = int(group_id) if str(group_id).isdigit() else group_id
if status:
params["status"] = status
url = self._url(endpoint)
resp = self._client.request(
"GET",
url,
params=params if params else None,
headers=self._headers(),
cookies=self._cookies,
)
resp.raise_for_status()
data = resp.json()
if isinstance(data, list):
return data
if isinstance(data, dict):
# 尝试展开常见的包装结构
for top_key in ("data", "result", "response"):
val = data.get(top_key)
if isinstance(val, list):
return val
if isinstance(val, dict):
for inner_key in ("items", "keys", "list", "records", "data"):
inner = val.get(inner_key)
if isinstance(inner, list):
return inner
# 顶层本身就是 list-like wrapper
for key in ("items", "keys", "list", "records"):
val = data.get(key)
if isinstance(val, list):
return val
raise UpstreamError(f"unexpected keys response type: {type(data).__name__}")
def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None:
"""删除远端上游上的一个 Key。"""
self._request("DELETE", f"{endpoint}/{key_id}")
def find_smartup_group_key(
self,
group_id: str | int,
expected_name: str,
prefix: str = "SmartUp",
) -> dict[str, Any] | None:
"""查找同一上游分组下是否已存在 SmartUp 前缀的 Key。
匹配规则:key_name 等于 expected_name,且以 prefix 开头。
返回匹配到的第一个 Key,或 None。
"""
gid = int(group_id) if str(group_id).isdigit() else group_id
keys = self.list_api_keys(search=prefix, group_id=gid, status="active")
for k in keys:
name = k.get("name") or k.get("key_name") or ""
if name == expected_name:
return k
# 部分后端返回的 name 可能带空格或 trimming
if name.strip() == expected_name.strip():
return k
return None
def create_api_key(
self,
name: str,
group_id: str | int,
quota: float = 0,
expires_in_days: int | None = None,
rate_limit_5h: float = 0,
rate_limit_1d: float = 0,
rate_limit_7d: float = 0,
endpoint: str = "/keys",
) -> dict[str, Any]:
body: dict[str, Any] = {
"name": name,
"group_id": int(group_id) if str(group_id).isdigit() else group_id,
"quota": quota,
"rate_limit_5h": rate_limit_5h,
"rate_limit_1d": rate_limit_1d,
"rate_limit_7d": rate_limit_7d,
}
if expires_in_days:
body["expires_in_days"] = expires_in_days
resp = self._request("POST", endpoint, body)
data = _unwrap_data(resp)
key_value = _extract_key_value(data)
if not key_value:
raise UpstreamError("key create response did not include key")
return {
"id": _extract_id(data),
"key": key_value,
"masked_key": mask_secret(key_value),
"raw": data if isinstance(data, dict) else {"value": data},
}