Add remote browser pages and website sync

Enable managed remote browser custom pages with login autofill and add website sync workflows so external admin surfaces can be handled inside SmartUp.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
liumangmang
2026-05-15 15:43:58 +08:00
parent a13a0070a5
commit 7adc7c00ab
43 changed files with 6615 additions and 641 deletions
+3 -1
View File
@@ -11,8 +11,10 @@ class Settings(BaseSettings):
tz: str = "Asia/Shanghai"
# consecutive failures before upstream goes unhealthy
unhealthy_threshold: int = 3
browser_profiles_dir: str = "/app/data/browser-profiles"
browser_headless: bool = True
model_config = {"env_file": ".env", "case_sensitive": False}
model_config = {"env_file": ".env", "case_sensitive": False, "extra": "ignore"}
@lru_cache
+50 -2
View File
@@ -1,4 +1,4 @@
from sqlalchemy import create_engine
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.orm import sessionmaker, DeclarativeBase
from app.config import get_settings
@@ -26,5 +26,53 @@ def get_db():
def init_db():
"""Create all tables."""
# import models so SQLAlchemy registers them
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page # noqa: F401
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website # noqa: F401
Base.metadata.create_all(bind=engine)
_migrate_custom_pages()
def _migrate_custom_pages():
"""Apply small SQLite-safe migrations for deployments without Alembic."""
inspector = inspect(engine)
if "custom_pages" not in inspector.get_table_names():
return
columns = {col["name"] for col in inspector.get_columns("custom_pages")}
with engine.begin() as conn:
if "access_mode" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN access_mode VARCHAR(32) NOT NULL DEFAULT 'direct'"))
conn.execute(text("UPDATE custom_pages SET access_mode = CASE WHEN use_proxy = 1 THEN 'proxy' ELSE 'direct' END"))
if "login_username" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_username VARCHAR(255)"))
if "login_password" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_password TEXT"))
if "login_username_selector" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_username_selector VARCHAR(512)"))
if "login_password_selector" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_password_selector VARCHAR(512)"))
if "login_submit_selector" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_submit_selector VARCHAR(512)"))
if "login_autofill_enabled" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_autofill_enabled BOOLEAN NOT NULL DEFAULT 0"))
if "login_autofill_backfilled_at" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_autofill_backfilled_at DATETIME"))
conn.execute(
text(
"UPDATE custom_pages "
"SET login_autofill_enabled = 1, login_autofill_backfilled_at = CURRENT_TIMESTAMP "
"WHERE login_autofill_enabled = 0 "
"AND NULLIF(TRIM(login_username), '') IS NOT NULL "
"AND NULLIF(TRIM(login_password), '') IS NOT NULL"
)
)
conn.execute(
text(
"UPDATE custom_pages "
"SET access_mode = 'remote_browser', use_proxy = 0 "
"WHERE url LIKE :host OR url LIKE :host_slash OR url LIKE :host_port"
),
{
"host": "%://170.106.100.210",
"host_slash": "%://170.106.100.210/%",
"host_port": "%://170.106.100.210:%",
},
)
+13 -2
View File
@@ -3,7 +3,7 @@ import logging
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
@@ -14,7 +14,8 @@ from app.models.admin_user import AdminUser
from app.database import SessionLocal
from app.utils.auth import hash_password
from app.services.scheduler import start_scheduler, stop_scheduler
from app.routers import auth, upstreams, webhooks, logs, custom_pages
from app.routers import auth, upstreams, webhooks, logs, custom_pages, browser_sessions, websites
from app.services.browser_session_service import browser_sessions as browser_session_service
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
logger = logging.getLogger(__name__)
@@ -48,6 +49,7 @@ async def lifespan(app: FastAPI):
_init_admin()
start_scheduler()
yield
await browser_session_service.shutdown()
stop_scheduler()
@@ -75,6 +77,8 @@ app.include_router(upstreams.router)
app.include_router(webhooks.router)
app.include_router(logs.router)
app.include_router(custom_pages.router)
app.include_router(browser_sessions.router)
app.include_router(websites.router)
@app.get("/healthz")
@@ -87,6 +91,13 @@ STATIC_DIR = Path(__file__).parent.parent / "static"
if STATIC_DIR.exists():
app.mount("/assets", StaticFiles(directory=str(STATIC_DIR / "assets")), name="assets")
@app.api_route("/favicon.svg", methods=["GET", "HEAD"])
def serve_favicon():
favicon = STATIC_DIR / "favicon.svg"
if not favicon.exists():
raise HTTPException(status_code=404, detail="favicon not found")
return FileResponse(str(favicon), media_type="image/svg+xml")
@app.get("/{full_path:path}")
def serve_spa(full_path: str):
index = STATIC_DIR / "index.html"
+8
View File
@@ -16,7 +16,15 @@ class CustomPage(Base):
sort_order: Mapped[int] = mapped_column(Integer, default=0)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
use_proxy: Mapped[bool] = mapped_column(Boolean, default=False)
access_mode: Mapped[str] = mapped_column(String(32), default="direct", nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
login_username: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
login_password: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
login_username_selector: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
login_password_selector: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
login_submit_selector: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
login_autofill_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
login_autofill_backfilled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at: Mapped[datetime] = mapped_column(
DateTime,
+1 -1
View File
@@ -14,7 +14,7 @@ class WebhookConfig(Base):
url: Mapped[str] = mapped_column(String(1024), nullable=False)
secret: Mapped[str] = mapped_column(String(512), default="")
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
# JSON array: ["upstream_rate_changed","upstream_unhealthy","upstream_recovered"]
# JSON array: ["upstream_rate_changed","website_rate_changed","upstream_unhealthy","upstream_recovered"]
events_json: Mapped[str] = mapped_column(Text, default='["upstream_rate_changed"]')
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at: Mapped[datetime] = mapped_column(
+68
View File
@@ -0,0 +1,68 @@
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class Website(Base):
__tablename__ = "websites"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
site_type: Mapped[str] = mapped_column(String(32), default="sub2api")
base_url: Mapped[str] = mapped_column(String(512), nullable=False)
api_prefix: Mapped[str] = mapped_column(String(128), default="/api/v1/admin")
auth_type: Mapped[str] = mapped_column(String(32), default="api_key")
auth_config_json: Mapped[str] = mapped_column(Text, default="{}")
groups_endpoint: Mapped[str] = mapped_column(String(256), default="/groups")
group_update_endpoint: Mapped[str] = mapped_column(String(256), default="/groups/{id}")
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
auto_sync_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
timeout_seconds: Mapped[int] = mapped_column(Integer, default=30)
last_status: Mapped[str] = mapped_column(String(32), default="unknown")
last_checked_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
last_error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)
)
class WebsiteGroupBinding(Base):
__tablename__ = "website_group_bindings"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
website_id: Mapped[int] = mapped_column(Integer, ForeignKey("websites.id", ondelete="CASCADE"), index=True)
target_group_id: Mapped[str] = mapped_column(String(255), nullable=False)
target_group_name: Mapped[str] = mapped_column(String(255), default="")
source_groups_json: Mapped[str] = mapped_column(Text, default="[]")
percent: Mapped[str] = mapped_column(String(32), default="0")
algorithm: Mapped[str] = mapped_column(String(64), default="max_plus_percent")
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)
)
class WebsiteSyncLog(Base):
__tablename__ = "website_sync_logs"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
website_id: Mapped[int] = mapped_column(Integer, ForeignKey("websites.id", ondelete="CASCADE"), index=True)
binding_id: Mapped[Optional[int]] = mapped_column(
Integer, ForeignKey("website_group_bindings.id", ondelete="SET NULL"), nullable=True, index=True
)
target_group_id: Mapped[str] = mapped_column(String(255), default="")
target_group_name: Mapped[str] = mapped_column(String(255), default="")
algorithm: Mapped[str] = mapped_column(String(64), default="")
percent: Mapped[str] = mapped_column(String(32), default="0")
source_rates_json: Mapped[str] = mapped_column(Text, default="[]")
old_rate: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
new_rate: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
status: Mapped[str] = mapped_column(String(16), nullable=False)
message: Mapped[str] = mapped_column(Text, default="")
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc), index=True)
+252
View File
@@ -0,0 +1,252 @@
"""Remote browser session API."""
from __future__ import annotations
import asyncio
import hashlib
import json
import logging
from typing import Any, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.responses import Response
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.custom_page import CustomPage
from app.services.browser_session_service import (
BrowserDependencyError,
BrowserSessionError,
browser_sessions,
)
from app.utils.auth import decode_token, get_current_user, get_user_from_token_param
from app.database import SessionLocal
from app.models.admin_user import AdminUser
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/browser-sessions", tags=["browser-sessions"])
class BrowserSessionCreate(BaseModel):
custom_page_id: int
width: int = Field(default=1280)
height: int = Field(default=720)
class BrowserSessionResponse(BaseModel):
id: str
custom_page_id: int
url: str
title: str
class BrowserEvent(BaseModel):
type: Literal["click", "dblclick", "mousemove", "mousedown", "mouseup", "type", "key", "scroll", "reload", "back", "forward", "resize"]
x: Optional[float] = None
y: Optional[float] = None
button: Optional[Literal["left", "right", "middle"]] = "left"
text: Optional[str] = None
key: Optional[str] = None
delta_x: Optional[float] = 0
delta_y: Optional[float] = 0
width: Optional[int] = None
height: Optional[int] = None
def _error_from_browser(exc: Exception) -> HTTPException:
if isinstance(exc, BrowserDependencyError):
return HTTPException(503, str(exc))
if isinstance(exc, BrowserSessionError):
return HTTPException(409, str(exc))
if isinstance(exc, KeyError):
return HTTPException(404, "browser session not found")
if isinstance(exc, ValueError):
return HTTPException(400, str(exc))
return HTTPException(502, f"Browser error: {exc}")
@router.post("", response_model=BrowserSessionResponse, status_code=201)
async def create_session(
body: BrowserSessionCreate,
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
page = db.query(CustomPage).filter(CustomPage.id == body.custom_page_id).first()
if not page or not page.enabled:
raise HTTPException(404, "page not found")
if page.access_mode != "remote_browser":
raise HTTPException(400, "custom page is not configured for remote browser mode")
login_config = {
"enabled": page.login_autofill_enabled,
"username": page.login_username,
"password": page.login_password,
"username_selector": page.login_username_selector,
"password_selector": page.login_password_selector,
"submit_selector": page.login_submit_selector,
}
try:
session = await browser_sessions.create(page.id, page.url, body.width, body.height, login_config)
return await browser_sessions.state(session.id)
except Exception as exc:
raise _error_from_browser(exc)
@router.get("/{session_id}", response_model=BrowserSessionResponse)
async def get_session(session_id: str, _=Depends(get_current_user)):
try:
return await browser_sessions.state(session_id)
except Exception as exc:
raise _error_from_browser(exc)
@router.get("/{session_id}/screenshot")
async def session_screenshot(session_id: str, _=Depends(get_user_from_token_param)):
try:
image = await browser_sessions.screenshot(session_id)
except Exception as exc:
raise _error_from_browser(exc)
return Response(content=image, media_type="image/jpeg", headers={"Cache-Control": "no-store"})
@router.post("/{session_id}/events", response_model=BrowserSessionResponse)
async def send_event(session_id: str, body: BrowserEvent, _=Depends(get_current_user)):
try:
payload: dict[str, Any] = body.model_dump(exclude_none=True)
event_type = payload.pop("type")
return await browser_sessions.event(session_id, event_type, payload)
except Exception as exc:
raise _error_from_browser(exc)
@router.delete("/{session_id}", status_code=204)
async def close_session(session_id: str, _=Depends(get_current_user)):
await browser_sessions.close(session_id)
# ——— WebSocket stream ———
# Frame interval & diff detection
_WS_MIN_INTERVAL = 0.05 # 50 ms floor (≈20 fps max)
_WS_IDLE_INTERVAL = 0.15 # 150 ms when nothing changed recently
_WS_ACTIVE_INTERVAL = 0.08 # 80 ms right after a user event
async def _ws_authenticate(token: Optional[str]) -> bool:
"""Validate JWT token for WebSocket connections."""
if not token:
return False
email = decode_token(token)
if not email:
return False
db = SessionLocal()
try:
user = db.query(AdminUser).filter(AdminUser.email == email).first()
return user is not None
finally:
db.close()
@router.websocket("/{session_id}/ws")
async def session_ws(
websocket: WebSocket,
session_id: str,
token: Optional[str] = Query(default=None),
):
"""WebSocket endpoint: pushes JPEG frames as binary, receives JSON event messages."""
# Authenticate before accepting
if not await _ws_authenticate(token):
await websocket.close(code=4401)
return
await websocket.accept()
# Track when a user event arrived so we can temporarily speed up
last_event_at: float = 0.0
last_frame_hash: str = ""
# Task: receive events from client
async def receive_loop():
nonlocal last_event_at
try:
while True:
raw = await websocket.receive_text()
try:
msg = json.loads(raw)
except json.JSONDecodeError:
continue
msg_type = msg.get("type")
if not msg_type:
continue
payload: dict[str, Any] = {k: v for k, v in msg.items() if k != "type"}
try:
await browser_sessions.event(session_id, msg_type, payload)
last_event_at = asyncio.get_event_loop().time()
except Exception as exc:
logger.warning("ws event error: %s", exc)
try:
await websocket.send_json({"error": str(exc)})
except Exception:
pass
except (WebSocketDisconnect, asyncio.CancelledError):
pass
except Exception as exc:
logger.debug("ws receive_loop ended: %s", exc)
# Task: push screenshots
async def push_loop():
nonlocal last_frame_hash
try:
while True:
now = asyncio.get_event_loop().time()
# Faster cadence right after a user interaction
interval = _WS_ACTIVE_INTERVAL if (now - last_event_at) < 1.0 else _WS_IDLE_INTERVAL
try:
frame = await browser_sessions.screenshot(session_id)
except KeyError:
# Session gone
await websocket.send_json({"error": "session_not_found"})
break
except Exception as exc:
logger.warning("ws screenshot error: %s", exc)
await asyncio.sleep(interval)
continue
# Only push if content changed
frame_hash = hashlib.md5(frame).hexdigest()
if frame_hash != last_frame_hash:
last_frame_hash = frame_hash
try:
await websocket.send_bytes(frame)
except Exception:
break
await asyncio.sleep(max(_WS_MIN_INTERVAL, interval))
except (WebSocketDisconnect, asyncio.CancelledError):
pass
except Exception as exc:
logger.debug("ws push_loop ended: %s", exc)
# Send initial metadata so client knows session info
try:
state = await browser_sessions.state(session_id)
await websocket.send_json({"type": "init", "session": state})
except Exception as exc:
await websocket.send_json({"error": f"session error: {exc}"})
await websocket.close()
return
recv_task = asyncio.create_task(receive_loop())
push_task = asyncio.create_task(push_loop())
# Run until one side closes
done, pending = await asyncio.wait(
[recv_task, push_task],
return_when=asyncio.FIRST_COMPLETED,
)
for t in pending:
t.cancel()
try:
await t
except asyncio.CancelledError:
pass
+387 -10
View File
@@ -1,10 +1,10 @@
"""Custom pages CRUD router + transparent iframe proxy."""
"""Custom pages CRUD router + authenticated iframe proxy."""
from __future__ import annotations
import re
from datetime import datetime, timezone
from typing import List, Optional
from urllib.parse import urljoin, urlparse, urlencode, quote
from typing import Any, List, Literal, Optional
from urllib.parse import parse_qs, parse_qsl, urlencode, urljoin, urlparse
import httpx
from fastapi import APIRouter, Depends, HTTPException, Query, Request
@@ -13,8 +13,11 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.admin_user import AdminUser
from app.models.custom_page import CustomPage
from app.utils.auth import get_current_user, get_user_from_token_param
from app.models.upstream import Upstream
from app.services.upstream_client import _find_user_id
from app.utils.auth import decode_token, get_current_user, get_user_from_token_param
router = APIRouter(prefix="/api/custom-pages", tags=["custom-pages"])
@@ -37,7 +40,14 @@ class CustomPageCreate(BaseModel):
sort_order: int = 0
enabled: bool = True
use_proxy: bool = False
access_mode: Literal["direct", "proxy", "remote_browser"] = "direct"
description: Optional[str] = None
login_username: Optional[str] = None
login_password: Optional[str] = None
login_username_selector: Optional[str] = None
login_password_selector: Optional[str] = None
login_submit_selector: Optional[str] = None
login_autofill_enabled: bool = False
class CustomPageUpdate(BaseModel):
@@ -47,7 +57,15 @@ class CustomPageUpdate(BaseModel):
sort_order: Optional[int] = None
enabled: Optional[bool] = None
use_proxy: Optional[bool] = None
access_mode: Optional[Literal["direct", "proxy", "remote_browser"]] = None
description: Optional[str] = None
login_username: Optional[str] = None
login_password: Optional[str] = None
login_username_selector: Optional[str] = None
login_password_selector: Optional[str] = None
login_submit_selector: Optional[str] = None
login_autofill_enabled: Optional[bool] = None
login_password_clear: Optional[bool] = None
class CustomPageResponse(BaseModel):
@@ -58,27 +76,80 @@ class CustomPageResponse(BaseModel):
sort_order: int
enabled: bool
use_proxy: bool
access_mode: str
description: Optional[str]
login_username: Optional[str]
login_username_selector: Optional[str]
login_password_selector: Optional[str]
login_submit_selector: Optional[str]
login_autofill_enabled: bool
login_password_configured: bool
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
def _blank_to_none(value: Optional[str]) -> Optional[str]:
if value is None:
return None
stripped = value.strip()
return stripped or None
def _has_login_credentials(username: Optional[str], password: Optional[str]) -> bool:
return bool(_blank_to_none(username) and _blank_to_none(password))
def _page_response(page: CustomPage) -> CustomPageResponse:
return CustomPageResponse(
id=page.id,
name=page.name,
url=page.url,
icon=page.icon,
sort_order=page.sort_order,
enabled=page.enabled,
use_proxy=page.use_proxy,
access_mode=page.access_mode,
description=page.description,
login_username=page.login_username,
login_username_selector=page.login_username_selector,
login_password_selector=page.login_password_selector,
login_submit_selector=page.login_submit_selector,
login_autofill_enabled=page.login_autofill_enabled,
login_password_configured=bool(page.login_password),
created_at=page.created_at,
updated_at=page.updated_at,
)
# ---- CRUD Endpoints ----
@router.get("", response_model=List[CustomPageResponse])
def list_pages(db: Session = Depends(get_db), _=Depends(get_current_user)):
return db.query(CustomPage).order_by(CustomPage.sort_order, CustomPage.id).all()
pages = db.query(CustomPage).order_by(CustomPage.sort_order, CustomPage.id).all()
return [_page_response(page) for page in pages]
@router.post("", response_model=CustomPageResponse, status_code=201)
def create_page(body: CustomPageCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
page = CustomPage(**body.model_dump())
data = body.model_dump()
data["use_proxy"] = data["access_mode"] == "proxy"
for key in (
"login_username",
"login_password",
"login_username_selector",
"login_password_selector",
"login_submit_selector",
):
data[key] = _blank_to_none(data.get(key))
if "login_autofill_enabled" not in body.model_fields_set and _has_login_credentials(data.get("login_username"), data.get("login_password")):
data["login_autofill_enabled"] = True
page = CustomPage(**data)
db.add(page)
db.commit()
db.refresh(page)
return page
return _page_response(page)
@router.put("/{pid}", response_model=CustomPageResponse)
@@ -86,12 +157,39 @@ def update_page(pid: int, body: CustomPageUpdate, db: Session = Depends(get_db),
page = db.query(CustomPage).filter(CustomPage.id == pid).first()
if not page:
raise HTTPException(404, "page not found")
for k, v in body.model_dump(exclude_none=True).items():
data = body.model_dump(exclude_none=True)
fields_set = body.model_fields_set
if "access_mode" in data:
data["use_proxy"] = data["access_mode"] == "proxy"
elif "use_proxy" in data:
data["access_mode"] = "proxy" if data["use_proxy"] else "direct"
for key in (
"login_username",
"login_username_selector",
"login_password_selector",
"login_submit_selector",
):
if key in data:
data[key] = _blank_to_none(data[key])
new_password_saved = False
if "login_password" in data:
# Empty password on update means "keep the existing secret"; the API never echoes it back.
password = data.pop("login_password")
if password and password.strip():
data["login_password"] = password
new_password_saved = True
if data.pop("login_password_clear", False):
data["login_password"] = None
next_username = data.get("login_username", page.login_username)
next_password = data.get("login_password", page.login_password)
if "login_autofill_enabled" not in fields_set and new_password_saved and _has_login_credentials(next_username, next_password):
data["login_autofill_enabled"] = True
for k, v in data.items():
setattr(page, k, v)
page.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(page)
return page
return _page_response(page)
@router.delete("/{pid}", status_code=204)
@@ -114,6 +212,286 @@ _STRIP_REQ = {
"host", "connection", "transfer-encoding", "te",
"trailers", "upgrade", "proxy-authorization", "authorization",
}
_PROXY_STATE: dict[int, dict[str, Any]] = {}
def _origin(url: str) -> str:
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
return ""
return f"{parsed.scheme}://{parsed.netloc}"
def _same_origin(a: str, b: str) -> bool:
return _origin(a).rstrip("/") == _origin(b).rstrip("/")
def _find_matching_upstream(db: Session, page: CustomPage) -> Optional[Upstream]:
page_origin = _origin(page.url)
if not page_origin:
return None
for upstream in db.query(Upstream).order_by(Upstream.id).all():
if _origin(upstream.base_url) == page_origin:
return upstream
return None
def _headers_for_upstream(request: Request, state: Optional[dict[str, Any]] = None) -> dict[str, str]:
fwd: dict[str, str] = {}
for k, v in request.headers.items():
lk = k.lower()
if lk in _STRIP_REQ or lk.startswith("x-forwarded"):
continue
fwd[k] = v
fwd["user-agent"] = (
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
)
fwd.setdefault("accept", "text/html,application/xhtml+xml,*/*;q=0.8")
if state and state.get("new_api_user"):
fwd["New-Api-User"] = str(state["new_api_user"])
return fwd
async def _ensure_new_api_state(page_id: int, upstream: Optional[Upstream]) -> Optional[dict[str, Any]]:
if not upstream or upstream.auth_type != "login_password":
return None
cached = _PROXY_STATE.get(page_id)
if cached and cached.get("cookies"):
return cached
import json
cfg = json.loads(upstream.auth_config_json or "{}")
email = cfg.get("email", "")
password = cfg.get("password", "")
if not email or not password:
return None
login_path = cfg.get("login_path", "/api/user/login")
username_field = cfg.get("username_field", "username")
login_url = urljoin(upstream.base_url.rstrip("/") + "/", login_path.lstrip("/"))
async with httpx.AsyncClient(follow_redirects=True, timeout=float(upstream.timeout_seconds)) as client:
resp = await client.post(
login_url,
json={username_field: email, "password": password},
headers={
"Accept": "application/json",
"Content-Type": "application/json",
"User-Agent": "SmartUp/1.0",
},
)
resp.raise_for_status()
try:
payload = resp.json()
except ValueError:
payload = {}
cookies = dict(resp.cookies)
if not cookies:
return None
state = {
"cookies": cookies,
"new_api_user": cfg.get("new_api_user", "") or _find_user_id(payload),
}
_PROXY_STATE[page_id] = state
return state
def _with_token(url: str, token: Optional[str]) -> str:
if not token:
return url
sep = "&" if "?" in url else "?"
return f"{url}{sep}token={token}"
def _token_from_request(request: Request, token: Optional[str]) -> Optional[str]:
if token:
return token
ref = request.headers.get("referer", "")
if not ref:
return None
parsed = urlparse(ref)
values = parse_qs(parsed.query).get("token", [])
return values[0] if values else None
def _require_proxy_user(request: Request, token: Optional[str], db: Session) -> None:
raw = _token_from_request(request, token)
if not raw:
raise HTTPException(401, "Not authenticated")
email = decode_token(raw)
if not email:
raise HTTPException(401, "Invalid token")
user = db.query(AdminUser).filter(AdminUser.email == email).first()
if not user:
raise HTTPException(401, "User not found")
def _rewrite_html(content: bytes, page_id: int, target_url: str, token: Optional[str]) -> bytes:
try:
html = content.decode("utf-8")
except UnicodeDecodeError:
return content
proxy_root = f"/api/custom-pages/{page_id}/proxy"
target_origin = _origin(target_url)
def rewrite_url(value: str) -> str:
if value.startswith(("data:", "blob:", "mailto:", "tel:", "#", "javascript:")):
return value
if value.startswith(proxy_root):
return value
if value.startswith("//"):
absolute = f"{urlparse(target_url).scheme}:{value}"
if _same_origin(absolute, target_url):
return _with_token(f"{proxy_root}{urlparse(absolute).path or '/'}", token)
return value
if value.startswith(("http://", "https://")):
if _same_origin(value, target_url):
parsed = urlparse(value)
proxied = f"{proxy_root}{parsed.path or '/'}" + (f"?{parsed.query}" if parsed.query else "")
return _with_token(proxied, token)
return value
if value.startswith("/"):
return _with_token(f"{proxy_root}{value}", token)
absolute = urljoin(target_url, value)
if _origin(absolute) == target_origin:
parsed = urlparse(absolute)
proxied = f"{proxy_root}{parsed.path or '/'}" + (f"?{parsed.query}" if parsed.query else "")
return _with_token(proxied, token)
return value
html = re.sub(
r'(?P<attr>\b(?:src|href|action)=)(?P<quote>["\'])(?P<url>[^"\']+)(?P=quote)',
lambda m: f"{m.group('attr')}{m.group('quote')}{rewrite_url(m.group('url'))}{m.group('quote')}",
html,
flags=re.IGNORECASE,
)
inject = f"""
<script>
(function() {{
var root = {proxy_root!r};
var token = {token or ''!r};
function withToken(url) {{
if (!token || url.indexOf('token=') !== -1) return url;
return url + (url.indexOf('?') === -1 ? '?' : '&') + 'token=' + encodeURIComponent(token);
}}
function proxify(input) {{
if (typeof input !== 'string') return input;
if (input.indexOf(root) === 0) return withToken(input);
if (input.indexOf('/') === 0) return withToken(root + input);
try {{
var url = new URL(input, window.location.href);
if (url.origin === window.location.origin && url.pathname.indexOf(root) !== 0) {{
return withToken(root + url.pathname + url.search + url.hash);
}}
}} catch (e) {{}}
return input;
}}
var oldFetch = window.fetch;
if (oldFetch) {{
window.fetch = function(input, init) {{
if (typeof input === 'string') input = proxify(input);
else if (input && input.url) input = new Request(proxify(input.url), input);
return oldFetch.call(this, input, init);
}};
}}
var oldOpen = XMLHttpRequest.prototype.open;
XMLHttpRequest.prototype.open = function(method, url) {{
arguments[1] = proxify(url);
return oldOpen.apply(this, arguments);
}};
}})();
</script>
"""
if "</head>" in html:
html = html.replace("</head>", inject + "</head>", 1)
else:
html = inject + html
return html.encode("utf-8")
async def _proxy_to_page(
request: Request,
page: CustomPage,
target_url: str,
state: Optional[dict[str, Any]],
) -> httpx.Response:
body = await request.body() if request.method not in ("GET", "HEAD") else None
async with httpx.AsyncClient(follow_redirects=True, timeout=30) as client:
return await client.request(
method=request.method,
url=target_url,
headers=_headers_for_upstream(request, state),
cookies=(state or {}).get("cookies", {}),
content=body,
)
def _response_from_upstream(
resp: httpx.Response,
page_id: int,
target_url: str,
token: Optional[str],
) -> Response:
out: dict[str, str] = {}
for k, v in resp.headers.items():
kl = k.lower()
if kl in _STRIP_RESP:
continue
if kl in ("content-encoding", "transfer-encoding", "content-length", "set-cookie"):
continue
out[k] = v
content = resp.content
content_type = resp.headers.get("content-type", "")
if "text/html" in content_type:
content = _rewrite_html(content, page_id, target_url, token)
return Response(
content=content,
status_code=resp.status_code,
media_type=content_type,
headers=out,
)
@router.api_route("/{pid}/proxy", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"])
@router.api_route("/{pid}/proxy/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"])
async def page_proxy(
pid: int,
request: Request,
path: str = "",
token: Optional[str] = Query(default=None),
db: Session = Depends(get_db),
):
_require_proxy_user(request, token, db)
page = db.query(CustomPage).filter(CustomPage.id == pid).first()
if not page or not page.enabled:
raise HTTPException(404, "page not found")
if not page.url.startswith(("http://", "https://")):
raise HTTPException(400, "Only http/https URLs are allowed")
base = page.url.rstrip("/") + "/"
target_url = urljoin(base, path or "")
query = urlencode([(k, v) for k, v in parse_qsl(request.url.query, keep_blank_values=True) if k != "token"])
if query:
target_url += f"?{query}"
upstream = _find_matching_upstream(db, page)
state = await _ensure_new_api_state(pid, upstream)
try:
resp = await _proxy_to_page(request, page, target_url, state)
if resp.status_code == 401 and upstream:
_PROXY_STATE.pop(pid, None)
state = await _ensure_new_api_state(pid, upstream)
resp = await _proxy_to_page(request, page, target_url, state)
except httpx.RequestError as exc:
raise HTTPException(502, f"Proxy error: {exc}")
except httpx.HTTPStatusError as exc:
raise HTTPException(exc.response.status_code, exc.response.text)
return _response_from_upstream(resp, pid, target_url, _token_from_request(request, token))
@router.api_route("/frame-proxy", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"])
@@ -175,4 +553,3 @@ async def frame_proxy(
media_type=resp.headers.get("content-type"),
headers=out,
)
+2
View File
@@ -19,6 +19,7 @@ from app.services.upstream_client import UpstreamClient, UpstreamError, build_sn
from app.services.snapshot_service import diff_snapshots
from app.services import scheduler as sched_svc
from app.services import webhook_service
from app.services import website_sync
from app.utils.auth import get_current_user
router = APIRouter(prefix="/api/upstreams", tags=["upstreams"])
@@ -209,6 +210,7 @@ def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_use
webhook_service.send_status_event(db, u.id, u.name, u.base_url, "upstream_recovered")
if changes:
webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes)
website_sync.sync_affected_bindings(db, u.id, changes)
msg = f"检测成功,{len(groups)} 个分组"
if changes:
+308
View File
@@ -0,0 +1,308 @@
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.schemas.website import (
BindingCreate,
BindingResponse,
BindingUpdate,
TestResult,
WebsiteCreate,
WebsiteGroupResponse,
WebsiteResponse,
WebsiteSyncLogResponse,
WebsiteUpdate,
)
from app.services.website_client import Sub2ApiWebsiteClient
from app.services.website_sync import binding_sources, sync_binding
from app.utils.auth import get_current_user
router = APIRouter(tags=["websites"])
logger = logging.getLogger(__name__)
MASK = "***"
SECRET_KEYS = {"password", "token", "key", "secret", "api_key"}
ALGORITHMS = {"max_plus_percent", "average_plus_percent", "min_plus_percent"}
def _mask(cfg: dict) -> dict:
masked = {}
for key, value in cfg.items():
masked[key] = MASK if key.lower() in SECRET_KEYS and value else value
return masked
def _website_response(row: Website) -> WebsiteResponse:
return WebsiteResponse(
id=row.id,
name=row.name,
site_type=row.site_type,
base_url=row.base_url,
api_prefix=row.api_prefix,
auth_type=row.auth_type,
auth_config_masked=_mask(json.loads(row.auth_config_json or "{}")),
groups_endpoint=row.groups_endpoint,
group_update_endpoint=row.group_update_endpoint,
enabled=row.enabled,
auto_sync_enabled=row.auto_sync_enabled,
timeout_seconds=row.timeout_seconds,
last_status=row.last_status,
last_checked_at=row.last_checked_at,
last_error=row.last_error,
created_at=row.created_at,
updated_at=row.updated_at,
)
def _binding_response(db: Session, row: WebsiteGroupBinding) -> BindingResponse:
website = db.query(Website).filter(Website.id == row.website_id).first()
return BindingResponse(
id=row.id,
website_id=row.website_id,
website_name=website.name if website else "",
target_group_id=row.target_group_id,
target_group_name=row.target_group_name,
source_groups=binding_sources(row),
percent=float(row.percent or 0),
algorithm=row.algorithm,
enabled=row.enabled,
created_at=row.created_at,
updated_at=row.updated_at,
)
def _log_response(row: WebsiteSyncLog) -> WebsiteSyncLogResponse:
return WebsiteSyncLogResponse(
id=row.id,
website_id=row.website_id,
binding_id=row.binding_id,
target_group_id=row.target_group_id,
target_group_name=row.target_group_name,
algorithm=row.algorithm,
percent=float(row.percent or 0),
source_rates=json.loads(row.source_rates_json or "[]"),
old_rate=row.old_rate,
new_rate=row.new_rate,
status=row.status,
message=row.message,
created_at=row.created_at,
)
def _ensure_unique_target(db: Session, website_id: int, target_group_id: str, exclude_id: int | None = None) -> None:
q = db.query(WebsiteGroupBinding).filter(
WebsiteGroupBinding.website_id == website_id,
WebsiteGroupBinding.target_group_id == target_group_id,
)
if exclude_id is not None:
q = q.filter(WebsiteGroupBinding.id != exclude_id)
if q.first():
raise HTTPException(400, "同一目标网站分组只能维护一条绑定记录")
def _client(row: Website) -> Sub2ApiWebsiteClient:
return Sub2ApiWebsiteClient(
base_url=row.base_url,
api_prefix=row.api_prefix,
auth_type=row.auth_type,
auth_config=json.loads(row.auth_config_json or "{}"),
timeout=float(row.timeout_seconds),
)
@router.get("/api/websites", response_model=List[WebsiteResponse])
def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)):
return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()]
@router.post("/api/websites", response_model=WebsiteResponse, status_code=201)
def create_website(body: WebsiteCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
if body.site_type != "sub2api":
raise HTTPException(400, "目前只支持 sub2api")
row = Website(
name=body.name,
site_type=body.site_type,
base_url=body.base_url.rstrip("/"),
api_prefix=body.api_prefix,
auth_type=body.auth_type,
auth_config_json=json.dumps(body.auth_config, ensure_ascii=False),
groups_endpoint=body.groups_endpoint,
group_update_endpoint=body.group_update_endpoint,
enabled=body.enabled,
auto_sync_enabled=body.auto_sync_enabled,
timeout_seconds=body.timeout_seconds,
)
db.add(row)
db.commit()
db.refresh(row)
return _website_response(row)
@router.put("/api/websites/{wid}", response_model=WebsiteResponse)
def update_website(wid: int, body: WebsiteUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(Website).filter(Website.id == wid).first()
if not row:
raise HTTPException(404, "website not found")
data = body.model_dump(exclude_none=True)
if "site_type" in data and data["site_type"] != "sub2api":
raise HTTPException(400, "目前只支持 sub2api")
if "auth_config" in data:
existing = json.loads(row.auth_config_json or "{}")
incoming = data.pop("auth_config")
for key, value in incoming.items():
if value != MASK:
existing[key] = value
row.auth_config_json = json.dumps(existing, ensure_ascii=False)
if "base_url" in data:
data["base_url"] = data["base_url"].rstrip("/")
for key, value in data.items():
setattr(row, key, value)
row.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(row)
return _website_response(row)
@router.delete("/api/websites/{wid}", status_code=204)
def delete_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(Website).filter(Website.id == wid).first()
if not row:
raise HTTPException(404, "website not found")
db.delete(row)
db.commit()
@router.post("/api/websites/{wid}/test", response_model=TestResult)
def test_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(Website).filter(Website.id == wid).first()
if not row:
raise HTTPException(404, "website not found")
try:
groups = _client(row).get_groups(row.groups_endpoint)
row.last_status = "healthy"
row.last_error = None
row.last_checked_at = datetime.now(timezone.utc)
db.commit()
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
except Exception as exc:
row.last_status = "unhealthy"
row.last_error = str(exc)
row.last_checked_at = datetime.now(timezone.utc)
db.commit()
return TestResult(success=False, message="连接失败", detail=str(exc))
@router.get("/api/websites/{wid}/groups", response_model=List[WebsiteGroupResponse])
def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(Website).filter(Website.id == wid).first()
if not row:
raise HTTPException(404, "website not found")
try:
return _client(row).get_groups(row.groups_endpoint)
except Exception as exc:
raise HTTPException(502, str(exc))
@router.get("/api/group-bindings", response_model=List[BindingResponse])
def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)):
rows = db.query(WebsiteGroupBinding).order_by(WebsiteGroupBinding.id.desc()).all()
return [_binding_response(db, row) for row in rows]
@router.post("/api/group-bindings", response_model=BindingResponse, status_code=201)
def create_binding(body: BindingCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
website = db.query(Website).filter(Website.id == body.website_id).first()
if not website:
raise HTTPException(404, "website not found")
if body.algorithm not in ALGORITHMS:
raise HTTPException(400, "不支持的算法")
_ensure_unique_target(db, body.website_id, body.target_group_id)
row = WebsiteGroupBinding(
website_id=body.website_id,
target_group_id=body.target_group_id,
target_group_name=body.target_group_name,
source_groups_json=json.dumps([item.model_dump() for item in body.source_groups], ensure_ascii=False),
percent=str(body.percent),
algorithm=body.algorithm,
enabled=body.enabled,
)
db.add(row)
db.commit()
db.refresh(row)
try:
sync_binding(db, row, write=True)
except Exception as exc:
logger.exception("initial website sync failed for binding %s: %s", row.id, exc)
return _binding_response(db, row)
@router.put("/api/group-bindings/{bid}", response_model=BindingResponse)
def update_binding(bid: int, body: BindingUpdate, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
if not row:
raise HTTPException(404, "binding not found")
data = body.model_dump(exclude_none=True)
if "website_id" in data and not db.query(Website).filter(Website.id == data["website_id"]).first():
raise HTTPException(404, "website not found")
if "algorithm" in data and data["algorithm"] not in ALGORITHMS:
raise HTTPException(400, "不支持的算法")
next_website_id = int(data.get("website_id", row.website_id))
next_target_group_id = str(data.get("target_group_id", row.target_group_id))
_ensure_unique_target(db, next_website_id, next_target_group_id, exclude_id=row.id)
if "source_groups" in data:
row.source_groups_json = json.dumps(data.pop("source_groups"), ensure_ascii=False)
if "percent" in data:
row.percent = str(data.pop("percent"))
for key, value in data.items():
setattr(row, key, value)
row.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(row)
try:
sync_binding(db, row, write=True)
except Exception as exc:
logger.exception("sync failed after updating binding %s: %s", row.id, exc)
return _binding_response(db, row)
@router.delete("/api/group-bindings/{bid}", status_code=204)
def delete_binding(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
if not row:
raise HTTPException(404, "binding not found")
db.delete(row)
db.commit()
@router.post("/api/group-bindings/{bid}/sync-now", response_model=WebsiteSyncLogResponse)
def sync_now(bid: int, db: Session = Depends(get_db), _=Depends(get_current_user)):
row = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.id == bid).first()
if not row:
raise HTTPException(404, "binding not found")
return _log_response(sync_binding(db, row, write=True))
@router.get("/api/website-sync-logs", response_model=List[WebsiteSyncLogResponse])
def list_sync_logs(
website_id: int | None = Query(None),
binding_id: int | None = Query(None),
limit: int = Query(50, le=200),
offset: int = Query(0),
db: Session = Depends(get_db),
_=Depends(get_current_user),
):
q = db.query(WebsiteSyncLog)
if website_id:
q = q.filter(WebsiteSyncLog.website_id == website_id)
if binding_id:
q = q.filter(WebsiteSyncLog.binding_id == binding_id)
rows = q.order_by(WebsiteSyncLog.created_at.desc()).offset(offset).limit(limit).all()
return [_log_response(row) for row in rows]
+2 -2
View File
@@ -2,7 +2,7 @@ from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel
ALLOWED_EVENTS = ["upstream_rate_changed", "upstream_unhealthy", "upstream_recovered"]
ALLOWED_EVENTS = ["upstream_rate_changed", "website_rate_changed", "upstream_unhealthy", "upstream_recovered"]
class WebhookCreate(BaseModel):
@@ -11,7 +11,7 @@ class WebhookCreate(BaseModel):
url: str
secret: str = ""
enabled: bool = True
events: List[str] = ["upstream_rate_changed"]
events: List[str] = ["upstream_rate_changed", "website_rate_changed"]
class WebhookUpdate(BaseModel):
+124
View File
@@ -0,0 +1,124 @@
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
class TestResult(BaseModel):
success: bool
message: str
detail: Optional[str] = None
class WebsiteCreate(BaseModel):
name: str
site_type: str = "sub2api"
base_url: str
api_prefix: str = "/api/v1/admin"
auth_type: str = "api_key"
auth_config: dict[str, Any] = {}
groups_endpoint: str = "/groups"
group_update_endpoint: str = "/groups/{id}"
enabled: bool = True
auto_sync_enabled: bool = True
timeout_seconds: int = 30
class WebsiteUpdate(BaseModel):
name: Optional[str] = None
site_type: Optional[str] = None
base_url: Optional[str] = None
api_prefix: Optional[str] = None
auth_type: Optional[str] = None
auth_config: Optional[dict[str, Any]] = None
groups_endpoint: Optional[str] = None
group_update_endpoint: Optional[str] = None
enabled: Optional[bool] = None
auto_sync_enabled: Optional[bool] = None
timeout_seconds: Optional[int] = None
class WebsiteResponse(BaseModel):
id: int
name: str
site_type: str
base_url: str
api_prefix: str
auth_type: str
auth_config_masked: dict[str, Any]
groups_endpoint: str
group_update_endpoint: str
enabled: bool
auto_sync_enabled: bool
timeout_seconds: int
last_status: str
last_checked_at: Optional[datetime]
last_error: Optional[str]
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class WebsiteGroupResponse(BaseModel):
id: str
name: str
rate_multiplier: Optional[str] = None
raw: dict[str, Any] = {}
class BindingSourceGroup(BaseModel):
upstream_id: int
group_id: str
upstream_name: str = ""
group_name: str = ""
class BindingCreate(BaseModel):
website_id: int
target_group_id: str
target_group_name: str = ""
source_groups: list[BindingSourceGroup] = Field(default_factory=list)
percent: float = Field(default=0, ge=0)
algorithm: str = "max_plus_percent"
enabled: bool = True
class BindingUpdate(BaseModel):
website_id: Optional[int] = None
target_group_id: Optional[str] = None
target_group_name: Optional[str] = None
source_groups: Optional[list[BindingSourceGroup]] = None
percent: Optional[float] = Field(default=None, ge=0)
algorithm: Optional[str] = None
enabled: Optional[bool] = None
class BindingResponse(BaseModel):
id: int
website_id: int
website_name: str
target_group_id: str
target_group_name: str
source_groups: list[BindingSourceGroup]
percent: float
algorithm: str
enabled: bool
created_at: datetime
updated_at: datetime
class WebsiteSyncLogResponse(BaseModel):
id: int
website_id: int
binding_id: Optional[int]
target_group_id: str
target_group_name: str
algorithm: str
percent: float
source_rates: list[dict[str, Any]]
old_rate: Optional[str]
new_rate: Optional[str]
status: str
message: str
created_at: datetime
@@ -0,0 +1,320 @@
"""Managed Playwright browser sessions for custom pages."""
from __future__ import annotations
import asyncio
import logging
import re
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
from urllib.parse import urlparse
from uuid import uuid4
from app.config import get_settings
logger = logging.getLogger(__name__)
class BrowserDependencyError(RuntimeError):
"""Raised when Playwright or its browser runtime is unavailable."""
class BrowserSessionError(RuntimeError):
"""Raised when an existing browser session can no longer be used."""
@dataclass
class BrowserSession:
id: str
custom_page_id: int
profile_key: str
context: Any
page: Any
lock: asyncio.Lock
class BrowserSessionService:
def __init__(self) -> None:
self._playwright: Optional[Any] = None
self._sessions: dict[str, BrowserSession] = {}
self._profiles: dict[str, str] = {}
self._lock = asyncio.Lock()
async def create(
self,
custom_page_id: int,
url: str,
width: int = 1280,
height: int = 720,
login_config: Optional[dict[str, Any]] = None,
) -> BrowserSession:
if not url.startswith(("http://", "https://")):
raise ValueError("Only http/https URLs are allowed")
width = max(320, min(width, 2560))
height = max(240, min(height, 1600))
async with self._lock:
await self._ensure_playwright()
profile_key = self._profile_key(custom_page_id, url)
existing_id = self._profiles.get(profile_key)
existing = self._sessions.get(existing_id or "")
if existing and not existing.page.is_closed():
async with existing.lock:
await existing.page.set_viewport_size({"width": width, "height": height})
if existing.page.url == "about:blank":
await existing.page.goto(url, wait_until="domcontentloaded", timeout=45000)
await self._autofill_login(existing.page, login_config)
await self._reset_page_zoom(existing)
return existing
if existing_id:
self._profiles.pop(profile_key, None)
context = await self._playwright.chromium.launch_persistent_context(
str(self._profile_dir(profile_key)),
headless=get_settings().browser_headless,
viewport={"width": width, "height": height},
args=["--no-sandbox", "--disable-dev-shm-usage"],
)
page = context.pages[0] if context.pages else await context.new_page()
session = BrowserSession(
id=uuid4().hex,
custom_page_id=custom_page_id,
profile_key=profile_key,
context=context,
page=page,
lock=asyncio.Lock(),
)
self._sessions[session.id] = session
self._profiles[profile_key] = session.id
try:
await page.goto(url, wait_until="domcontentloaded", timeout=45000)
await self._autofill_login(page, login_config)
await self._reset_page_zoom(session)
except Exception:
await self.close(session.id)
raise
return session
async def screenshot(self, session_id: str) -> bytes:
session = self._get(session_id)
async with session.lock:
self._ensure_open(session)
return await session.page.screenshot(type="jpeg", quality=78, full_page=False)
async def event(self, session_id: str, event_type: str, payload: dict[str, Any]) -> dict[str, Any]:
session = self._get(session_id)
async with session.lock:
self._ensure_open(session)
page = session.page
if event_type == "click":
await page.mouse.click(float(payload["x"]), float(payload["y"]), button=payload.get("button", "left"))
elif event_type == "dblclick":
await page.mouse.dblclick(float(payload["x"]), float(payload["y"]), button=payload.get("button", "left"))
elif event_type == "mousemove":
await page.mouse.move(float(payload["x"]), float(payload["y"]))
elif event_type == "mousedown":
await page.mouse.move(float(payload["x"]), float(payload["y"]))
await page.mouse.down(button=payload.get("button", "left"))
elif event_type == "mouseup":
await page.mouse.move(float(payload["x"]), float(payload["y"]))
await page.mouse.up(button=payload.get("button", "left"))
elif event_type == "type":
text = str(payload.get("text", ""))
if text:
await page.keyboard.type(text)
elif event_type == "key":
key = str(payload.get("key", ""))
if key:
await page.keyboard.press(key)
elif event_type == "scroll":
if payload.get("x") is not None and payload.get("y") is not None:
await page.mouse.move(float(payload["x"]), float(payload["y"]))
await page.mouse.wheel(float(payload.get("delta_x", 0)), float(payload.get("delta_y", 0)))
elif event_type == "reload":
await page.reload(wait_until="domcontentloaded", timeout=45000)
elif event_type == "back":
await page.go_back(wait_until="domcontentloaded", timeout=45000)
elif event_type == "forward":
await page.go_forward(wait_until="domcontentloaded", timeout=45000)
elif event_type == "resize":
width = max(320, min(int(payload.get("width", 1280)), 2560))
height = max(240, min(int(payload.get("height", 720)), 1600))
await page.set_viewport_size({"width": width, "height": height})
else:
raise ValueError("Unsupported browser event")
return await self._session_state(session)
async def close(self, session_id: str) -> None:
session = self._discard_session(session_id)
if not session:
return
try:
await session.context.close()
except Exception:
pass
async def shutdown(self) -> None:
sessions = list(self._sessions)
for session_id in sessions:
await self.close(session_id)
if self._playwright:
await self._playwright.stop()
self._playwright = None
async def state(self, session_id: str) -> dict[str, Any]:
session = self._get(session_id)
async with session.lock:
self._ensure_open(session)
return await self._session_state(session)
async def _session_state(self, session: BrowserSession) -> dict[str, Any]:
return {
"id": session.id,
"custom_page_id": session.custom_page_id,
"url": session.page.url,
"title": await session.page.title(),
}
async def _ensure_playwright(self) -> None:
if self._playwright:
return
try:
from playwright.async_api import async_playwright
except ImportError as exc:
raise BrowserDependencyError("Playwright is not installed. Run `pip install -r requirements.txt`.") from exc
try:
self._playwright = await async_playwright().start()
except Exception as exc:
raise BrowserDependencyError(f"Unable to start Playwright: {exc}") from exc
async def _reset_page_zoom(self, session: BrowserSession) -> None:
try:
cdp = await session.context.new_cdp_session(session.page)
try:
await cdp.send("Emulation.setPageScaleFactor", {"pageScaleFactor": 1})
finally:
await cdp.detach()
except Exception:
pass
async def _autofill_login(
self,
page: Any,
config: Optional[dict[str, Any]],
*,
max_wait_seconds: float = 8.0,
poll_interval_seconds: float = 0.25,
) -> None:
if not config or not config.get("enabled"):
return
username = str(config.get("username") or "")
password = str(config.get("password") or "")
if not username or not password:
return
try:
username_selectors = [
config.get("username_selector"),
"input[type='email']",
"input[name*='user' i]",
"input[id*='user' i]",
"input[name*='email' i]",
"input[id*='email' i]",
"input[name*='login' i]",
"input[id*='login' i]",
"input[autocomplete='username']",
"input:not([type]), input[type='text']",
]
password_selectors = [
config.get("password_selector"),
"input[type='password']",
"input[autocomplete='current-password']",
]
username_locator, password_locator = await self._wait_for_login_locators(
page,
username_selectors,
password_selectors,
max_wait_seconds=max_wait_seconds,
poll_interval_seconds=poll_interval_seconds,
)
if not username_locator or not password_locator:
logger.info("Login autofill skipped for %s: login fields not found", page.url)
return
await username_locator.fill(username, timeout=3000)
await password_locator.fill(password, timeout=3000)
submit_selector = str(config.get("submit_selector") or "").strip()
if submit_selector:
submit = await self._first_visible_locator(page, [submit_selector], timeout=500)
if submit:
await submit.click(timeout=3000)
except Exception as exc:
logger.info("Login autofill skipped for %s: %s", page.url, exc)
async def _wait_for_login_locators(
self,
page: Any,
username_selectors: list[Optional[str]],
password_selectors: list[Optional[str]],
*,
max_wait_seconds: float,
poll_interval_seconds: float,
) -> tuple[Optional[Any], Optional[Any]]:
deadline = time.monotonic() + max_wait_seconds
while True:
username_locator = await self._first_visible_locator(page, username_selectors, timeout=150)
password_locator = await self._first_visible_locator(page, password_selectors, timeout=150)
if username_locator and password_locator:
return username_locator, password_locator
if time.monotonic() >= deadline:
return None, None
await asyncio.sleep(poll_interval_seconds)
async def _first_visible_locator(
self,
page: Any,
selectors: list[Optional[str]],
*,
timeout: float = 1500,
) -> Optional[Any]:
for selector in selectors:
selector = str(selector or "").strip()
if not selector:
continue
try:
locator = page.locator(selector).first
if await locator.count() and await locator.is_visible(timeout=timeout):
return locator
except Exception:
continue
return None
def _get(self, session_id: str) -> BrowserSession:
session = self._sessions.get(session_id)
if not session:
raise KeyError("browser session not found")
return session
def _ensure_open(self, session: BrowserSession) -> None:
if session.page.is_closed():
self._discard_session(session.id)
raise BrowserSessionError("browser page is closed")
def _discard_session(self, session_id: str) -> BrowserSession | None:
session = self._sessions.pop(session_id, None)
if session and self._profiles.get(session.profile_key) == session_id:
self._profiles.pop(session.profile_key, None)
return session
def _profile_dir(self, profile_key: str) -> Path:
root = Path(get_settings().browser_profiles_dir)
root.mkdir(parents=True, exist_ok=True)
profile = root / profile_key
profile.mkdir(parents=True, exist_ok=True)
return profile
def _profile_key(self, custom_page_id: int, url: str) -> str:
parsed = urlparse(url)
origin = f"{parsed.scheme}-{parsed.netloc}".lower()
safe_origin = re.sub(r"[^a-z0-9_.-]+", "_", origin).strip("_") or "page"
return f"page-{custom_page_id}-{safe_origin[:80]}"
browser_sessions = BrowserSessionService()
+2
View File
@@ -14,6 +14,7 @@ from app.models.snapshot import UpstreamRateSnapshot
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot
from app.services.snapshot_service import diff_snapshots
from app.services import webhook_service
from app.services import website_sync
from app.config import get_settings
logger = logging.getLogger(__name__)
@@ -105,6 +106,7 @@ def _check_upstream(upstream_id: int) -> None:
webhook_service.send_rate_changed(
db, upstream.id, upstream.name, upstream.base_url, changes
)
website_sync.sync_affected_bindings(db, upstream.id, changes)
logger.info("upstream %s: %d rate change(s)", upstream.name, len(changes))
else:
logger.debug("upstream %s: no changes", upstream.name)
+98 -9
View File
@@ -27,14 +27,42 @@ def _find_token(value: Any) -> str:
return ""
def _find_user_id(value: Any) -> str:
if isinstance(value, dict):
for key in ("id", "user_id", "userId"):
candidate = value.get(key)
if candidate is not None:
return str(candidate)
for key in ("data", "result", "user", "session"):
user_id = _find_user_id(value.get(key))
if user_id:
return user_id
return ""
def _unwrap_list(value: Any) -> Optional[list[dict[str, Any]]]:
def _normalize(lst: list) -> list[dict[str, Any]]:
out = []
for i in lst:
if isinstance(i, dict):
out.append(i)
elif isinstance(i, str):
out.append({"id": i, "name": i})
return out
if isinstance(value, list):
return [i for i in value if isinstance(i, dict)]
return _normalize(value)
if isinstance(value, dict):
for key in ("data", "items", "groups", "available_groups", "availableGroups"):
nested = value.get(key)
if isinstance(nested, list):
return [i for i in nested if isinstance(i, dict)]
return _normalize(nested)
elif isinstance(nested, dict):
# Handle /api/user/self/groups where data is a dict of group_name -> { desc, ratio }
out = []
for k in nested.keys():
out.append({"id": k, "name": k})
return out
return None
@@ -76,19 +104,59 @@ def _rate_from_group(group: dict[str, Any]) -> str:
def _extract_rates_map(raw: Any) -> dict[str, str]:
if raw is None:
return {}
# Handle one-api/new-api /api/option response where GroupRatio is in a list of options
if isinstance(raw, dict) and isinstance(raw.get("data"), list):
for item in raw["data"]:
if isinstance(item, dict) and item.get("key") == "GroupRatio":
val = item.get("value")
if isinstance(val, str):
try:
import json
parsed = json.loads(val)
if isinstance(parsed, dict):
result: dict[str, str] = {}
for k, v in parsed.items():
r = _decimal_str(v)
if r:
result[str(k)] = r
return result
except Exception:
pass
elif isinstance(val, dict):
# In case it's returned as dict directly
result = {}
for k, v in val.items():
r = _decimal_str(v)
if r:
result[str(k)] = r
return result
if isinstance(raw, dict):
candidates = raw
for key in ("data", "rates", "group_rates", "groupRates"):
for key in ("data", "rates", "group_rates", "groupRates", "GroupRatio"):
nested = raw.get(key)
if isinstance(nested, dict):
candidates = nested
break
elif isinstance(nested, str) and key == "GroupRatio":
# Handle GroupRatio as a JSON string
try:
import json
parsed = json.loads(nested)
if isinstance(parsed, dict):
candidates = parsed
break
except Exception:
pass
result: dict[str, str] = {}
for k, v in candidates.items():
if isinstance(v, dict):
r = _decimal_str(
v.get("rate_multiplier") or v.get("rateMultiplier")
or v.get("user_rate_multiplier") or v.get("userRateMultiplier")
or v.get("ratio")
)
else:
r = _decimal_str(v)
@@ -151,6 +219,8 @@ class UpstreamClient:
self.auth_config = auth_config
self.timeout = timeout
self._token: str = ""
self._cookies: dict[str, str] = {}
self._new_api_user: str = ""
def _url(self, path: str) -> str:
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
@@ -174,15 +244,29 @@ class UpstreamClient:
headers[header] = key
elif self.auth_type == "login_password" and self._token:
headers["Authorization"] = f"Bearer {self._token}"
if self.auth_type == "login_password" and self._new_api_user:
headers["New-Api-User"] = self._new_api_user
return headers
def _request(self, method: str, path: str, body: Any = None, auth: bool = True) -> Any:
url = self._url(path)
with httpx.Client(timeout=self.timeout) as client:
if body is not None:
resp = client.request(method, url, json=body, headers=self._headers(auth))
resp = client.request(
method,
url,
json=body,
headers=self._headers(auth),
cookies=self._cookies,
)
else:
resp = client.request(method, url, headers=self._headers(auth))
resp = client.request(
method,
url,
headers=self._headers(auth),
cookies=self._cookies,
)
self._cookies.update(dict(resp.cookies))
resp.raise_for_status()
ct = resp.headers.get("content-type", "")
if not resp.content:
@@ -198,13 +282,18 @@ class UpstreamClient:
email = self.auth_config.get("email", "")
password = self.auth_config.get("password", "")
login_path = self.auth_config.get("login_path", "/auth/login")
username_field = self.auth_config.get("username_field", "email")
if not email or not password:
raise UpstreamError("login_password auth requires email and password in auth_config")
resp = self._request("POST", login_path, {"email": email, "password": password}, auth=False)
resp = self._request("POST", login_path, {username_field: email, "password": password}, auth=False)
token = _find_token(resp)
if not token:
raise UpstreamError("login succeeded but no token found in response")
self._token = token
if token:
self._token = token
return
if self._cookies:
self._new_api_user = self.auth_config.get("new_api_user", "") or _find_user_id(resp)
return
raise UpstreamError("login succeeded but no token or session cookie found in response")
def get_available_groups(self, endpoint: str) -> list[dict[str, Any]]:
resp = self._request("GET", endpoint)
+49
View File
@@ -13,6 +13,7 @@ from app.models.notification_log import NotificationLog
from app.utils.dingtalk import (
dingtalk_signed_url,
format_dingtalk_rate_changed,
format_dingtalk_website_rate_changed,
format_dingtalk_status,
)
@@ -101,6 +102,54 @@ def send_rate_changed(
_log(db, wh, "upstream_rate_changed", generic_payload, "failed", str(exc))
def send_website_rate_changed(
db: Session,
website_id: int,
website_name: str,
base_url: str,
binding_id: int,
target_group_id: str,
target_group_name: str,
old_rate: Any,
new_rate: Any,
source_rates: list[dict[str, Any]],
) -> None:
webhooks = (
db.query(WebhookConfig)
.filter(WebhookConfig.enabled == True)
.all()
)
changed_at = _now_iso()
generic_payload = {
"event": "website_rate_changed",
"website": {"id": website_id, "name": website_name, "base_url": base_url},
"binding": {"id": binding_id},
"target_group": {
"id": target_group_id,
"name": target_group_name,
"old_rate": old_rate,
"new_rate": new_rate,
},
"source_rates": source_rates,
"changed_at": changed_at,
}
for wh in webhooks:
events = json.loads(wh.events_json or "[]")
if "website_rate_changed" not in events:
continue
try:
if wh.type == "dingtalk":
msg = format_dingtalk_website_rate_changed(
website_name, target_group_name, changed_at, old_rate, new_rate
)
resp_text = _send_dingtalk(wh.url, wh.secret, msg)
else:
resp_text = _send_generic(wh.url, generic_payload)
_log(db, wh, "website_rate_changed", generic_payload, "success", resp_text)
except Exception as exc:
_log(db, wh, "website_rate_changed", generic_payload, "failed", str(exc))
def send_status_event(
db: Session,
upstream_id: int,
+154
View File
@@ -0,0 +1,154 @@
from __future__ import annotations
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
from typing import Any
from urllib.parse import quote
import httpx
class WebsiteError(RuntimeError):
pass
def decimal_string(value: Any) -> str:
if value is None or value == "":
return ""
try:
d = Decimal(str(value))
except (InvalidOperation, ValueError):
return str(value)
n = d.normalize()
if n == n.to_integral():
return str(n.quantize(Decimal("1")))
return format(n, "f")
def parse_positive_decimal(value: Any) -> Decimal | None:
if value is None or value == "":
return None
try:
d = Decimal(str(value))
except (InvalidOperation, ValueError):
return None
return d if d > 0 else None
def calculate_target_rate(values: list[Any], percent: Any = 0, algorithm: str = "max_plus_percent") -> Decimal:
rates = [rate for rate in (parse_positive_decimal(v) for v in values) if rate is not None]
if not rates:
raise WebsiteError("没有可用的正数上游倍率")
if algorithm == "average_plus_percent":
base = sum(rates, Decimal("0")) / Decimal(len(rates))
elif algorithm == "min_plus_percent":
base = min(rates)
elif algorithm == "max_plus_percent":
base = max(rates)
else:
raise WebsiteError(f"不支持的算法:{algorithm}")
pct = Decimal(str(percent or 0))
if pct < 0:
raise WebsiteError("百分比不能为负数")
return (base * (Decimal("1") + pct / Decimal("100"))).quantize(Decimal("0.0001"), rounding=ROUND_HALF_UP)
def _unwrap_data(value: Any) -> Any:
if isinstance(value, dict):
data = value.get("data")
if "data" in value and (
"code" in value
or "message" in value
or isinstance(data, list)
or (isinstance(data, dict) and any(key in data for key in ("items", "groups")))
):
value = data
if not isinstance(value, dict):
return value
for key in ("items", "groups"):
if key in value:
return value.get(key)
return value
def normalize_groups(value: Any) -> list[dict[str, Any]]:
raw = _unwrap_data(value)
if isinstance(raw, dict):
raw = list(raw.values())
if not isinstance(raw, list):
raise WebsiteError("分组接口没有返回列表")
groups: list[dict[str, Any]] = []
for item in raw:
if isinstance(item, str):
groups.append({"id": item, "name": item, "rate_multiplier": None, "raw": {"id": item, "name": item}})
continue
if not isinstance(item, dict):
continue
gid = item.get("id") or item.get("group_id") or item.get("groupId") or item.get("name") or item.get("group_name")
if gid is None:
continue
name = item.get("name") or item.get("group_name") or str(gid)
rate = item.get("rate_multiplier") or item.get("rateMultiplier") or item.get("ratio")
groups.append({
"id": str(gid),
"name": str(name),
"rate_multiplier": decimal_string(rate) if rate is not None else None,
"raw": item,
})
return groups
class Sub2ApiWebsiteClient:
def __init__(
self,
base_url: str,
api_prefix: str,
auth_type: str,
auth_config: dict[str, Any],
timeout: float = 30.0,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_prefix = api_prefix.strip("/")
self.auth_type = auth_type
self.auth_config = auth_config
self.timeout = timeout
def _url(self, path: str) -> str:
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
def _headers(self) -> dict[str, str]:
headers = {"Accept": "application/json", "User-Agent": "SmartUp/1.0"}
if self.auth_type == "api_key":
key = self.auth_config.get("key") or self.auth_config.get("api_key") or ""
header = self.auth_config.get("header") or "x-api-key"
if key:
headers[header] = key
elif self.auth_type == "bearer":
token = self.auth_config.get("token") or ""
if token:
headers["Authorization"] = f"Bearer {token}"
return headers
def _request(self, method: str, path: str, body: Any = None) -> Any:
with httpx.Client(timeout=self.timeout) as client:
resp = client.request(method, self._url(path), json=body, headers=self._headers())
resp.raise_for_status()
if not resp.content:
return None
text = resp.text
if "application/json" not in resp.headers.get("content-type", "") and text.lstrip().startswith("<"):
raise WebsiteError(f"{method} {path} returned HTML, not JSON")
return resp.json()
def get_groups(self, endpoint: str = "/groups") -> list[dict[str, Any]]:
errors: list[str] = []
for path in [endpoint, "/groups/all"]:
try:
return normalize_groups(self._request("GET", path))
except Exception as exc:
errors.append(f"{path}: {exc}")
raise WebsiteError("; ".join(errors))
def update_group_rate(self, endpoint_template: str, group_id: str, rate: Decimal) -> Any:
path = endpoint_template.replace("{id}", quote(group_id, safe=""))
return self._request("PUT", path, {"rate_multiplier": float(rate)})
+165
View File
@@ -0,0 +1,165 @@
from __future__ import annotations
import json
import logging
from decimal import Decimal
from typing import Any
from sqlalchemy.orm import Session
from app.models.snapshot import UpstreamRateSnapshot
from app.models.upstream import Upstream
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string
from app.services import webhook_service
logger = logging.getLogger(__name__)
def binding_sources(binding: WebsiteGroupBinding) -> list[dict[str, Any]]:
try:
data = json.loads(binding.source_groups_json or "[]")
except Exception:
return []
return data if isinstance(data, list) else []
def latest_rate_map(db: Session, upstream_id: int) -> dict[str, Any]:
row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
.order_by(UpstreamRateSnapshot.captured_at.desc())
.first()
)
if not row:
return {}
snapshot = json.loads(row.snapshot_json or "{}")
groups = snapshot.get("groups") or {}
return groups if isinstance(groups, dict) else {}
def get_affected_bindings(db: Session, changes: list[dict[str, Any]], upstream_id: int) -> list[WebsiteGroupBinding]:
changed_ids = {str(change.get("group_id")) for change in changes if change.get("group_id") is not None}
if not changed_ids:
return []
result: list[WebsiteGroupBinding] = []
bindings = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.enabled == True).all()
for binding in bindings:
for source in binding_sources(binding):
if int(source.get("upstream_id") or 0) == upstream_id and str(source.get("group_id")) in changed_ids:
result.append(binding)
break
return result
def _client_for(website: Website) -> Sub2ApiWebsiteClient:
return Sub2ApiWebsiteClient(
base_url=website.base_url,
api_prefix=website.api_prefix,
auth_type=website.auth_type,
auth_config=json.loads(website.auth_config_json or "{}"),
timeout=float(website.timeout_seconds),
)
def _log(
db: Session,
binding: WebsiteGroupBinding,
website: Website,
source_rates: list[dict[str, Any]],
status: str,
message: str,
old_rate: Any = None,
new_rate: Any = None,
) -> WebsiteSyncLog:
row = WebsiteSyncLog(
website_id=website.id,
binding_id=binding.id,
target_group_id=binding.target_group_id,
target_group_name=binding.target_group_name,
algorithm=binding.algorithm,
percent=binding.percent,
source_rates_json=json.dumps(source_rates, ensure_ascii=False),
old_rate=decimal_string(old_rate) if old_rate not in (None, "") else None,
new_rate=decimal_string(new_rate) if new_rate not in (None, "") else None,
status=status,
message=message,
)
db.add(row)
db.commit()
db.refresh(row)
return row
def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True) -> WebsiteSyncLog:
website = db.query(Website).filter(Website.id == binding.website_id).first()
if not website:
raise WebsiteError("网站不存在")
sources = binding_sources(binding)
source_rates: list[dict[str, Any]] = []
for source in sources:
upstream_id = int(source.get("upstream_id") or 0)
group_id = str(source.get("group_id") or "")
groups = latest_rate_map(db, upstream_id)
group = groups.get(group_id) if group_id else None
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
source_rates.append({
"upstream_id": upstream_id,
"upstream_name": source.get("upstream_name") or (upstream.name if upstream else ""),
"group_id": group_id,
"group_name": source.get("group_name") or (group.get("group_name", "") if isinstance(group, dict) else ""),
"rate": group.get("rate") if isinstance(group, dict) else None,
})
try:
target_rate = calculate_target_rate([item.get("rate") for item in source_rates], binding.percent, binding.algorithm)
except Exception as exc:
return _log(db, binding, website, source_rates, "failed", str(exc))
old_rate = None
if write and website.enabled and website.auto_sync_enabled and binding.enabled:
try:
client = _client_for(website)
groups = client.get_groups(website.groups_endpoint)
target = next((item for item in groups if item.get("id") == binding.target_group_id), None)
old_rate = target.get("rate_multiplier") if target else None
client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate)
website.last_status = "healthy"
website.last_error = None
except Exception as exc:
website.last_status = "unhealthy"
website.last_error = str(exc)
db.commit()
return _log(db, binding, website, source_rates, "failed", f"写回失败:{exc}", old_rate, target_rate)
db.commit()
log = _log(db, binding, website, source_rates, "success", "同步成功", old_rate, target_rate)
old_rate_str = decimal_string(old_rate) if old_rate not in (None, "") else None
new_rate_str = decimal_string(target_rate)
if old_rate_str != new_rate_str:
webhook_service.send_website_rate_changed(
db,
website.id,
website.name,
website.base_url,
binding.id,
binding.target_group_id,
binding.target_group_name,
old_rate_str,
new_rate_str,
source_rates,
)
return log
message = "已计算建议倍率,未写回"
if not website.enabled or not website.auto_sync_enabled:
message = "网站未启用自动同步,未写回"
elif not binding.enabled:
message = "绑定未启用,未写回"
return _log(db, binding, website, source_rates, "success", message, old_rate, target_rate)
def sync_affected_bindings(db: Session, upstream_id: int, changes: list[dict[str, Any]]) -> None:
for binding in get_affected_bindings(db, changes, upstream_id):
try:
sync_binding(db, binding, write=True)
except Exception as exc:
logger.exception("website sync failed for binding %s: %s", binding.id, exc)
+24
View File
@@ -40,6 +40,30 @@ def format_dingtalk_rate_changed(upstream_name: str, changed_at: str, changes: l
}
def format_dingtalk_website_rate_changed(
website_name: str,
target_group_name: str,
changed_at: str,
old_rate: Any,
new_rate: Any,
) -> dict[str, Any]:
group_name = target_group_name or "unknown"
lines = [
f"### 网站倍率变更:{website_name}",
"",
f"- 时间:{changed_at}",
f"- 分组:`{group_name}`",
f"- 倍率:`{old_rate}` -> `{new_rate}`",
]
return {
"msgtype": "markdown",
"markdown": {
"title": f"{website_name} 网站倍率变更",
"text": "\n".join(lines),
},
}
def format_dingtalk_status(upstream_name: str, event: str, changed_at: str, error: str = "") -> dict[str, Any]:
emoji = "🔴" if event == "upstream_unhealthy" else "🟢"
label = "服务异常" if event == "upstream_unhealthy" else "服务恢复"
+1
View File
@@ -8,3 +8,4 @@ apscheduler==3.10.4
python-dotenv==1.0.1
pydantic-settings==2.6.1
python-multipart==0.0.20
playwright==1.52.0
+104
View File
@@ -0,0 +1,104 @@
import asyncio
from app.services.browser_session_service import BrowserSessionService
class FakeLocator:
def __init__(self, *, visible=True, count=1):
self._visible = list(visible) if isinstance(visible, list) else [visible]
self._count = count
self.filled = []
self.clicked = 0
self.timeouts = []
@property
def first(self):
return self
async def count(self):
return self._count
async def is_visible(self, timeout=0):
self.timeouts.append(timeout)
if not self._visible:
return False
if len(self._visible) == 1:
return self._visible[0]
return self._visible.pop(0)
async def fill(self, value, timeout=0):
self.filled.append((value, timeout))
async def click(self, timeout=0):
self.clicked += 1
class FakePage:
url = "https://example.test/login"
def __init__(self, locators):
self.locators = locators
self.queries = []
def locator(self, selector):
self.queries.append(selector)
return self.locators.get(selector, FakeLocator(visible=False, count=0))
def run(coro):
return asyncio.run(coro)
def test_autofill_retries_until_delayed_fields_are_visible():
service = BrowserSessionService()
username = FakeLocator(visible=[False, True])
password = FakeLocator(visible=True)
submit = FakeLocator(visible=True)
page = FakePage({
"#user": username,
"#pass": password,
"#submit": submit,
})
run(service._autofill_login(
page,
{
"enabled": True,
"username": "alice",
"password": "secret",
"username_selector": "#user",
"password_selector": "#pass",
"submit_selector": "#submit",
},
max_wait_seconds=1,
poll_interval_seconds=0,
))
assert page.queries[0] == "#user"
assert "#pass" in page.queries
assert "input[type='password']" not in page.queries
assert username.filled == [("alice", 3000)]
assert password.filled == [("secret", 3000)]
assert submit.clicked == 1
def test_autofill_returns_without_selectors_when_disabled_or_missing_credentials():
service = BrowserSessionService()
disabled_page = FakePage({"#user": FakeLocator()})
run(service._autofill_login(
disabled_page,
{"enabled": False, "username": "alice", "password": "secret"},
max_wait_seconds=1,
poll_interval_seconds=0,
))
assert disabled_page.queries == []
missing_password_page = FakePage({"#user": FakeLocator()})
run(service._autofill_login(
missing_password_page,
{"enabled": True, "username": "alice", "password": ""},
max_wait_seconds=1,
poll_interval_seconds=0,
))
assert missing_password_page.queries == []
+171
View File
@@ -0,0 +1,171 @@
import sys
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
sys.path.insert(0, str(Path(__file__).resolve().parent))
from app import database as database_module
from app.database import Base, get_db
from app.main import app
from app.models.custom_page import CustomPage
from app.utils.auth import get_current_user
@pytest.fixture()
def db_session():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture()
def client(db_session):
def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = lambda: object()
try:
yield TestClient(app)
finally:
app.dependency_overrides.clear()
def test_create_page_auto_enables_autofill_when_credentials_are_saved(client):
response = client.post("/api/custom-pages", json={
"name": "Login page",
"url": "https://example.test/login",
"access_mode": "remote_browser",
"login_username": "alice",
"login_password": "secret",
})
assert response.status_code == 201
assert response.json()["login_autofill_enabled"] is True
assert response.json()["login_password_configured"] is True
def test_create_page_respects_explicit_autofill_disable(client):
response = client.post("/api/custom-pages", json={
"name": "Login page",
"url": "https://example.test/login",
"access_mode": "remote_browser",
"login_username": "alice",
"login_password": "secret",
"login_autofill_enabled": False,
})
assert response.status_code == 201
assert response.json()["login_autofill_enabled"] is False
def test_update_page_auto_enables_autofill_when_new_password_is_saved(client, db_session):
page = CustomPage(
name="Login page",
url="https://example.test/login",
access_mode="remote_browser",
login_username="alice",
login_password="old-secret",
login_autofill_enabled=False,
)
db_session.add(page)
db_session.commit()
db_session.refresh(page)
response = client.put(f"/api/custom-pages/{page.id}", json={
"login_username": "alice@example.test",
"login_password": "new-secret",
})
assert response.status_code == 200
assert response.json()["login_autofill_enabled"] is True
def test_update_page_keeps_autofill_disabled_when_existing_password_is_kept(client, db_session):
page = CustomPage(
name="Login page",
url="https://example.test/login",
access_mode="remote_browser",
login_username="alice",
login_password="secret",
login_autofill_enabled=False,
)
db_session.add(page)
db_session.commit()
db_session.refresh(page)
response = client.put(f"/api/custom-pages/{page.id}", json={
"login_username": "alice@example.test",
})
assert response.status_code == 200
assert response.json()["login_autofill_enabled"] is False
def test_update_page_respects_explicit_autofill_disable(client, db_session):
page = CustomPage(
name="Login page",
url="https://example.test/login",
access_mode="remote_browser",
login_username="alice",
login_password="secret",
login_autofill_enabled=False,
)
db_session.add(page)
db_session.commit()
db_session.refresh(page)
response = client.put(f"/api/custom-pages/{page.id}", json={
"login_username": "alice@example.test",
"login_autofill_enabled": False,
})
assert response.status_code == 200
assert response.json()["login_autofill_enabled"] is False
def test_custom_page_migration_backfills_autofill_once(monkeypatch):
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
monkeypatch.setattr(database_module, "engine", engine)
Base.metadata.create_all(bind=engine)
with engine.begin() as conn:
conn.execute(text("ALTER TABLE custom_pages DROP COLUMN login_autofill_backfilled_at"))
conn.execute(text(
"INSERT INTO custom_pages "
"(name, url, icon, sort_order, enabled, use_proxy, access_mode, login_username, login_password, login_autofill_enabled, created_at, updated_at) "
"VALUES ('Login page', 'https://example.test/login', 'Link', 0, 1, 0, 'remote_browser', 'alice', 'secret', 0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)"
))
database_module._migrate_custom_pages()
with engine.begin() as conn:
row = conn.execute(text(
"SELECT login_autofill_enabled, login_autofill_backfilled_at FROM custom_pages WHERE name = 'Login page'"
)).one()
assert row.login_autofill_enabled == 1
assert row.login_autofill_backfilled_at is not None
conn.execute(text("UPDATE custom_pages SET login_autofill_enabled = 0"))
database_module._migrate_custom_pages()
with engine.begin() as conn:
row = conn.execute(text("SELECT login_autofill_enabled FROM custom_pages WHERE name = 'Login page'")).one()
assert row.login_autofill_enabled == 0
+353
View File
@@ -0,0 +1,353 @@
import json
import sys
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
sys.path.insert(0, str(Path(__file__).resolve().parent))
from app.database import Base, get_db
from app.main import app
from app.models.snapshot import UpstreamRateSnapshot
from app.models.upstream import Upstream
from app.models.notification_log import NotificationLog
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
from app.models.webhook_config import WebhookConfig
from app.routers import websites as websites_router
from app.utils.auth import get_current_user
@pytest.fixture()
def db_session():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture()
def client(db_session):
def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = lambda: object()
try:
yield TestClient(app)
finally:
app.dependency_overrides.clear()
def seed_rows(db_session, *, website_enabled=True, auto_sync_enabled=True):
website = Website(
name="Target",
site_type="sub2api",
base_url="http://target.local",
api_prefix="/api",
auth_type="api_key",
auth_config_json="{}",
groups_endpoint="/groups",
group_update_endpoint="/groups/{id}",
enabled=website_enabled,
auto_sync_enabled=auto_sync_enabled,
)
upstream = Upstream(
name="Upstream",
base_url="http://upstream.local",
api_prefix="/api",
auth_type="bearer",
auth_config_json="{}",
)
db_session.add_all([website, upstream])
db_session.commit()
db_session.refresh(website)
db_session.refresh(upstream)
snapshot = UpstreamRateSnapshot(
upstream_id=upstream.id,
snapshot_json=json.dumps({
"groups": {
"source": {
"group_id": "source",
"group_name": "Source",
"rate": "2",
}
}
}),
)
db_session.add(snapshot)
db_session.commit()
return website, upstream
def binding_payload(website_id, upstream_id, *, enabled=True):
return {
"website_id": website_id,
"target_group_id": "target",
"target_group_name": "Target group",
"source_groups": [{
"upstream_id": upstream_id,
"upstream_name": "Upstream",
"group_id": "source",
"group_name": "Source",
}],
"percent": 10,
"algorithm": "max_plus_percent",
"enabled": enabled,
}
def test_create_binding_runs_initial_sync(monkeypatch, client, db_session):
website, upstream = seed_rows(db_session)
calls = []
class FakeClient:
def __init__(self, **kwargs):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
def update_group_rate(self, endpoint, group_id, rate):
calls.append((endpoint, group_id, str(rate)))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id))
assert response.status_code == 201
assert response.json()["target_group_id"] == "target"
assert calls == [("/groups/{id}", "target", "2.2000")]
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "同步成功"
assert log.old_rate == "1"
assert log.new_rate == "2.2"
def test_create_binding_skips_write_when_website_auto_sync_disabled(client, db_session):
website, upstream = seed_rows(db_session, auto_sync_enabled=False)
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id))
assert response.status_code == 201
assert db_session.query(WebsiteGroupBinding).count() == 1
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "网站未启用自动同步,未写回"
assert log.old_rate is None
assert log.new_rate == "2.2"
def test_create_binding_skips_write_when_binding_disabled(client, db_session):
website, upstream = seed_rows(db_session)
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id, enabled=False))
assert response.status_code == 201
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "绑定未启用,未写回"
assert log.new_rate == "2.2"
def test_create_binding_keeps_binding_when_initial_sync_calculation_fails(client, db_session):
website, upstream = seed_rows(db_session)
db_session.query(UpstreamRateSnapshot).delete()
db_session.commit()
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id))
assert response.status_code == 201
assert db_session.query(WebsiteGroupBinding).count() == 1
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "failed"
assert "没有可用的正数上游倍率" in log.message
assert log.new_rate is None
def test_update_binding_runs_sync_after_save(monkeypatch, client, db_session):
website, upstream = seed_rows(db_session)
binding = WebsiteGroupBinding(
website_id=website.id,
target_group_id="target",
target_group_name="Target group",
source_groups_json=json.dumps([{
"upstream_id": upstream.id,
"upstream_name": "Upstream",
"group_id": "source",
"group_name": "Source",
}], ensure_ascii=False),
percent="10",
algorithm="max_plus_percent",
enabled=True,
)
db_session.add(binding)
db_session.commit()
db_session.refresh(binding)
calls = []
class FakeClient:
def __init__(self, **kwargs):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
def update_group_rate(self, endpoint, group_id, rate):
calls.append((endpoint, group_id, str(rate)))
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
response = client.put(
f"/api/group-bindings/{binding.id}",
json={
"target_group_name": "Target group",
"percent": 20,
"enabled": True,
},
)
assert response.status_code == 200
assert calls == [("/groups/{id}", "target", "2.4000")]
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "同步成功"
assert log.new_rate == "2.4"
def test_update_binding_skips_write_when_disabled(monkeypatch, client, db_session):
website, upstream = seed_rows(db_session)
binding = WebsiteGroupBinding(
website_id=website.id,
target_group_id="target",
target_group_name="Target group",
source_groups_json=json.dumps([{
"upstream_id": upstream.id,
"upstream_name": "Upstream",
"group_id": "source",
"group_name": "Source",
}], ensure_ascii=False),
percent="10",
algorithm="max_plus_percent",
enabled=False,
)
db_session.add(binding)
db_session.commit()
db_session.refresh(binding)
class FakeClient:
def __init__(self, **kwargs):
raise AssertionError("should not write when binding is disabled")
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
response = client.put(
f"/api/group-bindings/{binding.id}",
json={
"target_group_name": "Target group",
"percent": 20,
},
)
assert response.status_code == 200
log = db_session.query(WebsiteSyncLog).one()
assert log.status == "success"
assert log.message == "绑定未启用,未写回"
def test_create_binding_notifies_when_website_rate_changes(monkeypatch, client, db_session):
website, upstream = seed_rows(db_session)
webhook = WebhookConfig(
name="Notify",
type="generic",
url="http://notify.local/webhook",
enabled=True,
events_json=json.dumps(["website_rate_changed"]),
)
db_session.add(webhook)
db_session.commit()
sent_payloads = []
class FakeClient:
def __init__(self, **kwargs):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
def update_group_rate(self, endpoint, group_id, rate):
pass
def fake_send_generic(url, payload, timeout=15.0):
sent_payloads.append((url, payload))
return "ok"
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic)
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id))
assert response.status_code == 201
assert len(sent_payloads) == 1
_, payload = sent_payloads[0]
assert payload["event"] == "website_rate_changed"
assert payload["website"]["id"] == website.id
assert payload["target_group"]["old_rate"] == "1"
assert payload["target_group"]["new_rate"] == "2.2"
log = db_session.query(NotificationLog).one()
assert log.event_type == "website_rate_changed"
assert log.status == "success"
def test_create_binding_does_not_notify_when_website_rate_unchanged(monkeypatch, client, db_session):
website, upstream = seed_rows(db_session)
webhook = WebhookConfig(
name="Notify",
type="generic",
url="http://notify.local/webhook",
enabled=True,
events_json=json.dumps(["website_rate_changed"]),
)
db_session.add(webhook)
db_session.commit()
class FakeClient:
def __init__(self, **kwargs):
pass
def get_groups(self, endpoint):
return [{"id": "target", "name": "Target group", "rate_multiplier": "2.2"}]
def update_group_rate(self, endpoint, group_id, rate):
pass
def fake_send_generic(url, payload, timeout=15.0):
raise AssertionError("should not notify when target rate is unchanged")
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
monkeypatch.setattr("app.services.webhook_service._send_generic", fake_send_generic)
response = client.post("/api/group-bindings", json=binding_payload(website.id, upstream.id))
assert response.status_code == 201
assert db_session.query(NotificationLog).count() == 0
+36
View File
@@ -0,0 +1,36 @@
import json
import sys
import logging
from app.services.upstream_client import UpstreamClient
logging.basicConfig(level=logging.DEBUG)
def main():
client = UpstreamClient(
base_url="http://170.106.100.210:55555",
api_prefix="",
auth_type="bearer",
auth_config={"token": ""}, # We don't have token, but /api/group/ in some new-api may be open, or fail with 401
timeout=10.0,
)
try:
groups = client.get_available_groups("/api/group/")
print("Groups:", groups)
except Exception as e:
print("Groups Error:", e)
try:
rates = client.get_group_rates("/api/option/?key=GroupRatio")
print("Rates:", rates)
except Exception as e:
print("Rates Error:", e)
try:
from app.services.upstream_client import _extract_rates_map, _unwrap_list
print("Unwrapped Groups:", _unwrap_list(groups))
print("Extracted Rates:", _extract_rates_map(rates))
except Exception as e:
pass
if __name__ == "__main__":
main()
+70
View File
@@ -0,0 +1,70 @@
from app.services.website_client import normalize_groups
def test_normalize_groups_unwraps_sub2api_paginated_response():
groups = normalize_groups({
"code": 0,
"message": "success",
"data": {
"items": [
{"id": "codex-free", "name": "codex-free", "rate_multiplier": 1},
{"id": "my-plus", "name": "my-plus", "rate_multiplier": "1.5"},
{"id": "deepseek", "name": "deepseek"},
],
"total": 3,
"page": 1,
},
})
assert [group["id"] for group in groups] == ["codex-free", "my-plus", "deepseek"]
assert groups[0]["rate_multiplier"] == "1"
assert groups[1]["rate_multiplier"] == "1.5"
assert groups[2]["rate_multiplier"] is None
def test_normalize_groups_unwraps_wrapped_list_response():
groups = normalize_groups({
"code": 0,
"message": "success",
"data": [
{"id": "default", "name": "Default", "rateMultiplier": "2.0"},
],
})
assert groups == [{
"id": "default",
"name": "Default",
"rate_multiplier": "2",
"raw": {"id": "default", "name": "Default", "rateMultiplier": "2.0"},
}]
def test_normalize_groups_unwraps_groups_key_response():
groups = normalize_groups({
"data": {
"groups": [
{"group_id": "vip", "group_name": "VIP", "ratio": "0.75"},
],
},
})
assert groups[0]["id"] == "vip"
assert groups[0]["name"] == "VIP"
assert groups[0]["rate_multiplier"] == "0.75"
def test_normalize_groups_keeps_string_list_compatibility():
groups = normalize_groups(["free", "paid"])
assert [group["id"] for group in groups] == ["free", "paid"]
assert groups[0]["raw"] == {"id": "free", "name": "free"}
def test_normalize_groups_keeps_plain_dict_mapping_compatibility():
groups = normalize_groups({
"free": {"id": "free", "name": "Free", "rate_multiplier": "1.00"},
"paid": {"id": "paid", "name": "Paid"},
})
assert [group["id"] for group in groups] == ["free", "paid"]
assert groups[0]["rate_multiplier"] == "1"