Add remote browser pages and website sync
Enable managed remote browser custom pages with login autofill and add website sync workflows so external admin surfaces can be handled inside SmartUp. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,252 @@
|
||||
"""Remote browser session API."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.custom_page import CustomPage
|
||||
from app.services.browser_session_service import (
|
||||
BrowserDependencyError,
|
||||
BrowserSessionError,
|
||||
browser_sessions,
|
||||
)
|
||||
from app.utils.auth import decode_token, get_current_user, get_user_from_token_param
|
||||
from app.database import SessionLocal
|
||||
from app.models.admin_user import AdminUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/browser-sessions", tags=["browser-sessions"])
|
||||
|
||||
|
||||
class BrowserSessionCreate(BaseModel):
|
||||
custom_page_id: int
|
||||
width: int = Field(default=1280)
|
||||
height: int = Field(default=720)
|
||||
|
||||
|
||||
class BrowserSessionResponse(BaseModel):
|
||||
id: str
|
||||
custom_page_id: int
|
||||
url: str
|
||||
title: str
|
||||
|
||||
|
||||
class BrowserEvent(BaseModel):
|
||||
type: Literal["click", "dblclick", "mousemove", "mousedown", "mouseup", "type", "key", "scroll", "reload", "back", "forward", "resize"]
|
||||
x: Optional[float] = None
|
||||
y: Optional[float] = None
|
||||
button: Optional[Literal["left", "right", "middle"]] = "left"
|
||||
text: Optional[str] = None
|
||||
key: Optional[str] = None
|
||||
delta_x: Optional[float] = 0
|
||||
delta_y: Optional[float] = 0
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
|
||||
def _error_from_browser(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, BrowserDependencyError):
|
||||
return HTTPException(503, str(exc))
|
||||
if isinstance(exc, BrowserSessionError):
|
||||
return HTTPException(409, str(exc))
|
||||
if isinstance(exc, KeyError):
|
||||
return HTTPException(404, "browser session not found")
|
||||
if isinstance(exc, ValueError):
|
||||
return HTTPException(400, str(exc))
|
||||
return HTTPException(502, f"Browser error: {exc}")
|
||||
|
||||
|
||||
@router.post("", response_model=BrowserSessionResponse, status_code=201)
|
||||
async def create_session(
|
||||
body: BrowserSessionCreate,
|
||||
db: Session = Depends(get_db),
|
||||
_=Depends(get_current_user),
|
||||
):
|
||||
page = db.query(CustomPage).filter(CustomPage.id == body.custom_page_id).first()
|
||||
if not page or not page.enabled:
|
||||
raise HTTPException(404, "page not found")
|
||||
if page.access_mode != "remote_browser":
|
||||
raise HTTPException(400, "custom page is not configured for remote browser mode")
|
||||
login_config = {
|
||||
"enabled": page.login_autofill_enabled,
|
||||
"username": page.login_username,
|
||||
"password": page.login_password,
|
||||
"username_selector": page.login_username_selector,
|
||||
"password_selector": page.login_password_selector,
|
||||
"submit_selector": page.login_submit_selector,
|
||||
}
|
||||
try:
|
||||
session = await browser_sessions.create(page.id, page.url, body.width, body.height, login_config)
|
||||
return await browser_sessions.state(session.id)
|
||||
except Exception as exc:
|
||||
raise _error_from_browser(exc)
|
||||
|
||||
|
||||
@router.get("/{session_id}", response_model=BrowserSessionResponse)
|
||||
async def get_session(session_id: str, _=Depends(get_current_user)):
|
||||
try:
|
||||
return await browser_sessions.state(session_id)
|
||||
except Exception as exc:
|
||||
raise _error_from_browser(exc)
|
||||
|
||||
|
||||
@router.get("/{session_id}/screenshot")
|
||||
async def session_screenshot(session_id: str, _=Depends(get_user_from_token_param)):
|
||||
try:
|
||||
image = await browser_sessions.screenshot(session_id)
|
||||
except Exception as exc:
|
||||
raise _error_from_browser(exc)
|
||||
return Response(content=image, media_type="image/jpeg", headers={"Cache-Control": "no-store"})
|
||||
|
||||
|
||||
@router.post("/{session_id}/events", response_model=BrowserSessionResponse)
|
||||
async def send_event(session_id: str, body: BrowserEvent, _=Depends(get_current_user)):
|
||||
try:
|
||||
payload: dict[str, Any] = body.model_dump(exclude_none=True)
|
||||
event_type = payload.pop("type")
|
||||
return await browser_sessions.event(session_id, event_type, payload)
|
||||
except Exception as exc:
|
||||
raise _error_from_browser(exc)
|
||||
|
||||
|
||||
@router.delete("/{session_id}", status_code=204)
|
||||
async def close_session(session_id: str, _=Depends(get_current_user)):
|
||||
await browser_sessions.close(session_id)
|
||||
|
||||
|
||||
# ——— WebSocket stream ———
|
||||
# Frame interval & diff detection
|
||||
_WS_MIN_INTERVAL = 0.05 # 50 ms floor (≈20 fps max)
|
||||
_WS_IDLE_INTERVAL = 0.15 # 150 ms when nothing changed recently
|
||||
_WS_ACTIVE_INTERVAL = 0.08 # 80 ms right after a user event
|
||||
|
||||
|
||||
async def _ws_authenticate(token: Optional[str]) -> bool:
|
||||
"""Validate JWT token for WebSocket connections."""
|
||||
if not token:
|
||||
return False
|
||||
email = decode_token(token)
|
||||
if not email:
|
||||
return False
|
||||
db = SessionLocal()
|
||||
try:
|
||||
user = db.query(AdminUser).filter(AdminUser.email == email).first()
|
||||
return user is not None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.websocket("/{session_id}/ws")
|
||||
async def session_ws(
|
||||
websocket: WebSocket,
|
||||
session_id: str,
|
||||
token: Optional[str] = Query(default=None),
|
||||
):
|
||||
"""WebSocket endpoint: pushes JPEG frames as binary, receives JSON event messages."""
|
||||
# Authenticate before accepting
|
||||
if not await _ws_authenticate(token):
|
||||
await websocket.close(code=4401)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# Track when a user event arrived so we can temporarily speed up
|
||||
last_event_at: float = 0.0
|
||||
last_frame_hash: str = ""
|
||||
|
||||
# Task: receive events from client
|
||||
async def receive_loop():
|
||||
nonlocal last_event_at
|
||||
try:
|
||||
while True:
|
||||
raw = await websocket.receive_text()
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
msg_type = msg.get("type")
|
||||
if not msg_type:
|
||||
continue
|
||||
payload: dict[str, Any] = {k: v for k, v in msg.items() if k != "type"}
|
||||
try:
|
||||
await browser_sessions.event(session_id, msg_type, payload)
|
||||
last_event_at = asyncio.get_event_loop().time()
|
||||
except Exception as exc:
|
||||
logger.warning("ws event error: %s", exc)
|
||||
try:
|
||||
await websocket.send_json({"error": str(exc)})
|
||||
except Exception:
|
||||
pass
|
||||
except (WebSocketDisconnect, asyncio.CancelledError):
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.debug("ws receive_loop ended: %s", exc)
|
||||
|
||||
# Task: push screenshots
|
||||
async def push_loop():
|
||||
nonlocal last_frame_hash
|
||||
try:
|
||||
while True:
|
||||
now = asyncio.get_event_loop().time()
|
||||
# Faster cadence right after a user interaction
|
||||
interval = _WS_ACTIVE_INTERVAL if (now - last_event_at) < 1.0 else _WS_IDLE_INTERVAL
|
||||
|
||||
try:
|
||||
frame = await browser_sessions.screenshot(session_id)
|
||||
except KeyError:
|
||||
# Session gone
|
||||
await websocket.send_json({"error": "session_not_found"})
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.warning("ws screenshot error: %s", exc)
|
||||
await asyncio.sleep(interval)
|
||||
continue
|
||||
|
||||
# Only push if content changed
|
||||
frame_hash = hashlib.md5(frame).hexdigest()
|
||||
if frame_hash != last_frame_hash:
|
||||
last_frame_hash = frame_hash
|
||||
try:
|
||||
await websocket.send_bytes(frame)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
await asyncio.sleep(max(_WS_MIN_INTERVAL, interval))
|
||||
except (WebSocketDisconnect, asyncio.CancelledError):
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.debug("ws push_loop ended: %s", exc)
|
||||
|
||||
# Send initial metadata so client knows session info
|
||||
try:
|
||||
state = await browser_sessions.state(session_id)
|
||||
await websocket.send_json({"type": "init", "session": state})
|
||||
except Exception as exc:
|
||||
await websocket.send_json({"error": f"session error: {exc}"})
|
||||
await websocket.close()
|
||||
return
|
||||
|
||||
recv_task = asyncio.create_task(receive_loop())
|
||||
push_task = asyncio.create_task(push_loop())
|
||||
|
||||
# Run until one side closes
|
||||
done, pending = await asyncio.wait(
|
||||
[recv_task, push_task],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Custom pages CRUD router + transparent iframe proxy."""
|
||||
"""Custom pages CRUD router + authenticated iframe proxy."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin, urlparse, urlencode, quote
|
||||
from typing import Any, List, Literal, Optional
|
||||
from urllib.parse import parse_qs, parse_qsl, urlencode, urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
@@ -13,8 +13,11 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.admin_user import AdminUser
|
||||
from app.models.custom_page import CustomPage
|
||||
from app.utils.auth import get_current_user, get_user_from_token_param
|
||||
from app.models.upstream import Upstream
|
||||
from app.services.upstream_client import _find_user_id
|
||||
from app.utils.auth import decode_token, get_current_user, get_user_from_token_param
|
||||
|
||||
router = APIRouter(prefix="/api/custom-pages", tags=["custom-pages"])
|
||||
|
||||
@@ -37,7 +40,14 @@ class CustomPageCreate(BaseModel):
|
||||
sort_order: int = 0
|
||||
enabled: bool = True
|
||||
use_proxy: bool = False
|
||||
access_mode: Literal["direct", "proxy", "remote_browser"] = "direct"
|
||||
description: Optional[str] = None
|
||||
login_username: Optional[str] = None
|
||||
login_password: Optional[str] = None
|
||||
login_username_selector: Optional[str] = None
|
||||
login_password_selector: Optional[str] = None
|
||||
login_submit_selector: Optional[str] = None
|
||||
login_autofill_enabled: bool = False
|
||||
|
||||
|
||||
class CustomPageUpdate(BaseModel):
|
||||
@@ -47,7 +57,15 @@ class CustomPageUpdate(BaseModel):
|
||||
sort_order: Optional[int] = None
|
||||
enabled: Optional[bool] = None
|
||||
use_proxy: Optional[bool] = None
|
||||
access_mode: Optional[Literal["direct", "proxy", "remote_browser"]] = None
|
||||
description: Optional[str] = None
|
||||
login_username: Optional[str] = None
|
||||
login_password: Optional[str] = None
|
||||
login_username_selector: Optional[str] = None
|
||||
login_password_selector: Optional[str] = None
|
||||
login_submit_selector: Optional[str] = None
|
||||
login_autofill_enabled: Optional[bool] = None
|
||||
login_password_clear: Optional[bool] = None
|
||||
|
||||
|
||||
class CustomPageResponse(BaseModel):
|
||||
@@ -58,27 +76,80 @@ class CustomPageResponse(BaseModel):
|
||||
sort_order: int
|
||||
enabled: bool
|
||||
use_proxy: bool
|
||||
access_mode: str
|
||||
description: Optional[str]
|
||||
login_username: Optional[str]
|
||||
login_username_selector: Optional[str]
|
||||
login_password_selector: Optional[str]
|
||||
login_submit_selector: Optional[str]
|
||||
login_autofill_enabled: bool
|
||||
login_password_configured: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
def _blank_to_none(value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
stripped = value.strip()
|
||||
return stripped or None
|
||||
|
||||
|
||||
def _has_login_credentials(username: Optional[str], password: Optional[str]) -> bool:
|
||||
return bool(_blank_to_none(username) and _blank_to_none(password))
|
||||
|
||||
|
||||
def _page_response(page: CustomPage) -> CustomPageResponse:
|
||||
return CustomPageResponse(
|
||||
id=page.id,
|
||||
name=page.name,
|
||||
url=page.url,
|
||||
icon=page.icon,
|
||||
sort_order=page.sort_order,
|
||||
enabled=page.enabled,
|
||||
use_proxy=page.use_proxy,
|
||||
access_mode=page.access_mode,
|
||||
description=page.description,
|
||||
login_username=page.login_username,
|
||||
login_username_selector=page.login_username_selector,
|
||||
login_password_selector=page.login_password_selector,
|
||||
login_submit_selector=page.login_submit_selector,
|
||||
login_autofill_enabled=page.login_autofill_enabled,
|
||||
login_password_configured=bool(page.login_password),
|
||||
created_at=page.created_at,
|
||||
updated_at=page.updated_at,
|
||||
)
|
||||
|
||||
|
||||
# ---- CRUD Endpoints ----
|
||||
|
||||
@router.get("", response_model=List[CustomPageResponse])
|
||||
def list_pages(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
return db.query(CustomPage).order_by(CustomPage.sort_order, CustomPage.id).all()
|
||||
pages = db.query(CustomPage).order_by(CustomPage.sort_order, CustomPage.id).all()
|
||||
return [_page_response(page) for page in pages]
|
||||
|
||||
|
||||
@router.post("", response_model=CustomPageResponse, status_code=201)
|
||||
def create_page(body: CustomPageCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
page = CustomPage(**body.model_dump())
|
||||
data = body.model_dump()
|
||||
data["use_proxy"] = data["access_mode"] == "proxy"
|
||||
for key in (
|
||||
"login_username",
|
||||
"login_password",
|
||||
"login_username_selector",
|
||||
"login_password_selector",
|
||||
"login_submit_selector",
|
||||
):
|
||||
data[key] = _blank_to_none(data.get(key))
|
||||
if "login_autofill_enabled" not in body.model_fields_set and _has_login_credentials(data.get("login_username"), data.get("login_password")):
|
||||
data["login_autofill_enabled"] = True
|
||||
page = CustomPage(**data)
|
||||
db.add(page)
|
||||
db.commit()
|
||||
db.refresh(page)
|
||||
return page
|
||||
return _page_response(page)
|
||||
|
||||
|
||||
@router.put("/{pid}", response_model=CustomPageResponse)
|
||||
@@ -86,12 +157,39 @@ def update_page(pid: int, body: CustomPageUpdate, db: Session = Depends(get_db),
|
||||
page = db.query(CustomPage).filter(CustomPage.id == pid).first()
|
||||
if not page:
|
||||
raise HTTPException(404, "page not found")
|
||||
for k, v in body.model_dump(exclude_none=True).items():
|
||||
data = body.model_dump(exclude_none=True)
|
||||
fields_set = body.model_fields_set
|
||||
if "access_mode" in data:
|
||||
data["use_proxy"] = data["access_mode"] == "proxy"
|
||||
elif "use_proxy" in data:
|
||||
data["access_mode"] = "proxy" if data["use_proxy"] else "direct"
|
||||
for key in (
|
||||
"login_username",
|
||||
"login_username_selector",
|
||||
"login_password_selector",
|
||||
"login_submit_selector",
|
||||
):
|
||||
if key in data:
|
||||
data[key] = _blank_to_none(data[key])
|
||||
new_password_saved = False
|
||||
if "login_password" in data:
|
||||
# Empty password on update means "keep the existing secret"; the API never echoes it back.
|
||||
password = data.pop("login_password")
|
||||
if password and password.strip():
|
||||
data["login_password"] = password
|
||||
new_password_saved = True
|
||||
if data.pop("login_password_clear", False):
|
||||
data["login_password"] = None
|
||||
next_username = data.get("login_username", page.login_username)
|
||||
next_password = data.get("login_password", page.login_password)
|
||||
if "login_autofill_enabled" not in fields_set and new_password_saved and _has_login_credentials(next_username, next_password):
|
||||
data["login_autofill_enabled"] = True
|
||||
for k, v in data.items():
|
||||
setattr(page, k, v)
|
||||
page.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(page)
|
||||
return page
|
||||
return _page_response(page)
|
||||
|
||||
|
||||
@router.delete("/{pid}", status_code=204)
|
||||
@@ -114,6 +212,286 @@ _STRIP_REQ = {
|
||||
"host", "connection", "transfer-encoding", "te",
|
||||
"trailers", "upgrade", "proxy-authorization", "authorization",
|
||||
}
|
||||
_PROXY_STATE: dict[int, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _origin(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
return ""
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
|
||||
def _same_origin(a: str, b: str) -> bool:
|
||||
return _origin(a).rstrip("/") == _origin(b).rstrip("/")
|
||||
|
||||
|
||||
def _find_matching_upstream(db: Session, page: CustomPage) -> Optional[Upstream]:
|
||||
page_origin = _origin(page.url)
|
||||
if not page_origin:
|
||||
return None
|
||||
for upstream in db.query(Upstream).order_by(Upstream.id).all():
|
||||
if _origin(upstream.base_url) == page_origin:
|
||||
return upstream
|
||||
return None
|
||||
|
||||
|
||||
def _headers_for_upstream(request: Request, state: Optional[dict[str, Any]] = None) -> dict[str, str]:
|
||||
fwd: dict[str, str] = {}
|
||||
for k, v in request.headers.items():
|
||||
lk = k.lower()
|
||||
if lk in _STRIP_REQ or lk.startswith("x-forwarded"):
|
||||
continue
|
||||
fwd[k] = v
|
||||
fwd["user-agent"] = (
|
||||
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
|
||||
)
|
||||
fwd.setdefault("accept", "text/html,application/xhtml+xml,*/*;q=0.8")
|
||||
if state and state.get("new_api_user"):
|
||||
fwd["New-Api-User"] = str(state["new_api_user"])
|
||||
return fwd
|
||||
|
||||
|
||||
async def _ensure_new_api_state(page_id: int, upstream: Optional[Upstream]) -> Optional[dict[str, Any]]:
|
||||
if not upstream or upstream.auth_type != "login_password":
|
||||
return None
|
||||
cached = _PROXY_STATE.get(page_id)
|
||||
if cached and cached.get("cookies"):
|
||||
return cached
|
||||
|
||||
import json
|
||||
|
||||
cfg = json.loads(upstream.auth_config_json or "{}")
|
||||
email = cfg.get("email", "")
|
||||
password = cfg.get("password", "")
|
||||
if not email or not password:
|
||||
return None
|
||||
login_path = cfg.get("login_path", "/api/user/login")
|
||||
username_field = cfg.get("username_field", "username")
|
||||
login_url = urljoin(upstream.base_url.rstrip("/") + "/", login_path.lstrip("/"))
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True, timeout=float(upstream.timeout_seconds)) as client:
|
||||
resp = await client.post(
|
||||
login_url,
|
||||
json={username_field: email, "password": password},
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "SmartUp/1.0",
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
payload = resp.json()
|
||||
except ValueError:
|
||||
payload = {}
|
||||
|
||||
cookies = dict(resp.cookies)
|
||||
if not cookies:
|
||||
return None
|
||||
state = {
|
||||
"cookies": cookies,
|
||||
"new_api_user": cfg.get("new_api_user", "") or _find_user_id(payload),
|
||||
}
|
||||
_PROXY_STATE[page_id] = state
|
||||
return state
|
||||
|
||||
|
||||
def _with_token(url: str, token: Optional[str]) -> str:
|
||||
if not token:
|
||||
return url
|
||||
sep = "&" if "?" in url else "?"
|
||||
return f"{url}{sep}token={token}"
|
||||
|
||||
|
||||
def _token_from_request(request: Request, token: Optional[str]) -> Optional[str]:
|
||||
if token:
|
||||
return token
|
||||
ref = request.headers.get("referer", "")
|
||||
if not ref:
|
||||
return None
|
||||
parsed = urlparse(ref)
|
||||
values = parse_qs(parsed.query).get("token", [])
|
||||
return values[0] if values else None
|
||||
|
||||
|
||||
def _require_proxy_user(request: Request, token: Optional[str], db: Session) -> None:
|
||||
raw = _token_from_request(request, token)
|
||||
if not raw:
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
email = decode_token(raw)
|
||||
if not email:
|
||||
raise HTTPException(401, "Invalid token")
|
||||
user = db.query(AdminUser).filter(AdminUser.email == email).first()
|
||||
if not user:
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
|
||||
def _rewrite_html(content: bytes, page_id: int, target_url: str, token: Optional[str]) -> bytes:
|
||||
try:
|
||||
html = content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return content
|
||||
|
||||
proxy_root = f"/api/custom-pages/{page_id}/proxy"
|
||||
target_origin = _origin(target_url)
|
||||
|
||||
def rewrite_url(value: str) -> str:
|
||||
if value.startswith(("data:", "blob:", "mailto:", "tel:", "#", "javascript:")):
|
||||
return value
|
||||
if value.startswith(proxy_root):
|
||||
return value
|
||||
if value.startswith("//"):
|
||||
absolute = f"{urlparse(target_url).scheme}:{value}"
|
||||
if _same_origin(absolute, target_url):
|
||||
return _with_token(f"{proxy_root}{urlparse(absolute).path or '/'}", token)
|
||||
return value
|
||||
if value.startswith(("http://", "https://")):
|
||||
if _same_origin(value, target_url):
|
||||
parsed = urlparse(value)
|
||||
proxied = f"{proxy_root}{parsed.path or '/'}" + (f"?{parsed.query}" if parsed.query else "")
|
||||
return _with_token(proxied, token)
|
||||
return value
|
||||
if value.startswith("/"):
|
||||
return _with_token(f"{proxy_root}{value}", token)
|
||||
absolute = urljoin(target_url, value)
|
||||
if _origin(absolute) == target_origin:
|
||||
parsed = urlparse(absolute)
|
||||
proxied = f"{proxy_root}{parsed.path or '/'}" + (f"?{parsed.query}" if parsed.query else "")
|
||||
return _with_token(proxied, token)
|
||||
return value
|
||||
|
||||
html = re.sub(
|
||||
r'(?P<attr>\b(?:src|href|action)=)(?P<quote>["\'])(?P<url>[^"\']+)(?P=quote)',
|
||||
lambda m: f"{m.group('attr')}{m.group('quote')}{rewrite_url(m.group('url'))}{m.group('quote')}",
|
||||
html,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
inject = f"""
|
||||
<script>
|
||||
(function() {{
|
||||
var root = {proxy_root!r};
|
||||
var token = {token or ''!r};
|
||||
function withToken(url) {{
|
||||
if (!token || url.indexOf('token=') !== -1) return url;
|
||||
return url + (url.indexOf('?') === -1 ? '?' : '&') + 'token=' + encodeURIComponent(token);
|
||||
}}
|
||||
function proxify(input) {{
|
||||
if (typeof input !== 'string') return input;
|
||||
if (input.indexOf(root) === 0) return withToken(input);
|
||||
if (input.indexOf('/') === 0) return withToken(root + input);
|
||||
try {{
|
||||
var url = new URL(input, window.location.href);
|
||||
if (url.origin === window.location.origin && url.pathname.indexOf(root) !== 0) {{
|
||||
return withToken(root + url.pathname + url.search + url.hash);
|
||||
}}
|
||||
}} catch (e) {{}}
|
||||
return input;
|
||||
}}
|
||||
var oldFetch = window.fetch;
|
||||
if (oldFetch) {{
|
||||
window.fetch = function(input, init) {{
|
||||
if (typeof input === 'string') input = proxify(input);
|
||||
else if (input && input.url) input = new Request(proxify(input.url), input);
|
||||
return oldFetch.call(this, input, init);
|
||||
}};
|
||||
}}
|
||||
var oldOpen = XMLHttpRequest.prototype.open;
|
||||
XMLHttpRequest.prototype.open = function(method, url) {{
|
||||
arguments[1] = proxify(url);
|
||||
return oldOpen.apply(this, arguments);
|
||||
}};
|
||||
}})();
|
||||
</script>
|
||||
"""
|
||||
if "</head>" in html:
|
||||
html = html.replace("</head>", inject + "</head>", 1)
|
||||
else:
|
||||
html = inject + html
|
||||
return html.encode("utf-8")
|
||||
|
||||
|
||||
async def _proxy_to_page(
|
||||
request: Request,
|
||||
page: CustomPage,
|
||||
target_url: str,
|
||||
state: Optional[dict[str, Any]],
|
||||
) -> httpx.Response:
|
||||
body = await request.body() if request.method not in ("GET", "HEAD") else None
|
||||
async with httpx.AsyncClient(follow_redirects=True, timeout=30) as client:
|
||||
return await client.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
headers=_headers_for_upstream(request, state),
|
||||
cookies=(state or {}).get("cookies", {}),
|
||||
content=body,
|
||||
)
|
||||
|
||||
|
||||
def _response_from_upstream(
|
||||
resp: httpx.Response,
|
||||
page_id: int,
|
||||
target_url: str,
|
||||
token: Optional[str],
|
||||
) -> Response:
|
||||
out: dict[str, str] = {}
|
||||
for k, v in resp.headers.items():
|
||||
kl = k.lower()
|
||||
if kl in _STRIP_RESP:
|
||||
continue
|
||||
if kl in ("content-encoding", "transfer-encoding", "content-length", "set-cookie"):
|
||||
continue
|
||||
out[k] = v
|
||||
|
||||
content = resp.content
|
||||
content_type = resp.headers.get("content-type", "")
|
||||
if "text/html" in content_type:
|
||||
content = _rewrite_html(content, page_id, target_url, token)
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=resp.status_code,
|
||||
media_type=content_type,
|
||||
headers=out,
|
||||
)
|
||||
|
||||
|
||||
@router.api_route("/{pid}/proxy", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"])
|
||||
@router.api_route("/{pid}/proxy/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"])
|
||||
async def page_proxy(
|
||||
pid: int,
|
||||
request: Request,
|
||||
path: str = "",
|
||||
token: Optional[str] = Query(default=None),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
_require_proxy_user(request, token, db)
|
||||
page = db.query(CustomPage).filter(CustomPage.id == pid).first()
|
||||
if not page or not page.enabled:
|
||||
raise HTTPException(404, "page not found")
|
||||
if not page.url.startswith(("http://", "https://")):
|
||||
raise HTTPException(400, "Only http/https URLs are allowed")
|
||||
|
||||
base = page.url.rstrip("/") + "/"
|
||||
target_url = urljoin(base, path or "")
|
||||
query = urlencode([(k, v) for k, v in parse_qsl(request.url.query, keep_blank_values=True) if k != "token"])
|
||||
if query:
|
||||
target_url += f"?{query}"
|
||||
|
||||
upstream = _find_matching_upstream(db, page)
|
||||
state = await _ensure_new_api_state(pid, upstream)
|
||||
try:
|
||||
resp = await _proxy_to_page(request, page, target_url, state)
|
||||
if resp.status_code == 401 and upstream:
|
||||
_PROXY_STATE.pop(pid, None)
|
||||
state = await _ensure_new_api_state(pid, upstream)
|
||||
resp = await _proxy_to_page(request, page, target_url, state)
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(502, f"Proxy error: {exc}")
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(exc.response.status_code, exc.response.text)
|
||||
|
||||
return _response_from_upstream(resp, pid, target_url, _token_from_request(request, token))
|
||||
|
||||
|
||||
@router.api_route("/frame-proxy", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"])
|
||||
@@ -175,4 +553,3 @@ async def frame_proxy(
|
||||
media_type=resp.headers.get("content-type"),
|
||||
headers=out,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from app.services.upstream_client import UpstreamClient, UpstreamError, build_sn
|
||||
from app.services.snapshot_service import diff_snapshots
|
||||
from app.services import scheduler as sched_svc
|
||||
from app.services import webhook_service
|
||||
from app.services import website_sync
|
||||
from app.utils.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/upstreams", tags=["upstreams"])
|
||||
@@ -209,6 +210,7 @@ def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_use
|
||||
webhook_service.send_status_event(db, u.id, u.name, u.base_url, "upstream_recovered")
|
||||
if changes:
|
||||
webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes)
|
||||
website_sync.sync_affected_bindings(db, u.id, changes)
|
||||
|
||||
msg = f"检测成功,{len(groups)} 个分组"
|
||||
if changes:
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
|
||||
from app.schemas.website import (
|
||||
BindingCreate,
|
||||
BindingResponse,
|
||||
BindingUpdate,
|
||||
TestResult,
|
||||
WebsiteCreate,
|
||||
WebsiteGroupResponse,
|
||||
WebsiteResponse,
|
||||
WebsiteSyncLogResponse,
|
||||
WebsiteUpdate,
|
||||
)
|
||||
from app.services.website_client import Sub2ApiWebsiteClient
|
||||
from app.services.website_sync import binding_sources, sync_binding
|
||||
from app.utils.auth import get_current_user
|
||||
|
||||
router = APIRouter(tags=["websites"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MASK = "***"
|
||||
SECRET_KEYS = {"password", "token", "key", "secret", "api_key"}
|
||||
ALGORITHMS = {"max_plus_percent", "average_plus_percent", "min_plus_percent"}
|
||||
|
||||
|
||||
def _mask(cfg: dict) -> dict:
|
||||
masked = {}
|
||||
for key, value in cfg.items():
|
||||
masked[key] = MASK if key.lower() in SECRET_KEYS and value else value
|
||||
return masked
|
||||
|
||||
|
||||
def _website_response(row: Website) -> WebsiteResponse:
|
||||
return WebsiteResponse(
|
||||
id=row.id,
|
||||
name=row.name,
|
||||
site_type=row.site_type,
|
||||
base_url=row.base_url,
|
||||
api_prefix=row.api_prefix,
|
||||
auth_type=row.auth_type,
|
||||
auth_config_masked=_mask(json.loads(row.auth_config_json or "{}")),
|
||||
groups_endpoint=row.groups_endpoint,
|
||||
group_update_endpoint=row.group_update_endpoint,
|
||||
enabled=row.enabled,
|
||||
auto_sync_enabled=row.auto_sync_enabled,
|
||||
timeout_seconds=row.timeout_seconds,
|
||||
last_status=row.last_status,
|
||||
last_checked_at=row.last_checked_at,
|
||||
last_error=row.last_error,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _binding_response(db: Session, row: WebsiteGroupBinding) -> BindingResponse:
|
||||
website = db.query(Website).filter(Website.id == row.website_id).first()
|
||||
return BindingResponse(
|
||||
id=row.id,
|
||||
website_id=row.website_id,
|
||||
website_name=website.name if website else "",
|
||||
target_group_id=row.target_group_id,
|
||||
target_group_name=row.target_group_name,
|
||||
source_groups=binding_sources(row),
|
||||
percent=float(row.percent or 0),
|
||||
algorithm=row.algorithm,
|
||||
enabled=row.enabled,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _log_response(row: WebsiteSyncLog) -> WebsiteSyncLogResponse:
|
||||
return WebsiteSyncLogResponse(
|
||||
id=row.id,
|
||||
website_id=row.website_id,
|
||||
binding_id=row.binding_id,
|
||||
target_group_id=row.target_group_id,
|
||||
target_group_name=row.target_group_name,
|
||||
algorithm=row.algorithm,
|
||||
percent=float(row.percent or 0),
|
||||
source_rates=json.loads(row.source_rates_json or "[]"),
|
||||
old_rate=row.old_rate,
|
||||
new_rate=row.new_rate,
|
||||
status=row.status,
|
||||
message=row.message,
|
||||
created_at=row.created_at,
|
||||
)
|
||||
|
||||
|
||||
def _ensure_unique_target(db: Session, website_id: int, target_group_id: str, exclude_id: int | None = None) -> None:
|
||||
q = db.query(WebsiteGroupBinding).filter(
|
||||
WebsiteGroupBinding.website_id == website_id,
|
||||
WebsiteGroupBinding.target_group_id == target_group_id,
|
||||
)
|
||||
if exclude_id is not None:
|
||||
q = q.filter(WebsiteGroupBinding.id != exclude_id)
|
||||
if q.first():
|
||||
raise HTTPException(400, "同一目标网站分组只能维护一条绑定记录")
|
||||
|
||||
|
||||
def _client(row: Website) -> Sub2ApiWebsiteClient:
|
||||
return Sub2ApiWebsiteClient(
|
||||
base_url=row.base_url,
|
||||
api_prefix=row.api_prefix,
|
||||
auth_type=row.auth_type,
|
||||
auth_config=json.loads(row.auth_config_json or "{}"),
|
||||
timeout=float(row.timeout_seconds),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/websites", response_model=List[WebsiteResponse])
|
||||
def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()]
|
||||
|
||||
|
||||
@router.post("/api/websites", response_model=WebsiteResponse, status_code=201)
|
||||
def create_website(body: WebsiteCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
if body.site_type != "sub2api":
|
||||
raise HTTPException(400, "目前只支持 sub2api")
|
||||
row = Website(
|
||||
name=body.name,
|
||||
site_type=body.site_type,
|
||||
base_url=body.base_url.rstrip("/"),
|
||||
api_prefix=body.api_prefix,
|
||||
auth_type=body.auth_type,
|
||||
auth_config_json=json.dumps(body.auth_config, ensure_ascii=False),
|
||||
groups_endpoint=body.groups_endpoint,
|
||||
group_update_endpoint=body.group_update_endpoint,
|
||||
enabled=body.enabled,
|
||||
auto_sync_enabled=body.auto_sync_enabled,
|
||||
timeout_seconds=body.timeout_seconds,
|
||||
)
|
||||
db.add(row)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
return _website_response(row)
|
||||
|
||||
|
||||
@router.put("/api/websites/{wid}", response_model=WebsiteResponse)
|
||||
def update_website(wid: int, body: WebsiteUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
row = db.query(Website).filter(Website.id == wid).first()
|
||||
if not row:
|
||||
raise HTTPException(404, "website not found")
|
||||
data = body.model_dump(exclude_none=True)
|
||||
if "site_type" in data and data["site_type"] != "sub2api":
|
||||
raise HTTPException(400, "目前只支持 sub2api")
|
||||
if "auth_config" in data:
|
||||
existing = json.loads(row.auth_config_json or "{}")
|
||||
incoming = data.pop("auth_config")
|
||||
for key, value in incoming.items():
|
||||
if value != MASK:
|
||||
existing[key] = value
|
||||
row.auth_config_json = json.dumps(existing, ensure_ascii=False)
|
||||
if "base_url" in data:
|
||||
data["base_url"] = data["base_url"].rstrip("/")
|
||||
for key, value in data.items():
|
||||
setattr(row, key, value)
|
||||
row.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
return _website_response(row)
|
||||
|
||||
|
||||
@router.delete("/api/websites/{wid}", status_code=204)
|
||||
def delete_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
row = db.query(Website).filter(Website.id == wid).first()
|
||||
if not row:
|
||||
raise HTTPException(404, "website not found")
|
||||
db.delete(row)
|
||||
db.commit()
|
||||
|
||||
|
||||
@router.post("/api/websites/{wid}/test", response_model=TestResult)
|
||||
def test_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
row = db.query(Website).filter(Website.id == wid).first()
|
||||
if not row:
|
||||
raise HTTPException(404, "website not found")
|
||||
try:
|
||||
groups = _client(row).get_groups(row.groups_endpoint)
|
||||
row.last_status = "healthy"
|
||||
row.last_error = None
|
||||
row.last_checked_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
|
||||
except Exception as exc:
|
||||
row.last_status = "unhealthy"
|
||||
row.last_error = str(exc)
|
||||
row.last_checked_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
return TestResult(success=False, message="连接失败", detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/api/websites/{wid}/groups", response_model=List[WebsiteGroupResponse])
|
||||
def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
row = db.query(Website).filter(Website.id == wid).first()
|
||||
if not row:
|
||||
raise HTTPException(404, "website not found")
|
||||
try:
|
||||
return _client(row).get_groups(row.groups_endpoint)
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, str(exc))
|
||||
|
||||
|
||||
@router.get("/api/group-bindings", response_model=List[BindingResponse])
|
||||
def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all()
|
||||
return [_binding_response(db, row) for row in rows]
|
||||
|
||||
|
||||
@router.post("/api/group-bindings", response_model=BindingResponse, status_code=201)
|
||||
def create_binding(body: BindingCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
website = db.query(Website).filter(Website.id == body.website_id).first()
|
||||
if not website:
|
||||
raise HTTPException(404, "website not found")
|
||||
if body.algorithm not in ALGORITHMS:
|
||||
raise HTTPException(400, "不支持的算法")
|
||||
_ensure_unique_target(db, body.website_id, body.target_group_id)
|
||||
row = WebsiteGroupBinding(
|
||||
website_id=body.website_id,
|
||||
target_group_id=body.target_group_id,
|
||||
target_group_name=body.target_group_name,
|
||||
source_groups_json=json.dumps([item.model_dump() for item in body.source_groups], ensure_ascii=False),
|
||||
percent=str(body.percent),
|
||||
algorithm=body.algorithm,
|
||||
enabled=body.enabled,
|
||||
)
|
||||
db.add(row)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
try:
|
||||
sync_binding(db, row, write=True)
|
||||
except Exception as exc:
|
||||
logger.exception("initial website sync failed for binding %s: %s", row.id, exc)
|
||||
return _binding_response(db, row)
|
||||
|
||||
|
||||
@router.put("/api/group-bindings/{bid}", response_model=BindingResponse)
|
||||
def update_binding(bid: int, body: BindingUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
|
||||
if not row:
|
||||
raise HTTPException(404, "binding not found")
|
||||
data = body.model_dump(exclude_none=True)
|
||||
if "website_id" in data and not db.query(Website).filter(Website.id == data["website_id"]).first():
|
||||
raise HTTPException(404, "website not found")
|
||||
if "algorithm" in data and data["algorithm"] not in ALGORITHMS:
|
||||
raise HTTPException(400, "不支持的算法")
|
||||
next_website_id = int(data.get("website_id", row.website_id))
|
||||
next_target_group_id = str(data.get("target_group_id", row.target_group_id))
|
||||
_ensure_unique_target(db, next_website_id, next_target_group_id, exclude_id=row.id)
|
||||
if "source_groups" in data:
|
||||
row.source_groups_json = json.dumps(data.pop("source_groups"), ensure_ascii=False)
|
||||
if "percent" in data:
|
||||
row.percent = str(data.pop("percent"))
|
||||
for key, value in data.items():
|
||||
setattr(row, key, value)
|
||||
row.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
try:
|
||||
sync_binding(db, row, write=True)
|
||||
except Exception as exc:
|
||||
logger.exception("sync failed after updating binding %s: %s", row.id, exc)
|
||||
return _binding_response(db, row)
|
||||
|
||||
|
||||
@router.delete("/api/group-bindings/{bid}", status_code=204)
|
||||
def delete_binding(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
|
||||
if not row:
|
||||
raise HTTPException(404, "binding not found")
|
||||
db.delete(row)
|
||||
db.commit()
|
||||
|
||||
|
||||
@router.post("/api/group-bindings/{bid}/sync-now", response_model=WebsiteSyncLogResponse)
|
||||
def sync_now(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
|
||||
if not row:
|
||||
raise HTTPException(404, "binding not found")
|
||||
return _log_response(sync_binding(db, row, write=True))
|
||||
|
||||
|
||||
@router.get("/api/website-sync-logs", response_model=List[WebsiteSyncLogResponse])
|
||||
def list_sync_logs(
|
||||
website_id: int | None = Query(None),
|
||||
binding_id: int | None = Query(None),
|
||||
limit: int = Query(50, le=200),
|
||||
offset: int = Query(0),
|
||||
db: Session = Depends(get_db),
|
||||
_=Depends(get_current_user),
|
||||
):
|
||||
q = db.query(WebsiteSyncLog)
|
||||
if website_id:
|
||||
q = q.filter(WebsiteSyncLog.website_id == website_id)
|
||||
if binding_id:
|
||||
q = q.filter(WebsiteSyncLog.binding_id == binding_id)
|
||||
rows = q.order_by(WebsiteSyncLog.created_at.desc()).offset(offset).limit(limit).all()
|
||||
return [_log_response(row) for row in rows]
|
||||
Reference in New Issue
Block a user