8a6ed249be
- HTTP connection pooling: UpstreamClient & WebsiteClient reuse httpx.Client - Deduplicate decimal_string into shared app/utils/number.py - Split scheduler transaction: snapshot write → webhook/website sync in separate sessions - Remove hardcoded 170.106.100.210 migration from database.py - Reset consecutive_failures on upstream update - Healthcheck: install curl, replace python -c with curl -f - Add .dockerignore to reduce build context - Frontend: add axios-retry with exponential backoff (5xx/network errors only)
147 lines
5.2 KiB
Python
147 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
|
|
from typing import Any
|
|
from urllib.parse import quote
|
|
|
|
import httpx
|
|
|
|
from app.utils.number import decimal_string
|
|
|
|
|
|
class WebsiteError(RuntimeError):
|
|
pass
|
|
|
|
|
|
def parse_positive_decimal(value: Any) -> Decimal | None:
|
|
if value is None or value == "":
|
|
return None
|
|
try:
|
|
d = Decimal(str(value))
|
|
except (InvalidOperation, ValueError):
|
|
return None
|
|
return d if d > 0 else None
|
|
|
|
|
|
def calculate_target_rate(values: list[Any], percent: Any = 0, algorithm: str = "max_plus_percent") -> Decimal:
|
|
rates = [rate for rate in (parse_positive_decimal(v) for v in values) if rate is not None]
|
|
if not rates:
|
|
raise WebsiteError("没有可用的正数上游倍率")
|
|
if algorithm == "average_plus_percent":
|
|
base = sum(rates, Decimal("0")) / Decimal(len(rates))
|
|
elif algorithm == "min_plus_percent":
|
|
base = min(rates)
|
|
elif algorithm == "max_plus_percent":
|
|
base = max(rates)
|
|
else:
|
|
raise WebsiteError(f"不支持的算法:{algorithm}")
|
|
pct = Decimal(str(percent or 0))
|
|
if pct < 0:
|
|
raise WebsiteError("百分比不能为负数")
|
|
return (base * (Decimal("1") + pct / Decimal("100"))).quantize(Decimal("0.0001"), rounding=ROUND_HALF_UP)
|
|
|
|
|
|
def _unwrap_data(value: Any) -> Any:
|
|
if isinstance(value, dict):
|
|
data = value.get("data")
|
|
if "data" in value and (
|
|
"code" in value
|
|
or "message" in value
|
|
or isinstance(data, list)
|
|
or (isinstance(data, dict) and any(key in data for key in ("items", "groups")))
|
|
):
|
|
value = data
|
|
if not isinstance(value, dict):
|
|
return value
|
|
for key in ("items", "groups"):
|
|
if key in value:
|
|
return value.get(key)
|
|
return value
|
|
|
|
|
|
def normalize_groups(value: Any) -> list[dict[str, Any]]:
|
|
raw = _unwrap_data(value)
|
|
if isinstance(raw, dict):
|
|
raw = list(raw.values())
|
|
if not isinstance(raw, list):
|
|
raise WebsiteError("分组接口没有返回列表")
|
|
groups: list[dict[str, Any]] = []
|
|
for item in raw:
|
|
if isinstance(item, str):
|
|
groups.append({"id": item, "name": item, "rate_multiplier": None, "raw": {"id": item, "name": item}})
|
|
continue
|
|
if not isinstance(item, dict):
|
|
continue
|
|
gid = item.get("id") or item.get("group_id") or item.get("groupId") or item.get("name") or item.get("group_name")
|
|
if gid is None:
|
|
continue
|
|
name = item.get("name") or item.get("group_name") or str(gid)
|
|
rate = item.get("rate_multiplier") or item.get("rateMultiplier") or item.get("ratio")
|
|
groups.append({
|
|
"id": str(gid),
|
|
"name": str(name),
|
|
"rate_multiplier": decimal_string(rate) if rate is not None else None,
|
|
"raw": item,
|
|
})
|
|
return groups
|
|
|
|
|
|
class Sub2ApiWebsiteClient:
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
api_prefix: str,
|
|
auth_type: str,
|
|
auth_config: dict[str, Any],
|
|
timeout: float = 30.0,
|
|
) -> None:
|
|
self.base_url = base_url.rstrip("/")
|
|
self.api_prefix = api_prefix.strip("/")
|
|
self.auth_type = auth_type
|
|
self.auth_config = auth_config
|
|
self.timeout = timeout
|
|
self._client = httpx.Client(timeout=timeout)
|
|
|
|
def close(self) -> None:
|
|
self._client.close()
|
|
|
|
def _url(self, path: str) -> str:
|
|
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
|
|
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
|
|
|
|
def _headers(self) -> dict[str, str]:
|
|
headers = {"Accept": "application/json", "User-Agent": "SmartUp/1.0"}
|
|
if self.auth_type == "api_key":
|
|
key = self.auth_config.get("key") or self.auth_config.get("api_key") or ""
|
|
header = self.auth_config.get("header") or "x-api-key"
|
|
if key:
|
|
headers[header] = key
|
|
elif self.auth_type == "bearer":
|
|
token = self.auth_config.get("token") or ""
|
|
if token:
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
return headers
|
|
|
|
def _request(self, method: str, path: str, body: Any = None) -> Any:
|
|
resp = self._client.request(method, self._url(path), json=body, headers=self._headers())
|
|
resp.raise_for_status()
|
|
if not resp.content:
|
|
return None
|
|
text = resp.text
|
|
if "application/json" not in resp.headers.get("content-type", "") and text.lstrip().startswith("<"):
|
|
raise WebsiteError(f"{method} {path} returned HTML, not JSON")
|
|
return resp.json()
|
|
|
|
def get_groups(self, endpoint: str = "/groups") -> list[dict[str, Any]]:
|
|
errors: list[str] = []
|
|
for path in [endpoint, "/groups/all"]:
|
|
try:
|
|
return normalize_groups(self._request("GET", path))
|
|
except Exception as exc:
|
|
errors.append(f"{path}: {exc}")
|
|
raise WebsiteError("; ".join(errors))
|
|
|
|
def update_group_rate(self, endpoint_template: str, group_id: str, rate: Decimal) -> Any:
|
|
path = endpoint_template.replace("{id}", quote(group_id, safe=""))
|
|
return self._request("PUT", path, {"rate_multiplier": float(rate)})
|