Files
SmartUp/backend/app/routers/websites.py
T
SmartUp Developer ad16618406 fix: address multiple code audit findings
- CORS: replace wildcard with explicit origin list from CORS_ORIGINS env
- Auth: enforce strong defaults, JWT blacklist (RevokedToken model), login rate limiting
- Auth: validate password length before bcrypt (72-byte limit)
- Scheduler: single-threaded worker to mitigate SQLite write contention
- Scheduler: graceful shutdown (wait=True)
- Snapshots: add prune_snapshots() with configurable retention count
- Storage: isolate localStorage keys via VITE_APP_KEY prefix
- Config: add cors_origins, login_rate_limit, snapshot_retention_count settings
2026-05-17 10:52:18 +08:00

311 lines
12 KiB
Python

from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.schemas.website import (
BindingCreate,
BindingResponse,
BindingUpdate,
TestResult,
WebsiteCreate,
WebsiteGroupResponse,
WebsiteResponse,
WebsiteSyncLogResponse,
WebsiteUpdate,
)
from app.services.website_client import Sub2ApiWebsiteClient
from app.services.website_sync import binding_sources, sync_binding
from app.utils.auth import get_current_user
router = APIRouter(tags=["websites"])
logger = logging.getLogger(__name__)
MASK = "***"
SECRET_KEYS = {"password", "token", "key", "secret", "api_key"}
ALGORITHMS = {"max_plus_percent", "average_plus_percent", "min_plus_percent"}
def _mask(cfg: dict) -> dict:
masked = {}
for key, value in cfg.items():
masked[key] = MASK if key.lower() in SECRET_KEYS and value else value
return masked
def _website_response(row: Website) -> WebsiteResponse:
return WebsiteResponse(
id=row.id,
name=row.name,
site_type=row.site_type,
base_url=row.base_url,
api_prefix=row.api_prefix,
auth_type=row.auth_type,
auth_config_masked=_mask(json.loads(row.auth_config_json or "{}")),
groups_endpoint=row.groups_endpoint,
group_update_endpoint=row.group_update_endpoint,
enabled=row.enabled,
auto_sync_enabled=row.auto_sync_enabled,
timeout_seconds=row.timeout_seconds,
last_status=row.last_status,
last_checked_at=row.last_checked_at,
last_error=row.last_error,
created_at=row.created_at,
updated_at=row.updated_at,
)
def _binding_response(db: Session, row: WebsiteGroupBinding) -> BindingResponse:
website = db.query(Website).filter(Website.id == row.website_id).first()
return BindingResponse(
id=row.id,
website_id=row.website_id,
website_name=website.name if website else "",
target_group_id=row.target_group_id,
target_group_name=row.target_group_name,
source_groups=binding_sources(row),
percent=float(row.percent or 0),
algorithm=row.algorithm,
enabled=row.enabled,
created_at=row.created_at,
updated_at=row.updated_at,
)
def _log_response(row: WebsiteSyncLog) -> WebsiteSyncLogResponse:
return WebsiteSyncLogResponse(
id=row.id,
website_id=row.website_id,
binding_id=row.binding_id,
target_group_id=row.target_group_id,
target_group_name=row.target_group_name,
algorithm=row.algorithm,
percent=float(row.percent or 0),
source_rates=json.loads(row.source_rates_json or "[]"),
old_rate=row.old_rate,
new_rate=row.new_rate,
status=row.status,
message=row.message,
created_at=row.created_at,
)
def _ensure_unique_target(db: Session, website_id: int, target_group_id: str, exclude_id: int | None = None) -> None:
q = db.query(WebsiteGroupBinding).filter(
WebsiteGroupBinding.website_id == website_id,
WebsiteGroupBinding.target_group_id == target_group_id,
)
if exclude_id is not None:
q = q.filter(WebsiteGroupBinding.id != exclude_id)
if q.first():
raise HTTPException(400, "同一目标网站分组只能维护一条绑定记录")
def _client(row: Website) -> Sub2ApiWebsiteClient:
return Sub2ApiWebsiteClient(
base_url=row.base_url,
api_prefix=row.api_prefix,
auth_type=row.auth_type,
auth_config=json.loads(row.auth_config_json or "{}"),
timeout=float(row.timeout_seconds),
)
@router.get("/api/websites", response_model=List[WebsiteResponse])
def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)):
return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()]
@router.post("/api/websites", response_model=WebsiteResponse, status_code=201)
def create_website(body: WebsiteCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
if body.site_type != "sub2api":
raise HTTPException(400, "目前只支持 sub2api")
row = Website(
name=body.name,
site_type=body.site_type,
base_url=body.base_url.rstrip("/"),
api_prefix=body.api_prefix,
auth_type=body.auth_type,
auth_config_json=json.dumps(body.auth_config, ensure_ascii=False),
groups_endpoint=body.groups_endpoint,
group_update_endpoint=body.group_update_endpoint,
enabled=body.enabled,
auto_sync_enabled=body.auto_sync_enabled,
timeout_seconds=body.timeout_seconds,
)
db.add(row)
db.commit()
db.refresh(row)
return _website_response(row)
@router.put("/api/websites/{wid}", response_model=WebsiteResponse)
def update_website(wid: int, body: WebsiteUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(Website).filter(Website.id == wid).first()
if not row:
raise HTTPException(404, "website not found")
data = body.model_dump(exclude_none=True)
if "site_type" in data and data["site_type"] != "sub2api":
raise HTTPException(400, "目前只支持 sub2api")
if "auth_config" in data:
existing = json.loads(row.auth_config_json or "{}")
incoming = data.pop("auth_config")
for key, value in incoming.items():
if value != MASK:
existing[key] = value
row.auth_config_json = json.dumps(existing, ensure_ascii=False)
if "base_url" in data:
data["base_url"] = data["base_url"].rstrip("/")
for key, value in data.items():
setattr(row, key, value)
row.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(row)
return _website_response(row)
@router.delete("/api/websites/{wid}", status_code=204)
def delete_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(Website).filter(Website.id == wid).first()
if not row:
raise HTTPException(404, "website not found")
db.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == wid).delete(synchronize_session=False)
db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.website_id == wid).delete(synchronize_session=False)
db.delete(row)
db.commit()
@router.post("/api/websites/{wid}/test", response_model=TestResult)
def test_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(Website).filter(Website.id == wid).first()
if not row:
raise HTTPException(404, "website not found")
try:
groups = _client(row).get_groups(row.groups_endpoint)
row.last_status = "healthy"
row.last_error = None
row.last_checked_at = datetime.now(timezone.utc)
db.commit()
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
except Exception as exc:
row.last_status = "unhealthy"
row.last_error = str(exc)
row.last_checked_at = datetime.now(timezone.utc)
db.commit()
return TestResult(success=False, message="连接失败", detail=str(exc))
@router.get("/api/websites/{wid}/groups", response_model=List[WebsiteGroupResponse])
def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(Website).filter(Website.id == wid).first()
if not row:
raise HTTPException(404, "website not found")
try:
return _client(row).get_groups(row.groups_endpoint)
except Exception as exc:
raise HTTPException(502, str(exc))
@router.get("/api/group-bindings", response_model=List[BindingResponse])
def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)):
rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all()
return [_binding_response(db, row) for row in rows]
@router.post("/api/group-bindings", response_model=BindingResponse, status_code=201)
def create_binding(body: BindingCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
website = db.query(Website).filter(Website.id == body.website_id).first()
if not website:
raise HTTPException(404, "website not found")
if body.algorithm not in ALGORITHMS:
raise HTTPException(400, "不支持的算法")
_ensure_unique_target(db, body.website_id, body.target_group_id)
row = WebsiteGroupBinding(
website_id=body.website_id,
target_group_id=body.target_group_id,
target_group_name=body.target_group_name,
source_groups_json=json.dumps([item.model_dump() for item in body.source_groups], ensure_ascii=False),
percent=str(body.percent),
algorithm=body.algorithm,
enabled=body.enabled,
)
db.add(row)
db.commit()
db.refresh(row)
try:
sync_binding(db, row, write=True)
except Exception as exc:
logger.exception("initial website sync failed for binding %s: %s", row.id, exc)
return _binding_response(db, row)
@router.put("/api/group-bindings/{bid}", response_model=BindingResponse)
def update_binding(bid: int, body: BindingUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
if not row:
raise HTTPException(404, "binding not found")
data = body.model_dump(exclude_none=True)
if "website_id" in data and not db.query(Website).filter(Website.id == data["website_id"]).first():
raise HTTPException(404, "website not found")
if "algorithm" in data and data["algorithm"] not in ALGORITHMS:
raise HTTPException(400, "不支持的算法")
next_website_id = int(data.get("website_id", row.website_id))
next_target_group_id = str(data.get("target_group_id", row.target_group_id))
_ensure_unique_target(db, next_website_id, next_target_group_id, exclude_id=row.id)
if "source_groups" in data:
row.source_groups_json = json.dumps(data.pop("source_groups"), ensure_ascii=False)
if "percent" in data:
row.percent = str(data.pop("percent"))
for key, value in data.items():
setattr(row, key, value)
row.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(row)
try:
sync_binding(db, row, write=True)
except Exception as exc:
logger.exception("sync failed after updating binding %s: %s", row.id, exc)
return _binding_response(db, row)
@router.delete("/api/group-bindings/{bid}", status_code=204)
def delete_binding(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
if not row:
raise HTTPException(404, "binding not found")
db.delete(row)
db.commit()
@router.post("/api/group-bindings/{bid}/sync-now", response_model=WebsiteSyncLogResponse)
def sync_now(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
if not row:
raise HTTPException(404, "binding not found")
return _log_response(sync_binding(db, row, write=True))
@router.get("/api/website-sync-logs", response_model=List[WebsiteSyncLogResponse])
def list_sync_logs(
website_id: int | None = Query(None),
binding_id: int | None = Query(None),
limit: int = Query(50, le=200),
offset: int = Query(0),
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
q = db.query(WebsiteSyncLog)
if website_id:
q = q.filter(WebsiteSyncLog.website_id == website_id)
if binding_id:
q = q.filter(WebsiteSyncLog.binding_id == binding_id)
rows = q.order_by(WebsiteSyncLog.created_at.desc()).offset(offset).limit(limit).all()
return [_log_response(row) for row in rows]