Add remote browser pages and website sync
Enable managed remote browser custom pages with login autofill and add website sync workflows so external admin surfaces can be handled inside SmartUp. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,320 @@
|
||||
"""Managed Playwright browser sessions for custom pages."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BrowserDependencyError(RuntimeError):
|
||||
"""Raised when Playwright or its browser runtime is unavailable."""
|
||||
|
||||
|
||||
class BrowserSessionError(RuntimeError):
|
||||
"""Raised when an existing browser session can no longer be used."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowserSession:
|
||||
id: str
|
||||
custom_page_id: int
|
||||
profile_key: str
|
||||
context: Any
|
||||
page: Any
|
||||
lock: asyncio.Lock
|
||||
|
||||
|
||||
class BrowserSessionService:
|
||||
def __init__(self) -> None:
|
||||
self._playwright: Optional[Any] = None
|
||||
self._sessions: dict[str, BrowserSession] = {}
|
||||
self._profiles: dict[str, str] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
custom_page_id: int,
|
||||
url: str,
|
||||
width: int = 1280,
|
||||
height: int = 720,
|
||||
login_config: Optional[dict[str, Any]] = None,
|
||||
) -> BrowserSession:
|
||||
if not url.startswith(("http://", "https://")):
|
||||
raise ValueError("Only http/https URLs are allowed")
|
||||
width = max(320, min(width, 2560))
|
||||
height = max(240, min(height, 1600))
|
||||
async with self._lock:
|
||||
await self._ensure_playwright()
|
||||
profile_key = self._profile_key(custom_page_id, url)
|
||||
existing_id = self._profiles.get(profile_key)
|
||||
existing = self._sessions.get(existing_id or "")
|
||||
if existing and not existing.page.is_closed():
|
||||
async with existing.lock:
|
||||
await existing.page.set_viewport_size({"width": width, "height": height})
|
||||
if existing.page.url == "about:blank":
|
||||
await existing.page.goto(url, wait_until="domcontentloaded", timeout=45000)
|
||||
await self._autofill_login(existing.page, login_config)
|
||||
await self._reset_page_zoom(existing)
|
||||
return existing
|
||||
if existing_id:
|
||||
self._profiles.pop(profile_key, None)
|
||||
context = await self._playwright.chromium.launch_persistent_context(
|
||||
str(self._profile_dir(profile_key)),
|
||||
headless=get_settings().browser_headless,
|
||||
viewport={"width": width, "height": height},
|
||||
args=["--no-sandbox", "--disable-dev-shm-usage"],
|
||||
)
|
||||
page = context.pages[0] if context.pages else await context.new_page()
|
||||
session = BrowserSession(
|
||||
id=uuid4().hex,
|
||||
custom_page_id=custom_page_id,
|
||||
profile_key=profile_key,
|
||||
context=context,
|
||||
page=page,
|
||||
lock=asyncio.Lock(),
|
||||
)
|
||||
self._sessions[session.id] = session
|
||||
self._profiles[profile_key] = session.id
|
||||
try:
|
||||
await page.goto(url, wait_until="domcontentloaded", timeout=45000)
|
||||
await self._autofill_login(page, login_config)
|
||||
await self._reset_page_zoom(session)
|
||||
except Exception:
|
||||
await self.close(session.id)
|
||||
raise
|
||||
return session
|
||||
|
||||
async def screenshot(self, session_id: str) -> bytes:
|
||||
session = self._get(session_id)
|
||||
async with session.lock:
|
||||
self._ensure_open(session)
|
||||
return await session.page.screenshot(type="jpeg", quality=78, full_page=False)
|
||||
|
||||
async def event(self, session_id: str, event_type: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
session = self._get(session_id)
|
||||
async with session.lock:
|
||||
self._ensure_open(session)
|
||||
page = session.page
|
||||
if event_type == "click":
|
||||
await page.mouse.click(float(payload["x"]), float(payload["y"]), button=payload.get("button", "left"))
|
||||
elif event_type == "dblclick":
|
||||
await page.mouse.dblclick(float(payload["x"]), float(payload["y"]), button=payload.get("button", "left"))
|
||||
elif event_type == "mousemove":
|
||||
await page.mouse.move(float(payload["x"]), float(payload["y"]))
|
||||
elif event_type == "mousedown":
|
||||
await page.mouse.move(float(payload["x"]), float(payload["y"]))
|
||||
await page.mouse.down(button=payload.get("button", "left"))
|
||||
elif event_type == "mouseup":
|
||||
await page.mouse.move(float(payload["x"]), float(payload["y"]))
|
||||
await page.mouse.up(button=payload.get("button", "left"))
|
||||
elif event_type == "type":
|
||||
text = str(payload.get("text", ""))
|
||||
if text:
|
||||
await page.keyboard.type(text)
|
||||
elif event_type == "key":
|
||||
key = str(payload.get("key", ""))
|
||||
if key:
|
||||
await page.keyboard.press(key)
|
||||
elif event_type == "scroll":
|
||||
if payload.get("x") is not None and payload.get("y") is not None:
|
||||
await page.mouse.move(float(payload["x"]), float(payload["y"]))
|
||||
await page.mouse.wheel(float(payload.get("delta_x", 0)), float(payload.get("delta_y", 0)))
|
||||
elif event_type == "reload":
|
||||
await page.reload(wait_until="domcontentloaded", timeout=45000)
|
||||
elif event_type == "back":
|
||||
await page.go_back(wait_until="domcontentloaded", timeout=45000)
|
||||
elif event_type == "forward":
|
||||
await page.go_forward(wait_until="domcontentloaded", timeout=45000)
|
||||
elif event_type == "resize":
|
||||
width = max(320, min(int(payload.get("width", 1280)), 2560))
|
||||
height = max(240, min(int(payload.get("height", 720)), 1600))
|
||||
await page.set_viewport_size({"width": width, "height": height})
|
||||
else:
|
||||
raise ValueError("Unsupported browser event")
|
||||
return await self._session_state(session)
|
||||
|
||||
async def close(self, session_id: str) -> None:
|
||||
session = self._discard_session(session_id)
|
||||
if not session:
|
||||
return
|
||||
try:
|
||||
await session.context.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
sessions = list(self._sessions)
|
||||
for session_id in sessions:
|
||||
await self.close(session_id)
|
||||
if self._playwright:
|
||||
await self._playwright.stop()
|
||||
self._playwright = None
|
||||
|
||||
async def state(self, session_id: str) -> dict[str, Any]:
|
||||
session = self._get(session_id)
|
||||
async with session.lock:
|
||||
self._ensure_open(session)
|
||||
return await self._session_state(session)
|
||||
|
||||
async def _session_state(self, session: BrowserSession) -> dict[str, Any]:
|
||||
return {
|
||||
"id": session.id,
|
||||
"custom_page_id": session.custom_page_id,
|
||||
"url": session.page.url,
|
||||
"title": await session.page.title(),
|
||||
}
|
||||
|
||||
async def _ensure_playwright(self) -> None:
|
||||
if self._playwright:
|
||||
return
|
||||
try:
|
||||
from playwright.async_api import async_playwright
|
||||
except ImportError as exc:
|
||||
raise BrowserDependencyError("Playwright is not installed. Run `pip install -r requirements.txt`.") from exc
|
||||
try:
|
||||
self._playwright = await async_playwright().start()
|
||||
except Exception as exc:
|
||||
raise BrowserDependencyError(f"Unable to start Playwright: {exc}") from exc
|
||||
|
||||
async def _reset_page_zoom(self, session: BrowserSession) -> None:
|
||||
try:
|
||||
cdp = await session.context.new_cdp_session(session.page)
|
||||
try:
|
||||
await cdp.send("Emulation.setPageScaleFactor", {"pageScaleFactor": 1})
|
||||
finally:
|
||||
await cdp.detach()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _autofill_login(
|
||||
self,
|
||||
page: Any,
|
||||
config: Optional[dict[str, Any]],
|
||||
*,
|
||||
max_wait_seconds: float = 8.0,
|
||||
poll_interval_seconds: float = 0.25,
|
||||
) -> None:
|
||||
if not config or not config.get("enabled"):
|
||||
return
|
||||
username = str(config.get("username") or "")
|
||||
password = str(config.get("password") or "")
|
||||
if not username or not password:
|
||||
return
|
||||
try:
|
||||
username_selectors = [
|
||||
config.get("username_selector"),
|
||||
"input[type='email']",
|
||||
"input[name*='user' i]",
|
||||
"input[id*='user' i]",
|
||||
"input[name*='email' i]",
|
||||
"input[id*='email' i]",
|
||||
"input[name*='login' i]",
|
||||
"input[id*='login' i]",
|
||||
"input[autocomplete='username']",
|
||||
"input:not([type]), input[type='text']",
|
||||
]
|
||||
password_selectors = [
|
||||
config.get("password_selector"),
|
||||
"input[type='password']",
|
||||
"input[autocomplete='current-password']",
|
||||
]
|
||||
username_locator, password_locator = await self._wait_for_login_locators(
|
||||
page,
|
||||
username_selectors,
|
||||
password_selectors,
|
||||
max_wait_seconds=max_wait_seconds,
|
||||
poll_interval_seconds=poll_interval_seconds,
|
||||
)
|
||||
if not username_locator or not password_locator:
|
||||
logger.info("Login autofill skipped for %s: login fields not found", page.url)
|
||||
return
|
||||
await username_locator.fill(username, timeout=3000)
|
||||
await password_locator.fill(password, timeout=3000)
|
||||
submit_selector = str(config.get("submit_selector") or "").strip()
|
||||
if submit_selector:
|
||||
submit = await self._first_visible_locator(page, [submit_selector], timeout=500)
|
||||
if submit:
|
||||
await submit.click(timeout=3000)
|
||||
except Exception as exc:
|
||||
logger.info("Login autofill skipped for %s: %s", page.url, exc)
|
||||
|
||||
async def _wait_for_login_locators(
|
||||
self,
|
||||
page: Any,
|
||||
username_selectors: list[Optional[str]],
|
||||
password_selectors: list[Optional[str]],
|
||||
*,
|
||||
max_wait_seconds: float,
|
||||
poll_interval_seconds: float,
|
||||
) -> tuple[Optional[Any], Optional[Any]]:
|
||||
deadline = time.monotonic() + max_wait_seconds
|
||||
while True:
|
||||
username_locator = await self._first_visible_locator(page, username_selectors, timeout=150)
|
||||
password_locator = await self._first_visible_locator(page, password_selectors, timeout=150)
|
||||
if username_locator and password_locator:
|
||||
return username_locator, password_locator
|
||||
if time.monotonic() >= deadline:
|
||||
return None, None
|
||||
await asyncio.sleep(poll_interval_seconds)
|
||||
|
||||
async def _first_visible_locator(
|
||||
self,
|
||||
page: Any,
|
||||
selectors: list[Optional[str]],
|
||||
*,
|
||||
timeout: float = 1500,
|
||||
) -> Optional[Any]:
|
||||
for selector in selectors:
|
||||
selector = str(selector or "").strip()
|
||||
if not selector:
|
||||
continue
|
||||
try:
|
||||
locator = page.locator(selector).first
|
||||
if await locator.count() and await locator.is_visible(timeout=timeout):
|
||||
return locator
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
def _get(self, session_id: str) -> BrowserSession:
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
raise KeyError("browser session not found")
|
||||
return session
|
||||
|
||||
def _ensure_open(self, session: BrowserSession) -> None:
|
||||
if session.page.is_closed():
|
||||
self._discard_session(session.id)
|
||||
raise BrowserSessionError("browser page is closed")
|
||||
|
||||
def _discard_session(self, session_id: str) -> BrowserSession | None:
|
||||
session = self._sessions.pop(session_id, None)
|
||||
if session and self._profiles.get(session.profile_key) == session_id:
|
||||
self._profiles.pop(session.profile_key, None)
|
||||
return session
|
||||
|
||||
def _profile_dir(self, profile_key: str) -> Path:
|
||||
root = Path(get_settings().browser_profiles_dir)
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
profile = root / profile_key
|
||||
profile.mkdir(parents=True, exist_ok=True)
|
||||
return profile
|
||||
|
||||
def _profile_key(self, custom_page_id: int, url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
origin = f"{parsed.scheme}-{parsed.netloc}".lower()
|
||||
safe_origin = re.sub(r"[^a-z0-9_.-]+", "_", origin).strip("_") or "page"
|
||||
return f"page-{custom_page_id}-{safe_origin[:80]}"
|
||||
|
||||
|
||||
browser_sessions = BrowserSessionService()
|
||||
@@ -14,6 +14,7 @@ from app.models.snapshot import UpstreamRateSnapshot
|
||||
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot
|
||||
from app.services.snapshot_service import diff_snapshots
|
||||
from app.services import webhook_service
|
||||
from app.services import website_sync
|
||||
from app.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -105,6 +106,7 @@ def _check_upstream(upstream_id: int) -> None:
|
||||
webhook_service.send_rate_changed(
|
||||
db, upstream.id, upstream.name, upstream.base_url, changes
|
||||
)
|
||||
website_sync.sync_affected_bindings(db, upstream.id, changes)
|
||||
logger.info("upstream %s: %d rate change(s)", upstream.name, len(changes))
|
||||
else:
|
||||
logger.debug("upstream %s: no changes", upstream.name)
|
||||
|
||||
@@ -27,14 +27,42 @@ def _find_token(value: Any) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _find_user_id(value: Any) -> str:
|
||||
if isinstance(value, dict):
|
||||
for key in ("id", "user_id", "userId"):
|
||||
candidate = value.get(key)
|
||||
if candidate is not None:
|
||||
return str(candidate)
|
||||
for key in ("data", "result", "user", "session"):
|
||||
user_id = _find_user_id(value.get(key))
|
||||
if user_id:
|
||||
return user_id
|
||||
return ""
|
||||
|
||||
|
||||
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
|
||||
def _normalize(lst: list) -> list[dict[str, Any]]:
|
||||
out = []
|
||||
for i in lst:
|
||||
if isinstance(i, dict):
|
||||
out.append(i)
|
||||
elif isinstance(i, str):
|
||||
out.append({"id": i, "name": i})
|
||||
return out
|
||||
|
||||
if isinstance(value, list):
|
||||
return [i for i in value if isinstance(i, dict)]
|
||||
return _normalize(value)
|
||||
if isinstance(value, dict):
|
||||
for key in ("data", "items", "groups", "available_groups", "availableGroups"):
|
||||
nested = value.get(key)
|
||||
if isinstance(nested, list):
|
||||
return [i for i in nested if isinstance(i, dict)]
|
||||
return _normalize(nested)
|
||||
elif isinstance(nested, dict):
|
||||
# Handle /api/user/self/groups where data is a dict of group_name -> { desc, ratio }
|
||||
out = []
|
||||
for k in nested.keys():
|
||||
out.append({"id": k, "name": k})
|
||||
return out
|
||||
return None
|
||||
|
||||
|
||||
@@ -76,19 +104,59 @@ def _rate_from_group(group: dict[str, Any]) -> str:
|
||||
def _extract_rates_map(raw: Any) -> dict[str, str]:
|
||||
if raw is None:
|
||||
return {}
|
||||
|
||||
# Handle one-api/new-api /api/option response where GroupRatio is in a list of options
|
||||
if isinstance(raw, dict) and isinstance(raw.get("data"), list):
|
||||
for item in raw["data"]:
|
||||
if isinstance(item, dict) and item.get("key") == "GroupRatio":
|
||||
val = item.get("value")
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
import json
|
||||
parsed = json.loads(val)
|
||||
if isinstance(parsed, dict):
|
||||
result: dict[str, str] = {}
|
||||
for k, v in parsed.items():
|
||||
r = _decimal_str(v)
|
||||
if r:
|
||||
result[str(k)] = r
|
||||
return result
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(val, dict):
|
||||
# In case it's returned as dict directly
|
||||
result = {}
|
||||
for k, v in val.items():
|
||||
r = _decimal_str(v)
|
||||
if r:
|
||||
result[str(k)] = r
|
||||
return result
|
||||
|
||||
if isinstance(raw, dict):
|
||||
candidates = raw
|
||||
for key in ("data", "rates", "group_rates", "groupRates"):
|
||||
for key in ("data", "rates", "group_rates", "groupRates", "GroupRatio"):
|
||||
nested = raw.get(key)
|
||||
if isinstance(nested, dict):
|
||||
candidates = nested
|
||||
break
|
||||
elif isinstance(nested, str) and key == "GroupRatio":
|
||||
# Handle GroupRatio as a JSON string
|
||||
try:
|
||||
import json
|
||||
parsed = json.loads(nested)
|
||||
if isinstance(parsed, dict):
|
||||
candidates = parsed
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
result: dict[str, str] = {}
|
||||
for k, v in candidates.items():
|
||||
if isinstance(v, dict):
|
||||
r = _decimal_str(
|
||||
v.get("rate_multiplier") or v.get("rateMultiplier")
|
||||
or v.get("user_rate_multiplier") or v.get("userRateMultiplier")
|
||||
or v.get("ratio")
|
||||
)
|
||||
else:
|
||||
r = _decimal_str(v)
|
||||
@@ -151,6 +219,8 @@ class UpstreamClient:
|
||||
self.auth_config = auth_config
|
||||
self.timeout = timeout
|
||||
self._token: str = ""
|
||||
self._cookies: dict[str, str] = {}
|
||||
self._new_api_user: str = ""
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
|
||||
@@ -174,15 +244,29 @@ class UpstreamClient:
|
||||
headers[header] = key
|
||||
elif self.auth_type == "login_password" and self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
if self.auth_type == "login_password" and self._new_api_user:
|
||||
headers["New-Api-User"] = self._new_api_user
|
||||
return headers
|
||||
|
||||
def _request(self, method: str, path: str, body: Any = None, auth: bool = True) -> Any:
|
||||
url = self._url(path)
|
||||
with httpx.Client(timeout=self.timeout) as client:
|
||||
if body is not None:
|
||||
resp = client.request(method, url, json=body, headers=self._headers(auth))
|
||||
resp = client.request(
|
||||
method,
|
||||
url,
|
||||
json=body,
|
||||
headers=self._headers(auth),
|
||||
cookies=self._cookies,
|
||||
)
|
||||
else:
|
||||
resp = client.request(method, url, headers=self._headers(auth))
|
||||
resp = client.request(
|
||||
method,
|
||||
url,
|
||||
headers=self._headers(auth),
|
||||
cookies=self._cookies,
|
||||
)
|
||||
self._cookies.update(dict(resp.cookies))
|
||||
resp.raise_for_status()
|
||||
ct = resp.headers.get("content-type", "")
|
||||
if not resp.content:
|
||||
@@ -198,13 +282,18 @@ class UpstreamClient:
|
||||
email = self.auth_config.get("email", "")
|
||||
password = self.auth_config.get("password", "")
|
||||
login_path = self.auth_config.get("login_path", "/auth/login")
|
||||
username_field = self.auth_config.get("username_field", "email")
|
||||
if not email or not password:
|
||||
raise UpstreamError("login_password auth requires email and password in auth_config")
|
||||
resp = self._request("POST", login_path, {"email": email, "password": password}, auth=False)
|
||||
resp = self._request("POST", login_path, {username_field: email, "password": password}, auth=False)
|
||||
token = _find_token(resp)
|
||||
if not token:
|
||||
raise UpstreamError("login succeeded but no token found in response")
|
||||
self._token = token
|
||||
if token:
|
||||
self._token = token
|
||||
return
|
||||
if self._cookies:
|
||||
self._new_api_user = self.auth_config.get("new_api_user", "") or _find_user_id(resp)
|
||||
return
|
||||
raise UpstreamError("login succeeded but no token or session cookie found in response")
|
||||
|
||||
def get_available_groups(self, endpoint: str) -> list[dict[str, Any]]:
|
||||
resp = self._request("GET", endpoint)
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.models.notification_log import NotificationLog
|
||||
from app.utils.dingtalk import (
|
||||
dingtalk_signed_url,
|
||||
format_dingtalk_rate_changed,
|
||||
format_dingtalk_website_rate_changed,
|
||||
format_dingtalk_status,
|
||||
)
|
||||
|
||||
@@ -101,6 +102,54 @@ def send_rate_changed(
|
||||
_log(db, wh, "upstream_rate_changed", generic_payload, "failed", str(exc))
|
||||
|
||||
|
||||
def send_website_rate_changed(
|
||||
db: Session,
|
||||
website_id: int,
|
||||
website_name: str,
|
||||
base_url: str,
|
||||
binding_id: int,
|
||||
target_group_id: str,
|
||||
target_group_name: str,
|
||||
old_rate: Any,
|
||||
new_rate: Any,
|
||||
source_rates: list[dict[str, Any]],
|
||||
) -> None:
|
||||
webhooks = (
|
||||
db.query(WebhookConfig)
|
||||
.filter(WebhookConfig.enabled == True)
|
||||
.all()
|
||||
)
|
||||
changed_at = _now_iso()
|
||||
generic_payload = {
|
||||
"event": "website_rate_changed",
|
||||
"website": {"id": website_id, "name": website_name, "base_url": base_url},
|
||||
"binding": {"id": binding_id},
|
||||
"target_group": {
|
||||
"id": target_group_id,
|
||||
"name": target_group_name,
|
||||
"old_rate": old_rate,
|
||||
"new_rate": new_rate,
|
||||
},
|
||||
"source_rates": source_rates,
|
||||
"changed_at": changed_at,
|
||||
}
|
||||
for wh in webhooks:
|
||||
events = json.loads(wh.events_json or "[]")
|
||||
if "website_rate_changed" not in events:
|
||||
continue
|
||||
try:
|
||||
if wh.type == "dingtalk":
|
||||
msg = format_dingtalk_website_rate_changed(
|
||||
website_name, target_group_name, changed_at, old_rate, new_rate
|
||||
)
|
||||
resp_text = _send_dingtalk(wh.url, wh.secret, msg)
|
||||
else:
|
||||
resp_text = _send_generic(wh.url, generic_payload)
|
||||
_log(db, wh, "website_rate_changed", generic_payload, "success", resp_text)
|
||||
except Exception as exc:
|
||||
_log(db, wh, "website_rate_changed", generic_payload, "failed", str(exc))
|
||||
|
||||
|
||||
def send_status_event(
|
||||
db: Session,
|
||||
upstream_id: int,
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class WebsiteError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def decimal_string(value: Any) -> str:
|
||||
if value is None or value == "":
|
||||
return ""
|
||||
try:
|
||||
d = Decimal(str(value))
|
||||
except (InvalidOperation, ValueError):
|
||||
return str(value)
|
||||
n = d.normalize()
|
||||
if n == n.to_integral():
|
||||
return str(n.quantize(Decimal("1")))
|
||||
return format(n, "f")
|
||||
|
||||
|
||||
def parse_positive_decimal(value: Any) -> Decimal | None:
|
||||
if value is None or value == "":
|
||||
return None
|
||||
try:
|
||||
d = Decimal(str(value))
|
||||
except (InvalidOperation, ValueError):
|
||||
return None
|
||||
return d if d > 0 else None
|
||||
|
||||
|
||||
def calculate_target_rate(values: list[Any], percent: Any = 0, algorithm: str = "max_plus_percent") -> Decimal:
|
||||
rates = [rate for rate in (parse_positive_decimal(v) for v in values) if rate is not None]
|
||||
if not rates:
|
||||
raise WebsiteError("没有可用的正数上游倍率")
|
||||
if algorithm == "average_plus_percent":
|
||||
base = sum(rates, Decimal("0")) / Decimal(len(rates))
|
||||
elif algorithm == "min_plus_percent":
|
||||
base = min(rates)
|
||||
elif algorithm == "max_plus_percent":
|
||||
base = max(rates)
|
||||
else:
|
||||
raise WebsiteError(f"不支持的算法:{algorithm}")
|
||||
pct = Decimal(str(percent or 0))
|
||||
if pct < 0:
|
||||
raise WebsiteError("百分比不能为负数")
|
||||
return (base * (Decimal("1") + pct / Decimal("100"))).quantize(Decimal("0.0001"), rounding=ROUND_HALF_UP)
|
||||
|
||||
|
||||
def _unwrap_data(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
data = value.get("data")
|
||||
if "data" in value and (
|
||||
"code" in value
|
||||
or "message" in value
|
||||
or isinstance(data, list)
|
||||
or (isinstance(data, dict) and any(key in data for key in ("items", "groups")))
|
||||
):
|
||||
value = data
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
for key in ("items", "groups"):
|
||||
if key in value:
|
||||
return value.get(key)
|
||||
return value
|
||||
|
||||
|
||||
def normalize_groups(value: Any) -> list[dict[str, Any]]:
|
||||
raw = _unwrap_data(value)
|
||||
if isinstance(raw, dict):
|
||||
raw = list(raw.values())
|
||||
if not isinstance(raw, list):
|
||||
raise WebsiteError("分组接口没有返回列表")
|
||||
groups: list[dict[str, Any]] = []
|
||||
for item in raw:
|
||||
if isinstance(item, str):
|
||||
groups.append({"id": item, "name": item, "rate_multiplier": None, "raw": {"id": item, "name": item}})
|
||||
continue
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
gid = item.get("id") or item.get("group_id") or item.get("groupId") or item.get("name") or item.get("group_name")
|
||||
if gid is None:
|
||||
continue
|
||||
name = item.get("name") or item.get("group_name") or str(gid)
|
||||
rate = item.get("rate_multiplier") or item.get("rateMultiplier") or item.get("ratio")
|
||||
groups.append({
|
||||
"id": str(gid),
|
||||
"name": str(name),
|
||||
"rate_multiplier": decimal_string(rate) if rate is not None else None,
|
||||
"raw": item,
|
||||
})
|
||||
return groups
|
||||
|
||||
|
||||
class Sub2ApiWebsiteClient:
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_prefix: str,
|
||||
auth_type: str,
|
||||
auth_config: dict[str, Any],
|
||||
timeout: float = 30.0,
|
||||
) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_prefix = api_prefix.strip("/")
|
||||
self.auth_type = auth_type
|
||||
self.auth_config = auth_config
|
||||
self.timeout = timeout
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
|
||||
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
|
||||
|
||||
def _headers(self) -> dict[str, str]:
|
||||
headers = {"Accept": "application/json", "User-Agent": "SmartUp/1.0"}
|
||||
if self.auth_type == "api_key":
|
||||
key = self.auth_config.get("key") or self.auth_config.get("api_key") or ""
|
||||
header = self.auth_config.get("header") or "x-api-key"
|
||||
if key:
|
||||
headers[header] = key
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_config.get("token") or ""
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
return headers
|
||||
|
||||
def _request(self, method: str, path: str, body: Any = None) -> Any:
|
||||
with httpx.Client(timeout=self.timeout) as client:
|
||||
resp = client.request(method, self._url(path), json=body, headers=self._headers())
|
||||
resp.raise_for_status()
|
||||
if not resp.content:
|
||||
return None
|
||||
text = resp.text
|
||||
if "application/json" not in resp.headers.get("content-type", "") and text.lstrip().startswith("<"):
|
||||
raise WebsiteError(f"{method} {path} returned HTML, not JSON")
|
||||
return resp.json()
|
||||
|
||||
def get_groups(self, endpoint: str = "/groups") -> list[dict[str, Any]]:
|
||||
errors: list[str] = []
|
||||
for path in [endpoint, "/groups/all"]:
|
||||
try:
|
||||
return normalize_groups(self._request("GET", path))
|
||||
except Exception as exc:
|
||||
errors.append(f"{path}: {exc}")
|
||||
raise WebsiteError("; ".join(errors))
|
||||
|
||||
def update_group_rate(self, endpoint_template: str, group_id: str, rate: Decimal) -> Any:
|
||||
path = endpoint_template.replace("{id}", quote(group_id, safe=""))
|
||||
return self._request("PUT", path, {"rate_multiplier": float(rate)})
|
||||
@@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.snapshot import UpstreamRateSnapshot
|
||||
from app.models.upstream import Upstream
|
||||
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
|
||||
from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string
|
||||
from app.services import webhook_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def binding_sources(binding: WebsiteGroupBinding) -> list[dict[str, Any]]:
|
||||
try:
|
||||
data = json.loads(binding.source_groups_json or "[]")
|
||||
except Exception:
|
||||
return []
|
||||
return data if isinstance(data, list) else []
|
||||
|
||||
|
||||
def latest_rate_map(db: Session, upstream_id: int) -> dict[str, Any]:
|
||||
row = (
|
||||
db.query(UpstreamRateSnapshot)
|
||||
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
|
||||
.order_by(UpstreamRateSnapshot.captured_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not row:
|
||||
return {}
|
||||
snapshot = json.loads(row.snapshot_json or "{}")
|
||||
groups = snapshot.get("groups") or {}
|
||||
return groups if isinstance(groups, dict) else {}
|
||||
|
||||
|
||||
def get_affected_bindings(db: Session, changes: list[dict[str, Any]], upstream_id: int) -> list[WebsiteGroupBinding]:
|
||||
changed_ids = {str(change.get("group_id")) for change in changes if change.get("group_id") is not None}
|
||||
if not changed_ids:
|
||||
return []
|
||||
result: list[WebsiteGroupBinding] = []
|
||||
bindings = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.enabled == True).all()
|
||||
for binding in bindings:
|
||||
for source in binding_sources(binding):
|
||||
if int(source.get("upstream_id") or 0) == upstream_id and str(source.get("group_id")) in changed_ids:
|
||||
result.append(binding)
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
def _client_for(website: Website) -> Sub2ApiWebsiteClient:
|
||||
return Sub2ApiWebsiteClient(
|
||||
base_url=website.base_url,
|
||||
api_prefix=website.api_prefix,
|
||||
auth_type=website.auth_type,
|
||||
auth_config=json.loads(website.auth_config_json or "{}"),
|
||||
timeout=float(website.timeout_seconds),
|
||||
)
|
||||
|
||||
|
||||
def _log(
|
||||
db: Session,
|
||||
binding: WebsiteGroupBinding,
|
||||
website: Website,
|
||||
source_rates: list[dict[str, Any]],
|
||||
status: str,
|
||||
message: str,
|
||||
old_rate: Any = None,
|
||||
new_rate: Any = None,
|
||||
) -> WebsiteSyncLog:
|
||||
row = WebsiteSyncLog(
|
||||
website_id=website.id,
|
||||
binding_id=binding.id,
|
||||
target_group_id=binding.target_group_id,
|
||||
target_group_name=binding.target_group_name,
|
||||
algorithm=binding.algorithm,
|
||||
percent=binding.percent,
|
||||
source_rates_json=json.dumps(source_rates, ensure_ascii=False),
|
||||
old_rate=decimal_string(old_rate) if old_rate not in (None, "") else None,
|
||||
new_rate=decimal_string(new_rate) if new_rate not in (None, "") else None,
|
||||
status=status,
|
||||
message=message,
|
||||
)
|
||||
db.add(row)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
return row
|
||||
|
||||
|
||||
def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True) -> WebsiteSyncLog:
|
||||
website = db.query(Website).filter(Website.id == binding.website_id).first()
|
||||
if not website:
|
||||
raise WebsiteError("网站不存在")
|
||||
sources = binding_sources(binding)
|
||||
source_rates: list[dict[str, Any]] = []
|
||||
for source in sources:
|
||||
upstream_id = int(source.get("upstream_id") or 0)
|
||||
group_id = str(source.get("group_id") or "")
|
||||
groups = latest_rate_map(db, upstream_id)
|
||||
group = groups.get(group_id) if group_id else None
|
||||
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
|
||||
source_rates.append({
|
||||
"upstream_id": upstream_id,
|
||||
"upstream_name": source.get("upstream_name") or (upstream.name if upstream else ""),
|
||||
"group_id": group_id,
|
||||
"group_name": source.get("group_name") or (group.get("group_name", "") if isinstance(group, dict) else ""),
|
||||
"rate": group.get("rate") if isinstance(group, dict) else None,
|
||||
})
|
||||
try:
|
||||
target_rate = calculate_target_rate([item.get("rate") for item in source_rates], binding.percent, binding.algorithm)
|
||||
except Exception as exc:
|
||||
return _log(db, binding, website, source_rates, "failed", str(exc))
|
||||
|
||||
old_rate = None
|
||||
if write and website.enabled and website.auto_sync_enabled and binding.enabled:
|
||||
try:
|
||||
client = _client_for(website)
|
||||
groups = client.get_groups(website.groups_endpoint)
|
||||
target = next((item for item in groups if item.get("id") == binding.target_group_id), None)
|
||||
old_rate = target.get("rate_multiplier") if target else None
|
||||
client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate)
|
||||
website.last_status = "healthy"
|
||||
website.last_error = None
|
||||
except Exception as exc:
|
||||
website.last_status = "unhealthy"
|
||||
website.last_error = str(exc)
|
||||
db.commit()
|
||||
return _log(db, binding, website, source_rates, "failed", f"写回失败:{exc}", old_rate, target_rate)
|
||||
db.commit()
|
||||
log = _log(db, binding, website, source_rates, "success", "同步成功", old_rate, target_rate)
|
||||
old_rate_str = decimal_string(old_rate) if old_rate not in (None, "") else None
|
||||
new_rate_str = decimal_string(target_rate)
|
||||
if old_rate_str != new_rate_str:
|
||||
webhook_service.send_website_rate_changed(
|
||||
db,
|
||||
website.id,
|
||||
website.name,
|
||||
website.base_url,
|
||||
binding.id,
|
||||
binding.target_group_id,
|
||||
binding.target_group_name,
|
||||
old_rate_str,
|
||||
new_rate_str,
|
||||
source_rates,
|
||||
)
|
||||
return log
|
||||
|
||||
message = "已计算建议倍率,未写回"
|
||||
if not website.enabled or not website.auto_sync_enabled:
|
||||
message = "网站未启用自动同步,未写回"
|
||||
elif not binding.enabled:
|
||||
message = "绑定未启用,未写回"
|
||||
return _log(db, binding, website, source_rates, "success", message, old_rate, target_rate)
|
||||
|
||||
|
||||
def sync_affected_bindings(db: Session, upstream_id: int, changes: list[dict[str, Any]]) -> None:
|
||||
for binding in get_affected_bindings(db, changes, upstream_id):
|
||||
try:
|
||||
sync_binding(db, binding, write=True)
|
||||
except Exception as exc:
|
||||
logger.exception("website sync failed for binding %s: %s", binding.id, exc)
|
||||
Reference in New Issue
Block a user