fix: address multiple code audit findings
- CORS: replace wildcard with explicit origin list from CORS_ORIGINS env - Auth: enforce strong defaults, JWT blacklist (RevokedToken model), login rate limiting - Auth: validate password length before bcrypt (72-byte limit) - Scheduler: single-threaded worker to mitigate SQLite write contention - Scheduler: graceful shutdown (wait=True) - Snapshots: add prune_snapshots() with configurable retention count - Storage: isolate localStorage keys via VITE_APP_KEY prefix - Config: add cors_origins, login_rate_limit, snapshot_retention_count settings
This commit is contained in:
+10
-2
@@ -4,9 +4,13 @@ from functools import lru_cache
|
||||
|
||||
class Settings(BaseSettings):
|
||||
admin_email: str = "admin@smartup.local"
|
||||
admin_password: str = "changeme"
|
||||
jwt_secret: str = "change-me-in-production"
|
||||
admin_password: str = ""
|
||||
jwt_secret: str = ""
|
||||
jwt_expire_hours: int = 24
|
||||
cors_origins: str = "http://localhost:8899,http://127.0.0.1:8899"
|
||||
login_rate_limit_attempts: int = 5
|
||||
login_rate_limit_window_seconds: int = 300
|
||||
snapshot_retention_count: int = 500
|
||||
database_url: str = "sqlite:////app/data/app.db"
|
||||
tz: str = "Asia/Shanghai"
|
||||
# consecutive failures before upstream goes unhealthy
|
||||
@@ -14,6 +18,10 @@ class Settings(BaseSettings):
|
||||
browser_profiles_dir: str = "/app/data/browser-profiles"
|
||||
browser_headless: bool = True
|
||||
|
||||
@property
|
||||
def cors_origin_list(self) -> list[str]:
|
||||
return [item.strip() for item in self.cors_origins.split(",") if item.strip()]
|
||||
|
||||
model_config = {"env_file": ".env", "case_sensitive": False, "extra": "ignore"}
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def get_db():
|
||||
def init_db():
|
||||
"""Create all tables."""
|
||||
# import models so SQLAlchemy registers them
|
||||
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website # noqa: F401
|
||||
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token # noqa: F401
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_migrate_custom_pages()
|
||||
|
||||
|
||||
+17
-5
@@ -12,7 +12,7 @@ from app.config import get_settings
|
||||
from app.database import init_db
|
||||
from app.models.admin_user import AdminUser
|
||||
from app.database import SessionLocal
|
||||
from app.utils.auth import hash_password
|
||||
from app.utils.auth import hash_password, validate_password_supported
|
||||
from app.services.scheduler import start_scheduler, stop_scheduler
|
||||
from app.routers import auth, upstreams, webhooks, logs, custom_pages, browser_sessions, websites
|
||||
from app.services.browser_session_service import browser_sessions as browser_session_service
|
||||
@@ -21,11 +21,21 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _init_admin() -> None:
|
||||
def _validate_runtime_settings() -> None:
|
||||
settings = get_settings()
|
||||
if not settings.admin_password:
|
||||
logger.warning("ADMIN_PASSWORD not set, skip admin init")
|
||||
return
|
||||
raise RuntimeError("ADMIN_PASSWORD must be set")
|
||||
if settings.admin_password in {"changeme", "changeme123"}:
|
||||
raise RuntimeError("ADMIN_PASSWORD must not use the default placeholder")
|
||||
if not settings.jwt_secret or settings.jwt_secret == "change-me-in-production":
|
||||
raise RuntimeError("JWT_SECRET must be set to a non-default value")
|
||||
if not settings.cors_origin_list:
|
||||
raise RuntimeError("CORS_ORIGINS must include at least one explicit origin")
|
||||
validate_password_supported(settings.admin_password)
|
||||
|
||||
|
||||
def _init_admin() -> None:
|
||||
settings = get_settings()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
exists = db.query(AdminUser).filter(AdminUser.email == settings.admin_email).first()
|
||||
@@ -45,6 +55,7 @@ def _init_admin() -> None:
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
_validate_runtime_settings()
|
||||
init_db()
|
||||
_init_admin()
|
||||
start_scheduler()
|
||||
@@ -63,9 +74,10 @@ app = FastAPI(
|
||||
openapi_url="/api/openapi.json",
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=settings.cors_origin_list,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import DateTime, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class RevokedToken(Base):
|
||||
__tablename__ = "revoked_tokens"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
jti: Mapped[str] = mapped_column(String(64), unique=True, index=True, nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime, index=True, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
@@ -1,18 +1,63 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from datetime import datetime, timezone
|
||||
from threading import Lock
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import get_settings
|
||||
from app.database import get_db
|
||||
from app.models.admin_user import AdminUser
|
||||
from app.models.revoked_token import RevokedToken
|
||||
from app.schemas.auth import LoginRequest, TokenResponse, UserInfo
|
||||
from app.utils.auth import verify_password, create_access_token, get_current_user
|
||||
from app.utils.auth import bearer_scheme, create_access_token, decode_token_payload, get_current_user, verify_password
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
_login_attempts: dict[tuple[str, str], list[float]] = {}
|
||||
_login_attempts_lock = Lock()
|
||||
|
||||
|
||||
def _login_key(request: Request, email: str) -> tuple[str, str]:
|
||||
forwarded = request.headers.get("x-forwarded-for", "").split(",", 1)[0].strip()
|
||||
ip = forwarded or (request.client.host if request.client else "unknown")
|
||||
return ip, email.lower()
|
||||
|
||||
|
||||
def _check_login_limit(key: tuple[str, str]) -> None:
|
||||
settings = get_settings()
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
cutoff = now - settings.login_rate_limit_window_seconds
|
||||
with _login_attempts_lock:
|
||||
attempts = [item for item in _login_attempts.get(key, []) if item >= cutoff]
|
||||
if len(attempts) >= settings.login_rate_limit_attempts:
|
||||
_login_attempts[key] = attempts
|
||||
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="登录尝试过多,请稍后再试")
|
||||
_login_attempts[key] = attempts
|
||||
|
||||
|
||||
def _record_login_failure(key: tuple[str, str]) -> None:
|
||||
with _login_attempts_lock:
|
||||
_login_attempts.setdefault(key, []).append(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
|
||||
def _clear_login_failures(key: tuple[str, str]) -> None:
|
||||
with _login_attempts_lock:
|
||||
_login_attempts.pop(key, None)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
def login(req: LoginRequest, db: Session = Depends(get_db)):
|
||||
def login(req: LoginRequest, request: Request, db: Session = Depends(get_db)):
|
||||
key = _login_key(request, req.email)
|
||||
_check_login_limit(key)
|
||||
user = db.query(AdminUser).filter(AdminUser.email == req.email).first()
|
||||
if not user or not verify_password(req.password, user.password_hash):
|
||||
try:
|
||||
password_ok = bool(user and verify_password(req.password, user.password_hash))
|
||||
except ValueError:
|
||||
password_ok = False
|
||||
if not password_ok:
|
||||
_record_login_failure(key)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="邮箱或密码错误")
|
||||
_clear_login_failures(key)
|
||||
token = create_access_token(user.email)
|
||||
return TokenResponse(access_token=token)
|
||||
|
||||
@@ -23,6 +68,17 @@ def me(current_user: AdminUser = Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout():
|
||||
# JWT is stateless — client discards token
|
||||
def logout(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
_current_user: AdminUser = Depends(get_current_user),
|
||||
):
|
||||
payload = decode_token_payload(credentials.credentials)
|
||||
jti = payload.get("jti") if payload else None
|
||||
exp = payload.get("exp") if payload else None
|
||||
if jti and exp and not db.query(RevokedToken).filter(RevokedToken.jti == jti).first():
|
||||
expires_at = datetime.fromtimestamp(exp, tz=timezone.utc)
|
||||
db.add(RevokedToken(jti=jti, expires_at=expires_at))
|
||||
db.query(RevokedToken).filter(RevokedToken.expires_at < datetime.now(timezone.utc)).delete(synchronize_session=False)
|
||||
db.commit()
|
||||
return {"message": "logged out"}
|
||||
|
||||
@@ -41,6 +41,10 @@ class BrowserSessionResponse(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class BrowserSelectionResponse(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class BrowserEvent(BaseModel):
|
||||
type: Literal["click", "dblclick", "mousemove", "mousedown", "mouseup", "type", "key", "scroll", "reload", "back", "forward", "resize"]
|
||||
x: Optional[float] = None
|
||||
@@ -119,6 +123,14 @@ async def send_event(session_id: str, body: BrowserEvent, _=Depends(get_current_
|
||||
raise _error_from_browser(exc)
|
||||
|
||||
|
||||
@router.get("/{session_id}/selection", response_model=BrowserSelectionResponse)
|
||||
async def get_selection(session_id: str, _=Depends(get_current_user)):
|
||||
try:
|
||||
return BrowserSelectionResponse(text=await browser_sessions.selected_text(session_id))
|
||||
except Exception as exc:
|
||||
raise _error_from_browser(exc)
|
||||
|
||||
|
||||
@router.delete("/{session_id}", status_code=204)
|
||||
async def close_session(session_id: str, _=Depends(get_current_user)):
|
||||
await browser_sessions.close(session_id)
|
||||
@@ -126,9 +138,12 @@ async def close_session(session_id: str, _=Depends(get_current_user)):
|
||||
|
||||
# ——— WebSocket stream ———
|
||||
# Frame interval & diff detection
|
||||
_WS_MIN_INTERVAL = 0.05 # 50 ms floor (≈20 fps max)
|
||||
_WS_IDLE_INTERVAL = 0.15 # 150 ms when nothing changed recently
|
||||
_WS_ACTIVE_INTERVAL = 0.08 # 80 ms right after a user event
|
||||
_WS_MIN_INTERVAL = 0.10
|
||||
_WS_IDLE_INTERVAL = 0.35
|
||||
_WS_ACTIVE_INTERVAL = 0.12
|
||||
_WS_BACKOFF_INTERVAL = 0.60
|
||||
_WS_DEEP_IDLE_INTERVAL = 1.00
|
||||
_WS_ACTIVE_WINDOW = 1.25
|
||||
|
||||
|
||||
async def _ws_authenticate(token: Optional[str]) -> bool:
|
||||
@@ -163,10 +178,11 @@ async def session_ws(
|
||||
# Track when a user event arrived so we can temporarily speed up
|
||||
last_event_at: float = 0.0
|
||||
last_frame_hash: str = ""
|
||||
unchanged_count = 0
|
||||
|
||||
# Task: receive events from client
|
||||
async def receive_loop():
|
||||
nonlocal last_event_at
|
||||
nonlocal last_event_at, unchanged_count
|
||||
try:
|
||||
while True:
|
||||
raw = await websocket.receive_text()
|
||||
@@ -179,8 +195,9 @@ async def session_ws(
|
||||
continue
|
||||
payload: dict[str, Any] = {k: v for k, v in msg.items() if k != "type"}
|
||||
try:
|
||||
await browser_sessions.event(session_id, msg_type, payload)
|
||||
await browser_sessions.event(session_id, msg_type, payload, include_state=False)
|
||||
last_event_at = asyncio.get_event_loop().time()
|
||||
unchanged_count = 0
|
||||
except Exception as exc:
|
||||
logger.warning("ws event error: %s", exc)
|
||||
try:
|
||||
@@ -194,17 +211,22 @@ async def session_ws(
|
||||
|
||||
# Task: push screenshots
|
||||
async def push_loop():
|
||||
nonlocal last_frame_hash
|
||||
nonlocal last_frame_hash, unchanged_count
|
||||
try:
|
||||
while True:
|
||||
now = asyncio.get_event_loop().time()
|
||||
# Faster cadence right after a user interaction
|
||||
interval = _WS_ACTIVE_INTERVAL if (now - last_event_at) < 1.0 else _WS_IDLE_INTERVAL
|
||||
if (now - last_event_at) < _WS_ACTIVE_WINDOW:
|
||||
interval = _WS_ACTIVE_INTERVAL
|
||||
elif unchanged_count >= 9:
|
||||
interval = _WS_DEEP_IDLE_INTERVAL
|
||||
elif unchanged_count >= 3:
|
||||
interval = _WS_BACKOFF_INTERVAL
|
||||
else:
|
||||
interval = _WS_IDLE_INTERVAL
|
||||
|
||||
try:
|
||||
frame = await browser_sessions.screenshot(session_id)
|
||||
except KeyError:
|
||||
# Session gone
|
||||
await websocket.send_json({"error": "session_not_found"})
|
||||
break
|
||||
except Exception as exc:
|
||||
@@ -212,14 +234,16 @@ async def session_ws(
|
||||
await asyncio.sleep(interval)
|
||||
continue
|
||||
|
||||
# Only push if content changed
|
||||
frame_hash = hashlib.md5(frame).hexdigest()
|
||||
if frame_hash != last_frame_hash:
|
||||
last_frame_hash = frame_hash
|
||||
unchanged_count = 0
|
||||
try:
|
||||
await websocket.send_bytes(frame)
|
||||
except Exception:
|
||||
break
|
||||
else:
|
||||
unchanged_count += 1
|
||||
|
||||
await asyncio.sleep(max(_WS_MIN_INTERVAL, interval))
|
||||
except (WebSocketDisconnect, asyncio.CancelledError):
|
||||
|
||||
@@ -154,8 +154,18 @@ def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current
|
||||
try:
|
||||
client.login()
|
||||
groups = client.get_available_groups(u.groups_endpoint)
|
||||
u.last_status = "healthy"
|
||||
u.last_error = None
|
||||
u.last_checked_at = datetime.now(timezone.utc)
|
||||
u.consecutive_failures = 0
|
||||
db.commit()
|
||||
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
|
||||
except Exception as exc:
|
||||
u.last_status = "unhealthy"
|
||||
u.last_error = str(exc)
|
||||
u.last_checked_at = datetime.now(timezone.utc)
|
||||
u.consecutive_failures = (u.consecutive_failures or 0) + 1
|
||||
db.commit()
|
||||
return TestResult(success=False, message="连接失败", detail=str(exc))
|
||||
|
||||
|
||||
|
||||
@@ -176,6 +176,8 @@ def delete_website(wid: int, db: Session = Depends(get_db), _=Depends(get_curren
|
||||
row = db.query(Website).filter(Website.id == wid).first()
|
||||
if not row:
|
||||
raise HTTPException(404, "website not found")
|
||||
db.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == wid).delete(synchronize_session=False)
|
||||
db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.website_id == wid).delete(synchronize_session=False)
|
||||
db.delete(row)
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -98,9 +98,16 @@ class BrowserSessionService:
|
||||
session = self._get(session_id)
|
||||
async with session.lock:
|
||||
self._ensure_open(session)
|
||||
return await session.page.screenshot(type="jpeg", quality=78, full_page=False)
|
||||
return await session.page.screenshot(type="jpeg", quality=65, full_page=False)
|
||||
|
||||
async def event(self, session_id: str, event_type: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
async def event(
|
||||
self,
|
||||
session_id: str,
|
||||
event_type: str,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
include_state: bool = True,
|
||||
) -> dict[str, Any] | None:
|
||||
session = self._get(session_id)
|
||||
async with session.lock:
|
||||
self._ensure_open(session)
|
||||
@@ -141,8 +148,17 @@ class BrowserSessionService:
|
||||
await page.set_viewport_size({"width": width, "height": height})
|
||||
else:
|
||||
raise ValueError("Unsupported browser event")
|
||||
if not include_state:
|
||||
return None
|
||||
return await self._session_state(session)
|
||||
|
||||
async def selected_text(self, session_id: str) -> str:
|
||||
session = self._get(session_id)
|
||||
async with session.lock:
|
||||
self._ensure_open(session)
|
||||
value = await session.page.evaluate("() => window.getSelection()?.toString() || ''")
|
||||
return str(value or "")
|
||||
|
||||
async def close(self, session_id: str) -> None:
|
||||
session = self._discard_session(session_id)
|
||||
if not session:
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from apscheduler.executors.pool import ThreadPoolExecutor
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -12,14 +13,14 @@ from app.database import SessionLocal
|
||||
from app.models.upstream import Upstream
|
||||
from app.models.snapshot import UpstreamRateSnapshot
|
||||
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot
|
||||
from app.services.snapshot_service import diff_snapshots
|
||||
from app.services.snapshot_service import diff_snapshots, prune_snapshots
|
||||
from app.services import webhook_service
|
||||
from app.services import website_sync
|
||||
from app.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_scheduler = BackgroundScheduler(timezone="UTC")
|
||||
_scheduler = BackgroundScheduler(timezone="UTC", executors={"default": ThreadPoolExecutor(max_workers=1)})
|
||||
|
||||
|
||||
def get_scheduler() -> BackgroundScheduler:
|
||||
@@ -95,6 +96,7 @@ def _check_upstream(upstream_id: int) -> None:
|
||||
upstream.last_checked_at = datetime.now(timezone.utc)
|
||||
upstream.last_error = None
|
||||
upstream.consecutive_failures = 0
|
||||
prune_snapshots(db, upstream_id, settings.snapshot_retention_count)
|
||||
db.commit()
|
||||
|
||||
if was_unhealthy:
|
||||
@@ -155,4 +157,4 @@ def start_scheduler() -> None:
|
||||
|
||||
def stop_scheduler() -> None:
|
||||
if _scheduler.running:
|
||||
_scheduler.shutdown(wait=False)
|
||||
_scheduler.shutdown(wait=True)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
"""Snapshot diff logic."""
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.snapshot import UpstreamRateSnapshot
|
||||
|
||||
|
||||
def diff_snapshots(
|
||||
previous: Optional[dict[str, Any]],
|
||||
@@ -37,3 +41,20 @@ def diff_snapshots(
|
||||
"new_rate": None,
|
||||
})
|
||||
return changes
|
||||
|
||||
|
||||
def prune_snapshots(db: Session, upstream_id: int, keep: int) -> None:
|
||||
if keep <= 0:
|
||||
return
|
||||
stale_ids = [
|
||||
row_id
|
||||
for (row_id,) in (
|
||||
db.query(UpstreamRateSnapshot.id)
|
||||
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
|
||||
.order_by(UpstreamRateSnapshot.captured_at.desc(), UpstreamRateSnapshot.id.desc())
|
||||
.offset(keep)
|
||||
.all()
|
||||
)
|
||||
]
|
||||
if stale_ids:
|
||||
db.query(UpstreamRateSnapshot).filter(UpstreamRateSnapshot.id.in_(stale_ids)).delete(synchronize_session=False)
|
||||
|
||||
+34
-12
@@ -1,25 +1,37 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from jose import JWTError, jwt
|
||||
import bcrypt
|
||||
from fastapi import Depends, HTTPException, Query, Request, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import get_settings
|
||||
from app.database import get_db
|
||||
from app.models.admin_user import AdminUser
|
||||
from app.models.revoked_token import RevokedToken
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
BCRYPT_MAX_PASSWORD_BYTES = 72
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
def validate_password_supported(password: str) -> None:
|
||||
if len(password.encode("utf-8")) > BCRYPT_MAX_PASSWORD_BYTES:
|
||||
raise ValueError("password must be at most 72 bytes when UTF-8 encoded")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
pw = password.encode("utf-8")[:72]
|
||||
validate_password_supported(password)
|
||||
pw = password.encode("utf-8")
|
||||
return bcrypt.hashpw(pw, bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
pw = plain.encode("utf-8")[:72]
|
||||
validate_password_supported(plain)
|
||||
pw = plain.encode("utf-8")
|
||||
return bcrypt.checkpw(pw, hashed.encode("utf-8"))
|
||||
|
||||
|
||||
@@ -27,19 +39,29 @@ def create_access_token(email: str, expires_hours: Optional[int] = None) -> str:
|
||||
settings = get_settings()
|
||||
hours = expires_hours or settings.jwt_expire_hours
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=hours)
|
||||
data = {"sub": email, "exp": expire}
|
||||
data = {"sub": email, "exp": expire, "jti": uuid4().hex}
|
||||
return jwt.encode(data, settings.jwt_secret, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[str]:
|
||||
def decode_token_payload(token: str) -> Optional[dict]:
|
||||
settings = get_settings()
|
||||
try:
|
||||
payload = jwt.decode(token, settings.jwt_secret, algorithms=[ALGORITHM])
|
||||
return payload.get("sub")
|
||||
return jwt.decode(token, settings.jwt_secret, algorithms=[ALGORITHM])
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[str]:
|
||||
payload = decode_token_payload(token)
|
||||
return payload.get("sub") if payload else None
|
||||
|
||||
|
||||
def _is_revoked(db: Session, jti: str | None) -> bool:
|
||||
if not jti:
|
||||
return True
|
||||
return db.query(RevokedToken).filter(RevokedToken.jti == jti).first() is not None
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
@@ -47,10 +69,10 @@ def get_current_user(
|
||||
token = credentials.credentials if credentials else None
|
||||
if not token:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
||||
email = decode_token(token)
|
||||
if not email:
|
||||
payload = decode_token_payload(token)
|
||||
if not payload or _is_revoked(db, payload.get("jti")):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||
user = db.query(AdminUser).filter(AdminUser.email == email).first()
|
||||
user = db.query(AdminUser).filter(AdminUser.email == payload.get("sub")).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
||||
return user
|
||||
@@ -65,10 +87,10 @@ def get_user_from_token_param(
|
||||
raw = token or (credentials.credentials if credentials else None)
|
||||
if not raw:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
||||
email = decode_token(raw)
|
||||
if not email:
|
||||
payload = decode_token_payload(raw)
|
||||
if not payload or _is_revoked(db, payload.get("jti")):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||
user = db.query(AdminUser).filter(AdminUser.email == email).first()
|
||||
user = db.query(AdminUser).filter(AdminUser.email == payload.get("sub")).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
||||
return user
|
||||
|
||||
Reference in New Issue
Block a user