2934473770
- UpstreamClient & Sub2ApiWebsiteClient: add __enter__/__exit__ - Convert all call sites to `with Client(...) as c:` pattern - Remove unused `upstream_name`/`upstream_base_url` locals in scheduler - Fix stale _decimal_str→decimal_string in _rate_from_group
313 lines
12 KiB
Python
313 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import List
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.database import get_db
|
|
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
|
|
from app.schemas.website import (
|
|
BindingCreate,
|
|
BindingResponse,
|
|
BindingUpdate,
|
|
TestResult,
|
|
WebsiteCreate,
|
|
WebsiteGroupResponse,
|
|
WebsiteResponse,
|
|
WebsiteSyncLogResponse,
|
|
WebsiteUpdate,
|
|
)
|
|
from app.services.website_client import Sub2ApiWebsiteClient
|
|
from app.services.website_sync import binding_sources, sync_binding
|
|
from app.utils.auth import get_current_user
|
|
|
|
router = APIRouter(tags=["websites"])
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MASK = "***"
|
|
SECRET_KEYS = {"password", "token", "key", "secret", "api_key"}
|
|
ALGORITHMS = {"max_plus_percent", "average_plus_percent", "min_plus_percent"}
|
|
|
|
|
|
def _mask(cfg: dict) -> dict:
|
|
masked = {}
|
|
for key, value in cfg.items():
|
|
masked[key] = MASK if key.lower() in SECRET_KEYS and value else value
|
|
return masked
|
|
|
|
|
|
def _website_response(row: Website) -> WebsiteResponse:
|
|
return WebsiteResponse(
|
|
id=row.id,
|
|
name=row.name,
|
|
site_type=row.site_type,
|
|
base_url=row.base_url,
|
|
api_prefix=row.api_prefix,
|
|
auth_type=row.auth_type,
|
|
auth_config_masked=_mask(json.loads(row.auth_config_json or "{}")),
|
|
groups_endpoint=row.groups_endpoint,
|
|
group_update_endpoint=row.group_update_endpoint,
|
|
enabled=row.enabled,
|
|
auto_sync_enabled=row.auto_sync_enabled,
|
|
timeout_seconds=row.timeout_seconds,
|
|
last_status=row.last_status,
|
|
last_checked_at=row.last_checked_at,
|
|
last_error=row.last_error,
|
|
created_at=row.created_at,
|
|
updated_at=row.updated_at,
|
|
)
|
|
|
|
|
|
def _binding_response(db: Session, row: WebsiteGroupBinding) -> BindingResponse:
|
|
website = db.query(Website).filter(Website.id == row.website_id).first()
|
|
return BindingResponse(
|
|
id=row.id,
|
|
website_id=row.website_id,
|
|
website_name=website.name if website else "",
|
|
target_group_id=row.target_group_id,
|
|
target_group_name=row.target_group_name,
|
|
source_groups=binding_sources(row),
|
|
percent=float(row.percent or 0),
|
|
algorithm=row.algorithm,
|
|
enabled=row.enabled,
|
|
created_at=row.created_at,
|
|
updated_at=row.updated_at,
|
|
)
|
|
|
|
|
|
def _log_response(row: WebsiteSyncLog) -> WebsiteSyncLogResponse:
|
|
return WebsiteSyncLogResponse(
|
|
id=row.id,
|
|
website_id=row.website_id,
|
|
binding_id=row.binding_id,
|
|
target_group_id=row.target_group_id,
|
|
target_group_name=row.target_group_name,
|
|
algorithm=row.algorithm,
|
|
percent=float(row.percent or 0),
|
|
source_rates=json.loads(row.source_rates_json or "[]"),
|
|
old_rate=row.old_rate,
|
|
new_rate=row.new_rate,
|
|
status=row.status,
|
|
message=row.message,
|
|
created_at=row.created_at,
|
|
)
|
|
|
|
|
|
def _ensure_unique_target(db: Session, website_id: int, target_group_id: str, exclude_id: int | None = None) -> None:
|
|
q = db.query(WebsiteGroupBinding).filter(
|
|
WebsiteGroupBinding.website_id == website_id,
|
|
WebsiteGroupBinding.target_group_id == target_group_id,
|
|
)
|
|
if exclude_id is not None:
|
|
q = q.filter(WebsiteGroupBinding.id != exclude_id)
|
|
if q.first():
|
|
raise HTTPException(400, "同一目标网站分组只能维护一条绑定记录")
|
|
|
|
|
|
def _client(row: Website) -> Sub2ApiWebsiteClient:
|
|
return Sub2ApiWebsiteClient(
|
|
base_url=row.base_url,
|
|
api_prefix=row.api_prefix,
|
|
auth_type=row.auth_type,
|
|
auth_config=json.loads(row.auth_config_json or "{}"),
|
|
timeout=float(row.timeout_seconds),
|
|
)
|
|
|
|
|
|
@router.get("/api/websites", response_model=List[WebsiteResponse])
|
|
def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()]
|
|
|
|
|
|
@router.post("/api/websites", response_model=WebsiteResponse, status_code=201)
|
|
def create_website(body: WebsiteCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
if body.site_type != "sub2api":
|
|
raise HTTPException(400, "目前只支持 sub2api")
|
|
row = Website(
|
|
name=body.name,
|
|
site_type=body.site_type,
|
|
base_url=body.base_url.rstrip("/"),
|
|
api_prefix=body.api_prefix,
|
|
auth_type=body.auth_type,
|
|
auth_config_json=json.dumps(body.auth_config, ensure_ascii=False),
|
|
groups_endpoint=body.groups_endpoint,
|
|
group_update_endpoint=body.group_update_endpoint,
|
|
enabled=body.enabled,
|
|
auto_sync_enabled=body.auto_sync_enabled,
|
|
timeout_seconds=body.timeout_seconds,
|
|
)
|
|
db.add(row)
|
|
db.commit()
|
|
db.refresh(row)
|
|
return _website_response(row)
|
|
|
|
|
|
@router.put("/api/websites/{wid}", response_model=WebsiteResponse)
|
|
def update_website(wid: int, body: WebsiteUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
row = db.query(Website).filter(Website.id == wid).first()
|
|
if not row:
|
|
raise HTTPException(404, "website not found")
|
|
data = body.model_dump(exclude_none=True)
|
|
if "site_type" in data and data["site_type"] != "sub2api":
|
|
raise HTTPException(400, "目前只支持 sub2api")
|
|
if "auth_config" in data:
|
|
existing = json.loads(row.auth_config_json or "{}")
|
|
incoming = data.pop("auth_config")
|
|
for key, value in incoming.items():
|
|
if value != MASK:
|
|
existing[key] = value
|
|
row.auth_config_json = json.dumps(existing, ensure_ascii=False)
|
|
if "base_url" in data:
|
|
data["base_url"] = data["base_url"].rstrip("/")
|
|
for key, value in data.items():
|
|
setattr(row, key, value)
|
|
row.updated_at = datetime.now(timezone.utc)
|
|
db.commit()
|
|
db.refresh(row)
|
|
return _website_response(row)
|
|
|
|
|
|
@router.delete("/api/websites/{wid}", status_code=204)
|
|
def delete_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
row = db.query(Website).filter(Website.id == wid).first()
|
|
if not row:
|
|
raise HTTPException(404, "website not found")
|
|
db.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == wid).delete(synchronize_session=False)
|
|
db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.website_id == wid).delete(synchronize_session=False)
|
|
db.delete(row)
|
|
db.commit()
|
|
|
|
|
|
@router.post("/api/websites/{wid}/test", response_model=TestResult)
|
|
def test_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
row = db.query(Website).filter(Website.id == wid).first()
|
|
if not row:
|
|
raise HTTPException(404, "website not found")
|
|
try:
|
|
with _client(row) as c:
|
|
groups = c.get_groups(row.groups_endpoint)
|
|
row.last_status = "healthy"
|
|
row.last_error = None
|
|
row.last_checked_at = datetime.now(timezone.utc)
|
|
db.commit()
|
|
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
|
|
except Exception as exc:
|
|
row.last_status = "unhealthy"
|
|
row.last_error = str(exc)
|
|
row.last_checked_at = datetime.now(timezone.utc)
|
|
db.commit()
|
|
return TestResult(success=False, message="连接失败", detail=str(exc))
|
|
|
|
|
|
@router.get("/api/websites/{wid}/groups", response_model=List[WebsiteGroupResponse])
|
|
def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
row = db.query(Website).filter(Website.id == wid).first()
|
|
if not row:
|
|
raise HTTPException(404, "website not found")
|
|
try:
|
|
with _client(row) as c:
|
|
return c.get_groups(row.groups_endpoint)
|
|
except Exception as exc:
|
|
raise HTTPException(502, str(exc))
|
|
|
|
|
|
@router.get("/api/group-bindings", response_model=List[BindingResponse])
|
|
def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all()
|
|
return [_binding_response(db, row) for row in rows]
|
|
|
|
|
|
@router.post("/api/group-bindings", response_model=BindingResponse, status_code=201)
|
|
def create_binding(body: BindingCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
website = db.query(Website).filter(Website.id == body.website_id).first()
|
|
if not website:
|
|
raise HTTPException(404, "website not found")
|
|
if body.algorithm not in ALGORITHMS:
|
|
raise HTTPException(400, "不支持的算法")
|
|
_ensure_unique_target(db, body.website_id, body.target_group_id)
|
|
row = WebsiteGroupBinding(
|
|
website_id=body.website_id,
|
|
target_group_id=body.target_group_id,
|
|
target_group_name=body.target_group_name,
|
|
source_groups_json=json.dumps([item.model_dump() for item in body.source_groups], ensure_ascii=False),
|
|
percent=str(body.percent),
|
|
algorithm=body.algorithm,
|
|
enabled=body.enabled,
|
|
)
|
|
db.add(row)
|
|
db.commit()
|
|
db.refresh(row)
|
|
try:
|
|
sync_binding(db, row, write=True)
|
|
except Exception as exc:
|
|
logger.exception("initial website sync failed for binding %s: %s", row.id, exc)
|
|
return _binding_response(db, row)
|
|
|
|
|
|
@router.put("/api/group-bindings/{bid}", response_model=BindingResponse)
|
|
def update_binding(bid: int, body: BindingUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
|
|
if not row:
|
|
raise HTTPException(404, "binding not found")
|
|
data = body.model_dump(exclude_none=True)
|
|
if "website_id" in data and not db.query(Website).filter(Website.id == data["website_id"]).first():
|
|
raise HTTPException(404, "website not found")
|
|
if "algorithm" in data and data["algorithm"] not in ALGORITHMS:
|
|
raise HTTPException(400, "不支持的算法")
|
|
next_website_id = int(data.get("website_id", row.website_id))
|
|
next_target_group_id = str(data.get("target_group_id", row.target_group_id))
|
|
_ensure_unique_target(db, next_website_id, next_target_group_id, exclude_id=row.id)
|
|
if "source_groups" in data:
|
|
row.source_groups_json = json.dumps(data.pop("source_groups"), ensure_ascii=False)
|
|
if "percent" in data:
|
|
row.percent = str(data.pop("percent"))
|
|
for key, value in data.items():
|
|
setattr(row, key, value)
|
|
row.updated_at = datetime.now(timezone.utc)
|
|
db.commit()
|
|
db.refresh(row)
|
|
try:
|
|
sync_binding(db, row, write=True)
|
|
except Exception as exc:
|
|
logger.exception("sync failed after updating binding %s: %s", row.id, exc)
|
|
return _binding_response(db, row)
|
|
|
|
|
|
@router.delete("/api/group-bindings/{bid}", status_code=204)
|
|
def delete_binding(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
|
|
if not row:
|
|
raise HTTPException(404, "binding not found")
|
|
db.delete(row)
|
|
db.commit()
|
|
|
|
|
|
@router.post("/api/group-bindings/{bid}/sync-now", response_model=WebsiteSyncLogResponse)
|
|
def sync_now(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
|
|
if not row:
|
|
raise HTTPException(404, "binding not found")
|
|
return _log_response(sync_binding(db, row, write=True))
|
|
|
|
|
|
@router.get("/api/website-sync-logs", response_model=List[WebsiteSyncLogResponse])
|
|
def list_sync_logs(
|
|
website_id: int | None = Query(None),
|
|
binding_id: int | None = Query(None),
|
|
limit: int = Query(50, le=200),
|
|
offset: int = Query(0),
|
|
db: Session = Depends(get_db),
|
|
_=Depends(get_current_user),
|
|
):
|
|
q = db.query(WebsiteSyncLog)
|
|
if website_id:
|
|
q = q.filter(WebsiteSyncLog.website_id == website_id)
|
|
if binding_id:
|
|
q = q.filter(WebsiteSyncLog.binding_id == binding_id)
|
|
rows = q.order_by(WebsiteSyncLog.created_at.desc()).offset(offset).limit(limit).all()
|
|
return [_log_response(row) for row in rows]
|