Files
SmartUp/backend/app/services/scheduler.py
T
liumangmang bea4344bb3 fix: reconcile upstream keys on list/generate/import to prevent stale key imports
- Extract reconcile_upstream_keys() to website_sync.py (shared scheduler + on-demand)
- Add reconcile_upstream_keys_full() for on-demand reconciliation at three entry points:
  list_generated_keys, generate_keys_by_groups, import_upstream_keys_as_accounts
- Safe on failure: active_group_ids=None / remote_key_ids=None skip cleanup
- Support custom managed_prefix via _fetch_remote_managed_key_ids() helper
- Exclude orphaned keys from frontend importable list
- Remove hardcoded search='SmartUp' from scheduler path
2026-06-01 11:29:37 +08:00

316 lines
12 KiB
Python

"""APScheduler background scheduler for upstream checks."""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from apscheduler.executors.pool import ThreadPoolExecutor
from apscheduler.schedulers.background import BackgroundScheduler
from sqlalchemy.orm import Session
from app.database import SessionLocal
from app.models.upstream import Upstream
from app.models.snapshot import UpstreamRateSnapshot
from app.services.upstream_client import UpstreamClient, build_snapshot
from app.services.snapshot_service import diff_snapshots, prune_snapshots
from app.services import webhook_service
from app.services import website_sync
from app.config import get_settings
logger = logging.getLogger(__name__)
_scheduler = BackgroundScheduler(timezone="UTC", executors={"default": ThreadPoolExecutor(max_workers=1)})
def get_scheduler() -> BackgroundScheduler:
return _scheduler
def _check_upstream(upstream_id: int) -> None:
"""Full upstream check executed by scheduler (runs in thread).
Phase 1 — upstream API call + snapshot write (single transaction).
Phase 2 — webhook/website sync (separate sessions, so a notification
failure never rolls back the snapshot).
"""
settings = get_settings()
# ── Phase 1: upstream check + DB write ──────────────────────────
db: Session = SessionLocal()
try:
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
if not upstream or not upstream.enabled:
_remove_job(upstream_id)
return
auth_config = json.loads(upstream.auth_config_json or "{}")
was_unhealthy = upstream.last_status == "unhealthy"
balance_alert_triggered = False
snapshot = None
changes = None
with UpstreamClient(
base_url=upstream.base_url,
api_prefix=upstream.api_prefix,
auth_type=upstream.auth_type,
auth_config=auth_config,
timeout=float(upstream.timeout_seconds),
) as client:
try:
client.login()
groups = client.get_available_groups(upstream.groups_endpoint)
raw_rates = client.get_group_rates(upstream.rate_endpoint)
snapshot = build_snapshot(
upstream.id, upstream.base_url, upstream.api_prefix, groups, raw_rates
)
# ── Balance fetch (inside with block, client still open) ──
balance: Optional[float] = None
if upstream.balance_endpoint and upstream.balance_response_path:
try:
raw_balance = client.get_balance(upstream.balance_endpoint, upstream.balance_response_path)
if raw_balance is not None:
divisor = upstream.balance_divisor or 1.0
balance = raw_balance / divisor
except Exception as exc:
logger.warning("upstream %s balance fetch failed: %s", upstream.name, exc)
if balance is not None:
upstream.balance = balance
upstream.balance_updated_at = datetime.now(timezone.utc)
# ── 余额告警阈值检查 ──
threshold = upstream.balance_alert_threshold
if threshold is not None and threshold > 0:
if balance < threshold and not upstream.balance_alert_notified:
upstream.balance_alert_notified = True
balance_alert_triggered = True
elif balance >= threshold and upstream.balance_alert_notified:
upstream.balance_alert_notified = False
except Exception as exc:
# failure path
upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1
upstream.last_error = str(exc)
upstream.last_checked_at = datetime.now(timezone.utc)
threshold = settings.unhealthy_threshold
became_unhealthy = (
upstream.consecutive_failures >= threshold
and upstream.last_status != "unhealthy"
)
if became_unhealthy:
upstream.last_status = "unhealthy"
db.commit()
logger.warning("upstream %s check failed: %s", upstream.name, exc)
# Phase 2: notify unhealthy in a fresh session
if became_unhealthy:
_notify_status(upstream.id, upstream.name, upstream.base_url,
"upstream_unhealthy", str(exc))
return
# success path (client auto-closed by `with`)
prev_snapshot_row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.first()
)
previous = json.loads(prev_snapshot_row.snapshot_json) if prev_snapshot_row else None
changes = diff_snapshots(previous, snapshot)
# save new snapshot
new_row = UpstreamRateSnapshot(
upstream_id=upstream_id,
snapshot_json=json.dumps(snapshot, ensure_ascii=False),
captured_at=datetime.now(timezone.utc),
)
db.add(new_row)
prune_snapshots(db, upstream_id, settings.snapshot_retention_count)
# update upstream status
upstream.last_status = "healthy"
upstream.last_checked_at = datetime.now(timezone.utc)
upstream.last_error = None
upstream.consecutive_failures = 0
db.commit()
logger.info(
"upstream %s: %d rate change(s)" if changes else "upstream %s: no changes",
upstream.name, len(changes) if changes else 0,
)
finally:
db.close()
# ── Phase 2: key sync (independent session) ───────────────────
if snapshot:
captured_at = snapshot.get("captured_at")
if isinstance(captured_at, str):
from datetime import datetime as dt
try:
captured_at = dt.fromisoformat(captured_at)
except Exception:
captured_at = datetime.now(timezone.utc)
elif captured_at is None:
captured_at = datetime.now(timezone.utc)
_sync_upstream_keys(upstream_id, snapshot, captured_at)
# ── Phase 3: notifications (independent sessions) ──────────────
if was_unhealthy:
_notify_status(upstream_id, upstream.name, upstream.base_url, "upstream_recovered")
if changes:
_notify_rate_changed(upstream_id, upstream.name, upstream.base_url, changes)
_sync_website_bindings(upstream_id, changes)
_sync_account_priorities(upstream_id)
if balance_alert_triggered:
_notify_balance_low(
upstream_id, upstream.name, upstream.base_url,
upstream.balance, upstream.balance_alert_threshold,
)
def _notify_status(
upstream_id: int,
upstream_name: str,
base_url: str,
event: str,
error: str = "",
) -> None:
db = SessionLocal()
try:
webhook_service.send_status_event(db, upstream_id, upstream_name, base_url, event, error)
except Exception:
logger.exception("status webhook failed for upstream %s", upstream_name)
finally:
db.close()
def _notify_rate_changed(
upstream_id: int,
upstream_name: str,
base_url: str,
changes: list[dict[str, Any]],
) -> None:
db = SessionLocal()
try:
webhook_service.send_rate_changed(db, upstream_id, upstream_name, base_url, changes)
except Exception:
logger.exception("rate webhook failed for upstream %s", upstream_name)
finally:
db.close()
def _notify_balance_low(
upstream_id: int,
upstream_name: str,
base_url: str,
balance: float,
threshold: float,
) -> None:
db = SessionLocal()
try:
webhook_service.send_balance_low(db, upstream_id, upstream_name, base_url, balance, threshold)
except Exception:
logger.exception("balance low webhook failed for upstream %s", upstream_name)
finally:
db.close()
def _sync_upstream_keys(upstream_id: int, snapshot: dict[str, Any], captured_at: datetime) -> None:
"""上游检测成功后同步 SmartUp Key 状态(远端删除/分组删除)。
委托给 website_sync.reconcile_upstream_keys 实现核心逻辑。
"""
db = SessionLocal()
try:
active_group_ids = set(snapshot.get("groups", {}).keys())
# 用 UpstreamClient 查询远端活跃 Key ID 集合
remote_key_ids: set[str] | None = None
try:
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
if upstream:
auth_config = json.loads(upstream.auth_config_json or "{}")
with UpstreamClient(
base_url=upstream.base_url,
api_prefix=upstream.api_prefix,
auth_type=upstream.auth_type,
auth_config=auth_config,
timeout=float(upstream.timeout_seconds),
) as client:
client.login()
remote_key_ids = website_sync._fetch_remote_managed_key_ids(db, client, upstream_id)
except Exception as exc:
logger.warning("sync upstream keys list failed for %s: %s", upstream_id, exc)
website_sync.reconcile_upstream_keys(db, upstream_id, active_group_ids, remote_key_ids, captured_at)
db.commit()
except Exception:
logger.exception("key sync failed for upstream %s", upstream_id)
finally:
db.close()
def _sync_account_priorities(upstream_id: int) -> None:
"""倍率变更后自动更新已导入下游账号的 priority。"""
db = SessionLocal()
try:
website_sync.sync_account_priorities_for_upstream(db, upstream_id)
except Exception:
logger.exception("account priority sync failed for upstream %s", upstream_id)
finally:
db.close()
def _sync_website_bindings(upstream_id: int, changes: list[dict[str, Any]]) -> None:
db = SessionLocal()
try:
website_sync.sync_affected_bindings(db, upstream_id, changes)
except Exception:
logger.exception("website sync failed for upstream %s", upstream_id)
finally:
db.close()
def _remove_job(upstream_id: int) -> None:
job_id = f"upstream_{upstream_id}"
if _scheduler.get_job(job_id):
_scheduler.remove_job(job_id)
def refresh_upstream(upstream_id: int, interval_seconds: int = 0, enabled: bool = True) -> None:
"""Add/update/remove a scheduler job for the given upstream."""
job_id = f"upstream_{upstream_id}"
if not enabled or interval_seconds <= 0:
_remove_job(upstream_id)
return
_scheduler.add_job(
_check_upstream,
"interval",
seconds=interval_seconds,
id=job_id,
args=[upstream_id],
replace_existing=True,
coalesce=True,
max_instances=1,
misfire_grace_time=60,
jitter=30,
)
logger.info("scheduler job %s set to %ds interval", job_id, interval_seconds)
def start_scheduler() -> None:
"""Start scheduler and load all enabled upstreams."""
_scheduler.start()
db: Session = SessionLocal()
try:
upstreams = db.query(Upstream).filter(Upstream.enabled == True).all()
for u in upstreams:
refresh_upstream(u.id, u.check_interval_seconds, u.enabled)
logger.info("scheduler started with %d upstream job(s)", len(upstreams))
finally:
db.close()
def stop_scheduler() -> None:
if _scheduler.running:
_scheduler.shutdown(wait=True)