from __future__ import annotations import json import logging from datetime import datetime, timezone from typing import List from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session from app.database import get_db from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog from app.schemas.website import ( BindingCreate, BindingResponse, BindingUpdate, TestResult, WebsiteCreate, WebsiteGroupResponse, WebsiteResponse, WebsiteSyncLogResponse, WebsiteUpdate, ) from app.services.website_client import Sub2ApiWebsiteClient from app.services.website_sync import binding_sources, sync_binding from app.utils.auth import get_current_user router = APIRouter(tags=["websites"]) logger = logging.getLogger(__name__) MASK = "***" SECRET_KEYS = {"password", "token", "key", "secret", "api_key"} ALGORITHMS = {"max_plus_percent", "average_plus_percent", "min_plus_percent"} def _mask(cfg: dict) -> dict: masked = {} for key, value in cfg.items(): masked[key] = MASK if key.lower() in SECRET_KEYS and value else value return masked def _website_response(row: Website) -> WebsiteResponse: return WebsiteResponse( id=row.id, name=row.name, site_type=row.site_type, base_url=row.base_url, api_prefix=row.api_prefix, auth_type=row.auth_type, auth_config_masked=_mask(json.loads(row.auth_config_json or "{}")), groups_endpoint=row.groups_endpoint, group_update_endpoint=row.group_update_endpoint, enabled=row.enabled, auto_sync_enabled=row.auto_sync_enabled, timeout_seconds=row.timeout_seconds, last_status=row.last_status, last_checked_at=row.last_checked_at, last_error=row.last_error, created_at=row.created_at, updated_at=row.updated_at, ) def _binding_response(db: Session, row: WebsiteGroupBinding) -> BindingResponse: website = db.query(Website).filter(Website.id == row.website_id).first() return BindingResponse( id=row.id, website_id=row.website_id, website_name=website.name if website else "", target_group_id=row.target_group_id, target_group_name=row.target_group_name, source_groups=binding_sources(row), percent=float(row.percent or 0), algorithm=row.algorithm, enabled=row.enabled, created_at=row.created_at, updated_at=row.updated_at, ) def _log_response(row: WebsiteSyncLog) -> WebsiteSyncLogResponse: return WebsiteSyncLogResponse( id=row.id, website_id=row.website_id, binding_id=row.binding_id, target_group_id=row.target_group_id, target_group_name=row.target_group_name, algorithm=row.algorithm, percent=float(row.percent or 0), source_rates=json.loads(row.source_rates_json or "[]"), old_rate=row.old_rate, new_rate=row.new_rate, status=row.status, message=row.message, created_at=row.created_at, ) def _ensure_unique_target(db: Session, website_id: int, target_group_id: str, exclude_id: int | None = None) -> None: q = db.query(WebsiteGroupBinding).filter( WebsiteGroupBinding.website_id == website_id, WebsiteGroupBinding.target_group_id == target_group_id, ) if exclude_id is not None: q = q.filter(WebsiteGroupBinding.id != exclude_id) if q.first(): raise HTTPException(400, "同一目标网站分组只能维护一条绑定记录") def _client(row: Website) -> Sub2ApiWebsiteClient: return Sub2ApiWebsiteClient( base_url=row.base_url, api_prefix=row.api_prefix, auth_type=row.auth_type, auth_config=json.loads(row.auth_config_json or "{}"), timeout=float(row.timeout_seconds), ) @router.get("/api/websites", response_model=List[WebsiteResponse]) def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)): return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()] @router.post("/api/websites", response_model=WebsiteResponse, status_code=201) def create_website(body: WebsiteCreate, db: Session = Depends(get_db), _=Depends(get_current_user)): if body.site_type != "sub2api": raise HTTPException(400, "目前只支持 sub2api") row = Website( name=body.name, site_type=body.site_type, base_url=body.base_url.rstrip("/"), api_prefix=body.api_prefix, auth_type=body.auth_type, auth_config_json=json.dumps(body.auth_config, ensure_ascii=False), groups_endpoint=body.groups_endpoint, group_update_endpoint=body.group_update_endpoint, enabled=body.enabled, auto_sync_enabled=body.auto_sync_enabled, timeout_seconds=body.timeout_seconds, ) db.add(row) db.commit() db.refresh(row) return _website_response(row) @router.put("/api/websites/{wid}", response_model=WebsiteResponse) def update_website(wid: int, body: WebsiteUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(Website).filter(Website.id == wid).first() if not row: raise HTTPException(404, "website not found") data = body.model_dump(exclude_none=True) if "site_type" in data and data["site_type"] != "sub2api": raise HTTPException(400, "目前只支持 sub2api") if "auth_config" in data: existing = json.loads(row.auth_config_json or "{}") incoming = data.pop("auth_config") for key, value in incoming.items(): if value != MASK: existing[key] = value row.auth_config_json = json.dumps(existing, ensure_ascii=False) if "base_url" in data: data["base_url"] = data["base_url"].rstrip("/") for key, value in data.items(): setattr(row, key, value) row.updated_at = datetime.now(timezone.utc) db.commit() db.refresh(row) return _website_response(row) @router.delete("/api/websites/{wid}", status_code=204) def delete_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(Website).filter(Website.id == wid).first() if not row: raise HTTPException(404, "website not found") db.delete(row) db.commit() @router.post("/api/websites/{wid}/test", response_model=TestResult) def test_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(Website).filter(Website.id == wid).first() if not row: raise HTTPException(404, "website not found") try: groups = _client(row).get_groups(row.groups_endpoint) row.last_status = "healthy" row.last_error = None row.last_checked_at = datetime.now(timezone.utc) db.commit() return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组") except Exception as exc: row.last_status = "unhealthy" row.last_error = str(exc) row.last_checked_at = datetime.now(timezone.utc) db.commit() return TestResult(success=False, message="连接失败", detail=str(exc)) @router.get("/api/websites/{wid}/groups", response_model=List[WebsiteGroupResponse]) def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(Website).filter(Website.id == wid).first() if not row: raise HTTPException(404, "website not found") try: return _client(row).get_groups(row.groups_endpoint) except Exception as exc: raise HTTPException(502, str(exc)) @router.get("/api/group-bindings", response_model=List[BindingResponse]) def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)): rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all() return [_binding_response(db, row) for row in rows] @router.post("/api/group-bindings", response_model=BindingResponse, status_code=201) def create_binding(body: BindingCreate, db: Session = Depends(get_db), _=Depends(get_current_user)): website = db.query(Website).filter(Website.id == body.website_id).first() if not website: raise HTTPException(404, "website not found") if body.algorithm not in ALGORITHMS: raise HTTPException(400, "不支持的算法") _ensure_unique_target(db, body.website_id, body.target_group_id) row = WebsiteGroupBinding( website_id=body.website_id, target_group_id=body.target_group_id, target_group_name=body.target_group_name, source_groups_json=json.dumps([item.model_dump() for item in body.source_groups], ensure_ascii=False), percent=str(body.percent), algorithm=body.algorithm, enabled=body.enabled, ) db.add(row) db.commit() db.refresh(row) try: sync_binding(db, row, write=True) except Exception as exc: logger.exception("initial website sync failed for binding %s: %s", row.id, exc) return _binding_response(db, row) @router.put("/api/group-bindings/{bid}", response_model=BindingResponse) def update_binding(bid: int, body: BindingUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first() if not row: raise HTTPException(404, "binding not found") data = body.model_dump(exclude_none=True) if "website_id" in data and not db.query(Website).filter(Website.id == data["website_id"]).first(): raise HTTPException(404, "website not found") if "algorithm" in data and data["algorithm"] not in ALGORITHMS: raise HTTPException(400, "不支持的算法") next_website_id = int(data.get("website_id", row.website_id)) next_target_group_id = str(data.get("target_group_id", row.target_group_id)) _ensure_unique_target(db, next_website_id, next_target_group_id, exclude_id=row.id) if "source_groups" in data: row.source_groups_json = json.dumps(data.pop("source_groups"), ensure_ascii=False) if "percent" in data: row.percent = str(data.pop("percent")) for key, value in data.items(): setattr(row, key, value) row.updated_at = datetime.now(timezone.utc) db.commit() db.refresh(row) try: sync_binding(db, row, write=True) except Exception as exc: logger.exception("sync failed after updating binding %s: %s", row.id, exc) return _binding_response(db, row) @router.delete("/api/group-bindings/{bid}", status_code=204) def delete_binding(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first() if not row: raise HTTPException(404, "binding not found") db.delete(row) db.commit() @router.post("/api/group-bindings/{bid}/sync-now", response_model=WebsiteSyncLogResponse) def sync_now(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)): row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first() if not row: raise HTTPException(404, "binding not found") return _log_response(sync_binding(db, row, write=True)) @router.get("/api/website-sync-logs", response_model=List[WebsiteSyncLogResponse]) def list_sync_logs( website_id: int | None = Query(None), binding_id: int | None = Query(None), limit: int = Query(50, le=200), offset: int = Query(0), db: Session = Depends(get_db), _=Depends(get_current_user), ): q = db.query(WebsiteSyncLog) if website_id: q = q.filter(WebsiteSyncLog.website_id == website_id) if binding_id: q = q.filter(WebsiteSyncLog.binding_id == binding_id) rows = q.order_by(WebsiteSyncLog.created_at.desc()).offset(offset).limit(limit).all() return [_log_response(row) for row in rows]