"""Upstream HTTP client — ported from monitor_ai98pro_group_rates.py.""" from __future__ import annotations import json from typing import Any, Optional from urllib.parse import urljoin import httpx from app.utils.number import decimal_string class UpstreamError(RuntimeError): pass def _find_token(value: Any) -> str: if isinstance(value, str) and value.count(".") >= 2: return value if isinstance(value, dict): for key in ("token", "access_token", "accessToken", "jwt", "auth_token", "authToken"): candidate = value.get(key) if isinstance(candidate, str) and candidate: return candidate for key in ("data", "result", "user", "session"): tok = _find_token(value.get(key)) if tok: return tok return "" def _clean_auth_header_value(value: Any, field_name: str) -> str: text = str(value or "").strip() if not text: return "" if text.startswith("Bearer "): text = text[7:].strip() # Try to sanitize non-latin-1 characters instead of hard-failing try: text.encode("latin-1") except UnicodeEncodeError: # Try stripping non-ASCII characters cleaned = text.encode("ascii", errors="ignore").decode("ascii").strip() if cleaned: return cleaned raise UpstreamError( f"{field_name} 含有非 HTTP 标头字符(如中文或 emoji)," f"请重新登录后再试" ) from None return text def _find_user_id(value: Any) -> str: if isinstance(value, dict): for key in ("id", "user_id", "userId"): candidate = value.get(key) if candidate is not None: return str(candidate) for key in ("data", "result", "user", "session"): user_id = _find_user_id(value.get(key)) if user_id: return user_id return "" def mask_secret(value: Any) -> str: text = str(value or "") if not text: return "" if len(text) <= 8: return text[:2] + "****" + text[-2:] if len(text) > 4 else "****" return text[:4] + "**********" + text[-4:] def _unwrap_data(value: Any) -> Any: if isinstance(value, dict) and "data" in value and ("code" in value or "message" in value): return value.get("data") return value def _extract_id(value: Any) -> str: if isinstance(value, dict): for key in ("id", "key_id", "keyId"): candidate = value.get(key) if candidate is not None: return str(candidate) for key in ("data", "result", "key", "api_key"): found = _extract_id(value.get(key)) if found: return found return "" def _extract_key_value(value: Any) -> str: if isinstance(value, str): return value if isinstance(value, dict): for key in ("key", "api_key", "apiKey", "token", "value"): candidate = value.get(key) if isinstance(candidate, str) and candidate: return candidate for key in ("data", "result", "api_key", "key"): found = _extract_key_value(value.get(key)) if found: return found return "" def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]: def _normalize(lst: list) -> list[dict[str, Any]]: out = [] for i in lst: if isinstance(i, dict): out.append(i) elif isinstance(i, str): out.append({"id": i, "name": i}) return out if isinstance(value, list): return _normalize(value) if isinstance(value, dict): for key in ("data", "items", "groups", "available_groups", "availableGroups"): nested = value.get(key) if isinstance(nested, list): return _normalize(nested) elif isinstance(nested, dict): # Handle /api/user/self/groups where data is a dict of group_name -> { desc, ratio } out = [] for k in nested.keys(): out.append({"id": k, "name": k}) return out return None def _group_id(group: dict[str, Any]) -> str: for key in ("id", "group_id", "groupId"): v = group.get(key) if v is not None: return str(v) name = str(group.get("name") or group.get("group_name") or "") platform = str(group.get("platform") or "") return f"{platform}:{name}" def _rate_from_group(group: dict[str, Any]) -> str: for key in ( "user_rate_multiplier", "userRateMultiplier", "effective_rate_multiplier", "effectiveRateMultiplier", "rate_multiplier", "rateMultiplier", ): r = decimal_string(group.get(key)) if r: return r return "" def _extract_rates_map(raw: Any) -> dict[str, str]: if raw is None: return {} # Handle one-api/new-api /api/option response where GroupRatio is in a list of options if isinstance(raw, dict) and isinstance(raw.get("data"), list): for item in raw["data"]: if isinstance(item, dict) and item.get("key") == "GroupRatio": val = item.get("value") if isinstance(val, str): try: import json parsed = json.loads(val) if isinstance(parsed, dict): result: dict[str, str] = {} for k, v in parsed.items(): r = decimal_string(v) if r: result[str(k)] = r return result except Exception: pass elif isinstance(val, dict): # In case it's returned as dict directly result = {} for k, v in val.items(): r = decimal_string(v) if r: result[str(k)] = r return result if isinstance(raw, dict): candidates = raw for key in ("data", "rates", "group_rates", "groupRates", "GroupRatio"): nested = raw.get(key) if isinstance(nested, dict): candidates = nested break elif isinstance(nested, str) and key == "GroupRatio": # Handle GroupRatio as a JSON string try: import json parsed = json.loads(nested) if isinstance(parsed, dict): candidates = parsed break except Exception: pass result: dict[str, str] = {} for k, v in candidates.items(): if isinstance(v, dict): r = decimal_string( v.get("rate_multiplier") or v.get("rateMultiplier") or v.get("user_rate_multiplier") or v.get("userRateMultiplier") or v.get("ratio") ) else: r = decimal_string(v) if r: result[str(k)] = r return result if isinstance(raw, list): result = {} for item in raw: if not isinstance(item, dict): continue gid = _group_id(item) rate = _rate_from_group(item) if gid and rate: result[gid] = rate return result return {} def build_snapshot(upstream_id: int, base_url: str, api_prefix: str, groups: list[dict[str, Any]], raw_rates: Any) -> dict[str, Any]: from datetime import datetime, timezone override_rates = _extract_rates_map(raw_rates) entries: dict[str, dict[str, Any]] = {} for g in groups: gid = _group_id(g) default_rate = _rate_from_group(g) effective_rate = override_rates.get(gid, default_rate) entries[gid] = { "group_id": gid, "group_name": g.get("name") or g.get("group_name") or "", "platform": g.get("platform") or "", "rate": effective_rate, "default_rate": default_rate, "override_rate": override_rates.get(gid, ""), } return { "upstream_id": upstream_id, "base_url": base_url.rstrip("/"), "api_prefix": api_prefix, "captured_at": datetime.now(timezone.utc).astimezone().isoformat(timespec="seconds"), "groups": entries, } class UpstreamClient: """Sync HTTP client that handles all auth types.""" def __init__( self, base_url: str, api_prefix: str, auth_type: str, auth_config: dict[str, Any], timeout: float = 30.0, ) -> None: self.base_url = base_url.rstrip("/") self.api_prefix = api_prefix.strip("/") self.auth_type = auth_type self.auth_config = auth_config self.timeout = timeout self._token: str = "" self._cookies: dict[str, str] = {} self._new_api_user: str = "" self._client = httpx.Client(timeout=timeout) def close(self) -> None: self._client.close() def __enter__(self) -> UpstreamClient: return self def __exit__(self, *args: Any) -> None: self.close() def _url(self, path: str) -> str: prefix = f"/{self.api_prefix}" if self.api_prefix else "" return f"{self.base_url}{prefix}/{path.lstrip('/')}" def _headers(self, auth: bool = True) -> dict[str, str]: headers: dict[str, str] = { "Accept": "application/json", "User-Agent": "SmartUp/1.0", } if not auth: return headers if self.auth_type == "bearer": token = _clean_auth_header_value(self.auth_config.get("token", ""), "Bearer token") if token: headers["Authorization"] = f"Bearer {token}" elif self.auth_type == "api_key": key = _clean_auth_header_value(self.auth_config.get("key", ""), "API key") header = self.auth_config.get("header", "Authorization") if key: headers[header] = key elif self.auth_type == "cookie": cookie_str = _clean_auth_header_value(self.auth_config.get("cookie_string", ""), "Cookie") if cookie_str: headers["Cookie"] = cookie_str new_api_user = _clean_auth_header_value(self.auth_config.get("new_api_user", ""), "New-Api-User") if new_api_user: headers["New-Api-User"] = new_api_user elif self.auth_type == "login_password" and self._token: token = _clean_auth_header_value(self._token, "Login token") if token: headers["Authorization"] = f"Bearer {token}" if self.auth_type == "login_password" and self._new_api_user: headers["New-Api-User"] = self._new_api_user return headers def _request(self, method: str, path: str, body: Any = None, auth: bool = True) -> Any: if auth and self.auth_type == "cookie" and "user/self" in path and not self.auth_config.get("new_api_user"): raise UpstreamError("New-API user endpoint requires New-Api-User; re-extract the session cookie after login and save the upstream") url = self._url(path) if body is not None: resp = self._client.request( method, url, json=body, headers=self._headers(auth), cookies=self._cookies, ) else: resp = self._client.request( method, url, headers=self._headers(auth), cookies=self._cookies, ) self._cookies.update(dict(resp.cookies)) resp.raise_for_status() ct = resp.headers.get("content-type", "") if not resp.content: return None text = resp.text if "application/json" not in ct and text.lstrip().startswith("<"): raise UpstreamError(f"{method} {path} returned HTML, not JSON") return resp.json() def login(self) -> None: if self.auth_type != "login_password": return email = self.auth_config.get("email", "") password = self.auth_config.get("password", "") login_path = self.auth_config.get("login_path", "/auth/login") username_field = self.auth_config.get("username_field", "email") if not email or not password: raise UpstreamError("login_password auth requires email and password in auth_config") resp = self._request("POST", login_path, {username_field: email, "password": password}, auth=False) token = _find_token(resp) if token: self._token = token return if self._cookies: self._new_api_user = self.auth_config.get("new_api_user", "") or _find_user_id(resp) return raise UpstreamError("login succeeded but no token or session cookie found in response") def get_available_groups(self, endpoint: str) -> list[dict[str, Any]]: resp = self._request("GET", endpoint) groups = _unwrap_list(resp) if groups is None: raise UpstreamError(f"{endpoint} did not return a list") return groups def get_group_rates(self, endpoint: str) -> Any: return self._request("GET", endpoint) def get_balance(self, endpoint: str, response_path: str) -> Optional[float]: """Call the balance endpoint and extract a numeric value using a dot-separated JSON path. response_path 示例: "balance" → resp["balance"] "data.quota" → resp["data"]["quota"] "data.total_balance" → resp["data"]["total_balance"] """ if not endpoint or not response_path: return None resp = self._request("GET", endpoint) if not isinstance(resp, dict): return None parts = response_path.split(".") value: Any = resp for part in parts: if isinstance(value, dict): value = value.get(part) else: return None if value is None: return None try: return float(value) except (ValueError, TypeError): return None def list_api_keys( self, search: str = "", group_id: str | int | None = None, status: str = "active", endpoint: str = "/keys", ) -> list[dict[str, Any]]: """查询远端上游 Key 列表,支持按名称搜索、分组筛选、状态筛选。""" params: dict[str, Any] = {} if search: params["search"] = search if group_id is not None: params["group_id"] = int(group_id) if str(group_id).isdigit() else group_id if status: params["status"] = status url = self._url(endpoint) resp = self._client.request( "GET", url, params=params if params else None, headers=self._headers(), cookies=self._cookies, ) resp.raise_for_status() data = resp.json() if isinstance(data, list): return data if isinstance(data, dict): # 尝试展开常见的包装结构 for top_key in ("data", "result", "response"): val = data.get(top_key) if isinstance(val, list): return val if isinstance(val, dict): for inner_key in ("items", "keys", "list", "records", "data"): inner = val.get(inner_key) if isinstance(inner, list): return inner # 顶层本身就是 list-like wrapper for key in ("items", "keys", "list", "records"): val = data.get(key) if isinstance(val, list): return val raise UpstreamError(f"unexpected keys response type: {type(data).__name__}") def delete_api_key(self, key_id: str, endpoint: str = "/keys") -> None: """删除远端上游上的一个 Key。""" self._request("DELETE", f"{endpoint}/{key_id}") def find_smartup_group_key( self, group_id: str | int, expected_name: str, prefix: str = "SmartUp", ) -> dict[str, Any] | None: """查找同一上游分组下是否已存在 SmartUp 前缀的 Key。 匹配规则:key_name 等于 expected_name,且以 prefix 开头。 返回匹配到的第一个 Key,或 None。 """ gid = int(group_id) if str(group_id).isdigit() else group_id keys = self.list_api_keys(search=prefix, group_id=gid, status="active") for k in keys: name = k.get("name") or k.get("key_name") or "" if name == expected_name: return k # 部分后端返回的 name 可能带空格或 trimming if name.strip() == expected_name.strip(): return k return None def create_api_key( self, name: str, group_id: str | int, quota: float = 0, expires_in_days: int | None = None, rate_limit_5h: float = 0, rate_limit_1d: float = 0, rate_limit_7d: float = 0, endpoint: str = "/keys", ) -> dict[str, Any]: body: dict[str, Any] = { "name": name, "group_id": int(group_id) if str(group_id).isdigit() else group_id, "quota": quota, "rate_limit_5h": rate_limit_5h, "rate_limit_1d": rate_limit_1d, "rate_limit_7d": rate_limit_7d, } if expires_in_days: body["expires_in_days"] = expires_in_days resp = self._request("POST", endpoint, body) data = _unwrap_data(resp) key_value = _extract_key_value(data) if not key_value: raise UpstreamError("key create response did not include key") return { "id": _extract_id(data), "key": key_value, "masked_key": mask_secret(key_value), "raw": data if isinstance(data, dict) else {"value": data}, }