309 lines
12 KiB
Python
309 lines
12 KiB
Python
"""APScheduler background scheduler for upstream checks."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
|
|
from apscheduler.executors.pool import ThreadPoolExecutor
|
|
from apscheduler.schedulers.background import BackgroundScheduler
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.database import SessionLocal
|
|
from app.models.upstream import Upstream
|
|
from app.models.upstream_key import UpstreamGeneratedKey
|
|
from app.models.snapshot import UpstreamRateSnapshot
|
|
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot
|
|
from app.services.snapshot_service import diff_snapshots, prune_snapshots
|
|
from app.services import webhook_service
|
|
from app.services import website_sync
|
|
from app.config import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_scheduler = BackgroundScheduler(timezone="UTC", executors={"default": ThreadPoolExecutor(max_workers=1)})
|
|
|
|
|
|
def get_scheduler() -> BackgroundScheduler:
|
|
return _scheduler
|
|
|
|
|
|
def _check_upstream(upstream_id: int) -> None:
|
|
"""Full upstream check executed by scheduler (runs in thread).
|
|
|
|
Phase 1 — upstream API call + snapshot write (single transaction).
|
|
Phase 2 — webhook/website sync (separate sessions, so a notification
|
|
failure never rolls back the snapshot).
|
|
"""
|
|
settings = get_settings()
|
|
# ── Phase 1: upstream check + DB write ──────────────────────────
|
|
db: Session = SessionLocal()
|
|
try:
|
|
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
|
|
if not upstream or not upstream.enabled:
|
|
_remove_job(upstream_id)
|
|
return
|
|
|
|
auth_config = json.loads(upstream.auth_config_json or "{}")
|
|
was_unhealthy = upstream.last_status == "unhealthy"
|
|
snapshot = None
|
|
changes = None
|
|
|
|
with UpstreamClient(
|
|
base_url=upstream.base_url,
|
|
api_prefix=upstream.api_prefix,
|
|
auth_type=upstream.auth_type,
|
|
auth_config=auth_config,
|
|
timeout=float(upstream.timeout_seconds),
|
|
) as client:
|
|
try:
|
|
client.login()
|
|
groups = client.get_available_groups(upstream.groups_endpoint)
|
|
raw_rates = client.get_group_rates(upstream.rate_endpoint)
|
|
snapshot = build_snapshot(
|
|
upstream.id, upstream.base_url, upstream.api_prefix, groups, raw_rates
|
|
)
|
|
# ── Balance fetch (inside with block, client still open) ──
|
|
balance: Optional[float] = None
|
|
if upstream.balance_endpoint and upstream.balance_response_path:
|
|
try:
|
|
raw_balance = client.get_balance(upstream.balance_endpoint, upstream.balance_response_path)
|
|
if raw_balance is not None:
|
|
divisor = upstream.balance_divisor or 1.0
|
|
balance = raw_balance / divisor
|
|
except Exception as exc:
|
|
logger.warning("upstream %s balance fetch failed: %s", upstream.name, exc)
|
|
if balance is not None:
|
|
upstream.balance = balance
|
|
upstream.balance_updated_at = datetime.now(timezone.utc)
|
|
except Exception as exc:
|
|
# failure path
|
|
upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1
|
|
upstream.last_error = str(exc)
|
|
upstream.last_checked_at = datetime.now(timezone.utc)
|
|
threshold = settings.unhealthy_threshold
|
|
became_unhealthy = (
|
|
upstream.consecutive_failures >= threshold
|
|
and upstream.last_status != "unhealthy"
|
|
)
|
|
if became_unhealthy:
|
|
upstream.last_status = "unhealthy"
|
|
db.commit()
|
|
logger.warning("upstream %s check failed: %s", upstream.name, exc)
|
|
# Phase 2: notify unhealthy in a fresh session
|
|
if became_unhealthy:
|
|
_notify_status(upstream.id, upstream.name, upstream.base_url,
|
|
"upstream_unhealthy", str(exc))
|
|
return
|
|
|
|
# success path (client auto-closed by `with`)
|
|
|
|
prev_snapshot_row = (
|
|
db.query(UpstreamRateSnapshot)
|
|
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
|
|
.order_by(UpstreamRateSnapshot.captured_at.desc())
|
|
.first()
|
|
)
|
|
previous = json.loads(prev_snapshot_row.snapshot_json) if prev_snapshot_row else None
|
|
changes = diff_snapshots(previous, snapshot)
|
|
|
|
# save new snapshot
|
|
new_row = UpstreamRateSnapshot(
|
|
upstream_id=upstream_id,
|
|
snapshot_json=json.dumps(snapshot, ensure_ascii=False),
|
|
captured_at=datetime.now(timezone.utc),
|
|
)
|
|
db.add(new_row)
|
|
prune_snapshots(db, upstream_id, settings.snapshot_retention_count)
|
|
|
|
# update upstream status
|
|
upstream.last_status = "healthy"
|
|
upstream.last_checked_at = datetime.now(timezone.utc)
|
|
upstream.last_error = None
|
|
upstream.consecutive_failures = 0
|
|
db.commit()
|
|
|
|
logger.info(
|
|
"upstream %s: %d rate change(s)" if changes else "upstream %s: no changes",
|
|
upstream.name, len(changes) if changes else 0,
|
|
)
|
|
|
|
finally:
|
|
db.close()
|
|
|
|
# ── Phase 2: key sync (independent session) ───────────────────
|
|
if snapshot:
|
|
captured_at = snapshot.get("captured_at")
|
|
if isinstance(captured_at, str):
|
|
from datetime import datetime as dt
|
|
try:
|
|
captured_at = dt.fromisoformat(captured_at)
|
|
except Exception:
|
|
captured_at = datetime.now(timezone.utc)
|
|
elif captured_at is None:
|
|
captured_at = datetime.now(timezone.utc)
|
|
_sync_upstream_keys(upstream_id, snapshot, captured_at)
|
|
|
|
# ── Phase 3: notifications (independent sessions) ──────────────
|
|
if was_unhealthy:
|
|
_notify_status(upstream_id, upstream.name, upstream.base_url, "upstream_recovered")
|
|
|
|
if changes:
|
|
_notify_rate_changed(upstream_id, upstream.name, upstream.base_url, changes)
|
|
_sync_website_bindings(upstream_id, changes)
|
|
|
|
|
|
def _notify_status(
|
|
upstream_id: int,
|
|
upstream_name: str,
|
|
base_url: str,
|
|
event: str,
|
|
error: str = "",
|
|
) -> None:
|
|
db = SessionLocal()
|
|
try:
|
|
webhook_service.send_status_event(db, upstream_id, upstream_name, base_url, event, error)
|
|
except Exception:
|
|
logger.exception("status webhook failed for upstream %s", upstream_name)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def _notify_rate_changed(
|
|
upstream_id: int,
|
|
upstream_name: str,
|
|
base_url: str,
|
|
changes: list[dict[str, Any]],
|
|
) -> None:
|
|
db = SessionLocal()
|
|
try:
|
|
webhook_service.send_rate_changed(db, upstream_id, upstream_name, base_url, changes)
|
|
except Exception:
|
|
logger.exception("rate webhook failed for upstream %s", upstream_name)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def _sync_upstream_keys(upstream_id: int, snapshot: dict[str, Any], captured_at: datetime) -> None:
|
|
"""上游检测成功后同步 SmartUp Key 状态(远端删除/分组删除)。"""
|
|
db = SessionLocal()
|
|
try:
|
|
active_group_ids = set(snapshot.get("groups", {}).keys())
|
|
key_rows = (
|
|
db.query(UpstreamGeneratedKey)
|
|
.filter(
|
|
UpstreamGeneratedKey.upstream_id == upstream_id,
|
|
UpstreamGeneratedKey.key_name.like("SmartUp-%"),
|
|
)
|
|
.all()
|
|
)
|
|
auth_config = json.loads(
|
|
db.query(Upstream).filter(Upstream.id == upstream_id).first().auth_config_json or "{}"
|
|
)
|
|
# 用 UpstreamClient 查询远端活跃 Key ID 集合
|
|
remote_key_ids: set[str] | None = None # None=查询失败,set()=查询成功但为空
|
|
try:
|
|
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
|
|
if upstream:
|
|
with UpstreamClient(
|
|
base_url=upstream.base_url,
|
|
api_prefix=upstream.api_prefix,
|
|
auth_type=upstream.auth_type,
|
|
auth_config=auth_config,
|
|
timeout=float(upstream.timeout_seconds),
|
|
) as client:
|
|
client.login()
|
|
remote_keys = client.list_api_keys(search="SmartUp", status="active")
|
|
remote_key_ids = {
|
|
str(k["id"]) for k in remote_keys if k.get("id")
|
|
}
|
|
except Exception as exc:
|
|
logger.warning("sync upstream keys list failed for %s: %s", upstream_id, exc)
|
|
|
|
for row in key_rows:
|
|
# 1. 分组已不在当前快照中 → 删除本地记录
|
|
if row.group_id not in active_group_ids:
|
|
if row.imported_website_id and row.imported_account_id:
|
|
row.status = "orphaned"
|
|
row.error = "来源分组已不存在"
|
|
row.updated_at = captured_at
|
|
logger.info("marked key %s orphaned (group %s no longer in snapshot)", row.id, row.group_id)
|
|
else:
|
|
db.delete(row)
|
|
logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id)
|
|
continue
|
|
# 2. 远端查询成功但 key_id 不在列表中 → 删除本地记录
|
|
if row.key_id and remote_key_ids is not None and row.key_id not in remote_key_ids:
|
|
if row.imported_website_id and row.imported_account_id:
|
|
row.status = "orphaned"
|
|
row.error = "远端 Key 已不存在"
|
|
row.updated_at = captured_at
|
|
logger.info("marked key %s orphaned (key_id %s gone from remote)", row.id, row.key_id)
|
|
else:
|
|
db.delete(row)
|
|
logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id)
|
|
continue
|
|
# 3. 更新同步时间戳(仅当查询成功且 Key 仍在远端时)
|
|
if remote_key_ids is not None and row.key_id in remote_key_ids:
|
|
row.updated_at = captured_at
|
|
db.commit()
|
|
except Exception:
|
|
logger.exception("key sync failed for upstream %s", upstream_id)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def _sync_website_bindings(upstream_id: int, changes: list[dict[str, Any]]) -> None:
|
|
db = SessionLocal()
|
|
try:
|
|
website_sync.sync_affected_bindings(db, upstream_id, changes)
|
|
except Exception:
|
|
logger.exception("website sync failed for upstream %s", upstream_id)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def _remove_job(upstream_id: int) -> None:
|
|
job_id = f"upstream_{upstream_id}"
|
|
if _scheduler.get_job(job_id):
|
|
_scheduler.remove_job(job_id)
|
|
|
|
|
|
def refresh_upstream(upstream_id: int, interval_seconds: int = 0, enabled: bool = True) -> None:
|
|
"""Add/update/remove a scheduler job for the given upstream."""
|
|
job_id = f"upstream_{upstream_id}"
|
|
if not enabled or interval_seconds <= 0:
|
|
_remove_job(upstream_id)
|
|
return
|
|
_scheduler.add_job(
|
|
_check_upstream,
|
|
"interval",
|
|
seconds=interval_seconds,
|
|
id=job_id,
|
|
args=[upstream_id],
|
|
replace_existing=True,
|
|
coalesce=True,
|
|
max_instances=1,
|
|
misfire_grace_time=60,
|
|
jitter=30,
|
|
)
|
|
logger.info("scheduler job %s set to %ds interval", job_id, interval_seconds)
|
|
|
|
|
|
def start_scheduler() -> None:
|
|
"""Start scheduler and load all enabled upstreams."""
|
|
_scheduler.start()
|
|
db: Session = SessionLocal()
|
|
try:
|
|
upstreams = db.query(Upstream).filter(Upstream.enabled == True).all()
|
|
for u in upstreams:
|
|
refresh_upstream(u.id, u.check_interval_seconds, u.enabled)
|
|
logger.info("scheduler started with %d upstream job(s)", len(upstreams))
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def stop_scheduler() -> None:
|
|
if _scheduler.running:
|
|
_scheduler.shutdown(wait=True)
|