594 lines
23 KiB
Python
594 lines
23 KiB
Python
"""Upstream management CRUD + test + check-now + snapshots."""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from datetime import datetime, timezone
|
||
from typing import Any, List
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.database import get_db
|
||
from app.models.admin_user import AdminUser
|
||
from app.models.upstream import Upstream
|
||
from app.models.upstream_key import UpstreamGeneratedKey
|
||
from app.models.snapshot import UpstreamRateSnapshot
|
||
from app.schemas.upstream import (
|
||
GenerateKeysByGroupsRequest,
|
||
GenerateKeysByGroupsResponse,
|
||
GeneratedUpstreamKeyResponse,
|
||
UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult
|
||
)
|
||
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret, _extract_key_value
|
||
from app.services.snapshot_service import diff_snapshots
|
||
from app.services import scheduler as sched_svc
|
||
from app.services import webhook_service
|
||
from app.services import website_sync
|
||
from app.utils.auth import get_current_user
|
||
|
||
router = APIRouter(prefix="/api/upstreams", tags=["upstreams"])
|
||
|
||
MASK = "***"
|
||
SECRET_KEYS = {"password", "token", "key", "secret"}
|
||
|
||
|
||
def _group_id(group: dict) -> str:
|
||
for key in ("id", "group_id", "groupId"):
|
||
value = group.get(key)
|
||
if value is not None:
|
||
return str(value)
|
||
return str(group.get("name") or group.get("group_name") or "")
|
||
|
||
|
||
def _group_name(group: dict, gid: str) -> str:
|
||
return str(group.get("name") or group.get("group_name") or gid)
|
||
|
||
|
||
def _key_response(row: UpstreamGeneratedKey, include_value: bool = False) -> GeneratedUpstreamKeyResponse:
|
||
return GeneratedUpstreamKeyResponse(
|
||
id=row.id,
|
||
upstream_id=row.upstream_id,
|
||
group_id=row.group_id,
|
||
group_name=row.group_name,
|
||
key_id=row.key_id,
|
||
key_name=row.key_name,
|
||
key_value=row.key_value if include_value else None,
|
||
masked_key=row.masked_key,
|
||
status=row.status,
|
||
error=row.error,
|
||
imported_website_id=row.imported_website_id,
|
||
imported_account_id=row.imported_account_id,
|
||
imported_at=row.imported_at,
|
||
created_at=row.created_at,
|
||
updated_at=row.updated_at,
|
||
has_key_value=bool(row.key_value),
|
||
)
|
||
|
||
|
||
def _mask_auth_config(auth_type: str, cfg: dict) -> dict:
|
||
masked = {}
|
||
for k, v in cfg.items():
|
||
if k.lower() in SECRET_KEYS and v:
|
||
masked[k] = MASK
|
||
else:
|
||
masked[k] = v
|
||
return masked
|
||
|
||
|
||
def _extract_plaintext_key(payload: dict[str, Any] | None) -> str:
|
||
if not isinstance(payload, dict):
|
||
return ""
|
||
key_value = _extract_key_value(payload)
|
||
if not key_value:
|
||
return ""
|
||
if "*" in key_value:
|
||
return ""
|
||
return key_value
|
||
|
||
|
||
def _to_response(u: Upstream) -> UpstreamResponse:
|
||
cfg = json.loads(u.auth_config_json or "{}")
|
||
return UpstreamResponse(
|
||
id=u.id,
|
||
name=u.name,
|
||
base_url=u.base_url,
|
||
api_prefix=u.api_prefix,
|
||
auth_type=u.auth_type,
|
||
auth_config_masked=_mask_auth_config(u.auth_type, cfg),
|
||
rate_endpoint=u.rate_endpoint,
|
||
groups_endpoint=u.groups_endpoint,
|
||
enabled=u.enabled,
|
||
check_interval_seconds=u.check_interval_seconds,
|
||
timeout_seconds=u.timeout_seconds,
|
||
last_status=u.last_status,
|
||
last_checked_at=u.last_checked_at,
|
||
last_error=u.last_error,
|
||
balance=u.balance,
|
||
balance_updated_at=u.balance_updated_at,
|
||
balance_endpoint=u.balance_endpoint or "",
|
||
balance_response_path=u.balance_response_path or "",
|
||
balance_divisor=u.balance_divisor or 1.0,
|
||
balance_alert_threshold=u.balance_alert_threshold,
|
||
created_at=u.created_at,
|
||
updated_at=u.updated_at,
|
||
)
|
||
|
||
|
||
@router.get("", response_model=List[UpstreamResponse])
|
||
def list_upstreams(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
return [_to_response(u) for u in db.query(Upstream).order_by(Upstream.id).all()]
|
||
|
||
|
||
@router.get("/{uid}/generated-keys", response_model=List[GeneratedUpstreamKeyResponse])
|
||
def list_generated_keys(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
if not db.query(Upstream.id).filter(Upstream.id == uid).first():
|
||
raise HTTPException(404, "upstream not found")
|
||
rows = (
|
||
db.query(UpstreamGeneratedKey)
|
||
.filter(UpstreamGeneratedKey.upstream_id == uid)
|
||
.order_by(UpstreamGeneratedKey.id.desc())
|
||
.limit(200)
|
||
.all()
|
||
)
|
||
return [_key_response(row) for row in rows]
|
||
|
||
|
||
_generate_key_lock = __import__("threading").Lock()
|
||
|
||
|
||
def _ensure_group_key(
|
||
db: Session,
|
||
client: UpstreamClient,
|
||
upstream: Upstream,
|
||
group: dict[str, Any],
|
||
prefix: str,
|
||
body: GenerateKeysByGroupsRequest,
|
||
) -> GeneratedUpstreamKeyResponse:
|
||
"""确保一个上游分组有一个 SmartUp 前缀 Key:存在则 upsert,不存在则创建。"""
|
||
gid = _group_id(group)
|
||
gname = _group_name(group, gid)
|
||
# 使用稳定的 upstream_id + group_id 而非可变名称,避免因改名产生重复
|
||
# 可读 Key 名:{prefix}-{upstream.id}-{安全的分组名}-{group_id}
|
||
safe_group_name = re.sub(r"[^a-zA-Z0-9\u4e00-\u9fff_-]", "", gname)[:30] if gname else gid
|
||
stable_name = f"{prefix}-{upstream.id}-{safe_group_name}-{gid}"
|
||
|
||
with _generate_key_lock:
|
||
try:
|
||
# 1. 先查本地是否已有该分组的托管 Key(兼容迁移前无 managed_prefix 的记录)
|
||
row = (
|
||
db.query(UpstreamGeneratedKey)
|
||
.filter(
|
||
UpstreamGeneratedKey.upstream_id == upstream.id,
|
||
UpstreamGeneratedKey.group_id == gid,
|
||
(UpstreamGeneratedKey.managed_prefix == prefix)
|
||
| ((UpstreamGeneratedKey.managed_prefix.is_(None))
|
||
& UpstreamGeneratedKey.key_name.like(f"{prefix}-%")),
|
||
)
|
||
.first()
|
||
)
|
||
if row and row.key_id:
|
||
# 本地已有记录,检查远端是否仍存在
|
||
try:
|
||
existing = client.find_smartup_group_key(gid, stable_name, prefix)
|
||
except Exception:
|
||
existing = None
|
||
if existing:
|
||
key_id = str(existing.get("id") or "")
|
||
key_value = _extract_plaintext_key(existing)
|
||
masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "")
|
||
row.key_id = key_id or row.key_id
|
||
if key_value:
|
||
row.key_value = key_value
|
||
row.masked_key = masked
|
||
elif masked:
|
||
row.masked_key = str(masked)
|
||
row.raw_json = json.dumps(existing, ensure_ascii=False)
|
||
row.status = "exists"
|
||
row.updated_at = datetime.now(timezone.utc)
|
||
db.add(row)
|
||
db.commit()
|
||
db.refresh(row)
|
||
return _key_response(row, include_value=False)
|
||
# 远端不存在,需要重新创建
|
||
row.status = "replaced"
|
||
|
||
# 2. 查远端是否有同名 Key(防止并发时另一个请求已创建)
|
||
existing = client.find_smartup_group_key(gid, stable_name, prefix)
|
||
if existing:
|
||
key_id = str(existing.get("id") or "")
|
||
key_value = _extract_plaintext_key(existing)
|
||
masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "")
|
||
if row:
|
||
row.key_id = key_id or row.key_id
|
||
if key_value:
|
||
row.key_value = key_value
|
||
row.masked_key = masked
|
||
elif masked:
|
||
row.masked_key = str(masked)
|
||
row.raw_json = json.dumps(existing, ensure_ascii=False)
|
||
row.status = "exists"
|
||
row.updated_at = datetime.now(timezone.utc)
|
||
else:
|
||
row = UpstreamGeneratedKey(
|
||
upstream_id=upstream.id,
|
||
group_id=gid,
|
||
group_name=gname,
|
||
key_id=key_id or None,
|
||
key_name=stable_name,
|
||
key_value=key_value or "",
|
||
masked_key=masked,
|
||
raw_json=json.dumps(existing, ensure_ascii=False),
|
||
managed_prefix=prefix,
|
||
status="exists",
|
||
)
|
||
db.add(row)
|
||
db.commit()
|
||
db.refresh(row)
|
||
return _key_response(row, include_value=False)
|
||
|
||
# 3. 远端不存在,创建新 Key
|
||
created = client.create_api_key(
|
||
stable_name,
|
||
gid,
|
||
quota=body.quota,
|
||
expires_in_days=body.expires_in_days,
|
||
rate_limit_5h=body.rate_limit_5h,
|
||
rate_limit_1d=body.rate_limit_1d,
|
||
rate_limit_7d=body.rate_limit_7d,
|
||
endpoint=body.endpoint,
|
||
)
|
||
if row:
|
||
# 复用旧行
|
||
row.key_id = created.get("id") or None
|
||
row.key_name = stable_name
|
||
row.key_value = created["key"]
|
||
row.masked_key = created.get("masked_key") or mask_secret(created["key"])
|
||
row.raw_json = json.dumps(created.get("raw") or {}, ensure_ascii=False)
|
||
row.managed_prefix = prefix
|
||
row.status = "created"
|
||
row.error = None
|
||
else:
|
||
row = UpstreamGeneratedKey(
|
||
upstream_id=upstream.id,
|
||
group_id=gid,
|
||
group_name=gname,
|
||
key_id=created.get("id") or None,
|
||
key_name=stable_name,
|
||
key_value=created["key"],
|
||
masked_key=created.get("masked_key") or mask_secret(created["key"]),
|
||
raw_json=json.dumps(created.get("raw") or {}, ensure_ascii=False),
|
||
managed_prefix=prefix,
|
||
status="created",
|
||
)
|
||
db.add(row)
|
||
db.commit()
|
||
db.refresh(row)
|
||
return _key_response(row, include_value=True)
|
||
except Exception as exc:
|
||
logger.exception("ensure group key failed for upstream=%s group=%s", upstream.id, gid)
|
||
return GeneratedUpstreamKeyResponse(
|
||
upstream_id=upstream.id,
|
||
group_id=gid,
|
||
group_name=gname,
|
||
key_name=stable_name,
|
||
status="failed",
|
||
error=str(exc),
|
||
)
|
||
|
||
|
||
@router.post("/{uid}/keys/generate-by-groups", response_model=GenerateKeysByGroupsResponse)
|
||
def generate_keys_by_groups(
|
||
uid: int,
|
||
body: GenerateKeysByGroupsRequest,
|
||
db: Session = Depends(get_db),
|
||
_=Depends(get_current_user),
|
||
):
|
||
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
||
if not u:
|
||
raise HTTPException(404, "upstream not found")
|
||
if u.api_prefix.strip("/") != "api/v1":
|
||
raise HTTPException(400, "首版仅支持 Sub2API 上游(API Prefix 应为 /api/v1)")
|
||
|
||
auth_config = json.loads(u.auth_config_json or "{}")
|
||
selected = set(body.group_ids)
|
||
prefix = body.name_prefix
|
||
results: list[GeneratedUpstreamKeyResponse] = []
|
||
with UpstreamClient(
|
||
base_url=u.base_url,
|
||
api_prefix=u.api_prefix,
|
||
auth_type=u.auth_type,
|
||
auth_config=auth_config,
|
||
timeout=float(u.timeout_seconds),
|
||
) as client:
|
||
try:
|
||
client.login()
|
||
groups = client.get_available_groups(u.groups_endpoint)
|
||
except Exception as exc:
|
||
raise HTTPException(502, str(exc))
|
||
|
||
for group in groups:
|
||
gid = _group_id(group)
|
||
if not gid or (selected and gid not in selected):
|
||
continue
|
||
result = _ensure_group_key(db, client, u, group, prefix, body)
|
||
results.append(result)
|
||
|
||
created = len([item for item in results if item.status == "created"])
|
||
existed = len([item for item in results if item.status == "exists"])
|
||
total = len(results)
|
||
msg_parts = []
|
||
if created:
|
||
msg_parts.append(f"新创建 {created}")
|
||
if existed:
|
||
msg_parts.append(f"已存在 {existed}")
|
||
msg = "、".join(msg_parts) + f" / 共 {total} 个分组" if msg_parts else f"共处理 {total} 个分组"
|
||
return GenerateKeysByGroupsResponse(
|
||
success=total > 0 and all(item.status != "failed" for item in results),
|
||
message=msg,
|
||
items=results,
|
||
)
|
||
|
||
|
||
@router.post("", response_model=UpstreamResponse, status_code=201)
|
||
def create_upstream(
|
||
body: UpstreamCreate,
|
||
db: Session = Depends(get_db),
|
||
_=Depends(get_current_user),
|
||
):
|
||
u = Upstream(
|
||
name=body.name,
|
||
base_url=body.base_url.rstrip("/"),
|
||
api_prefix=body.api_prefix,
|
||
auth_type=body.auth_type,
|
||
auth_config_json=json.dumps(body.auth_config, ensure_ascii=False),
|
||
rate_endpoint=body.rate_endpoint,
|
||
groups_endpoint=body.groups_endpoint,
|
||
enabled=body.enabled,
|
||
check_interval_seconds=body.check_interval_seconds,
|
||
timeout_seconds=body.timeout_seconds,
|
||
balance_endpoint=body.balance_endpoint,
|
||
balance_response_path=body.balance_response_path,
|
||
balance_divisor=body.balance_divisor,
|
||
balance_alert_threshold=body.balance_alert_threshold,
|
||
)
|
||
db.add(u)
|
||
db.commit()
|
||
db.refresh(u)
|
||
sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled)
|
||
return _to_response(u)
|
||
|
||
|
||
@router.get("/{uid}", response_model=UpstreamResponse)
|
||
def get_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
||
if not u:
|
||
raise HTTPException(404, "upstream not found")
|
||
return _to_response(u)
|
||
|
||
|
||
@router.put("/{uid}", response_model=UpstreamResponse)
|
||
def update_upstream(
|
||
uid: int,
|
||
body: UpstreamUpdate,
|
||
db: Session = Depends(get_db),
|
||
_=Depends(get_current_user),
|
||
):
|
||
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
||
if not u:
|
||
raise HTTPException(404, "upstream not found")
|
||
data = body.model_dump(exclude_none=True)
|
||
if "auth_config" in data:
|
||
# merge with existing config to avoid overwriting masked fields
|
||
existing = json.loads(u.auth_config_json or "{}")
|
||
incoming = data.pop("auth_config")
|
||
for k, v in incoming.items():
|
||
if v != MASK: # don't overwrite with mask placeholder
|
||
existing[k] = v
|
||
u.auth_config_json = json.dumps(existing, ensure_ascii=False)
|
||
if "base_url" in data:
|
||
data["base_url"] = data["base_url"].rstrip("/")
|
||
for k, v in data.items():
|
||
setattr(u, k, v)
|
||
# Reset failure counter on any update — the user may have fixed the issue
|
||
u.consecutive_failures = 0
|
||
u.updated_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
db.refresh(u)
|
||
sched_svc.refresh_upstream(u.id, u.check_interval_seconds, u.enabled)
|
||
return _to_response(u)
|
||
|
||
|
||
@router.delete("/{uid}", status_code=204)
|
||
def delete_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
||
if not u:
|
||
raise HTTPException(404, "upstream not found")
|
||
sched_svc.refresh_upstream(uid, 0, False) # remove job
|
||
db.delete(u)
|
||
db.commit()
|
||
|
||
|
||
@router.post("/{uid}/test", response_model=TestResult)
|
||
def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
||
if not u:
|
||
raise HTTPException(404, "upstream not found")
|
||
auth_config = json.loads(u.auth_config_json or "{}")
|
||
with UpstreamClient(
|
||
base_url=u.base_url,
|
||
api_prefix=u.api_prefix,
|
||
auth_type=u.auth_type,
|
||
auth_config=auth_config,
|
||
timeout=float(u.timeout_seconds),
|
||
) as client:
|
||
try:
|
||
client.login()
|
||
groups = client.get_available_groups(u.groups_endpoint)
|
||
# Also try balance if configured
|
||
if u.balance_endpoint and u.balance_response_path:
|
||
try:
|
||
raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path)
|
||
if raw_balance is not None:
|
||
divisor = u.balance_divisor or 1.0
|
||
u.balance = raw_balance / divisor
|
||
u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None
|
||
except Exception as exc:
|
||
logger.warning("upstream %s balance fetch failed during test: %s", u.name, exc)
|
||
u.last_status = "healthy"
|
||
u.last_error = None
|
||
u.last_checked_at = datetime.now(timezone.utc)
|
||
u.consecutive_failures = 0
|
||
db.commit()
|
||
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
|
||
except Exception as exc:
|
||
u.last_status = "unhealthy"
|
||
u.last_error = str(exc)
|
||
u.last_checked_at = datetime.now(timezone.utc)
|
||
u.consecutive_failures = (u.consecutive_failures or 0) + 1
|
||
db.commit()
|
||
return TestResult(success=False, message="连接失败", detail=str(exc))
|
||
|
||
|
||
@router.post("/{uid}/check-now", response_model=TestResult)
|
||
def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
u = db.query(Upstream).filter(Upstream.id == uid).first()
|
||
if not u:
|
||
raise HTTPException(404, "upstream not found")
|
||
auth_config = json.loads(u.auth_config_json or "{}")
|
||
with UpstreamClient(
|
||
base_url=u.base_url,
|
||
api_prefix=u.api_prefix,
|
||
auth_type=u.auth_type,
|
||
auth_config=auth_config,
|
||
timeout=float(u.timeout_seconds),
|
||
) as client:
|
||
try:
|
||
client.login()
|
||
groups = client.get_available_groups(u.groups_endpoint)
|
||
raw_rates = client.get_group_rates(u.rate_endpoint)
|
||
snapshot = build_snapshot(u.id, u.base_url, u.api_prefix, groups, raw_rates)
|
||
# Also try balance if configured
|
||
if u.balance_endpoint and u.balance_response_path:
|
||
try:
|
||
raw_balance = client.get_balance(u.balance_endpoint, u.balance_response_path)
|
||
if raw_balance is not None:
|
||
divisor = u.balance_divisor or 1.0
|
||
u.balance = raw_balance / divisor
|
||
u.balance_updated_at = datetime.now(timezone.utc) if raw_balance is not None else None
|
||
except Exception as exc:
|
||
logger.warning("upstream %s balance fetch failed during check-now: %s", u.name, exc)
|
||
except Exception as exc:
|
||
u.consecutive_failures = (u.consecutive_failures or 0) + 1
|
||
u.last_error = str(exc)
|
||
u.last_checked_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
return TestResult(success=False, message="检测失败", detail=str(exc))
|
||
|
||
prev_row = (
|
||
db.query(UpstreamRateSnapshot)
|
||
.filter(UpstreamRateSnapshot.upstream_id == uid)
|
||
.order_by(UpstreamRateSnapshot.captured_at.desc())
|
||
.first()
|
||
)
|
||
previous = json.loads(prev_row.snapshot_json) if prev_row else None
|
||
changes = diff_snapshots(previous, snapshot)
|
||
|
||
new_row = UpstreamRateSnapshot(
|
||
upstream_id=uid,
|
||
snapshot_json=json.dumps(snapshot, ensure_ascii=False),
|
||
captured_at=datetime.now(timezone.utc),
|
||
)
|
||
db.add(new_row)
|
||
was_unhealthy = u.last_status == "unhealthy"
|
||
u.last_status = "healthy"
|
||
u.last_checked_at = datetime.now(timezone.utc)
|
||
u.last_error = None
|
||
u.consecutive_failures = 0
|
||
db.commit()
|
||
|
||
if was_unhealthy:
|
||
webhook_service.send_status_event(db, u.id, u.name, u.base_url, "upstream_recovered")
|
||
# 先同步 Key 状态(标记 orphaned),再执行优先级同步(避免未标记的 key 参与计算)
|
||
from app.services.scheduler import _sync_upstream_keys as _synck
|
||
_synck(uid, snapshot, new_row.captured_at)
|
||
|
||
if changes:
|
||
webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes)
|
||
website_sync.sync_affected_bindings(db, u.id, changes)
|
||
website_sync.sync_account_priorities_for_upstream(db, u.id)
|
||
|
||
msg = f"检测成功,{len(groups)} 个分组"
|
||
if changes:
|
||
msg += f",发现 {len(changes)} 处倍率变化"
|
||
elif previous is None:
|
||
msg += ",初始化快照完成"
|
||
else:
|
||
msg += ",无变化"
|
||
return TestResult(success=True, message=msg)
|
||
|
||
|
||
@router.get("/{uid}/snapshots/latest", response_model=SnapshotResponse)
|
||
def latest_snapshot(uid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||
row = (
|
||
db.query(UpstreamRateSnapshot)
|
||
.filter(UpstreamRateSnapshot.upstream_id == uid)
|
||
.order_by(UpstreamRateSnapshot.captured_at.desc())
|
||
.first()
|
||
)
|
||
if not row:
|
||
raise HTTPException(404, "no snapshot found")
|
||
return SnapshotResponse(
|
||
id=row.id,
|
||
upstream_id=row.upstream_id,
|
||
snapshot=json.loads(row.snapshot_json),
|
||
captured_at=row.captured_at,
|
||
)
|
||
|
||
|
||
from fastapi import Query as QueryParam
|
||
|
||
@router.get("/{uid}/snapshots", response_model=List[SnapshotResponse])
|
||
def list_snapshots(
|
||
uid: int,
|
||
limit: int = QueryParam(20, le=100),
|
||
offset: int = QueryParam(0),
|
||
db: Session = Depends(get_db),
|
||
_=Depends(get_current_user),
|
||
):
|
||
"""Return paginated snapshot history with diff vs previous snapshot embedded."""
|
||
rows = (
|
||
db.query(UpstreamRateSnapshot)
|
||
.filter(UpstreamRateSnapshot.upstream_id == uid)
|
||
.order_by(UpstreamRateSnapshot.captured_at.desc())
|
||
.offset(offset)
|
||
.limit(limit + 1) # fetch one extra to get the "previous" for diffing
|
||
.all()
|
||
)
|
||
# We need the snapshot just before each one to compute changes count.
|
||
# rows are desc order; rows[i+1] is older than rows[i]
|
||
results = []
|
||
for i, row in enumerate(rows[:limit]):
|
||
snap = json.loads(row.snapshot_json)
|
||
# try to diff against the next row (which is older)
|
||
changes_count: int | None = None
|
||
if i + 1 < len(rows):
|
||
older = json.loads(rows[i + 1].snapshot_json)
|
||
from app.services.snapshot_service import diff_snapshots
|
||
ch = diff_snapshots(older, snap)
|
||
changes_count = len(ch)
|
||
groups_count = len(snap.get("groups", {}))
|
||
# embed lightweight summary into snapshot dict so frontend can display it
|
||
snap["_groups_count"] = groups_count
|
||
snap["_changes_count"] = changes_count # None means first ever snapshot
|
||
results.append(SnapshotResponse(
|
||
id=row.id,
|
||
upstream_id=row.upstream_id,
|
||
snapshot=snap,
|
||
captured_at=row.captured_at,
|
||
))
|
||
return results
|