from __future__ import annotations import json import logging from datetime import datetime, timezone from typing import List from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session from app.database import get_db from app.models.snapshot import UpstreamRateSnapshot from app.models.upstream import Upstream from app.models.upstream_key import UpstreamGeneratedKey from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog from app.schemas.website import ( BindingCreate, BindingResponse, BindingUpdate, ImportAccountItem, ImportAccountsRequest, ImportAccountsResponse, ImportGroupItem, ImportGroupsRequest, ImportGroupsResponse, SyncImportStatusRequest, TestResult, WebsiteCreate, WebsiteGroupResponse, WebsiteResponse, WebsiteSyncLogResponse, WebsiteUpdate, ) from app.services.website_client import Sub2ApiWebsiteClient from app.services.website_sync import binding_sources, sync_binding from app.utils.auth import get_current_user router = APIRouter(tags=["websites"]) logger = logging.getLogger(__name__) MASK = "***" SECRET_KEYS = {"password", "token", "key", "secret", "api_key"} ALGORITHMS = {"max_plus_percent", "average_plus_percent", "min_plus_percent"} def _mask(cfg: dict) -> dict: masked = {} for key, value in cfg.items(): masked[key] = MASK if key.lower() in SECRET_KEYS and value else value return masked def _website_response(row: Website) -> WebsiteResponse: return WebsiteResponse( id=row.id, name=row.name, site_type=row.site_type, base_url=row.base_url, api_prefix=row.api_prefix, auth_type=row.auth_type, auth_config_masked=_mask(json.loads(row.auth_config_json or "{}")), groups_endpoint=row.groups_endpoint, group_update_endpoint=row.group_update_endpoint, enabled=row.enabled, auto_sync_enabled=row.auto_sync_enabled, timeout_seconds=row.timeout_seconds, last_status=row.last_status, last_checked_at=row.last_checked_at, last_error=row.last_error, created_at=row.created_at, updated_at=row.updated_at, ) def _binding_response(db: Session, row: WebsiteGroupBinding) -> BindingResponse: website = db.query(Website).filter(Website.id == row.website_id).first() return BindingResponse( id=row.id, website_id=row.website_id, website_name=website.name if website else "", target_group_id=row.target_group_id, target_group_name=row.target_group_name, source_groups=binding_sources(row), percent=float(row.percent or 0), algorithm=row.algorithm, enabled=row.enabled, created_at=row.created_at, updated_at=row.updated_at, ) def _log_response(row: WebsiteSyncLog) -> WebsiteSyncLogResponse: return WebsiteSyncLogResponse( id=row.id, website_id=row.website_id, binding_id=row.binding_id, target_group_id=row.target_group_id, target_group_name=row.target_group_name, algorithm=row.algorithm, percent=float(row.percent or 0), source_rates=json.loads(row.source_rates_json or "[]"), old_rate=row.old_rate, new_rate=row.new_rate, status=row.status, message=row.message, created_at=row.created_at, ) def _ensure_unique_target(db: Session, website_id: int, target_group_id: str, exclude_id: int | None = None) -> None: q = db.query(WebsiteGroupBinding).filter( WebsiteGroupBinding.website_id == website_id, WebsiteGroupBinding.target_group_id == target_group_id, ) if exclude_id is not None: q = q.filter(WebsiteGroupBinding.id != exclude_id) if q.first(): raise HTTPException(400, "同一目标网站分组只能维护一条绑定记录") def _client(row: Website) -> Sub2ApiWebsiteClient: return Sub2ApiWebsiteClient( base_url=row.base_url, api_prefix=row.api_prefix, auth_type=row.auth_type, auth_config=json.loads(row.auth_config_json or "{}"), timeout=float(row.timeout_seconds), ) def _latest_upstream_groups(db: Session, upstream_id: int) -> list[dict]: row = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == upstream_id) .order_by(UpstreamRateSnapshot.captured_at.desc()) .first() ) if not row: raise HTTPException(404, "no upstream snapshot found; run upstream check first") snapshot = json.loads(row.snapshot_json or "{}") groups = snapshot.get("groups") or {} if not isinstance(groups, dict): return [] return [item for item in groups.values() if isinstance(item, dict)] def _source_group_id(group: dict) -> str: return str(group.get("group_id") or group.get("id") or group.get("name") or "") def _source_group_name(group: dict, gid: str) -> str: return str(group.get("group_name") or group.get("name") or gid) def _source_group_rate(group: dict) -> float: raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1 try: return float(raw) except (TypeError, ValueError): return 1.0 def _numeric_group_id(value: str | None) -> int | None: if value is None or value == "": return None try: return int(value) except ValueError: return None @router.get("/api/websites", response_model=List[WebsiteResponse]) def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)): return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()] @router.post("/api/websites", response_model=WebsiteResponse, status_code=201) def create_website(body: WebsiteCreate, db: Session = Depends(get_db), _=Depends(get_current_user)): if body.site_type != "sub2api": raise HTTPException(400, "目前只支持 sub2api") row = Website( name=body.name, site_type=body.site_type, base_url=body.base_url.rstrip("/"), api_prefix=body.api_prefix, auth_type=body.auth_type, auth_config_json=json.dumps(body.auth_config, ensure_ascii=False), groups_endpoint=body.groups_endpoint, group_update_endpoint=body.group_update_endpoint, enabled=body.enabled, auto_sync_enabled=body.auto_sync_enabled, timeout_seconds=body.timeout_seconds, ) db.add(row) db.commit() db.refresh(row) return _website_response(row) @router.put("/api/websites/{wid}", response_model=WebsiteResponse) def update_website(wid: int, body: WebsiteUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(Website).filter(Website.id == wid).first() if not row: raise HTTPException(404, "website not found") data = body.model_dump(exclude_none=True) if "site_type" in data and data["site_type"] != "sub2api": raise HTTPException(400, "目前只支持 sub2api") if "auth_config" in data: existing = json.loads(row.auth_config_json or "{}") incoming = data.pop("auth_config") for key, value in incoming.items(): if value != MASK: existing[key] = value row.auth_config_json = json.dumps(existing, ensure_ascii=False) if "base_url" in data: data["base_url"] = data["base_url"].rstrip("/") for key, value in data.items(): setattr(row, key, value) row.updated_at = datetime.now(timezone.utc) db.commit() db.refresh(row) return _website_response(row) @router.delete("/api/websites/{wid}", status_code=204) def delete_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(Website).filter(Website.id == wid).first() if not row: raise HTTPException(404, "website not found") db.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == wid).delete(synchronize_session=False) db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.website_id == wid).delete(synchronize_session=False) db.delete(row) db.commit() @router.post("/api/websites/{wid}/test", response_model=TestResult) def test_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(Website).filter(Website.id == wid).first() if not row: raise HTTPException(404, "website not found") try: with _client(row) as c: groups = c.get_groups(row.groups_endpoint) row.last_status = "healthy" row.last_error = None row.last_checked_at = datetime.now(timezone.utc) db.commit() return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组") except Exception as exc: row.last_status = "unhealthy" row.last_error = str(exc) row.last_checked_at = datetime.now(timezone.utc) db.commit() return TestResult(success=False, message="连接失败", detail=str(exc)) @router.get("/api/websites/{wid}/groups", response_model=List[WebsiteGroupResponse]) def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(Website).filter(Website.id == wid).first() if not row: raise HTTPException(404, "website not found") try: with _client(row) as c: return c.get_groups(row.groups_endpoint) except Exception as exc: raise HTTPException(502, str(exc)) @router.post("/api/websites/{wid}/groups/import-from-upstream/{upstream_id}", response_model=ImportGroupsResponse) def import_groups_from_upstream( wid: int, upstream_id: int, body: ImportGroupsRequest, db: Session = Depends(get_db), _=Depends(get_current_user), ): website = db.query(Website).filter(Website.id == wid).first() if not website: raise HTTPException(404, "website not found") if website.site_type != "sub2api": raise HTTPException(400, "目前只支持 sub2api") upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first() if not upstream: raise HTTPException(404, "upstream not found") selected = set(body.group_ids) groups = _latest_upstream_groups(db, upstream_id) # 拉取目标网站已有分组,同名则跳过 try: existing_names = set() with _client(website) as c: for eg in c.get_groups(website.groups_endpoint): gname = eg.get("name") or eg.get("group_name") or "" if gname: existing_names.add(gname) except Exception: existing_names = set() items: list[ImportGroupItem] = [] with _client(website) as c: for group in groups: source_gid = _source_group_id(group) if not source_gid or (selected and source_gid not in selected): continue source_name = _source_group_name(group, source_gid) target_name = f"{body.name_prefix}{source_name}" if body.name_prefix else source_name # 检查是否已存在同名分组 if target_name in existing_names: items.append(ImportGroupItem( source_group_id=source_gid, source_group_name=source_name, target_group_name=target_name, status="exists", message="目标分组已存在,已跳过", )) continue create_body = { "name": target_name, "description": group.get("description") or f"Imported from {upstream.name} / {source_name}", "platform": group.get("platform") or "openai", "rate_multiplier": _source_group_rate(group), } if group.get("rpm_limit") is not None: create_body["rpm_limit"] = group.get("rpm_limit") try: created = c.create_group(create_body) target_id = c.extract_id(created) items.append(ImportGroupItem( source_group_id=source_gid, source_group_name=source_name, target_group_id=target_id or None, target_group_name=str(created.get("name") or target_name), status="created", message="已创建", raw=created, )) except Exception as exc: msg = str(exc) # 捕获 409 等已存在错误 if "已存在" in msg or "already exists" in msg.lower() or "409" in msg or "Conflict" in msg: items.append(ImportGroupItem( source_group_id=source_gid, source_group_name=source_name, target_group_name=target_name, status="exists", message="目标分组已存在(接口返回冲突)", )) else: logger.exception("import website group failed website=%s upstream=%s group=%s", wid, upstream_id, source_gid) items.append(ImportGroupItem( source_group_id=source_gid, source_group_name=source_name, target_group_name=target_name, status="failed", message=msg, )) created_count = len([item for item in items if item.status == "created"]) exists_count = len([item for item in items if item.status == "exists"]) failed_count = len([item for item in items if item.status == "failed"]) msg_parts = [] if created_count: msg_parts.append(f"新建 {created_count}") if exists_count: msg_parts.append(f"已存在 {exists_count}") if failed_count: msg_parts.append(f"失败 {failed_count}") return ImportGroupsResponse( success=failed_count == 0, message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共处理 {len(items)} 个分组", items=items, ) def _detect_platform(text: str, fallback: str = "openai") -> str: """根据 Key 名或分组名关键词判断平台类型。""" lower = text.lower() if "claude" in lower or "anthropic" in lower: return "anthropic" if "gemini" in lower: return "gemini" if "antigravity" in lower: return "antigravity" return fallback @router.post("/api/websites/{wid}/accounts/sync-imported-upstream-keys", response_model=ImportAccountsResponse) def sync_imported_upstream_keys( wid: int, body: SyncImportStatusRequest, db: Session = Depends(get_db), _=Depends(get_current_user), ): """校验已导入的上游 Key 在目标 Sub2API 账号管理中是否仍存在。""" website = db.query(Website).filter(Website.id == wid).first() if not website: raise HTTPException(404, "website not found") rows = ( db.query(UpstreamGeneratedKey) .filter( UpstreamGeneratedKey.upstream_id == body.upstream_id, UpstreamGeneratedKey.imported_website_id == wid, UpstreamGeneratedKey.imported_account_id.isnot(None), ) .all() ) items: list[ImportAccountItem] = [] with _client(website) as c: for row in rows: platform = _detect_platform(f"{row.group_name} {row.group_id} {row.key_name}", "openai") if not row.imported_account_id: continue old_account_id = row.imported_account_id exists = c.account_exists(row.imported_account_id) if exists is False: row.imported_website_id = None row.imported_account_id = None row.imported_at = None row.status = "created" items.append(ImportAccountItem( upstream_key_id=row.id, source_group_id=row.group_id, source_group_name=row.group_name, account_id=old_account_id, platform=platform, status="stale_cleared", message="目标账号已删除,已清除导入标记", )) elif exists is True: items.append(ImportAccountItem( upstream_key_id=row.id, source_group_id=row.group_id, source_group_name=row.group_name, account_id=old_account_id, platform=platform, status="exists", message="目标账号仍存在", )) else: items.append(ImportAccountItem( upstream_key_id=row.id, source_group_id=row.group_id, source_group_name=row.group_name, account_id=old_account_id, platform=platform, status="check_failed", message="无法校验目标账号存在性(目标网站认证/网络问题)", )) db.commit() cleared_count = len([i for i in items if i.status == "stale_cleared"]) check_failed_count = len([i for i in items if i.status == "check_failed"]) msg_parts = [] if cleared_count: msg_parts.append(f"清除 {cleared_count}") if check_failed_count: msg_parts.append(f"校验失败 {check_failed_count}") return ImportAccountsResponse( success=check_failed_count == 0, message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共校验 {len(items)} 个,无变化", items=items, ) @router.post("/api/websites/{wid}/accounts/import-upstream-keys", response_model=ImportAccountsResponse) def import_upstream_keys_as_accounts( wid: int, body: ImportAccountsRequest, db: Session = Depends(get_db), _=Depends(get_current_user), ): website = db.query(Website).filter(Website.id == wid).first() if not website: raise HTTPException(404, "website not found") if website.site_type != "sub2api": raise HTTPException(400, "目前只支持 sub2api") if not body.upstream_key_ids: raise HTTPException(400, "请选择要导入的 Key") rows = ( db.query(UpstreamGeneratedKey) .filter(UpstreamGeneratedKey.id.in_(body.upstream_key_ids)) .order_by(UpstreamGeneratedKey.id) .all() ) found_ids = {row.id for row in rows} missing_ids = [kid for kid in body.upstream_key_ids if kid not in found_ids] items: list[ImportAccountItem] = [ ImportAccountItem( upstream_key_id=kid, source_group_id="", source_group_name="", platform=body.default_platform, status="failed", message="key not found", ) for kid in missing_ids ] # 查出来源上游的 Base URL upstream_base_url = "" if body.upstream_key_ids: first_row = rows[0] if rows else None if first_row: from app.models.upstream import Upstream as _Up _u = db.query(_Up).filter(_Up.id == first_row.upstream_id).first() if _u: upstream_base_url = _u.base_url with _client(website) as c: for row in rows: # 先确定平台(失败项也需要记录) if body.platform_mode == "auto": platform = _detect_platform( f"{row.group_name} {row.group_id} {row.key_name}", body.default_platform, ) else: platform = body.default_platform # 幂等校验:已导入过则检查远端账号是否仍存在 if row.imported_website_id == wid and row.imported_account_id: old_account_id = row.imported_account_id exists = c.account_exists(row.imported_account_id) if exists is True: items.append(ImportAccountItem( upstream_key_id=row.id, source_group_id=row.group_id, source_group_name=row.group_name, target_group_id=body.target_group_map.get(row.group_id), account_id=old_account_id, account_name=f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}", platform=platform, upstream_base_url=upstream_base_url, status="exists", message="已导入过,已跳过", )) continue elif exists is False: # 远端已删除,清空标记后继续创建 row.imported_website_id = None row.imported_account_id = None row.imported_at = None row.status = "created" # 继续往下走(不 continue) else: # 校验失败,保守跳过 items.append(ImportAccountItem( upstream_key_id=row.id, source_group_id=row.group_id, source_group_name=row.group_name, target_group_id=body.target_group_map.get(row.group_id), account_id=old_account_id, platform=platform, status="check_failed", message="无法校验目标账号状态,已保守跳过", )) continue if not row.key_value: items.append(ImportAccountItem( upstream_key_id=row.id, source_group_id=row.group_id, source_group_name=row.group_name, platform=platform, status="failed", message="该 Key 无明文值,无法导入(远端已存在 Key 不会保留明文,请重新创建或手动填入)", )) continue target_group_id = body.target_group_map.get(row.group_id) group_ids = [] numeric_target = _numeric_group_id(target_group_id) if numeric_target is not None: group_ids.append(numeric_target) account_name = f"{body.account_name_prefix}-{row.group_name or row.group_id}-{row.id}" account_body = { "name": account_name, "platform": platform, "type": "apikey", "credentials": { "api_key": row.key_value, "base_url": upstream_base_url, }, "group_ids": group_ids, "rate_multiplier": 1, "concurrency": body.concurrency, "priority": body.priority, "notes": f"Imported by SmartUp from upstream key #{row.id}", } try: created = c.create_account(account_body) account_id = c.extract_id(created) row.imported_website_id = wid row.imported_account_id = account_id or None row.imported_at = datetime.now(timezone.utc) row.status = "imported" row.error = None db.commit() items.append(ImportAccountItem( upstream_key_id=row.id, source_group_id=row.group_id, source_group_name=row.group_name, target_group_id=target_group_id, account_id=account_id or None, account_name=str(created.get("name") or account_name), platform=platform, upstream_base_url=upstream_base_url, status="created", message="已创建账号", raw=created, )) except Exception as exc: logger.exception("import upstream key as account failed website=%s key=%s", wid, row.id) row.status = "import_failed" row.error = str(exc) db.commit() items.append(ImportAccountItem( upstream_key_id=row.id, source_group_id=row.group_id, source_group_name=row.group_name, target_group_id=target_group_id, account_name=account_name, platform=platform, upstream_base_url=upstream_base_url, status="failed", message=str(exc), )) created_count = len([item for item in items if item.status == "created"]) exists_count = len([item for item in items if item.status == "exists"]) failed_count = len([item for item in items if item.status == "failed"]) check_failed_count = len([item for item in items if item.status == "check_failed"]) msg_parts = [] if created_count: msg_parts.append(f"新建 {created_count}") if exists_count: msg_parts.append(f"已存在 {exists_count}") if check_failed_count: msg_parts.append(f"校验失败 {check_failed_count}") if failed_count: msg_parts.append(f"失败 {failed_count}") return ImportAccountsResponse( success=failed_count == 0 and check_failed_count == 0, message="、".join(msg_parts) + f" / 共 {len(items)} 个" if msg_parts else f"共处理 {len(items)} 个", items=items, ) @router.get("/api/group-bindings", response_model=List[BindingResponse]) def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)): rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all() return [_binding_response(db, row) for row in rows] @router.post("/api/group-bindings", response_model=BindingResponse, status_code=201) def create_binding(body: BindingCreate, db: Session = Depends(get_db), _=Depends(get_current_user)): website = db.query(Website).filter(Website.id == body.website_id).first() if not website: raise HTTPException(404, "website not found") if body.algorithm not in ALGORITHMS: raise HTTPException(400, "不支持的算法") _ensure_unique_target(db, body.website_id, body.target_group_id) row = WebsiteGroupBinding( website_id=body.website_id, target_group_id=body.target_group_id, target_group_name=body.target_group_name, source_groups_json=json.dumps([item.model_dump() for item in body.source_groups], ensure_ascii=False), percent=str(body.percent), algorithm=body.algorithm, enabled=body.enabled, ) db.add(row) db.commit() db.refresh(row) try: sync_binding(db, row, write=True) except Exception as exc: logger.exception("initial website sync failed for binding %s: %s", row.id, exc) return _binding_response(db, row) @router.put("/api/group-bindings/{bid}", response_model=BindingResponse) def update_binding(bid: int, body: BindingUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first() if not row: raise HTTPException(404, "binding not found") data = body.model_dump(exclude_none=True) if "website_id" in data and not db.query(Website).filter(Website.id == data["website_id"]).first(): raise HTTPException(404, "website not found") if "algorithm" in data and data["algorithm"] not in ALGORITHMS: raise HTTPException(400, "不支持的算法") next_website_id = int(data.get("website_id", row.website_id)) next_target_group_id = str(data.get("target_group_id", row.target_group_id)) _ensure_unique_target(db, next_website_id, next_target_group_id, exclude_id=row.id) if "source_groups" in data: row.source_groups_json = json.dumps(data.pop("source_groups"), ensure_ascii=False) if "percent" in data: row.percent = str(data.pop("percent")) for key, value in data.items(): setattr(row, key, value) row.updated_at = datetime.now(timezone.utc) db.commit() db.refresh(row) try: sync_binding(db, row, write=True) except Exception as exc: logger.exception("sync failed after updating binding %s: %s", row.id, exc) return _binding_response(db, row) @router.delete("/api/group-bindings/{bid}", status_code=204) def delete_binding(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first() if not row: raise HTTPException(404, "binding not found") db.delete(row) db.commit() @router.post("/api/group-bindings/{bid}/sync-now", response_model=WebsiteSyncLogResponse) def sync_now(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first() if not row: raise HTTPException(404, "binding not found") return _log_response(sync_binding(db, row, write=True)) @router.get("/api/website-sync-logs", response_model=List[WebsiteSyncLogResponse]) def list_sync_logs( website_id: int | None = Query(None), binding_id: int | None = Query(None), limit: int = Query(50, le=200), offset: int = Query(0), db: Session = Depends(get_db), _=Depends(get_current_user), ): q = db.query(WebsiteSyncLog) if website_id: q = q.filter(WebsiteSyncLog.website_id == website_id) if binding_id: q = q.filter(WebsiteSyncLog.binding_id == binding_id) rows = q.order_by(WebsiteSyncLog.created_at.desc()).offset(offset).limit(limit).all() return [_log_response(row) for row in rows]