6044b00685
- 上游 Key 命名改为 {prefix}-{upstream.id}-{safe_group_name}-{group_id}
- 唯一约束 (upstream_id, group_id, managed_prefix) 加 managed_prefix 列
- 上游检测成功时同步 Key 状态,远端已删/分组已删自动清理
- 重复分组导入跳过,目标网站已存在同名分组返回 exists
- 账号导入平台自动识别(auto/manual 模式)
- 全选可导入 Key 按钮 + 目标分组自动匹配
- 导入幂等:已导入过的 Key 校验远端账号,不存在则重建
- 新增同步接口 POST /sync-imported-upstream-keys
- account_exists() 通过拉取账号列表判断,避免 404 误判
- credentials.base_url 注入来源上游地址,避免 401
- 前端导入弹窗自动同步+刷新按钮+并发/优先级设置
- 新增 12 个测试覆盖同步、幂等、远端删除、校验失败路径
282 lines
11 KiB
Python
282 lines
11 KiB
Python
from __future__ import annotations
|
||
|
||
import logging
|
||
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
|
||
from typing import Any
|
||
from urllib.parse import quote
|
||
|
||
import httpx
|
||
|
||
from app.utils.number import decimal_string
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class WebsiteError(RuntimeError):
|
||
pass
|
||
|
||
|
||
def _friendly_http_error(exc: httpx.HTTPStatusError) -> str:
|
||
"""将常见 HTTP 错误转换为中文友好提示,原始信息保留在日志中。"""
|
||
status = exc.response.status_code
|
||
url = exc.request.url if exc.request else "?"
|
||
logger.warning("website_client HTTP %s from %s: %s", status, url, exc)
|
||
if status == 401:
|
||
return "目标网站认证失败,请检查 Admin API Key / JWT 是否正确"
|
||
if status == 403:
|
||
return "目标网站权限不足,请检查当前凭证是否有分组管理权限"
|
||
if status == 404:
|
||
return f"目标网站接口不存在,请检查 API Prefix 和分组接口路径({exc.response.url.path})"
|
||
if 500 <= status < 600:
|
||
return "目标网站服务异常,请稍后重试"
|
||
return f"目标网站返回错误(HTTP {status})"
|
||
|
||
|
||
def _friendly_connection_error(exc: Exception) -> str:
|
||
"""将网络/超时异常转换为中文友好提示。"""
|
||
logger.warning("website_client connection error: %s", exc)
|
||
if isinstance(exc, httpx.TimeoutException):
|
||
return "目标网站请求超时,请检查网络连接和 API 地址是否正确"
|
||
if isinstance(exc, httpx.ConnectError):
|
||
return "无法连接目标网站,请检查 API 地址和网络连通性"
|
||
return f"目标网站通信异常:{exc}"
|
||
|
||
|
||
def parse_positive_decimal(value: Any) -> Decimal | None:
|
||
if value is None or value == "":
|
||
return None
|
||
try:
|
||
d = Decimal(str(value))
|
||
except (InvalidOperation, ValueError):
|
||
return None
|
||
return d if d > 0 else None
|
||
|
||
|
||
def calculate_target_rate(values: list[Any], percent: Any = 0, algorithm: str = "max_plus_percent") -> Decimal:
|
||
rates = [rate for rate in (parse_positive_decimal(v) for v in values) if rate is not None]
|
||
if not rates:
|
||
raise WebsiteError("没有可用的正数上游倍率")
|
||
if algorithm == "average_plus_percent":
|
||
base = sum(rates, Decimal("0")) / Decimal(len(rates))
|
||
elif algorithm == "min_plus_percent":
|
||
base = min(rates)
|
||
elif algorithm == "max_plus_percent":
|
||
base = max(rates)
|
||
else:
|
||
raise WebsiteError(f"不支持的算法:{algorithm}")
|
||
pct = Decimal(str(percent or 0))
|
||
if pct < 0:
|
||
raise WebsiteError("百分比不能为负数")
|
||
return (base * (Decimal("1") + pct / Decimal("100"))).quantize(Decimal("0.0001"), rounding=ROUND_HALF_UP)
|
||
|
||
|
||
def _unwrap_data(value: Any) -> Any:
|
||
if isinstance(value, dict):
|
||
data = value.get("data")
|
||
if "data" in value and (
|
||
"code" in value
|
||
or "message" in value
|
||
or isinstance(data, list)
|
||
or (isinstance(data, dict) and any(key in data for key in ("items", "groups")))
|
||
):
|
||
value = data
|
||
if not isinstance(value, dict):
|
||
return value
|
||
for key in ("items", "groups"):
|
||
if key in value:
|
||
return value.get(key)
|
||
return value
|
||
|
||
|
||
def _extract_id(value: Any) -> str:
|
||
if isinstance(value, dict):
|
||
for key in ("id", "account_id", "accountId", "group_id", "groupId"):
|
||
candidate = value.get(key)
|
||
if candidate is not None:
|
||
return str(candidate)
|
||
for key in ("data", "result", "account", "group"):
|
||
found = _extract_id(value.get(key))
|
||
if found:
|
||
return found
|
||
return ""
|
||
|
||
|
||
def normalize_groups(value: Any) -> list[dict[str, Any]]:
|
||
raw = _unwrap_data(value)
|
||
if isinstance(raw, dict):
|
||
raw = list(raw.values())
|
||
if not isinstance(raw, list):
|
||
raise WebsiteError("分组接口没有返回列表")
|
||
groups: list[dict[str, Any]] = []
|
||
for item in raw:
|
||
if isinstance(item, str):
|
||
groups.append({"id": item, "name": item, "rate_multiplier": None, "raw": {"id": item, "name": item}})
|
||
continue
|
||
if not isinstance(item, dict):
|
||
continue
|
||
gid = item.get("id") or item.get("group_id") or item.get("groupId") or item.get("name") or item.get("group_name")
|
||
if gid is None:
|
||
continue
|
||
name = item.get("name") or item.get("group_name") or str(gid)
|
||
rate = item.get("rate_multiplier") or item.get("rateMultiplier") or item.get("ratio")
|
||
groups.append({
|
||
"id": str(gid),
|
||
"name": str(name),
|
||
"rate_multiplier": decimal_string(rate) if rate is not None else None,
|
||
"raw": item,
|
||
})
|
||
return groups
|
||
|
||
|
||
class Sub2ApiWebsiteClient:
|
||
def __init__(
|
||
self,
|
||
base_url: str,
|
||
api_prefix: str,
|
||
auth_type: str,
|
||
auth_config: dict[str, Any],
|
||
timeout: float = 30.0,
|
||
) -> None:
|
||
self.base_url = base_url.rstrip("/")
|
||
self.api_prefix = api_prefix.strip("/")
|
||
self.auth_type = auth_type
|
||
self.auth_config = auth_config
|
||
self.timeout = timeout
|
||
self._client = httpx.Client(timeout=timeout)
|
||
|
||
def close(self) -> None:
|
||
self._client.close()
|
||
|
||
def __enter__(self) -> Sub2ApiWebsiteClient:
|
||
return self
|
||
|
||
def __exit__(self, *args: Any) -> None:
|
||
self.close()
|
||
|
||
def _url(self, path: str) -> str:
|
||
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
|
||
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
|
||
|
||
def _headers(self) -> dict[str, str]:
|
||
headers = {"Accept": "application/json", "User-Agent": "SmartUp/1.0"}
|
||
if self.auth_type == "api_key":
|
||
key = self.auth_config.get("key") or self.auth_config.get("api_key") or ""
|
||
header = self.auth_config.get("header") or "x-api-key"
|
||
if key:
|
||
headers[header] = key
|
||
elif self.auth_type == "bearer":
|
||
token = self.auth_config.get("token") or ""
|
||
if token:
|
||
headers["Authorization"] = f"Bearer {token}"
|
||
return headers
|
||
|
||
def _request(self, method: str, path: str, body: Any = None) -> Any:
|
||
try:
|
||
resp = self._client.request(method, self._url(path), json=body, headers=self._headers())
|
||
resp.raise_for_status()
|
||
except httpx.HTTPStatusError as exc:
|
||
raise WebsiteError(_friendly_http_error(exc)) from exc
|
||
except httpx.TimeoutException as exc:
|
||
raise WebsiteError(_friendly_connection_error(exc)) from exc
|
||
except httpx.ConnectError as exc:
|
||
raise WebsiteError(_friendly_connection_error(exc)) from exc
|
||
if not resp.content:
|
||
return None
|
||
text = resp.text
|
||
if "application/json" not in resp.headers.get("content-type", "") and text.lstrip().startswith("<"):
|
||
raise WebsiteError(f"{method} {path} 返回了 HTML,请检查接口地址是否正确")
|
||
return resp.json()
|
||
|
||
def get_groups(self, endpoint: str = "/groups") -> list[dict[str, Any]]:
|
||
"""拉取分组列表,尝试 endpoint 和 fallback /groups/all。"""
|
||
last_error: Exception | None = None
|
||
tried_paths: list[str] = []
|
||
for path in [endpoint, "/groups/all"]:
|
||
tried_paths.append(path)
|
||
try:
|
||
return normalize_groups(self._request("GET", path))
|
||
except WebsiteError as exc:
|
||
msg = str(exc)
|
||
# 认证/权限类错误:直接抛出,不需要尝试 fallback
|
||
if "认证失败" in msg or "权限不足" in msg:
|
||
raise
|
||
# 404/5xx 等路径相关错误,试试另一个路径
|
||
last_error = exc
|
||
except Exception as exc:
|
||
last_error = exc
|
||
logger.info("get_groups fallback %s failed: %s", path, exc)
|
||
|
||
msg = str(last_error) if last_error else "拉取分组失败"
|
||
raise WebsiteError(f"{msg}(尝试接口:{'、'.join(tried_paths)})")
|
||
|
||
def update_group_rate(self, endpoint_template: str, group_id: str, rate: Decimal) -> Any:
|
||
path = endpoint_template.replace("{id}", quote(group_id, safe=""))
|
||
return self._request("PUT", path, {"rate_multiplier": float(rate)})
|
||
|
||
def create_group(self, body: dict[str, Any], endpoint: str = "/groups") -> dict[str, Any]:
|
||
resp = self._request("POST", endpoint, body)
|
||
data = _unwrap_data(resp)
|
||
return data if isinstance(data, dict) else {"value": data}
|
||
|
||
def create_account(self, body: dict[str, Any], endpoint: str = "/accounts") -> dict[str, Any]:
|
||
resp = self._request("POST", endpoint, body)
|
||
data = _unwrap_data(resp)
|
||
return data if isinstance(data, dict) else {"value": data}
|
||
|
||
@staticmethod
|
||
def _unwrap_list(value: dict) -> list | None:
|
||
"""递归展开嵌套的列表包装:data.items、data.data、items、accounts 等。"""
|
||
if isinstance(value, list):
|
||
return value
|
||
if not isinstance(value, dict):
|
||
return None
|
||
# 先看顶层
|
||
for key in ("items", "accounts", "records", "list", "data"):
|
||
v = value.get(key)
|
||
if isinstance(v, list):
|
||
return v
|
||
# 再看 data.items、data.records、data.list 等嵌套
|
||
data_val = value.get("data")
|
||
if isinstance(data_val, dict):
|
||
for key in ("items", "records", "list", "data", "accounts"):
|
||
v = data_val.get(key)
|
||
if isinstance(v, list):
|
||
return v
|
||
return None
|
||
|
||
def _get_account_ids(self, endpoint: str = "/accounts") -> set[str] | None:
|
||
"""拉取远端账号列表。成功返回 ID 集合(可能为空),解析失败返回 None。"""
|
||
try:
|
||
resp = self._request("GET", endpoint)
|
||
except Exception:
|
||
logger.warning("account list fetch failed for %s", endpoint, exc_info=True)
|
||
return None
|
||
items = self._unwrap_list(resp)
|
||
if items is None:
|
||
logger.warning("account list unexpected format for %s", endpoint)
|
||
return None
|
||
ids: set[str] = set()
|
||
for item in items:
|
||
item_id = self.extract_id(item)
|
||
if item_id:
|
||
ids.add(item_id)
|
||
return ids
|
||
|
||
def account_exists(self, account_id: str, endpoint: str = "/accounts") -> bool | None:
|
||
"""检查目标账号是否存在。
|
||
|
||
优先拉取账号列表判断:
|
||
- 列表成功取到 → return account_id in ids(True=存在,False=已删除)
|
||
- 列表取不到(None)→ return None(校验失败,不清本地)
|
||
返回 True=存在,False=已删除,None=校验失败。
|
||
"""
|
||
ids = self._get_account_ids(endpoint)
|
||
if ids is None:
|
||
logger.warning("account_exists cannot verify %s: list fetch failed", account_id)
|
||
return None
|
||
return account_id in ids
|
||
|
||
@staticmethod
|
||
def extract_id(value: Any) -> str:
|
||
return _extract_id(value)
|