Files
SmartUp/backend/app/services/website_client.py
T
2026-06-03 18:39:21 +08:00

295 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import logging
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
from typing import Any
from urllib.parse import quote
import httpx
from app.utils.number import decimal_string
logger = logging.getLogger(__name__)
class WebsiteError(RuntimeError):
pass
def _friendly_http_error(exc: httpx.HTTPStatusError) -> str:
"""将常见 HTTP 错误转换为中文友好提示,原始信息保留在日志中。"""
status = exc.response.status_code
url = exc.request.url if exc.request else "?"
logger.warning("website_client HTTP %s from %s: %s", status, url, exc)
if status == 401:
return "目标网站认证失败,请检查 Admin API Key / JWT 是否正确"
if status == 403:
return "目标网站权限不足,请检查当前凭证是否有分组管理权限"
if status == 404:
return f"目标网站接口不存在,请检查 API Prefix 和分组接口路径({exc.response.url.path}"
if 500 <= status < 600:
return "目标网站服务异常,请稍后重试"
return f"目标网站返回错误(HTTP {status}"
def _friendly_connection_error(exc: Exception) -> str:
"""将网络/超时异常转换为中文友好提示。"""
logger.warning("website_client connection error: %s", exc)
if isinstance(exc, httpx.TimeoutException):
return "目标网站请求超时,请检查网络连接和 API 地址是否正确"
if isinstance(exc, httpx.ConnectError):
return "无法连接目标网站,请检查 API 地址和网络连通性"
return f"目标网站通信异常:{exc}"
def parse_positive_decimal(value: Any) -> Decimal | None:
if value is None or value == "":
return None
try:
d = Decimal(str(value))
except (InvalidOperation, ValueError):
return None
return d if d > 0 else None
def calculate_target_rate(values: list[Any], percent: Any = 0, algorithm: str = "max_plus_percent") -> Decimal:
rates = [rate for rate in (parse_positive_decimal(v) for v in values) if rate is not None]
if not rates:
raise WebsiteError("没有可用的正数上游倍率")
if algorithm == "average_plus_percent":
base = sum(rates, Decimal("0")) / Decimal(len(rates))
elif algorithm == "min_plus_percent":
base = min(rates)
elif algorithm == "max_plus_percent":
base = max(rates)
else:
raise WebsiteError(f"不支持的算法:{algorithm}")
pct = Decimal(str(percent or 0))
if pct < 0:
raise WebsiteError("百分比不能为负数")
return (base * (Decimal("1") + pct / Decimal("100"))).quantize(Decimal("0.0001"), rounding=ROUND_HALF_UP)
def _unwrap_data(value: Any) -> Any:
if isinstance(value, dict):
data = value.get("data")
if "data" in value and (
"code" in value
or "message" in value
or isinstance(data, list)
or (isinstance(data, dict) and any(key in data for key in ("items", "groups")))
):
value = data
if not isinstance(value, dict):
return value
for key in ("items", "groups"):
if key in value:
return value.get(key)
return value
def _extract_id(value: Any) -> str:
if isinstance(value, dict):
for key in ("id", "account_id", "accountId", "group_id", "groupId"):
candidate = value.get(key)
if candidate is not None:
return str(candidate)
for key in ("data", "result", "account", "group"):
found = _extract_id(value.get(key))
if found:
return found
return ""
def normalize_groups(value: Any) -> list[dict[str, Any]]:
raw = _unwrap_data(value)
if isinstance(raw, dict):
raw = list(raw.values())
if not isinstance(raw, list):
raise WebsiteError("分组接口没有返回列表")
groups: list[dict[str, Any]] = []
for item in raw:
if isinstance(item, str):
groups.append({"id": item, "name": item, "rate_multiplier": None, "raw": {"id": item, "name": item}})
continue
if not isinstance(item, dict):
continue
gid = item.get("id") or item.get("group_id") or item.get("groupId") or item.get("name") or item.get("group_name")
if gid is None:
continue
name = item.get("name") or item.get("group_name") or str(gid)
rate = item.get("rate_multiplier") or item.get("rateMultiplier") or item.get("ratio")
groups.append({
"id": str(gid),
"name": str(name),
"rate_multiplier": decimal_string(rate) if rate is not None else None,
"raw": item,
})
return groups
class Sub2ApiWebsiteClient:
def __init__(
self,
base_url: str,
api_prefix: str,
auth_type: str,
auth_config: dict[str, Any],
timeout: float = 30.0,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_prefix = api_prefix.strip("/")
self.auth_type = auth_type
self.auth_config = auth_config
self.timeout = timeout
self._client = httpx.Client(timeout=timeout)
def close(self) -> None:
self._client.close()
def __enter__(self) -> Sub2ApiWebsiteClient:
return self
def __exit__(self, *args: Any) -> None:
self.close()
def _url(self, path: str) -> str:
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
def _headers(self) -> dict[str, str]:
headers = {"Accept": "application/json", "User-Agent": "SmartUp/1.0"}
if self.auth_type == "api_key":
key = self.auth_config.get("key") or self.auth_config.get("api_key") or ""
header = self.auth_config.get("header") or "x-api-key"
if key:
headers[header] = key
elif self.auth_type == "bearer":
token = self.auth_config.get("token") or ""
if token:
headers["Authorization"] = f"Bearer {token}"
return headers
def _request(self, method: str, path: str, body: Any = None) -> Any:
try:
resp = self._client.request(method, self._url(path), json=body, headers=self._headers())
resp.raise_for_status()
except httpx.HTTPStatusError as exc:
raise WebsiteError(_friendly_http_error(exc)) from exc
except httpx.TimeoutException as exc:
raise WebsiteError(_friendly_connection_error(exc)) from exc
except httpx.ConnectError as exc:
raise WebsiteError(_friendly_connection_error(exc)) from exc
if not resp.content:
return None
text = resp.text
if "application/json" not in resp.headers.get("content-type", "") and text.lstrip().startswith("<"):
raise WebsiteError(f"{method} {path} 返回了 HTML,请检查接口地址是否正确")
return resp.json()
def get_groups(self, endpoint: str = "/groups") -> list[dict[str, Any]]:
"""拉取分组列表,尝试 endpoint 和 fallback /groups/all。"""
last_error: Exception | None = None
tried_paths: list[str] = []
for path in [endpoint, "/groups/all"]:
tried_paths.append(path)
try:
return normalize_groups(self._request("GET", path))
except WebsiteError as exc:
msg = str(exc)
# 认证/权限类错误:直接抛出,不需要尝试 fallback
if "认证失败" in msg or "权限不足" in msg:
raise
# 404/5xx 等路径相关错误,试试另一个路径
last_error = exc
except Exception as exc:
last_error = exc
logger.info("get_groups fallback %s failed: %s", path, exc)
msg = str(last_error) if last_error else "拉取分组失败"
raise WebsiteError(f"{msg}(尝试接口:{''.join(tried_paths)}")
def update_group_rate(self, endpoint_template: str, group_id: str, rate: Decimal) -> Any:
path = endpoint_template.replace("{id}", quote(group_id, safe=""))
return self._request("PUT", path, {"rate_multiplier": float(rate)})
def create_group(self, body: dict[str, Any], endpoint: str = "/groups") -> dict[str, Any]:
resp = self._request("POST", endpoint, body)
data = _unwrap_data(resp)
return data if isinstance(data, dict) else {"value": data}
def create_account(self, body: dict[str, Any], endpoint: str = "/accounts") -> dict[str, Any]:
resp = self._request("POST", endpoint, body)
data = _unwrap_data(resp)
return data if isinstance(data, dict) else {"value": data}
def update_account(self, account_id: str, body: dict[str, Any], endpoint: str = "/accounts") -> dict[str, Any]:
"""更新远端账号(仅传入需要变更的字段)。"""
resp = self._request("PUT", f"{endpoint}/{account_id}", body)
data = _unwrap_data(resp)
return data if isinstance(data, dict) else {"value": data}
@staticmethod
def _unwrap_list(value: dict) -> list | None:
"""递归展开嵌套的列表包装:data.items、data.data、items、accounts 等。"""
if isinstance(value, list):
return value
if not isinstance(value, dict):
return None
# 先看顶层
for key in ("items", "accounts", "records", "list", "data"):
v = value.get(key)
if isinstance(v, list):
return v
# 再看 data.items、data.records、data.list 等嵌套
data_val = value.get("data")
if isinstance(data_val, dict):
for key in ("items", "records", "list", "data", "accounts"):
v = data_val.get(key)
if isinstance(v, list):
return v
return None
def list_accounts(self, endpoint: str = "/accounts") -> list[dict[str, Any]] | None:
"""拉取远端账号列表。成功返回账号 dict 列表,失败返回 None。"""
try:
resp = self._request("GET", endpoint)
except Exception:
logger.warning("account list fetch failed for %s", endpoint, exc_info=True)
return None
items = self._unwrap_list(resp)
if items is None:
logger.warning("account list unexpected format for %s", endpoint)
return None
return [item for item in items if isinstance(item, dict)]
def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None:
"""拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。"""
items = self.list_accounts(endpoint)
if items is None:
return None
ids: set[str] = set()
for item in items:
item_id = self.extract_id(item)
if item_id:
ids.add(item_id)
return ids
def account_exists(self, account_id: str, endpoint: str = "/accounts") -> bool | None:
"""检查目标账号是否存在。
优先拉取账号列表判断:
- 列表成功取到 → return account_id in idsTrue=存在,False=已删除)
- 列表取不到(None)→ return None(校验失败,不清本地)
返回 True=存在,False=已删除,None=校验失败。
"""
ids = self._get_account_ids(endpoint)
if ids is None:
logger.warning("account_exists cannot verify %s: list fetch failed", account_id)
return None
return account_id in ids
@staticmethod
def extract_id(value: Any) -> str:
return _extract_id(value)