fix: address multiple code audit findings

- CORS: replace wildcard with explicit origin list from CORS_ORIGINS env
- Auth: enforce strong defaults, JWT blacklist (RevokedToken model), login rate limiting
- Auth: validate password length before bcrypt (72-byte limit)
- Scheduler: single-threaded worker to mitigate SQLite write contention
- Scheduler: graceful shutdown (wait=True)
- Snapshots: add prune_snapshots() with configurable retention count
- Storage: isolate localStorage keys via VITE_APP_KEY prefix
- Config: add cors_origins, login_rate_limit, snapshot_retention_count settings
This commit is contained in:
SmartUp Developer
2026-05-17 10:52:18 +08:00
parent a42ecf7bcc
commit ad16618406
25 changed files with 792 additions and 165 deletions
+10 -2
View File
@@ -4,9 +4,13 @@ from functools import lru_cache
class Settings(BaseSettings):
admin_email: str = "admin@smartup.local"
admin_password: str = "changeme"
jwt_secret: str = "change-me-in-production"
admin_password: str = ""
jwt_secret: str = ""
jwt_expire_hours: int = 24
cors_origins: str = "http://localhost:8899,http://127.0.0.1:8899"
login_rate_limit_attempts: int = 5
login_rate_limit_window_seconds: int = 300
snapshot_retention_count: int = 500
database_url: str = "sqlite:////app/data/app.db"
tz: str = "Asia/Shanghai"
# consecutive failures before upstream goes unhealthy
@@ -14,6 +18,10 @@ class Settings(BaseSettings):
browser_profiles_dir: str = "/app/data/browser-profiles"
browser_headless: bool = True
@property
def cors_origin_list(self) -> list[str]:
return [item.strip() for item in self.cors_origins.split(",") if item.strip()]
model_config = {"env_file": ".env", "case_sensitive": False, "extra": "ignore"}
+1 -1
View File
@@ -26,7 +26,7 @@ def get_db():
def init_db():
"""Create all tables."""
# import models so SQLAlchemy registers them
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website # noqa: F401
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token # noqa: F401
Base.metadata.create_all(bind=engine)
_migrate_custom_pages()
+17 -5
View File
@@ -12,7 +12,7 @@ from app.config import get_settings
from app.database import init_db
from app.models.admin_user import AdminUser
from app.database import SessionLocal
from app.utils.auth import hash_password
from app.utils.auth import hash_password, validate_password_supported
from app.services.scheduler import start_scheduler, stop_scheduler
from app.routers import auth, upstreams, webhooks, logs, custom_pages, browser_sessions, websites
from app.services.browser_session_service import browser_sessions as browser_session_service
@@ -21,11 +21,21 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name
logger = logging.getLogger(__name__)
def _init_admin() -> None:
def _validate_runtime_settings() -> None:
settings = get_settings()
if not settings.admin_password:
logger.warning("ADMIN_PASSWORD not set, skip admin init")
return
raise RuntimeError("ADMIN_PASSWORD must be set")
if settings.admin_password in {"changeme", "changeme123"}:
raise RuntimeError("ADMIN_PASSWORD must not use the default placeholder")
if not settings.jwt_secret or settings.jwt_secret == "change-me-in-production":
raise RuntimeError("JWT_SECRET must be set to a non-default value")
if not settings.cors_origin_list:
raise RuntimeError("CORS_ORIGINS must include at least one explicit origin")
validate_password_supported(settings.admin_password)
def _init_admin() -> None:
settings = get_settings()
db = SessionLocal()
try:
exists = db.query(AdminUser).filter(AdminUser.email == settings.admin_email).first()
@@ -45,6 +55,7 @@ def _init_admin() -> None:
@asynccontextmanager
async def lifespan(app: FastAPI):
_validate_runtime_settings()
init_db()
_init_admin()
start_scheduler()
@@ -63,9 +74,10 @@ app = FastAPI(
openapi_url="/api/openapi.json",
)
settings = get_settings()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=settings.cors_origin_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
+14
View File
@@ -0,0 +1,14 @@
from datetime import datetime, timezone
from sqlalchemy import DateTime, Integer, String
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class RevokedToken(Base):
__tablename__ = "revoked_tokens"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
jti: Mapped[str] = mapped_column(String(64), unique=True, index=True, nullable=False)
expires_at: Mapped[datetime] = mapped_column(DateTime, index=True, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
+62 -6
View File
@@ -1,18 +1,63 @@
from fastapi import APIRouter, Depends, HTTPException, status
from datetime import datetime, timezone
from threading import Lock
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from app.config import get_settings
from app.database import get_db
from app.models.admin_user import AdminUser
from app.models.revoked_token import RevokedToken
from app.schemas.auth import LoginRequest, TokenResponse, UserInfo
from app.utils.auth import verify_password, create_access_token, get_current_user
from app.utils.auth import bearer_scheme, create_access_token, decode_token_payload, get_current_user, verify_password
router = APIRouter(prefix="/api/auth", tags=["auth"])
_login_attempts: dict[tuple[str, str], list[float]] = {}
_login_attempts_lock = Lock()
def _login_key(request: Request, email: str) -> tuple[str, str]:
forwarded = request.headers.get("x-forwarded-for", "").split(",", 1)[0].strip()
ip = forwarded or (request.client.host if request.client else "unknown")
return ip, email.lower()
def _check_login_limit(key: tuple[str, str]) -> None:
settings = get_settings()
now = datetime.now(timezone.utc).timestamp()
cutoff = now - settings.login_rate_limit_window_seconds
with _login_attempts_lock:
attempts = [item for item in _login_attempts.get(key, []) if item >= cutoff]
if len(attempts) >= settings.login_rate_limit_attempts:
_login_attempts[key] = attempts
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="登录尝试过多,请稍后再试")
_login_attempts[key] = attempts
def _record_login_failure(key: tuple[str, str]) -> None:
with _login_attempts_lock:
_login_attempts.setdefault(key, []).append(datetime.now(timezone.utc).timestamp())
def _clear_login_failures(key: tuple[str, str]) -> None:
with _login_attempts_lock:
_login_attempts.pop(key, None)
@router.post("/login", response_model=TokenResponse)
def login(req: LoginRequest, db: Session = Depends(get_db)):
def login(req: LoginRequest, request: Request, db: Session = Depends(get_db)):
key = _login_key(request, req.email)
_check_login_limit(key)
user = db.query(AdminUser).filter(AdminUser.email == req.email).first()
if not user or not verify_password(req.password, user.password_hash):
try:
password_ok = bool(user and verify_password(req.password, user.password_hash))
except ValueError:
password_ok = False
if not password_ok:
_record_login_failure(key)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="邮箱或密码错误")
_clear_login_failures(key)
token = create_access_token(user.email)
return TokenResponse(access_token=token)
@@ -23,6 +68,17 @@ def me(current_user: AdminUser = Depends(get_current_user)):
@router.post("/logout")
def logout():
# JWT is stateless — client discards token
def logout(
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
db: Session = Depends(get_db),
_current_user: AdminUser = Depends(get_current_user),
):
payload = decode_token_payload(credentials.credentials)
jti = payload.get("jti") if payload else None
exp = payload.get("exp") if payload else None
if jti and exp and not db.query(RevokedToken).filter(RevokedToken.jti == jti).first():
expires_at = datetime.fromtimestamp(exp, tz=timezone.utc)
db.add(RevokedToken(jti=jti, expires_at=expires_at))
db.query(RevokedToken).filter(RevokedToken.expires_at < datetime.now(timezone.utc)).delete(synchronize_session=False)
db.commit()
return {"message": "logged out"}
+34 -10
View File
@@ -41,6 +41,10 @@ class BrowserSessionResponse(BaseModel):
title: str
class BrowserSelectionResponse(BaseModel):
text: str
class BrowserEvent(BaseModel):
type: Literal["click", "dblclick", "mousemove", "mousedown", "mouseup", "type", "key", "scroll", "reload", "back", "forward", "resize"]
x: Optional[float] = None
@@ -119,6 +123,14 @@ async def send_event(session_id: str, body: BrowserEvent, _=Depends(get_current_
raise _error_from_browser(exc)
@router.get("/{session_id}/selection", response_model=BrowserSelectionResponse)
async def get_selection(session_id: str, _=Depends(get_current_user)):
try:
return BrowserSelectionResponse(text=await browser_sessions.selected_text(session_id))
except Exception as exc:
raise _error_from_browser(exc)
@router.delete("/{session_id}", status_code=204)
async def close_session(session_id: str, _=Depends(get_current_user)):
await browser_sessions.close(session_id)
@@ -126,9 +138,12 @@ async def close_session(session_id: str, _=Depends(get_current_user)):
# ——— WebSocket stream ———
# Frame interval & diff detection
_WS_MIN_INTERVAL = 0.05 # 50 ms floor (≈20 fps max)
_WS_IDLE_INTERVAL = 0.15 # 150 ms when nothing changed recently
_WS_ACTIVE_INTERVAL = 0.08 # 80 ms right after a user event
_WS_MIN_INTERVAL = 0.10
_WS_IDLE_INTERVAL = 0.35
_WS_ACTIVE_INTERVAL = 0.12
_WS_BACKOFF_INTERVAL = 0.60
_WS_DEEP_IDLE_INTERVAL = 1.00
_WS_ACTIVE_WINDOW = 1.25
async def _ws_authenticate(token: Optional[str]) -> bool:
@@ -163,10 +178,11 @@ async def session_ws(
# Track when a user event arrived so we can temporarily speed up
last_event_at: float = 0.0
last_frame_hash: str = ""
unchanged_count = 0
# Task: receive events from client
async def receive_loop():
nonlocal last_event_at
nonlocal last_event_at, unchanged_count
try:
while True:
raw = await websocket.receive_text()
@@ -179,8 +195,9 @@ async def session_ws(
continue
payload: dict[str, Any] = {k: v for k, v in msg.items() if k != "type"}
try:
await browser_sessions.event(session_id, msg_type, payload)
await browser_sessions.event(session_id, msg_type, payload, include_state=False)
last_event_at = asyncio.get_event_loop().time()
unchanged_count = 0
except Exception as exc:
logger.warning("ws event error: %s", exc)
try:
@@ -194,17 +211,22 @@ async def session_ws(
# Task: push screenshots
async def push_loop():
nonlocal last_frame_hash
nonlocal last_frame_hash, unchanged_count
try:
while True:
now = asyncio.get_event_loop().time()
# Faster cadence right after a user interaction
interval = _WS_ACTIVE_INTERVAL if (now - last_event_at) < 1.0 else _WS_IDLE_INTERVAL
if (now - last_event_at) < _WS_ACTIVE_WINDOW:
interval = _WS_ACTIVE_INTERVAL
elif unchanged_count >= 9:
interval = _WS_DEEP_IDLE_INTERVAL
elif unchanged_count >= 3:
interval = _WS_BACKOFF_INTERVAL
else:
interval = _WS_IDLE_INTERVAL
try:
frame = await browser_sessions.screenshot(session_id)
except KeyError:
# Session gone
await websocket.send_json({"error": "session_not_found"})
break
except Exception as exc:
@@ -212,14 +234,16 @@ async def session_ws(
await asyncio.sleep(interval)
continue
# Only push if content changed
frame_hash = hashlib.md5(frame).hexdigest()
if frame_hash != last_frame_hash:
last_frame_hash = frame_hash
unchanged_count = 0
try:
await websocket.send_bytes(frame)
except Exception:
break
else:
unchanged_count += 1
await asyncio.sleep(max(_WS_MIN_INTERVAL, interval))
except (WebSocketDisconnect, asyncio.CancelledError):
+10
View File
@@ -154,8 +154,18 @@ def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current
try:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
u.last_status = "healthy"
u.last_error = None
u.last_checked_at = datetime.now(timezone.utc)
u.consecutive_failures = 0
db.commit()
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
except Exception as exc:
u.last_status = "unhealthy"
u.last_error = str(exc)
u.last_checked_at = datetime.now(timezone.utc)
u.consecutive_failures = (u.consecutive_failures or 0) + 1
db.commit()
return TestResult(success=False, message="连接失败", detail=str(exc))
+2
View File
@@ -176,6 +176,8 @@ def delete_website(wid: int, db: Session = Depends(get_db), _=Depends(get_curren
row = db.query(Website).filter(Website.id == wid).first()
if not row:
raise HTTPException(404, "website not found")
db.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == wid).delete(synchronize_session=False)
db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.website_id == wid).delete(synchronize_session=False)
db.delete(row)
db.commit()
@@ -98,9 +98,16 @@ class BrowserSessionService:
session = self._get(session_id)
async with session.lock:
self._ensure_open(session)
return await session.page.screenshot(type="jpeg", quality=78, full_page=False)
return await session.page.screenshot(type="jpeg", quality=65, full_page=False)
async def event(self, session_id: str, event_type: str, payload: dict[str, Any]) -> dict[str, Any]:
async def event(
self,
session_id: str,
event_type: str,
payload: dict[str, Any],
*,
include_state: bool = True,
) -> dict[str, Any] | None:
session = self._get(session_id)
async with session.lock:
self._ensure_open(session)
@@ -141,8 +148,17 @@ class BrowserSessionService:
await page.set_viewport_size({"width": width, "height": height})
else:
raise ValueError("Unsupported browser event")
if not include_state:
return None
return await self._session_state(session)
async def selected_text(self, session_id: str) -> str:
session = self._get(session_id)
async with session.lock:
self._ensure_open(session)
value = await session.page.evaluate("() => window.getSelection()?.toString() || ''")
return str(value or "")
async def close(self, session_id: str) -> None:
session = self._discard_session(session_id)
if not session:
+5 -3
View File
@@ -5,6 +5,7 @@ import json
import logging
from datetime import datetime, timezone
from apscheduler.executors.pool import ThreadPoolExecutor
from apscheduler.schedulers.background import BackgroundScheduler
from sqlalchemy.orm import Session
@@ -12,14 +13,14 @@ from app.database import SessionLocal
from app.models.upstream import Upstream
from app.models.snapshot import UpstreamRateSnapshot
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot
from app.services.snapshot_service import diff_snapshots
from app.services.snapshot_service import diff_snapshots, prune_snapshots
from app.services import webhook_service
from app.services import website_sync
from app.config import get_settings
logger = logging.getLogger(__name__)
_scheduler = BackgroundScheduler(timezone="UTC")
_scheduler = BackgroundScheduler(timezone="UTC", executors={"default": ThreadPoolExecutor(max_workers=1)})
def get_scheduler() -> BackgroundScheduler:
@@ -95,6 +96,7 @@ def _check_upstream(upstream_id: int) -> None:
upstream.last_checked_at = datetime.now(timezone.utc)
upstream.last_error = None
upstream.consecutive_failures = 0
prune_snapshots(db, upstream_id, settings.snapshot_retention_count)
db.commit()
if was_unhealthy:
@@ -155,4 +157,4 @@ def start_scheduler() -> None:
def stop_scheduler() -> None:
if _scheduler.running:
_scheduler.shutdown(wait=False)
_scheduler.shutdown(wait=True)
+21
View File
@@ -1,6 +1,10 @@
"""Snapshot diff logic."""
from typing import Any, Optional
from sqlalchemy.orm import Session
from app.models.snapshot import UpstreamRateSnapshot
def diff_snapshots(
previous: Optional[dict[str, Any]],
@@ -37,3 +41,20 @@ def diff_snapshots(
"new_rate": None,
})
return changes
def prune_snapshots(db: Session, upstream_id: int, keep: int) -> None:
if keep <= 0:
return
stale_ids = [
row_id
for (row_id,) in (
db.query(UpstreamRateSnapshot.id)
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
.order_by(UpstreamRateSnapshot.captured_at.desc(), UpstreamRateSnapshot.id.desc())
.offset(keep)
.all()
)
]
if stale_ids:
db.query(UpstreamRateSnapshot).filter(UpstreamRateSnapshot.id.in_(stale_ids)).delete(synchronize_session=False)
+34 -12
View File
@@ -1,25 +1,37 @@
from datetime import datetime, timedelta, timezone
from typing import Optional
from uuid import uuid4
from jose import JWTError, jwt
import bcrypt
from fastapi import Depends, HTTPException, Query, Request, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from app.config import get_settings
from app.database import get_db
from app.models.admin_user import AdminUser
from app.models.revoked_token import RevokedToken
ALGORITHM = "HS256"
BCRYPT_MAX_PASSWORD_BYTES = 72
bearer_scheme = HTTPBearer(auto_error=False)
def validate_password_supported(password: str) -> None:
if len(password.encode("utf-8")) > BCRYPT_MAX_PASSWORD_BYTES:
raise ValueError("password must be at most 72 bytes when UTF-8 encoded")
def hash_password(password: str) -> str:
pw = password.encode("utf-8")[:72]
validate_password_supported(password)
pw = password.encode("utf-8")
return bcrypt.hashpw(pw, bcrypt.gensalt()).decode("utf-8")
def verify_password(plain: str, hashed: str) -> bool:
pw = plain.encode("utf-8")[:72]
validate_password_supported(plain)
pw = plain.encode("utf-8")
return bcrypt.checkpw(pw, hashed.encode("utf-8"))
@@ -27,19 +39,29 @@ def create_access_token(email: str, expires_hours: Optional[int] = None) -> str:
settings = get_settings()
hours = expires_hours or settings.jwt_expire_hours
expire = datetime.now(timezone.utc) + timedelta(hours=hours)
data = {"sub": email, "exp": expire}
data = {"sub": email, "exp": expire, "jti": uuid4().hex}
return jwt.encode(data, settings.jwt_secret, algorithm=ALGORITHM)
def decode_token(token: str) -> Optional[str]:
def decode_token_payload(token: str) -> Optional[dict]:
settings = get_settings()
try:
payload = jwt.decode(token, settings.jwt_secret, algorithms=[ALGORITHM])
return payload.get("sub")
return jwt.decode(token, settings.jwt_secret, algorithms=[ALGORITHM])
except JWTError:
return None
def decode_token(token: str) -> Optional[str]:
payload = decode_token_payload(token)
return payload.get("sub") if payload else None
def _is_revoked(db: Session, jti: str | None) -> bool:
if not jti:
return True
return db.query(RevokedToken).filter(RevokedToken.jti == jti).first() is not None
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
db: Session = Depends(get_db),
@@ -47,10 +69,10 @@ def get_current_user(
token = credentials.credentials if credentials else None
if not token:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
email = decode_token(token)
if not email:
payload = decode_token_payload(token)
if not payload or _is_revoked(db, payload.get("jti")):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
user = db.query(AdminUser).filter(AdminUser.email == email).first()
user = db.query(AdminUser).filter(AdminUser.email == payload.get("sub")).first()
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
return user
@@ -65,10 +87,10 @@ def get_user_from_token_param(
raw = token or (credentials.credentials if credentials else None)
if not raw:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
email = decode_token(raw)
if not email:
payload = decode_token_payload(raw)
if not payload or _is_revoked(db, payload.get("jti")):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
user = db.query(AdminUser).filter(AdminUser.email == email).first()
user = db.query(AdminUser).filter(AdminUser.email == payload.get("sub")).first()
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
return user