from __future__ import annotations from decimal import Decimal, InvalidOperation, ROUND_HALF_UP from typing import Any from urllib.parse import quote import httpx from app.utils.number import decimal_string class WebsiteError(RuntimeError): pass def parse_positive_decimal(value: Any) -> Decimal | None: if value is None or value == "": return None try: d = Decimal(str(value)) except (InvalidOperation, ValueError): return None return d if d > 0 else None def calculate_target_rate(values: list[Any], percent: Any = 0, algorithm: str = "max_plus_percent") -> Decimal: rates = [rate for rate in (parse_positive_decimal(v) for v in values) if rate is not None] if not rates: raise WebsiteError("没有可用的正数上游倍率") if algorithm == "average_plus_percent": base = sum(rates, Decimal("0")) / Decimal(len(rates)) elif algorithm == "min_plus_percent": base = min(rates) elif algorithm == "max_plus_percent": base = max(rates) else: raise WebsiteError(f"不支持的算法:{algorithm}") pct = Decimal(str(percent or 0)) if pct < 0: raise WebsiteError("百分比不能为负数") return (base * (Decimal("1") + pct / Decimal("100"))).quantize(Decimal("0.0001"), rounding=ROUND_HALF_UP) def _unwrap_data(value: Any) -> Any: if isinstance(value, dict): data = value.get("data") if "data" in value and ( "code" in value or "message" in value or isinstance(data, list) or (isinstance(data, dict) and any(key in data for key in ("items", "groups"))) ): value = data if not isinstance(value, dict): return value for key in ("items", "groups"): if key in value: return value.get(key) return value def normalize_groups(value: Any) -> list[dict[str, Any]]: raw = _unwrap_data(value) if isinstance(raw, dict): raw = list(raw.values()) if not isinstance(raw, list): raise WebsiteError("分组接口没有返回列表") groups: list[dict[str, Any]] = [] for item in raw: if isinstance(item, str): groups.append({"id": item, "name": item, "rate_multiplier": None, "raw": {"id": item, "name": item}}) continue if not isinstance(item, dict): continue gid = item.get("id") or item.get("group_id") or item.get("groupId") or item.get("name") or item.get("group_name") if gid is None: continue name = item.get("name") or item.get("group_name") or str(gid) rate = item.get("rate_multiplier") or item.get("rateMultiplier") or item.get("ratio") groups.append({ "id": str(gid), "name": str(name), "rate_multiplier": decimal_string(rate) if rate is not None else None, "raw": item, }) return groups class Sub2ApiWebsiteClient: def __init__( self, base_url: str, api_prefix: str, auth_type: str, auth_config: dict[str, Any], timeout: float = 30.0, ) -> None: self.base_url = base_url.rstrip("/") self.api_prefix = api_prefix.strip("/") self.auth_type = auth_type self.auth_config = auth_config self.timeout = timeout self._client = httpx.Client(timeout=timeout) def close(self) -> None: self._client.close() def __enter__(self) -> Sub2ApiWebsiteClient: return self def __exit__(self, *args: Any) -> None: self.close() def _url(self, path: str) -> str: prefix = f"/{self.api_prefix}" if self.api_prefix else "" return f"{self.base_url}{prefix}/{path.lstrip('/')}" def _headers(self) -> dict[str, str]: headers = {"Accept": "application/json", "User-Agent": "SmartUp/1.0"} if self.auth_type == "api_key": key = self.auth_config.get("key") or self.auth_config.get("api_key") or "" header = self.auth_config.get("header") or "x-api-key" if key: headers[header] = key elif self.auth_type == "bearer": token = self.auth_config.get("token") or "" if token: headers["Authorization"] = f"Bearer {token}" return headers def _request(self, method: str, path: str, body: Any = None) -> Any: resp = self._client.request(method, self._url(path), json=body, headers=self._headers()) resp.raise_for_status() if not resp.content: return None text = resp.text if "application/json" not in resp.headers.get("content-type", "") and text.lstrip().startswith("<"): raise WebsiteError(f"{method} {path} returned HTML, not JSON") return resp.json() def get_groups(self, endpoint: str = "/groups") -> list[dict[str, Any]]: errors: list[str] = [] for path in [endpoint, "/groups/all"]: try: return normalize_groups(self._request("GET", path)) except Exception as exc: errors.append(f"{path}: {exc}") raise WebsiteError("; ".join(errors)) def update_group_rate(self, endpoint_template: str, group_id: str, rate: Decimal) -> Any: path = endpoint_template.replace("{id}", quote(group_id, safe="")) return self._request("PUT", path, {"rate_multiplier": float(rate)})