feat: sync upstream keys and reorder priorities
This commit is contained in:
@@ -23,6 +23,9 @@ from app.schemas.website import (
|
||||
ImportGroupItem,
|
||||
ImportGroupsRequest,
|
||||
ImportGroupsResponse,
|
||||
ReorderPriorityItem,
|
||||
ReorderPriorityRequest,
|
||||
ReorderPriorityResponse,
|
||||
SyncImportStatusRequest,
|
||||
TestResult,
|
||||
WebsiteCreate,
|
||||
@@ -34,7 +37,13 @@ from app.schemas.website import (
|
||||
)
|
||||
|
||||
from app.services.website_client import Sub2ApiWebsiteClient
|
||||
from app.services.website_sync import binding_sources, sync_binding, build_rate_priority_map, reconcile_upstream_keys_full
|
||||
from app.services.website_sync import (
|
||||
binding_sources,
|
||||
sync_binding,
|
||||
build_rate_priority_map,
|
||||
reconcile_upstream_keys_full,
|
||||
sync_account_priorities_for_upstream,
|
||||
)
|
||||
from app.utils.auth import get_current_user
|
||||
|
||||
router = APIRouter(tags=["websites"])
|
||||
@@ -463,6 +472,44 @@ def sync_imported_upstream_keys(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/websites/{wid}/accounts/reorder-priority", response_model=ReorderPriorityResponse)
|
||||
def reorder_account_priorities(
|
||||
wid: int,
|
||||
body: ReorderPriorityRequest,
|
||||
db: Session = Depends(get_db),
|
||||
_=Depends(get_current_user),
|
||||
):
|
||||
website = db.query(Website).filter(Website.id == wid).first()
|
||||
if not website:
|
||||
raise HTTPException(404, "website not found")
|
||||
if website.site_type != "sub2api":
|
||||
raise HTTPException(400, "目前只支持 sub2api")
|
||||
|
||||
upstream = db.query(Upstream).filter(Upstream.id == body.upstream_id).first()
|
||||
if not upstream:
|
||||
raise HTTPException(404, "upstream not found")
|
||||
|
||||
results = sync_account_priorities_for_upstream(db, body.upstream_id, website_id=wid)
|
||||
failed_count = sum(1 for item in results if item.get("status") == "failed")
|
||||
success_count = sum(1 for item in results if item.get("status") == "success")
|
||||
skipped_count = sum(1 for item in results if item.get("status") == "skipped")
|
||||
|
||||
parts = []
|
||||
if success_count:
|
||||
parts.append(f"更新 {success_count}")
|
||||
if failed_count:
|
||||
parts.append(f"失败 {failed_count}")
|
||||
if skipped_count:
|
||||
parts.append(f"跳过 {skipped_count}")
|
||||
message = "、".join(parts) + f" / 共 {len(results)} 个" if parts else "没有需要重排的账号"
|
||||
|
||||
return ReorderPriorityResponse(
|
||||
success=failed_count == 0,
|
||||
message=message,
|
||||
items=[ReorderPriorityItem(**item) for item in results],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/websites/{wid}/accounts/import-upstream-keys", response_model=ImportAccountsResponse)
|
||||
def import_upstream_keys_as_accounts(
|
||||
wid: int,
|
||||
|
||||
@@ -180,6 +180,26 @@ class ImportAccountsResponse(BaseModel):
|
||||
items: list[ImportAccountItem]
|
||||
|
||||
|
||||
class ReorderPriorityRequest(BaseModel):
|
||||
upstream_id: int = Field(gt=0)
|
||||
|
||||
|
||||
class ReorderPriorityItem(BaseModel):
|
||||
account_id: Optional[str] = None
|
||||
group_id: str = ""
|
||||
upstream_id: int
|
||||
old_priority: Optional[int] = None
|
||||
new_priority: Optional[int] = None
|
||||
status: str
|
||||
message: str = ""
|
||||
|
||||
|
||||
class ReorderPriorityResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
items: list[ReorderPriorityItem]
|
||||
|
||||
|
||||
class WebsiteBatchSyncResponse(BaseModel):
|
||||
total: int
|
||||
success: int
|
||||
|
||||
@@ -407,17 +407,25 @@ class UpstreamClient:
|
||||
out["group_name"] = str(out.get("group"))
|
||||
return out
|
||||
|
||||
def _list_new_api_tokens(
|
||||
self,
|
||||
search: str = "",
|
||||
group_id: str | int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if search:
|
||||
path = "/api/token/search"
|
||||
params = {"keyword": search, "token": "", "p": 1, "size": 100}
|
||||
else:
|
||||
path = "/api/token/"
|
||||
params = {"p": 1, "size": 100}
|
||||
@staticmethod
|
||||
def _extract_new_api_token_items(payload: Any) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
nested = _unwrap_data(payload)
|
||||
meta: dict[str, Any] = {}
|
||||
items: list[dict[str, Any]] | None = None
|
||||
if isinstance(nested, list):
|
||||
items = [i for i in nested if isinstance(i, dict)]
|
||||
elif isinstance(nested, dict):
|
||||
meta = nested
|
||||
for key in ("items", "tokens", "list", "records"):
|
||||
value = nested.get(key)
|
||||
if isinstance(value, list):
|
||||
items = [i for i in value if isinstance(i, dict)]
|
||||
break
|
||||
if items is None:
|
||||
raise UpstreamError("unexpected New-API token list response")
|
||||
return items, meta
|
||||
|
||||
def _request_new_api_token_list(self, path: str, params: dict[str, Any]) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
resp = self._client.request(
|
||||
"GET",
|
||||
self._url(path),
|
||||
@@ -429,23 +437,80 @@ class UpstreamClient:
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._ensure_api_success(data, "list New-API tokens")
|
||||
nested = _unwrap_data(data)
|
||||
items: list[dict[str, Any]] | None = None
|
||||
if isinstance(nested, list):
|
||||
items = [i for i in nested if isinstance(i, dict)]
|
||||
elif isinstance(nested, dict):
|
||||
for key in ("items", "tokens", "list", "records"):
|
||||
value = nested.get(key)
|
||||
if isinstance(value, list):
|
||||
items = [i for i in value if isinstance(i, dict)]
|
||||
break
|
||||
if items is None:
|
||||
raise UpstreamError("unexpected New-API token list response")
|
||||
normalized = [self._normalize_key_record(i) for i in items]
|
||||
return self._extract_new_api_token_items(data)
|
||||
|
||||
def _list_all_new_api_tokens(self, page_size: int = 100, max_pages: int = 20) -> list[dict[str, Any]]:
|
||||
all_items: list[dict[str, Any]] = []
|
||||
for page in range(1, max_pages + 1):
|
||||
items, meta = self._request_new_api_token_list(
|
||||
"/api/token/",
|
||||
{"p": page, "size": page_size},
|
||||
)
|
||||
all_items.extend(items)
|
||||
total = meta.get("total") if isinstance(meta, dict) else None
|
||||
if isinstance(total, int) and len(all_items) >= total:
|
||||
break
|
||||
if len(items) < page_size:
|
||||
break
|
||||
return [self._normalize_key_record(i) for i in all_items]
|
||||
|
||||
@staticmethod
|
||||
def _matches_new_api_token_search(record: dict[str, Any], search: str) -> bool:
|
||||
if not search:
|
||||
return True
|
||||
needle = search.strip()
|
||||
if not needle:
|
||||
return True
|
||||
name = str(record.get("name") or record.get("key_name") or "")
|
||||
key = str(record.get("key") or record.get("api_key") or record.get("apiKey") or record.get("token") or "")
|
||||
return name.startswith(needle) or name == needle or key == needle
|
||||
|
||||
def _hydrate_new_api_token_key(self, record: dict[str, Any]) -> dict[str, Any]:
|
||||
out = dict(record)
|
||||
key_value = _extract_key_value(out)
|
||||
if key_value and "*" not in key_value:
|
||||
out["key"] = key_value
|
||||
out["masked_key"] = out.get("masked_key") or mask_secret(key_value)
|
||||
return out
|
||||
token_id = out.get("id")
|
||||
if token_id is None:
|
||||
return out
|
||||
try:
|
||||
plaintext = self._get_new_api_token_key(token_id)
|
||||
except Exception:
|
||||
return out
|
||||
out["key"] = plaintext
|
||||
out["masked_key"] = mask_secret(plaintext)
|
||||
return out
|
||||
|
||||
def _list_new_api_tokens(
|
||||
self,
|
||||
search: str = "",
|
||||
group_id: str | int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
normalized: list[dict[str, Any]] = []
|
||||
if search:
|
||||
search_items, _ = self._request_new_api_token_list(
|
||||
"/api/token/search",
|
||||
{"keyword": search, "token": "", "p": 1, "size": 100},
|
||||
)
|
||||
normalized.extend(self._normalize_key_record(i) for i in search_items)
|
||||
|
||||
all_tokens = self._list_all_new_api_tokens()
|
||||
seen_ids = {str(i.get("id")) for i in normalized if i.get("id") is not None}
|
||||
for item in all_tokens:
|
||||
item_id = item.get("id")
|
||||
if item_id is not None and str(item_id) in seen_ids:
|
||||
continue
|
||||
if self._matches_new_api_token_search(item, search):
|
||||
normalized.append(item)
|
||||
if item_id is not None:
|
||||
seen_ids.add(str(item_id))
|
||||
|
||||
if group_id is not None:
|
||||
gid = str(group_id)
|
||||
normalized = [i for i in normalized if str(i.get("group_id") or i.get("group") or "") == gid]
|
||||
return normalized
|
||||
return [self._hydrate_new_api_token_key(i) for i in normalized]
|
||||
|
||||
def _get_new_api_token_key(self, token_id: str | int) -> str:
|
||||
payload = self._request("POST", f"/api/token/{token_id}/key")
|
||||
|
||||
@@ -18,6 +18,15 @@ from app.services import webhook_service
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
PRIORITY_BASE = 1
|
||||
PRIORITY_STEP = 10
|
||||
|
||||
|
||||
def priority_for_rate_rank(rank: int) -> int:
|
||||
"""Convert a zero-based sorted rate rank to an account priority."""
|
||||
return PRIORITY_BASE + rank * PRIORITY_STEP
|
||||
|
||||
|
||||
def binding_sources(binding: WebsiteGroupBinding) -> list[dict[str, Any]]:
|
||||
try:
|
||||
data = json.loads(binding.source_groups_json or "[]")
|
||||
@@ -187,7 +196,7 @@ def build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, in
|
||||
|
||||
使用 (upstream_id, group_id) 复合键避免不同上游的同名分组互相覆盖。
|
||||
遍历所有涉及的上游的最新快照,收集分组的倍率,按倍率升序排列后赋值 priority。
|
||||
倍率最低的 priority=1,次低的 priority=2,以此类推。相同倍率的分组共享同一 priority。
|
||||
倍率最低的 priority=1,次低的 priority=11,以此类推。相同倍率的分组共享同一 priority。
|
||||
"""
|
||||
group_rates: dict[str, float] = {}
|
||||
for uid in upstream_ids:
|
||||
@@ -199,7 +208,7 @@ def build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, in
|
||||
key = f"{uid}:{gid}"
|
||||
group_rates[key] = rate
|
||||
unique_rates = sorted(set(group_rates.values()))
|
||||
rate_to_priority = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
|
||||
rate_to_priority = {rate: priority_for_rate_rank(idx) for idx, rate in enumerate(unique_rates)}
|
||||
return {key: rate_to_priority[rate] for key, rate in group_rates.items()}
|
||||
|
||||
|
||||
@@ -285,27 +294,30 @@ def _try_send_priority_webhook(
|
||||
logger.warning("account_priority_changed webhook failed for website %s: %s", wid, exc)
|
||||
|
||||
|
||||
def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[dict]:
|
||||
def sync_account_priorities_for_upstream(
|
||||
db: Session,
|
||||
upstream_id: int,
|
||||
website_id: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""上游倍率变化后,自动更新已导入下游账号的 priority。
|
||||
|
||||
只处理同一目标分组内有多个账号(存在竞争)的情况:
|
||||
- 竞争分组键:imported_target_group_id(老数据 fallback 到 group_id)
|
||||
- 同一竞争分组内按倍率升序排序,priority 从 1 开始(相同倍率共享)
|
||||
- 同一竞争分组内按倍率升序排序,priority 从 1 开始,每档间隔 10(相同倍率共享)
|
||||
- 单账号分组:完全跳过,不调用 update_account,不发通知
|
||||
- 无竞争分组:直接返回,不写日志,不发通知
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
key_rows = (
|
||||
db.query(UpstreamGeneratedKey)
|
||||
.filter(
|
||||
UpstreamGeneratedKey.upstream_id == upstream_id,
|
||||
UpstreamGeneratedKey.imported_website_id.isnot(None),
|
||||
UpstreamGeneratedKey.imported_account_id.isnot(None),
|
||||
UpstreamGeneratedKey.status != "orphaned",
|
||||
)
|
||||
.all()
|
||||
key_query = db.query(UpstreamGeneratedKey).filter(
|
||||
UpstreamGeneratedKey.upstream_id == upstream_id,
|
||||
UpstreamGeneratedKey.imported_website_id.isnot(None),
|
||||
UpstreamGeneratedKey.imported_account_id.isnot(None),
|
||||
UpstreamGeneratedKey.status != "orphaned",
|
||||
)
|
||||
if website_id is not None:
|
||||
key_query = key_query.filter(UpstreamGeneratedKey.imported_website_id == website_id)
|
||||
key_rows = key_query.all()
|
||||
if not key_rows:
|
||||
return []
|
||||
|
||||
@@ -387,7 +399,7 @@ def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[
|
||||
continue
|
||||
# 组内按倍率升序排序(倍率低 → priority 小 → 优先)
|
||||
unique_rates = sorted(set(r for _, r in rated))
|
||||
rate_to_prio = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
|
||||
rate_to_prio = {rate: priority_for_rate_rank(idx) for idx, rate in enumerate(unique_rates)}
|
||||
for row, rate in rated:
|
||||
priority_assignment[row.imported_account_id] = rate_to_prio[rate]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user