from __future__ import annotations import logging from decimal import Decimal, InvalidOperation, ROUND_HALF_UP from typing import Any from urllib.parse import quote import httpx from app.utils.number import decimal_string logger = logging.getLogger(__name__) class WebsiteError(RuntimeError): pass def _friendly_http_error(exc: httpx.HTTPStatusError) -> str: """将常见 HTTP 错误转换为中文友好提示,原始信息保留在日志中。""" status = exc.response.status_code url = exc.request.url if exc.request else "?" logger.warning("website_client HTTP %s from %s: %s", status, url, exc) if status == 401: return "目标网站认证失败,请检查 Admin API Key / JWT 是否正确" if status == 403: return "目标网站权限不足,请检查当前凭证是否有分组管理权限" if status == 404: return f"目标网站接口不存在,请检查 API Prefix 和分组接口路径({exc.response.url.path})" if 500 <= status < 600: return "目标网站服务异常,请稍后重试" return f"目标网站返回错误(HTTP {status})" def _friendly_connection_error(exc: Exception) -> str: """将网络/超时异常转换为中文友好提示。""" logger.warning("website_client connection error: %s", exc) if isinstance(exc, httpx.TimeoutException): return "目标网站请求超时,请检查网络连接和 API 地址是否正确" if isinstance(exc, httpx.ConnectError): return "无法连接目标网站,请检查 API 地址和网络连通性" return f"目标网站通信异常:{exc}" def parse_positive_decimal(value: Any) -> Decimal | None: if value is None or value == "": return None try: d = Decimal(str(value)) except (InvalidOperation, ValueError): return None return d if d > 0 else None def calculate_target_rate(values: list[Any], percent: Any = 0, algorithm: str = "max_plus_percent") -> Decimal: rates = [rate for rate in (parse_positive_decimal(v) for v in values) if rate is not None] if not rates: raise WebsiteError("没有可用的正数上游倍率") if algorithm == "average_plus_percent": base = sum(rates, Decimal("0")) / Decimal(len(rates)) elif algorithm == "min_plus_percent": base = min(rates) elif algorithm == "max_plus_percent": base = max(rates) else: raise WebsiteError(f"不支持的算法:{algorithm}") pct = Decimal(str(percent or 0)) if pct < 0: raise WebsiteError("百分比不能为负数") return (base * (Decimal("1") + pct / Decimal("100"))).quantize(Decimal("0.0001"), rounding=ROUND_HALF_UP) def _unwrap_data(value: Any) -> Any: if isinstance(value, dict): data = value.get("data") if "data" in value and ( "code" in value or "message" in value or isinstance(data, list) or (isinstance(data, dict) and any(key in data for key in ("items", "groups"))) ): value = data if not isinstance(value, dict): return value for key in ("items", "groups"): if key in value: return value.get(key) return value def _extract_id(value: Any) -> str: if isinstance(value, dict): for key in ("id", "account_id", "accountId", "group_id", "groupId"): candidate = value.get(key) if candidate is not None: return str(candidate) for key in ("data", "result", "account", "group"): found = _extract_id(value.get(key)) if found: return found return "" def normalize_groups(value: Any) -> list[dict[str, Any]]: raw = _unwrap_data(value) if isinstance(raw, dict): raw = list(raw.values()) if not isinstance(raw, list): raise WebsiteError("分组接口没有返回列表") groups: list[dict[str, Any]] = [] for item in raw: if isinstance(item, str): groups.append({"id": item, "name": item, "rate_multiplier": None, "raw": {"id": item, "name": item}}) continue if not isinstance(item, dict): continue gid = item.get("id") or item.get("group_id") or item.get("groupId") or item.get("name") or item.get("group_name") if gid is None: continue name = item.get("name") or item.get("group_name") or str(gid) rate = item.get("rate_multiplier") or item.get("rateMultiplier") or item.get("ratio") groups.append({ "id": str(gid), "name": str(name), "rate_multiplier": decimal_string(rate) if rate is not None else None, "raw": item, }) return groups class Sub2ApiWebsiteClient: def __init__( self, base_url: str, api_prefix: str, auth_type: str, auth_config: dict[str, Any], timeout: float = 30.0, ) -> None: self.base_url = base_url.rstrip("/") self.api_prefix = api_prefix.strip("/") self.auth_type = auth_type self.auth_config = auth_config self.timeout = timeout self._client = httpx.Client(timeout=timeout) def close(self) -> None: self._client.close() def __enter__(self) -> Sub2ApiWebsiteClient: return self def __exit__(self, *args: Any) -> None: self.close() def _url(self, path: str) -> str: prefix = f"/{self.api_prefix}" if self.api_prefix else "" return f"{self.base_url}{prefix}/{path.lstrip('/')}" def _headers(self) -> dict[str, str]: headers = {"Accept": "application/json", "User-Agent": "SmartUp/1.0"} if self.auth_type == "api_key": key = self.auth_config.get("key") or self.auth_config.get("api_key") or "" header = self.auth_config.get("header") or "x-api-key" if key: headers[header] = key elif self.auth_type == "bearer": token = self.auth_config.get("token") or "" if token: headers["Authorization"] = f"Bearer {token}" return headers def _request(self, method: str, path: str, body: Any = None) -> Any: try: resp = self._client.request(method, self._url(path), json=body, headers=self._headers()) resp.raise_for_status() except httpx.HTTPStatusError as exc: raise WebsiteError(_friendly_http_error(exc)) from exc except httpx.TimeoutException as exc: raise WebsiteError(_friendly_connection_error(exc)) from exc except httpx.ConnectError as exc: raise WebsiteError(_friendly_connection_error(exc)) from exc if not resp.content: return None text = resp.text if "application/json" not in resp.headers.get("content-type", "") and text.lstrip().startswith("<"): raise WebsiteError(f"{method} {path} 返回了 HTML,请检查接口地址是否正确") return resp.json() def get_groups(self, endpoint: str = "/groups") -> list[dict[str, Any]]: """拉取分组列表,尝试 endpoint 和 fallback /groups/all。""" last_error: Exception | None = None tried_paths: list[str] = [] for path in [endpoint, "/groups/all"]: tried_paths.append(path) try: return normalize_groups(self._request("GET", path)) except WebsiteError as exc: msg = str(exc) # 认证/权限类错误:直接抛出,不需要尝试 fallback if "认证失败" in msg or "权限不足" in msg: raise # 404/5xx 等路径相关错误,试试另一个路径 last_error = exc except Exception as exc: last_error = exc logger.info("get_groups fallback %s failed: %s", path, exc) msg = str(last_error) if last_error else "拉取分组失败" raise WebsiteError(f"{msg}(尝试接口:{'、'.join(tried_paths)})") def update_group_rate(self, endpoint_template: str, group_id: str, rate: Decimal) -> Any: path = endpoint_template.replace("{id}", quote(group_id, safe="")) return self._request("PUT", path, {"rate_multiplier": float(rate)}) def create_group(self, body: dict[str, Any], endpoint: str = "/groups") -> dict[str, Any]: resp = self._request("POST", endpoint, body) data = _unwrap_data(resp) return data if isinstance(data, dict) else {"value": data} def create_account(self, body: dict[str, Any], endpoint: str = "/accounts") -> dict[str, Any]: resp = self._request("POST", endpoint, body) data = _unwrap_data(resp) return data if isinstance(data, dict) else {"value": data} def update_account(self, account_id: str, body: dict[str, Any], endpoint: str = "/accounts") -> dict[str, Any]: """更新远端账号(仅传入需要变更的字段)。""" resp = self._request("PUT", f"{endpoint}/{account_id}", body) data = _unwrap_data(resp) return data if isinstance(data, dict) else {"value": data} @staticmethod def _unwrap_list(value: dict) -> list | None: """递归展开嵌套的列表包装:data.items、data.data、items、accounts 等。""" if isinstance(value, list): return value if not isinstance(value, dict): return None # 先看顶层 for key in ("items", "accounts", "records", "list", "data"): v = value.get(key) if isinstance(v, list): return v # 再看 data.items、data.records、data.list 等嵌套 data_val = value.get("data") if isinstance(data_val, dict): for key in ("items", "records", "list", "data", "accounts"): v = data_val.get(key) if isinstance(v, list): return v return None def list_accounts(self, endpoint: str = "/accounts") -> list[dict[str, Any]] | None: """拉取远端账号列表。成功返回账号 dict 列表,失败返回 None。""" try: resp = self._request("GET", endpoint) except Exception: logger.warning("account list fetch failed for %s", endpoint, exc_info=True) return None items = self._unwrap_list(resp) if items is None: logger.warning("account list unexpected format for %s", endpoint) return None return [item for item in items if isinstance(item, dict)] def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None: """拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。""" items = self.list_accounts(endpoint) if items is None: return None ids: set[str] = set() for item in items: item_id = self.extract_id(item) if item_id: ids.add(item_id) return ids def account_exists(self, account_id: str, endpoint: str = "/accounts") -> bool | None: """检查目标账号是否存在。 优先拉取账号列表判断: - 列表成功取到 → return account_id in ids(True=存在,False=已删除) - 列表取不到(None)→ return None(校验失败,不清本地) 返回 True=存在,False=已删除,None=校验失败。 """ ids = self._get_account_ids(endpoint) if ids is None: logger.warning("account_exists cannot verify %s: list fetch failed", account_id) return None return account_id in ids @staticmethod def extract_id(value: Any) -> str: return _extract_id(value)