"""APScheduler background scheduler for upstream checks.""" from __future__ import annotations import json import logging from datetime import datetime, timezone from apscheduler.executors.pool import ThreadPoolExecutor from apscheduler.schedulers.background import BackgroundScheduler from sqlalchemy.orm import Session from app.database import SessionLocal from app.models.upstream import Upstream from app.models.upstream_key import UpstreamGeneratedKey from app.models.snapshot import UpstreamRateSnapshot from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot from app.services.snapshot_service import diff_snapshots, prune_snapshots from app.services import webhook_service from app.services import website_sync from app.config import get_settings logger = logging.getLogger(__name__) _scheduler = BackgroundScheduler(timezone="UTC", executors={"default": ThreadPoolExecutor(max_workers=1)}) def get_scheduler() -> BackgroundScheduler: return _scheduler def _check_upstream(upstream_id: int) -> None: """Full upstream check executed by scheduler (runs in thread). Phase 1 — upstream API call + snapshot write (single transaction). Phase 2 — webhook/website sync (separate sessions, so a notification failure never rolls back the snapshot). """ settings = get_settings() # ── Phase 1: upstream check + DB write ────────────────────────── db: Session = SessionLocal() try: upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first() if not upstream or not upstream.enabled: _remove_job(upstream_id) return auth_config = json.loads(upstream.auth_config_json or "{}") was_unhealthy = upstream.last_status == "unhealthy" balance_alert_triggered = False snapshot = None changes = None with UpstreamClient( base_url=upstream.base_url, api_prefix=upstream.api_prefix, auth_type=upstream.auth_type, auth_config=auth_config, timeout=float(upstream.timeout_seconds), ) as client: try: client.login() groups = client.get_available_groups(upstream.groups_endpoint) raw_rates = client.get_group_rates(upstream.rate_endpoint) snapshot = build_snapshot( upstream.id, upstream.base_url, upstream.api_prefix, groups, raw_rates ) # ── Balance fetch (inside with block, client still open) ── balance: Optional[float] = None if upstream.balance_endpoint and upstream.balance_response_path: try: raw_balance = client.get_balance(upstream.balance_endpoint, upstream.balance_response_path) if raw_balance is not None: divisor = upstream.balance_divisor or 1.0 balance = raw_balance / divisor except Exception as exc: logger.warning("upstream %s balance fetch failed: %s", upstream.name, exc) if balance is not None: upstream.balance = balance upstream.balance_updated_at = datetime.now(timezone.utc) # ── 余额告警阈值检查 ── threshold = upstream.balance_alert_threshold if threshold is not None and threshold > 0: if balance < threshold and not upstream.balance_alert_notified: upstream.balance_alert_notified = True balance_alert_triggered = True elif balance >= threshold and upstream.balance_alert_notified: upstream.balance_alert_notified = False except Exception as exc: # failure path upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1 upstream.last_error = str(exc) upstream.last_checked_at = datetime.now(timezone.utc) threshold = settings.unhealthy_threshold became_unhealthy = ( upstream.consecutive_failures >= threshold and upstream.last_status != "unhealthy" ) if became_unhealthy: upstream.last_status = "unhealthy" db.commit() logger.warning("upstream %s check failed: %s", upstream.name, exc) # Phase 2: notify unhealthy in a fresh session if became_unhealthy: _notify_status(upstream.id, upstream.name, upstream.base_url, "upstream_unhealthy", str(exc)) return # success path (client auto-closed by `with`) prev_snapshot_row = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == upstream_id) .order_by(UpstreamRateSnapshot.captured_at.desc()) .first() ) previous = json.loads(prev_snapshot_row.snapshot_json) if prev_snapshot_row else None changes = diff_snapshots(previous, snapshot) # save new snapshot new_row = UpstreamRateSnapshot( upstream_id=upstream_id, snapshot_json=json.dumps(snapshot, ensure_ascii=False), captured_at=datetime.now(timezone.utc), ) db.add(new_row) prune_snapshots(db, upstream_id, settings.snapshot_retention_count) # update upstream status upstream.last_status = "healthy" upstream.last_checked_at = datetime.now(timezone.utc) upstream.last_error = None upstream.consecutive_failures = 0 db.commit() logger.info( "upstream %s: %d rate change(s)" if changes else "upstream %s: no changes", upstream.name, len(changes) if changes else 0, ) finally: db.close() # ── Phase 2: key sync (independent session) ─────────────────── if snapshot: captured_at = snapshot.get("captured_at") if isinstance(captured_at, str): from datetime import datetime as dt try: captured_at = dt.fromisoformat(captured_at) except Exception: captured_at = datetime.now(timezone.utc) elif captured_at is None: captured_at = datetime.now(timezone.utc) _sync_upstream_keys(upstream_id, snapshot, captured_at) # ── Phase 3: notifications (independent sessions) ────────────── if was_unhealthy: _notify_status(upstream_id, upstream.name, upstream.base_url, "upstream_recovered") if changes: _notify_rate_changed(upstream_id, upstream.name, upstream.base_url, changes) _sync_website_bindings(upstream_id, changes) if balance_alert_triggered: _notify_balance_low( upstream_id, upstream.name, upstream.base_url, upstream.balance, upstream.balance_alert_threshold, ) def _notify_status( upstream_id: int, upstream_name: str, base_url: str, event: str, error: str = "", ) -> None: db = SessionLocal() try: webhook_service.send_status_event(db, upstream_id, upstream_name, base_url, event, error) except Exception: logger.exception("status webhook failed for upstream %s", upstream_name) finally: db.close() def _notify_rate_changed( upstream_id: int, upstream_name: str, base_url: str, changes: list[dict[str, Any]], ) -> None: db = SessionLocal() try: webhook_service.send_rate_changed(db, upstream_id, upstream_name, base_url, changes) except Exception: logger.exception("rate webhook failed for upstream %s", upstream_name) finally: db.close() def _notify_balance_low( upstream_id: int, upstream_name: str, base_url: str, balance: float, threshold: float, ) -> None: db = SessionLocal() try: webhook_service.send_balance_low(db, upstream_id, upstream_name, base_url, balance, threshold) except Exception: logger.exception("balance low webhook failed for upstream %s", upstream_name) finally: db.close() def _sync_upstream_keys(upstream_id: int, snapshot: dict[str, Any], captured_at: datetime) -> None: """上游检测成功后同步 SmartUp Key 状态(远端删除/分组删除)。""" db = SessionLocal() try: active_group_ids = set(snapshot.get("groups", {}).keys()) key_rows = ( db.query(UpstreamGeneratedKey) .filter( UpstreamGeneratedKey.upstream_id == upstream_id, UpstreamGeneratedKey.key_name.like("SmartUp-%"), ) .all() ) auth_config = json.loads( db.query(Upstream).filter(Upstream.id == upstream_id).first().auth_config_json or "{}" ) # 用 UpstreamClient 查询远端活跃 Key ID 集合 remote_key_ids: set[str] | None = None # None=查询失败,set()=查询成功但为空 try: upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first() if upstream: with UpstreamClient( base_url=upstream.base_url, api_prefix=upstream.api_prefix, auth_type=upstream.auth_type, auth_config=auth_config, timeout=float(upstream.timeout_seconds), ) as client: client.login() remote_keys = client.list_api_keys(search="SmartUp", status="active") remote_key_ids = { str(k["id"]) for k in remote_keys if k.get("id") } except Exception as exc: logger.warning("sync upstream keys list failed for %s: %s", upstream_id, exc) for row in key_rows: # 1. 分组已不在当前快照中 → 删除本地记录 if row.group_id not in active_group_ids: if row.imported_website_id and row.imported_account_id: row.status = "orphaned" row.error = "来源分组已不存在" row.updated_at = captured_at logger.info("marked key %s orphaned (group %s no longer in snapshot)", row.id, row.group_id) else: db.delete(row) logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id) continue # 2. 远端查询成功但 key_id 不在列表中 → 删除本地记录 if row.key_id and remote_key_ids is not None and row.key_id not in remote_key_ids: if row.imported_website_id and row.imported_account_id: row.status = "orphaned" row.error = "远端 Key 已不存在" row.updated_at = captured_at logger.info("marked key %s orphaned (key_id %s gone from remote)", row.id, row.key_id) else: db.delete(row) logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id) continue # 3. 更新同步时间戳(仅当查询成功且 Key 仍在远端时) if remote_key_ids is not None and row.key_id in remote_key_ids: row.updated_at = captured_at db.commit() except Exception: logger.exception("key sync failed for upstream %s", upstream_id) finally: db.close() def _sync_website_bindings(upstream_id: int, changes: list[dict[str, Any]]) -> None: db = SessionLocal() try: website_sync.sync_affected_bindings(db, upstream_id, changes) except Exception: logger.exception("website sync failed for upstream %s", upstream_id) finally: db.close() def _remove_job(upstream_id: int) -> None: job_id = f"upstream_{upstream_id}" if _scheduler.get_job(job_id): _scheduler.remove_job(job_id) def refresh_upstream(upstream_id: int, interval_seconds: int = 0, enabled: bool = True) -> None: """Add/update/remove a scheduler job for the given upstream.""" job_id = f"upstream_{upstream_id}" if not enabled or interval_seconds <= 0: _remove_job(upstream_id) return _scheduler.add_job( _check_upstream, "interval", seconds=interval_seconds, id=job_id, args=[upstream_id], replace_existing=True, coalesce=True, max_instances=1, misfire_grace_time=60, jitter=30, ) logger.info("scheduler job %s set to %ds interval", job_id, interval_seconds) def start_scheduler() -> None: """Start scheduler and load all enabled upstreams.""" _scheduler.start() db: Session = SessionLocal() try: upstreams = db.query(Upstream).filter(Upstream.enabled == True).all() for u in upstreams: refresh_upstream(u.id, u.check_interval_seconds, u.enabled) logger.info("scheduler started with %d upstream job(s)", len(upstreams)) finally: db.close() def stop_scheduler() -> None: if _scheduler.running: _scheduler.shutdown(wait=True)