896 lines
36 KiB
Python
896 lines
36 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from datetime import datetime, timezone
|
||
from typing import List
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.database import get_db
|
||
from app.models.snapshot import UpstreamRateSnapshot
|
||
from app.models.upstream import Upstream
|
||
from app.models.upstream_key import UpstreamGeneratedKey
|
||
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
|
||
from app.schemas.website import (
|
||
BindingCreate,
|
||
BindingResponse,
|
||
BindingUpdate,
|
||
ImportAccountItem,
|
||
ImportAccountsRequest,
|
||
ImportAccountsResponse,
|
||
ImportGroupItem,
|
||
ImportGroupsRequest,
|
||
ImportGroupsResponse,
|
||
ReorderPriorityItem,
|
||
ReorderPriorityRequest,
|
||
ReorderPriorityResponse,
|
||
SyncImportStatusRequest,
|
||
TestResult,
|
||
WebsiteCreate,
|
||
WebsiteGroupResponse,
|
||
WebsiteResponse,
|
||
WebsiteSyncLogResponse,
|
||
WebsiteUpdate,
|
||
WebsiteBatchSyncResponse,
|
||
)
|
||
|
||
from app.services.website_client import Sub2ApiWebsiteClient
|
||
from app.services.website_sync import (
|
||
binding_sources,
|
||
sync_binding,
|
||
build_rate_priority_map,
|
||
reconcile_upstream_keys_full,
|
||
sync_account_priorities_for_upstream,
|
||
)
|
||
from app.utils.auth import get_current_user
|
||
|
||
router = APIRouter(tags=["websites"])
|
||
logger = logging.getLogger(__name__)
|
||
|
||
MASK = "***"
|
||
SECRET_KEYS = {"password", "token", "key", "secret", "api_key"}
|
||
ALGORITHMS = {"max_plus_percent", "average_plus_percent", "min_plus_percent"}
|
||
|
||
|
||
def _mask(cfg: dict) -> dict:
|
||
masked = {}
|
||
for key, value in cfg.items():
|
||
masked[key] = MASK if key.lower() in SECRET_KEYS and value else value
|
||
return masked
|
||
|
||
|
||
def _website_response(row: Website) -> WebsiteResponse:
|
||
return WebsiteResponse(
|
||
id=row.id,
|
||
name=row.name,
|
||
site_type=row.site_type,
|
||
base_url=row.base_url,
|
||
api_prefix=row.api_prefix,
|
||
auth_type=row.auth_type,
|
||
auth_config_masked=_mask(json.loads(row.auth_config_json or "{}")),
|
||
groups_endpoint=row.groups_endpoint,
|
||
group_update_endpoint=row.group_update_endpoint,
|
||
enabled=row.enabled,
|
||
auto_sync_enabled=row.auto_sync_enabled,
|
||
timeout_seconds=row.timeout_seconds,
|
||
last_status=row.last_status,
|
||
last_checked_at=row.last_checked_at,
|
||
last_error=row.last_error,
|
||
created_at=row.created_at,
|
||
updated_at=row.updated_at,
|
||
)
|
||
|
||
|
||
def _binding_response(db: Session, row: WebsiteGroupBinding) -> BindingResponse:
|
||
website = db.query(Website).filter(Website.id == row.website_id).first()
|
||
return BindingResponse(
|
||
id=row.id,
|
||
website_id=row.website_id,
|
||
website_name=website.name if website else "",
|
||
target_group_id=row.target_group_id,
|
||
target_group_name=row.target_group_name,
|
||
source_groups=binding_sources(row),
|
||
percent=float(row.percent or 0),
|
||
algorithm=row.algorithm,
|
||
enabled=row.enabled,
|
||
created_at=row.created_at,
|
||
updated_at=row.updated_at,
|
||
)
|
||
|
||
|
||
def _log_response(row: WebsiteSyncLog) -> WebsiteSyncLogResponse:
|
||
return WebsiteSyncLogResponse(
|
||
id=row.id,
|
||
website_id=row.website_id,
|
||
binding_id=row.binding_id,
|
||
target_group_id=row.target_group_id,
|
||
target_group_name=row.target_group_name,
|
||
algorithm=row.algorithm,
|
||
percent=float(row.percent or 0),
|
||
source_rates=json.loads(row.source_rates_json or "[]"),
|
||
old_rate=row.old_rate,
|
||
new_rate=row.new_rate,
|
||
status=row.status,
|
||
message=row.message,
|
||
created_at=row.created_at,
|
||
)
|
||
|
||
|
||
def _ensure_unique_target(db: Session, website_id: int, target_group_id: str, exclude_id: int | None = None) -> None:
|
||
q = db.query(WebsiteGroupBinding).filter(
|
||
WebsiteGroupBinding.website_id == website_id,
|
||
WebsiteGroupBinding.target_group_id == target_group_id,
|
||
)
|
||
if exclude_id is not None:
|
||
q = q.filter(WebsiteGroupBinding.id != exclude_id)
|
||
if q.first():
|
||
raise HTTPException(400, "同一目标网站分组只能维护一条绑定记录")
|
||
|
||
|
||
def _client(row: Website) -> Sub2ApiWebsiteClient:
|
||
return Sub2ApiWebsiteClient(
|
||
base_url=row.base_url,
|
||
api_prefix=row.api_prefix,
|
||
auth_type=row.auth_type,
|
||
auth_config=json.loads(row.auth_config_json or "{}"),
|
||
timeout=float(row.timeout_seconds),
|
||
)
|
||
|
||
|
||
def _latest_upstream_groups(db: Session, upstream_id: int) -> list[dict]:
|
||
row = (
|
||
db.query(UpstreamRateSnapshot)
|
||
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
|
||
.order_by(UpstreamRateSnapshot.captured_at.desc())
|
||
.first()
|
||
)
|
||
if not row:
|
||
raise HTTPException(404, "no upstream snapshot found; run upstream check first")
|
||
snapshot = json.loads(row.snapshot_json or "{}")
|
||
groups = snapshot.get("groups") or {}
|
||
if not isinstance(groups, dict):
|
||
return []
|
||
return [item for item in groups.values() if isinstance(item, dict)]
|
||
|
||
|
||
def _source_group_id(group: dict) -> str:
|
||
return str(group.get("group_id") or group.get("id") or group.get("name") or "")
|
||
|
||
|
||
def _source_group_name(group: dict, gid: str) -> str:
|
||
return str(group.get("group_name") or group.get("name") or gid)
|
||
|
||
|
||
def _source_group_rate(group: dict) -> float:
|
||
raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1
|
||
try:
|
||
return float(raw)
|
||
except (TypeError, ValueError):
|
||
return 1.0
|
||
|
||
|
||
def _numeric_group_id(value: str | None) -> int | None:
|
||
if value is None or value == "":
|
||
return None
|
||
try:
|
||
return int(value)
|
||
except ValueError:
|
||
return None
|
||
|
||
|
||
def _build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, int]:
|
||
"""根据上游分组倍率构建 group_id → priority 映射。
|
||
|
||
委托给 website_sync.build_rate_priority_map 避免逻辑重复。
|
||
"""
|
||
return build_rate_priority_map(db, upstream_ids)
|
||
|
||
|
||
@router.get("/api/websites", response_model=List[WebsiteResponse])
|
||
def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()]
|
||
|
||
|
||
@router.post("/api/websites", response_model=WebsiteResponse, status_code=201)
|
||
def create_website(body: WebsiteCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
if body.site_type != "sub2api":
|
||
raise HTTPException(400, "目前只支持 sub2api")
|
||
row = Website(
|
||
name=body.name,
|
||
site_type=body.site_type,
|
||
base_url=body.base_url.rstrip("/"),
|
||
api_prefix=body.api_prefix,
|
||
auth_type=body.auth_type,
|
||
auth_config_json=json.dumps(body.auth_config, ensure_ascii=False),
|
||
groups_endpoint=body.groups_endpoint,
|
||
group_update_endpoint=body.group_update_endpoint,
|
||
enabled=body.enabled,
|
||
auto_sync_enabled=body.auto_sync_enabled,
|
||
timeout_seconds=body.timeout_seconds,
|
||
)
|
||
db.add(row)
|
||
db.commit()
|
||
db.refresh(row)
|
||
return _website_response(row)
|
||
|
||
|
||
@router.put("/api/websites/{wid}", response_model=WebsiteResponse)
|
||
def update_website(wid: int, body: WebsiteUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
row = db.query(Website).filter(Website.id == wid).first()
|
||
if not row:
|
||
raise HTTPException(404, "website not found")
|
||
data = body.model_dump(exclude_none=True)
|
||
if "site_type" in data and data["site_type"] != "sub2api":
|
||
raise HTTPException(400, "目前只支持 sub2api")
|
||
if "auth_config" in data:
|
||
existing = json.loads(row.auth_config_json or "{}")
|
||
incoming = data.pop("auth_config")
|
||
for key, value in incoming.items():
|
||
if value != MASK:
|
||
existing[key] = value
|
||
row.auth_config_json = json.dumps(existing, ensure_ascii=False)
|
||
if "base_url" in data:
|
||
data["base_url"] = data["base_url"].rstrip("/")
|
||
for key, value in data.items():
|
||
setattr(row, key, value)
|
||
row.updated_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
db.refresh(row)
|
||
return _website_response(row)
|
||
|
||
|
||
@router.delete("/api/websites/{wid}", status_code=204)
|
||
def delete_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
row = db.query(Website).filter(Website.id == wid).first()
|
||
if not row:
|
||
raise HTTPException(404, "website not found")
|
||
db.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == wid).delete(synchronize_session=False)
|
||
db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.website_id == wid).delete(synchronize_session=False)
|
||
db.delete(row)
|
||
db.commit()
|
||
|
||
|
||
@router.post("/api/websites/{wid}/test", response_model=TestResult)
|
||
def test_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
row = db.query(Website).filter(Website.id == wid).first()
|
||
if not row:
|
||
raise HTTPException(404, "website not found")
|
||
try:
|
||
with _client(row) as c:
|
||
groups = c.get_groups(row.groups_endpoint)
|
||
row.last_status = "healthy"
|
||
row.last_error = None
|
||
row.last_checked_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
|
||
except Exception as exc:
|
||
row.last_status = "unhealthy"
|
||
row.last_error = str(exc)
|
||
row.last_checked_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
return TestResult(success=False, message="连接失败", detail=str(exc))
|
||
|
||
|
||
@router.get("/api/websites/{wid}/groups", response_model=List[WebsiteGroupResponse])
|
||
def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
row = db.query(Website).filter(Website.id == wid).first()
|
||
if not row:
|
||
raise HTTPException(404, "website not found")
|
||
try:
|
||
with _client(row) as c:
|
||
return c.get_groups(row.groups_endpoint)
|
||
except Exception as exc:
|
||
raise HTTPException(502, str(exc))
|
||
|
||
|
||
@router.post("/api/websites/{wid}/groups/import-from-upstream/{upstream_id}", response_model=ImportGroupsResponse)
|
||
def import_groups_from_upstream(
|
||
wid: int,
|
||
upstream_id: int,
|
||
body: ImportGroupsRequest,
|
||
db: Session = Depends(get_db),
|
||
_=Depends(get_current_user),
|
||
):
|
||
website = db.query(Website).filter(Website.id == wid).first()
|
||
if not website:
|
||
raise HTTPException(404, "website not found")
|
||
if website.site_type != "sub2api":
|
||
raise HTTPException(400, "目前只支持 sub2api")
|
||
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
|
||
if not upstream:
|
||
raise HTTPException(404, "upstream not found")
|
||
|
||
selected = set(body.group_ids)
|
||
groups = _latest_upstream_groups(db, upstream_id)
|
||
# 拉取目标网站已有分组,同名则跳过
|
||
try:
|
||
existing_names = set()
|
||
with _client(website) as c:
|
||
for eg in c.get_groups(website.groups_endpoint):
|
||
gname = eg.get("name") or eg.get("group_name") or ""
|
||
if gname:
|
||
existing_names.add(gname)
|
||
except Exception:
|
||
existing_names = set()
|
||
|
||
items: list[ImportGroupItem] = []
|
||
with _client(website) as c:
|
||
for group in groups:
|
||
source_gid = _source_group_id(group)
|
||
if not source_gid or (selected and source_gid not in selected):
|
||
continue
|
||
source_name = _source_group_name(group, source_gid)
|
||
target_name = f"{body.name_prefix}{source_name}" if body.name_prefix else source_name
|
||
|
||
# 检查是否已存在同名分组
|
||
if target_name in existing_names:
|
||
items.append(ImportGroupItem(
|
||
source_group_id=source_gid,
|
||
source_group_name=source_name,
|
||
target_group_name=target_name,
|
||
status="exists",
|
||
message="目标分组已存在,已跳过",
|
||
))
|
||
continue
|
||
|
||
create_body = {
|
||
"name": target_name,
|
||
"description": group.get("description") or f"Imported from {upstream.name} / {source_name}",
|
||
"platform": group.get("platform") or "openai",
|
||
"rate_multiplier": _source_group_rate(group),
|
||
}
|
||
if group.get("rpm_limit") is not None:
|
||
create_body["rpm_limit"] = group.get("rpm_limit")
|
||
try:
|
||
created = c.create_group(create_body)
|
||
target_id = c.extract_id(created)
|
||
items.append(ImportGroupItem(
|
||
source_group_id=source_gid,
|
||
source_group_name=source_name,
|
||
target_group_id=target_id or None,
|
||
target_group_name=str(created.get("name") or target_name),
|
||
status="created",
|
||
message="已创建",
|
||
raw=created,
|
||
))
|
||
except Exception as exc:
|
||
msg = str(exc)
|
||
# 捕获 409 等已存在错误
|
||
if "已存在" in msg or "already exists" in msg.lower() or "409" in msg or "Conflict" in msg:
|
||
items.append(ImportGroupItem(
|
||
source_group_id=source_gid,
|
||
source_group_name=source_name,
|
||
target_group_name=target_name,
|
||
status="exists",
|
||
message="目标分组已存在(接口返回冲突)",
|
||
))
|
||
else:
|
||
logger.exception("import website group failed website=%s upstream=%s group=%s", wid, upstream_id, source_gid)
|
||
items.append(ImportGroupItem(
|
||
source_group_id=source_gid,
|
||
source_group_name=source_name,
|
||
target_group_name=target_name,
|
||
status="failed",
|
||
message=msg,
|
||
))
|
||
created_count = len([item for item in items if item.status == "created"])
|
||
exists_count = len([item for item in items if item.status == "exists"])
|
||
failed_count = len([item for item in items if item.status == "failed"])
|
||
msg_parts = []
|
||
if created_count:
|
||
msg_parts.append(f"新建 {created_count}")
|
||
if exists_count:
|
||
msg_parts.append(f"已存在 {exists_count}")
|
||
if failed_count:
|
||
msg_parts.append(f"失败 {failed_count}")
|
||
return ImportGroupsResponse(
|
||
success=failed_count == 0,
|
||
message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共处理 {len(items)} 个分组",
|
||
items=items,
|
||
)
|
||
|
||
|
||
def _detect_platform(text: str, fallback: str = "openai") -> str:
|
||
"""根据 Key 名或分组名关键词判断平台类型。"""
|
||
lower = text.lower()
|
||
if "claude" in lower or "anthropic" in lower:
|
||
return "anthropic"
|
||
if "gemini" in lower:
|
||
return "gemini"
|
||
if "antigravity" in lower:
|
||
return "antigravity"
|
||
return fallback
|
||
|
||
|
||
@router.post("/api/websites/{wid}/accounts/sync-imported-upstream-keys", response_model=ImportAccountsResponse)
|
||
def sync_imported_upstream_keys(
|
||
wid: int,
|
||
body: SyncImportStatusRequest,
|
||
db: Session = Depends(get_db),
|
||
_=Depends(get_current_user),
|
||
):
|
||
"""校验已导入的上游 Key 在目标 Sub2API 账号管理中是否仍存在。"""
|
||
website = db.query(Website).filter(Website.id == wid).first()
|
||
if not website:
|
||
raise HTTPException(404, "website not found")
|
||
|
||
rows = (
|
||
db.query(UpstreamGeneratedKey)
|
||
.filter(
|
||
UpstreamGeneratedKey.upstream_id == body.upstream_id,
|
||
UpstreamGeneratedKey.imported_website_id == wid,
|
||
UpstreamGeneratedKey.imported_account_id.isnot(None),
|
||
)
|
||
.all()
|
||
)
|
||
items: list[ImportAccountItem] = []
|
||
with _client(website) as c:
|
||
for row in rows:
|
||
platform = _detect_platform(f"{row.group_name} {row.group_id} {row.key_name}", "openai")
|
||
if not row.imported_account_id:
|
||
continue
|
||
old_account_id = row.imported_account_id
|
||
exists = c.account_exists(row.imported_account_id)
|
||
if exists is False:
|
||
row.imported_website_id = None
|
||
row.imported_account_id = None
|
||
row.imported_at = None
|
||
row.status = "created"
|
||
items.append(ImportAccountItem(
|
||
upstream_key_id=row.id, source_group_id=row.group_id,
|
||
source_group_name=row.group_name, account_id=old_account_id,
|
||
platform=platform, status="stale_cleared",
|
||
message="目标账号已删除,已清除导入标记",
|
||
))
|
||
elif exists is True:
|
||
items.append(ImportAccountItem(
|
||
upstream_key_id=row.id, source_group_id=row.group_id,
|
||
source_group_name=row.group_name, account_id=old_account_id,
|
||
platform=platform, status="exists", message="目标账号仍存在",
|
||
))
|
||
else:
|
||
items.append(ImportAccountItem(
|
||
upstream_key_id=row.id, source_group_id=row.group_id,
|
||
source_group_name=row.group_name, account_id=old_account_id,
|
||
platform=platform, status="check_failed",
|
||
message="无法校验目标账号存在性(目标网站认证/网络问题)",
|
||
))
|
||
db.commit()
|
||
cleared_count = len([i for i in items if i.status == "stale_cleared"])
|
||
check_failed_count = len([i for i in items if i.status == "check_failed"])
|
||
msg_parts = []
|
||
if cleared_count:
|
||
msg_parts.append(f"清除 {cleared_count}")
|
||
if check_failed_count:
|
||
msg_parts.append(f"校验失败 {check_failed_count}")
|
||
return ImportAccountsResponse(
|
||
success=check_failed_count == 0,
|
||
message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共校验 {len(items)} 个,无变化",
|
||
items=items,
|
||
)
|
||
|
||
|
||
@router.post("/api/websites/{wid}/accounts/reorder-priority", response_model=ReorderPriorityResponse)
|
||
def reorder_account_priorities(
|
||
wid: int,
|
||
body: ReorderPriorityRequest,
|
||
db: Session = Depends(get_db),
|
||
_=Depends(get_current_user),
|
||
):
|
||
website = db.query(Website).filter(Website.id == wid).first()
|
||
if not website:
|
||
raise HTTPException(404, "website not found")
|
||
if website.site_type != "sub2api":
|
||
raise HTTPException(400, "目前只支持 sub2api")
|
||
|
||
upstream = db.query(Upstream).filter(Upstream.id == body.upstream_id).first()
|
||
if not upstream:
|
||
raise HTTPException(404, "upstream not found")
|
||
|
||
results = sync_account_priorities_for_upstream(db, body.upstream_id, website_id=wid)
|
||
failed_count = sum(1 for item in results if item.get("status") == "failed")
|
||
success_count = sum(1 for item in results if item.get("status") == "success")
|
||
skipped_count = sum(1 for item in results if item.get("status") == "skipped")
|
||
|
||
parts = []
|
||
if success_count:
|
||
parts.append(f"更新 {success_count}")
|
||
if failed_count:
|
||
parts.append(f"失败 {failed_count}")
|
||
if skipped_count:
|
||
parts.append(f"跳过 {skipped_count}")
|
||
message = "、".join(parts) + f" / 共 {len(results)} 个" if parts else "没有需要重排的账号"
|
||
|
||
return ReorderPriorityResponse(
|
||
success=failed_count == 0,
|
||
message=message,
|
||
items=[ReorderPriorityItem(**item) for item in results],
|
||
)
|
||
|
||
|
||
@router.post("/api/websites/{wid}/accounts/import-upstream-keys", response_model=ImportAccountsResponse)
|
||
def import_upstream_keys_as_accounts(
|
||
wid: int,
|
||
body: ImportAccountsRequest,
|
||
db: Session = Depends(get_db),
|
||
_=Depends(get_current_user),
|
||
):
|
||
website = db.query(Website).filter(Website.id == wid).first()
|
||
if not website:
|
||
raise HTTPException(404, "website not found")
|
||
if website.site_type != "sub2api":
|
||
raise HTTPException(400, "目前只支持 sub2api")
|
||
if not body.upstream_key_ids:
|
||
raise HTTPException(400, "请选择要导入的 Key")
|
||
rows = (
|
||
db.query(UpstreamGeneratedKey)
|
||
.filter(UpstreamGeneratedKey.id.in_(body.upstream_key_ids))
|
||
.order_by(UpstreamGeneratedKey.id)
|
||
.all()
|
||
)
|
||
# 导入前尝试对账(失败不阻塞,仅打日志—避免远端不可达时误删本地 Key)
|
||
upstream_ids = {row.upstream_id for row in rows}
|
||
for uid in upstream_ids:
|
||
try:
|
||
reconcile_upstream_keys_full(db, uid)
|
||
except Exception as exc:
|
||
logger.warning("import reconcile failed for upstream %s: %s", uid, exc)
|
||
# 重新查询(对账成功时已清理失效 Key)
|
||
rows = (
|
||
db.query(UpstreamGeneratedKey)
|
||
.filter(UpstreamGeneratedKey.id.in_(body.upstream_key_ids))
|
||
.order_by(UpstreamGeneratedKey.id)
|
||
.all()
|
||
)
|
||
found_ids = {row.id for row in rows}
|
||
missing_ids = [kid for kid in body.upstream_key_ids if kid not in found_ids]
|
||
items: list[ImportAccountItem] = [
|
||
ImportAccountItem(
|
||
upstream_key_id=kid,
|
||
source_group_id="",
|
||
source_group_name="",
|
||
platform=body.default_platform,
|
||
status="failed",
|
||
message="key not found",
|
||
)
|
||
for kid in missing_ids
|
||
]
|
||
# 查出来源上游的 Base URL
|
||
upstream_base_url = ""
|
||
if body.upstream_key_ids:
|
||
first_row = rows[0] if rows else None
|
||
if first_row:
|
||
from app.models.upstream import Upstream as _Up
|
||
_u = db.query(_Up).filter(_Up.id == first_row.upstream_id).first()
|
||
if _u:
|
||
upstream_base_url = _u.base_url
|
||
|
||
# 按倍率自动分配优先级
|
||
rate_priority_map: dict[str, int] = {}
|
||
if body.auto_priority_by_rate:
|
||
upstream_ids = {row.upstream_id for row in rows}
|
||
try:
|
||
rate_priority_map = _build_rate_priority_map(db, upstream_ids)
|
||
except HTTPException:
|
||
# 没有快照时忽略,后续 fallback 到 body.priority
|
||
pass
|
||
|
||
with _client(website) as c:
|
||
for row in rows:
|
||
# 跳过远端已不存在或导入失败的 Key(对账后标记为 orphaned / import_failed)
|
||
if row.status in ("orphaned", "failed", "import_failed"):
|
||
items.append(ImportAccountItem(
|
||
upstream_key_id=row.id,
|
||
source_group_id=row.group_id,
|
||
source_group_name=row.group_name,
|
||
target_group_id=body.target_group_map.get(row.group_id),
|
||
platform=body.default_platform,
|
||
status="failed",
|
||
message="远端 Key 已不存在,请重新生成",
|
||
))
|
||
continue
|
||
# 先确定平台(失败项也需要记录)
|
||
if body.platform_mode == "auto":
|
||
platform = _detect_platform(
|
||
f"{row.group_name} {row.group_id} {row.key_name}",
|
||
body.default_platform,
|
||
)
|
||
else:
|
||
platform = body.default_platform
|
||
|
||
# 幂等校验:已导入过则检查远端账号是否仍存在
|
||
if row.imported_website_id == wid and row.imported_account_id:
|
||
old_account_id = row.imported_account_id
|
||
exists = c.account_exists(row.imported_account_id)
|
||
if exists is True:
|
||
# 顺手回填 imported_target_group_id(老数据升级后可通过重导自动补齐)
|
||
new_tgid = body.target_group_map.get(row.group_id) or None
|
||
if new_tgid and row.imported_target_group_id != new_tgid:
|
||
row.imported_target_group_id = new_tgid
|
||
db.commit()
|
||
items.append(ImportAccountItem(
|
||
upstream_key_id=row.id,
|
||
source_group_id=row.group_id,
|
||
source_group_name=row.group_name,
|
||
target_group_id=body.target_group_map.get(row.group_id),
|
||
account_id=old_account_id,
|
||
account_name=f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}",
|
||
platform=platform,
|
||
upstream_base_url=upstream_base_url,
|
||
status="exists",
|
||
message="已导入过,已跳过",
|
||
))
|
||
continue
|
||
elif exists is False:
|
||
# 远端已删除,清空标记后继续创建
|
||
row.imported_website_id = None
|
||
row.imported_account_id = None
|
||
row.imported_at = None
|
||
row.status = "created"
|
||
# 继续往下走(不 continue)
|
||
else:
|
||
# 校验失败,保守跳过
|
||
items.append(ImportAccountItem(
|
||
upstream_key_id=row.id,
|
||
source_group_id=row.group_id,
|
||
source_group_name=row.group_name,
|
||
target_group_id=body.target_group_map.get(row.group_id),
|
||
account_id=old_account_id,
|
||
platform=platform,
|
||
status="check_failed",
|
||
message="无法校验目标账号状态,已保守跳过",
|
||
))
|
||
continue
|
||
|
||
if not row.key_value:
|
||
items.append(ImportAccountItem(
|
||
upstream_key_id=row.id,
|
||
source_group_id=row.group_id,
|
||
source_group_name=row.group_name,
|
||
platform=platform,
|
||
status="failed",
|
||
message="该 Key 无明文值,无法导入(远端已存在 Key 不会保留明文,请重新创建或手动填入)",
|
||
))
|
||
continue
|
||
|
||
target_group_id = body.target_group_map.get(row.group_id)
|
||
group_ids = []
|
||
numeric_target = _numeric_group_id(target_group_id)
|
||
if numeric_target is not None:
|
||
group_ids.append(numeric_target)
|
||
account_name = f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}"
|
||
account_body = {
|
||
"name": account_name,
|
||
"platform": platform,
|
||
"type": "apikey",
|
||
"credentials": {
|
||
"api_key": row.key_value,
|
||
"base_url": upstream_base_url,
|
||
},
|
||
"group_ids": group_ids,
|
||
"rate_multiplier": 1,
|
||
"concurrency": body.concurrency,
|
||
"priority": rate_priority_map.get(f"{row.upstream_id}:{row.group_id}", body.priority) if body.auto_priority_by_rate else body.priority,
|
||
"notes": f"Imported by SmartUp from upstream key #{row.id}",
|
||
}
|
||
try:
|
||
created = c.create_account(account_body)
|
||
account_id = c.extract_id(created)
|
||
row.imported_website_id = wid
|
||
row.imported_account_id = account_id or None
|
||
row.imported_at = datetime.now(timezone.utc)
|
||
row.imported_target_group_id = target_group_id or None
|
||
row.imported_target_group_name = None # target_group_map 只存 ID,name 展示用可留 NULL
|
||
row.status = "imported"
|
||
row.error = None
|
||
db.commit()
|
||
items.append(ImportAccountItem(
|
||
upstream_key_id=row.id,
|
||
source_group_id=row.group_id,
|
||
source_group_name=row.group_name,
|
||
target_group_id=target_group_id,
|
||
account_id=account_id or None,
|
||
account_name=str(created.get("name") or account_name),
|
||
platform=platform,
|
||
upstream_base_url=upstream_base_url,
|
||
status="created",
|
||
message="已创建账号",
|
||
raw=created,
|
||
))
|
||
except Exception as exc:
|
||
logger.exception("import upstream key as account failed website=%s key=%s", wid, row.id)
|
||
row.status = "import_failed"
|
||
row.error = str(exc)
|
||
db.commit()
|
||
items.append(ImportAccountItem(
|
||
upstream_key_id=row.id,
|
||
source_group_id=row.group_id,
|
||
source_group_name=row.group_name,
|
||
target_group_id=target_group_id,
|
||
account_name=account_name,
|
||
platform=platform,
|
||
upstream_base_url=upstream_base_url,
|
||
status="failed",
|
||
message=str(exc),
|
||
))
|
||
created_count = len([item for item in items if item.status == "created"])
|
||
exists_count = len([item for item in items if item.status == "exists"])
|
||
failed_count = len([item for item in items if item.status == "failed"])
|
||
check_failed_count = len([item for item in items if item.status == "check_failed"])
|
||
msg_parts = []
|
||
if created_count:
|
||
msg_parts.append(f"新建 {created_count}")
|
||
if exists_count:
|
||
msg_parts.append(f"已存在 {exists_count}")
|
||
if check_failed_count:
|
||
msg_parts.append(f"校验失败 {check_failed_count}")
|
||
if failed_count:
|
||
msg_parts.append(f"失败 {failed_count}")
|
||
return ImportAccountsResponse(
|
||
success=failed_count == 0 and check_failed_count == 0,
|
||
message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共处理 {len(items)} 个",
|
||
items=items,
|
||
)
|
||
|
||
|
||
@router.get("/api/group-bindings", response_model=List[BindingResponse])
|
||
def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all()
|
||
return [_binding_response(db, row) for row in rows]
|
||
|
||
|
||
@router.post("/api/websites/{wid}/group-bindings/sync-now", response_model=WebsiteBatchSyncResponse)
|
||
def sync_website_group_bindings(
|
||
wid: int,
|
||
db: Session = Depends(get_db),
|
||
_=Depends(get_current_user),
|
||
):
|
||
"""一键同步网站下所有分组绑定。"""
|
||
website = db.query(Website).filter(Website.id == wid).first()
|
||
if not website:
|
||
raise HTTPException(404, "website not found")
|
||
|
||
bindings = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.website_id == wid).order_by(WebsiteGroupBinding.id.asc()).all()
|
||
if not bindings:
|
||
return WebsiteBatchSyncResponse(
|
||
total=0, success=0, failed=0, skipped=0,
|
||
message="暂无绑定可同步",
|
||
logs=[],
|
||
)
|
||
|
||
results: list[WebsiteSyncLog] = []
|
||
for binding in bindings:
|
||
try:
|
||
log = sync_binding(db, binding, write=True)
|
||
results.append(log)
|
||
except Exception as exc:
|
||
logger.exception("batch sync failed for binding %s", binding.id)
|
||
# Create and persist a synthetic failed log if sync_binding crashed before creating one
|
||
synthetic_log = WebsiteSyncLog(
|
||
website_id=wid,
|
||
binding_id=binding.id,
|
||
target_group_id=binding.target_group_id,
|
||
target_group_name=binding.target_group_name,
|
||
algorithm=binding.algorithm,
|
||
percent=binding.percent,
|
||
source_rates_json="[]",
|
||
status="failed",
|
||
message=str(exc),
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
db.add(synthetic_log)
|
||
db.commit()
|
||
db.refresh(synthetic_log)
|
||
results.append(synthetic_log)
|
||
|
||
success_count = sum(1 for r in results if r.status == "success")
|
||
failed_count = sum(1 for r in results if r.status == "failed")
|
||
skipped_count = sum(1 for r in results if r.status == "skipped")
|
||
|
||
msg_parts = []
|
||
if success_count: msg_parts.append(f"成功 {success_count}")
|
||
if failed_count: msg_parts.append(f"失败 {failed_count}")
|
||
if skipped_count: msg_parts.append(f"跳过 {skipped_count}")
|
||
|
||
return WebsiteBatchSyncResponse(
|
||
total=len(bindings),
|
||
success=success_count,
|
||
failed=failed_count,
|
||
skipped=skipped_count,
|
||
message=f"同步完成:{'、'.join(msg_parts)}" if msg_parts else "同步完成",
|
||
logs=[_log_response(r) for r in results],
|
||
)
|
||
|
||
|
||
@router.post("/api/group-bindings", response_model=BindingResponse, status_code=201)
|
||
def create_binding(body: BindingCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
website = db.query(Website).filter(Website.id == body.website_id).first()
|
||
if not website:
|
||
raise HTTPException(404, "website not found")
|
||
if body.algorithm not in ALGORITHMS:
|
||
raise HTTPException(400, "不支持的算法")
|
||
_ensure_unique_target(db, body.website_id, body.target_group_id)
|
||
row = WebsiteGroupBinding(
|
||
website_id=body.website_id,
|
||
target_group_id=body.target_group_id,
|
||
target_group_name=body.target_group_name,
|
||
source_groups_json=json.dumps([item.model_dump() for item in body.source_groups], ensure_ascii=False),
|
||
percent=str(body.percent),
|
||
algorithm=body.algorithm,
|
||
enabled=body.enabled,
|
||
)
|
||
db.add(row)
|
||
db.commit()
|
||
db.refresh(row)
|
||
try:
|
||
sync_binding(db, row, write=True)
|
||
except Exception as exc:
|
||
logger.exception("initial website sync failed for binding %s: %s", row.id, exc)
|
||
return _binding_response(db, row)
|
||
|
||
|
||
@router.put("/api/group-bindings/{bid}", response_model=BindingResponse)
|
||
def update_binding(bid: int, body: BindingUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
|
||
if not row:
|
||
raise HTTPException(404, "binding not found")
|
||
data = body.model_dump(exclude_none=True)
|
||
if "website_id" in data and not db.query(Website).filter(Website.id == data["website_id"]).first():
|
||
raise HTTPException(404, "website not found")
|
||
if "algorithm" in data and data["algorithm"] not in ALGORITHMS:
|
||
raise HTTPException(400, "不支持的算法")
|
||
next_website_id = int(data.get("website_id", row.website_id))
|
||
next_target_group_id = str(data.get("target_group_id", row.target_group_id))
|
||
_ensure_unique_target(db, next_website_id, next_target_group_id, exclude_id=row.id)
|
||
if "source_groups" in data:
|
||
row.source_groups_json = json.dumps(data.pop("source_groups"), ensure_ascii=False)
|
||
if "percent" in data:
|
||
row.percent = str(data.pop("percent"))
|
||
for key, value in data.items():
|
||
setattr(row, key, value)
|
||
row.updated_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
db.refresh(row)
|
||
try:
|
||
sync_binding(db, row, write=True)
|
||
except Exception as exc:
|
||
logger.exception("sync failed after updating binding %s: %s", row.id, exc)
|
||
return _binding_response(db, row)
|
||
|
||
|
||
@router.delete("/api/group-bindings/{bid}", status_code=204)
|
||
def delete_binding(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
|
||
if not row:
|
||
raise HTTPException(404, "binding not found")
|
||
db.delete(row)
|
||
db.commit()
|
||
|
||
|
||
@router.post("/api/group-bindings/{bid}/sync-now", response_model=WebsiteSyncLogResponse)
|
||
def sync_now(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
|
||
if not row:
|
||
raise HTTPException(404, "binding not found")
|
||
return _log_response(sync_binding(db, row, write=True))
|
||
|
||
|
||
@router.get("/api/website-sync-logs", response_model=List[WebsiteSyncLogResponse])
|
||
def list_sync_logs(
|
||
website_id: int | None = Query(None),
|
||
binding_id: int | None = Query(None),
|
||
limit: int = Query(50, le=200),
|
||
offset: int = Query(0),
|
||
db: Session = Depends(get_db),
|
||
_=Depends(get_current_user),
|
||
):
|
||
q = db.query(WebsiteSyncLog)
|
||
if website_id:
|
||
q = q.filter(WebsiteSyncLog.website_id == website_id)
|
||
if binding_id:
|
||
q = q.filter(WebsiteSyncLog.binding_id == binding_id)
|
||
rows = q.order_by(WebsiteSyncLog.created_at.desc()).offset(offset).limit(limit).all()
|
||
return [_log_response(row) for row in rows]
|