283 lines
9.9 KiB
Python
283 lines
9.9 KiB
Python
"""Upstream management CRUD + test + check-now + snapshots."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from typing import List
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.database import get_db
|
|
from app.models.admin_user import AdminUser
|
|
from app.models.upstream import Upstream
|
|
from app.models.snapshot import UpstreamRateSnapshot
|
|
from app.schemas.upstream import (
|
|
UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult
|
|
)
|
|
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot
|
|
from app.services.snapshot_service import diff_snapshots
|
|
from app.services import scheduler as sched_svc
|
|
from app.services import webhook_service
|
|
from app.utils.auth import get_current_user
|
|
|
|
router = APIRouter(prefix="/api/upstreams", tags=["upstreams"])
|
|
|
|
MASK = "***"
|
|
SECRET_KEYS = {"password", "token", "key", "secret"}
|
|
|
|
|
|
def _mask_auth_config(auth_type: str, cfg: dict) -> dict:
|
|
masked = {}
|
|
for k, v in cfg.items():
|
|
if k.lower() in SECRET_KEYS and v:
|
|
masked[k] = MASK
|
|
else:
|
|
masked[k] = v
|
|
return masked
|
|
|
|
|
|
def _to_response(u: Upstream) -> UpstreamResponse:
|
|
cfg = json.loads(u.auth_config_json or "{}")
|
|
return UpstreamResponse(
|
|
id=u.id,
|
|
name=u.name,
|
|
base_url=u.base_url,
|
|
api_prefix=u.api_prefix,
|
|
auth_type=u.auth_type,
|
|
auth_config_masked=_mask_auth_config(u.auth_type, cfg),
|
|
rate_endpoint=u.rate_endpoint,
|
|
groups_endpoint=u.groups_endpoint,
|
|
enabled=u.enabled,
|
|
check_interval_seconds=u.check_interval_seconds,
|
|
timeout_seconds=u.timeout_seconds,
|
|
last_status=u.last_status,
|
|
last_checked_at=u.last_checked_at,
|
|
last_error=u.last_error,
|
|
created_at=u.created_at,
|
|
updated_at=u.updated_at,
|
|
)
|
|
|
|
|
|
@router.get("", response_model=List[UpstreamResponse])
|
|
def list_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()]
|
|
|
|
|
|
@router.post("", response_model=UpstreamResponse, status_code=201)
|
|
def create_upstream(
|
|
body: UpstreamCreate,
|
|
db: Session = Depends(get_db),
|
|
_=Depends(get_current_user),
|
|
):
|
|
u = Upstream(
|
|
name=body.name,
|
|
base_url=body.base_url.rstrip("/"),
|
|
api_prefix=body.api_prefix,
|
|
auth_type=body.auth_type,
|
|
auth_config_json=json.dumps(body.auth_config, ensure_ascii=False),
|
|
rate_endpoint=body.rate_endpoint,
|
|
groups_endpoint=body.groups_endpoint,
|
|
enabled=body.enabled,
|
|
check_interval_seconds=body.check_interval_seconds,
|
|
timeout_seconds=body.timeout_seconds,
|
|
)
|
|
db.add(u)
|
|
db.commit()
|
|
db.refresh(u)
|
|
sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled)
|
|
return _to_response(u)
|
|
|
|
|
|
@router.get("/{uid}", response_model=UpstreamResponse)
|
|
def get_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
|
if not u:
|
|
raise HTTPException(404, "upstream not found")
|
|
return _to_response(u)
|
|
|
|
|
|
@router.put("/{uid}", response_model=UpstreamResponse)
|
|
def update_upstream(
|
|
uid: int,
|
|
body: UpstreamUpdate,
|
|
db: Session = Depends(get_db),
|
|
_=Depends(get_current_user),
|
|
):
|
|
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
|
if not u:
|
|
raise HTTPException(404, "upstream not found")
|
|
data = body.model_dump(exclude_none=True)
|
|
if "auth_config" in data:
|
|
# merge with existing config to avoid overwriting masked fields
|
|
existing = json.loads(u.auth_config_json or "{}")
|
|
incoming = data.pop("auth_config")
|
|
for k, v in incoming.items():
|
|
if v != MASK: # don't overwrite with mask placeholder
|
|
existing[k] = v
|
|
u.auth_config_json = json.dumps(existing, ensure_ascii=False)
|
|
if "base_url" in data:
|
|
data["base_url"] = data["base_url"].rstrip("/")
|
|
for k, v in data.items():
|
|
setattr(u, k, v)
|
|
u.updated_at = datetime.now(timezone.utc)
|
|
db.commit()
|
|
db.refresh(u)
|
|
sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled)
|
|
return _to_response(u)
|
|
|
|
|
|
@router.delete("/{uid}", status_code=204)
|
|
def delete_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
|
if not u:
|
|
raise HTTPException(404, "upstream not found")
|
|
sched_svc.refresh_upstream(uid, 0, False) # remove job
|
|
db.delete(u)
|
|
db.commit()
|
|
|
|
|
|
@router.post("/{uid}/test", response_model=TestResult)
|
|
def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
|
if not u:
|
|
raise HTTPException(404, "upstream not found")
|
|
auth_config = json.loads(u.auth_config_json or "{}")
|
|
client = UpstreamClient(
|
|
base_url=u.base_url,
|
|
api_prefix=u.api_prefix,
|
|
auth_type=u.auth_type,
|
|
auth_config=auth_config,
|
|
timeout=float(u.timeout_seconds),
|
|
)
|
|
try:
|
|
client.login()
|
|
groups = client.get_available_groups(u.groups_endpoint)
|
|
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
|
|
except Exception as exc:
|
|
return TestResult(success=False, message="连接失败", detail=str(exc))
|
|
|
|
|
|
@router.post("/{uid}/check-now", response_model=TestResult)
|
|
def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
|
if not u:
|
|
raise HTTPException(404, "upstream not found")
|
|
auth_config = json.loads(u.auth_config_json or "{}")
|
|
client = UpstreamClient(
|
|
base_url=u.base_url,
|
|
api_prefix=u.api_prefix,
|
|
auth_type=u.auth_type,
|
|
auth_config=auth_config,
|
|
timeout=float(u.timeout_seconds),
|
|
)
|
|
try:
|
|
client.login()
|
|
groups = client.get_available_groups(u.groups_endpoint)
|
|
raw_rates = client.get_group_rates(u.rate_endpoint)
|
|
snapshot = build_snapshot(u.id, u.base_url, u.api_prefix, groups, raw_rates)
|
|
except Exception as exc:
|
|
u.consecutive_failures = (u.consecutive_failures or 0) + 1
|
|
u.last_error = str(exc)
|
|
u.last_checked_at = datetime.now(timezone.utc)
|
|
db.commit()
|
|
return TestResult(success=False, message="检测失败", detail=str(exc))
|
|
|
|
prev_row = (
|
|
db.query(UpstreamRateSnapshot)
|
|
.filter(UpstreamRateSnapshot.upstream_id == uid)
|
|
.order_by(UpstreamRateSnapshot.captured_at.desc())
|
|
.first()
|
|
)
|
|
previous = json.loads(prev_row.snapshot_json) if prev_row else None
|
|
changes = diff_snapshots(previous, snapshot)
|
|
|
|
new_row = UpstreamRateSnapshot(
|
|
upstream_id=uid,
|
|
snapshot_json=json.dumps(snapshot, ensure_ascii=False),
|
|
captured_at=datetime.now(timezone.utc),
|
|
)
|
|
db.add(new_row)
|
|
was_unhealthy = u.last_status == "unhealthy"
|
|
u.last_status = "healthy"
|
|
u.last_checked_at = datetime.now(timezone.utc)
|
|
u.last_error = None
|
|
u.consecutive_failures = 0
|
|
db.commit()
|
|
|
|
if was_unhealthy:
|
|
webhook_service.send_status_event(db, u.id, u.name, u.base_url, "upstream_recovered")
|
|
if changes:
|
|
webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes)
|
|
|
|
msg = f"检测成功,{len(groups)} 个分组"
|
|
if changes:
|
|
msg += f",发现 {len(changes)} 处倍率变化"
|
|
elif previous is None:
|
|
msg += ",初始化快照完成"
|
|
else:
|
|
msg += ",无变化"
|
|
return TestResult(success=True, message=msg)
|
|
|
|
|
|
@router.get("/{uid}/snapshots/latest", response_model=SnapshotResponse)
|
|
def latest_snapshot(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
|
row = (
|
|
db.query(UpstreamRateSnapshot)
|
|
.filter(UpstreamRateSnapshot.upstream_id == uid)
|
|
.order_by(UpstreamRateSnapshot.captured_at.desc())
|
|
.first()
|
|
)
|
|
if not row:
|
|
raise HTTPException(404, "no snapshot found")
|
|
return SnapshotResponse(
|
|
id=row.id,
|
|
upstream_id=row.upstream_id,
|
|
snapshot=json.loads(row.snapshot_json),
|
|
captured_at=row.captured_at,
|
|
)
|
|
|
|
|
|
from fastapi import Query as QueryParam
|
|
|
|
@router.get("/{uid}/snapshots", response_model=List[SnapshotResponse])
|
|
def list_snapshots(
|
|
uid: int,
|
|
limit: int = QueryParam(20, le=100),
|
|
offset: int = QueryParam(0),
|
|
db: Session = Depends(get_db),
|
|
_=Depends(get_current_user),
|
|
):
|
|
"""Return paginated snapshot history with diff vs previous snapshot embedded."""
|
|
rows = (
|
|
db.query(UpstreamRateSnapshot)
|
|
.filter(UpstreamRateSnapshot.upstream_id == uid)
|
|
.order_by(UpstreamRateSnapshot.captured_at.desc())
|
|
.offset(offset)
|
|
.limit(limit + 1) # fetch one extra to get the "previous" for diffing
|
|
.all()
|
|
)
|
|
# We need the snapshot just before each one to compute changes count.
|
|
# rows are desc order; rows[i+1] is older than rows[i]
|
|
results = []
|
|
for i, row in enumerate(rows[:limit]):
|
|
snap = json.loads(row.snapshot_json)
|
|
# try to diff against the next row (which is older)
|
|
changes_count: int | None = None
|
|
if i + 1 < len(rows):
|
|
older = json.loads(rows[i + 1].snapshot_json)
|
|
from app.services.snapshot_service import diff_snapshots
|
|
ch = diff_snapshots(older, snap)
|
|
changes_count = len(ch)
|
|
groups_count = len(snap.get("groups", {}))
|
|
# embed lightweight summary into snapshot dict so frontend can display it
|
|
snap["_groups_count"] = groups_count
|
|
snap["_changes_count"] = changes_count # None means first ever snapshot
|
|
results.append(SnapshotResponse(
|
|
id=row.id,
|
|
upstream_id=row.upstream_id,
|
|
snapshot=snap,
|
|
captured_at=row.captured_at,
|
|
))
|
|
return results
|