Files
MusicWorkshop/backend/app/main.py
T
2026-04-30 14:34:28 +08:00

443 lines
13 KiB
Python

import asyncio
import mimetypes
import os
import threading
from pathlib import Path
from fastapi import FastAPI, Query, Request, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from .exception_service import ExceptionItemNotFoundError, ExceptionService
from .matcher import Matcher
from .metadata_status import probe_metadata_services
from .preprocessor import Preprocessor
from .repair_runner import RepairRunner, RepairService
from .library_service import LibraryService, LibraryTrackNotFoundError
from .scanner import Scanner
from .schemas import (
ConfigPayload,
ConfigSaveResponse,
ExceptionDetailPayload,
ExceptionListResponse,
ExceptionSummaryPayload,
LibraryMoveToExceptionResponse,
LibrarySummaryPayload,
LibraryTracksPageResponse,
MetadataStatusResponse,
RepairExecuteRequest,
RepairPreviewRequest,
RepairPreviewResponse,
RepairTaskCurrentResponse,
RepairTaskRunResponse,
TaskCurrentResponse,
TaskDetailResponse,
TaskHistoryListResponse,
TaskItemsPageResponse,
TaskLogsPageResponse,
TaskRunResponse
)
from .storage import ConfigStore
from .task_runner import TaskRunner
from .task_store import TaskConflictError, TaskNotFoundError, TaskStore
from .task_stream import TaskStreamManager
BASE_DIR = Path(__file__).resolve().parent.parent
DEFAULT_DB_PATH = BASE_DIR / 'data' / 'music_workshop.db'
DB_PATH = Path(os.getenv('MUSIC_WORKSHOP_DB_PATH', DEFAULT_DB_PATH))
store = ConfigStore(DB_PATH)
task_store = TaskStore(DB_PATH)
task_stream = TaskStreamManager()
scanner = Scanner()
preprocessor = Preprocessor()
matcher = Matcher()
library_service = LibraryService(task_store, preprocessor)
exception_service = ExceptionService(task_store)
task_runner = TaskRunner(task_store, scanner, preprocessor, task_stream, matcher)
repair_service = RepairService(task_store, exception_service, matcher, preprocessor, task_stream)
repair_runner = RepairRunner(task_store, task_stream, repair_service)
repair_service.runner = repair_runner
app = FastAPI(title='Music Workshop API', version='0.1.0')
app.add_middleware(
CORSMiddleware,
allow_origins=[
'http://localhost:5173',
'http://127.0.0.1:5173'
],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*']
)
@app.on_event('startup')
async def startup():
task_stream.set_loop(asyncio.get_running_loop())
task_store.fail_stale_active_tasks()
@app.get('/api/health')
def healthcheck():
return {'status': 'ok'}
@app.get('/api/config', response_model=ConfigPayload)
def get_config():
return store.get_config()
@app.get('/api/config/metadata-status', response_model=MetadataStatusResponse)
def get_config_metadata_status():
config = store.get_config()
return {'metadataStatus': probe_metadata_services(config['metadata'])}
@app.put('/api/config', response_model=ConfigSaveResponse)
def update_config(payload: ConfigPayload):
saved_config = store.save_config(payload.model_dump())
return {
'config': saved_config,
'metadataStatus': probe_metadata_services(saved_config['metadata'])
}
@app.post('/api/tasks/run', response_model=TaskRunResponse, status_code=202)
def run_task():
config_snapshot = store.get_config()
try:
task = task_store.create_task_if_idle(config_snapshot)
except TaskConflictError as error:
return JSONResponse(
status_code=409,
content={
'detail': 'Task already running',
'task_id': error.active_task_id
}
)
threading.Thread(
target=task_runner.start_task,
args=(task['task_id'], config_snapshot),
daemon=True
).start()
return {
'task_id': task['task_id'],
'status': task['status'],
'current_stage': task['current_stage'],
'stage_states': task['stage_states'],
'started_at': task['started_at']
}
@app.get('/api/tasks/current', response_model=TaskCurrentResponse)
def get_current_task():
task = task_store.get_active_task() or task_store.get_latest_task()
return {'task': task}
@app.get('/api/repair-tasks/current', response_model=RepairTaskCurrentResponse)
def get_current_repair_task():
task = task_store.get_active_task('repair') or task_store.get_latest_task('repair')
return {'task': task}
@app.get('/api/tasks', response_model=TaskHistoryListResponse)
def get_tasks(
page: int = Query(default=1, ge=1),
page_size: int = Query(default=8, ge=1, le=200)
):
return task_store.list_task_history(page, page_size)
@app.get('/api/tasks/{task_id}', response_model=TaskDetailResponse)
def get_task(task_id: str):
return {'task': task_store.get_task(task_id)}
@app.get('/api/tasks/{task_id}/items', response_model=TaskItemsPageResponse)
def get_task_items(
task_id: str,
scan_status: str | None = None,
preprocess_status: str | None = None,
match_status: str | None = None,
dedupe_status: str | None = None,
organize_status: str | None = None,
active_only: bool = False,
page: int = Query(default=1, ge=1),
page_size: int = Query(default=50, ge=1, le=200)
):
task_store.get_task(task_id)
return task_store.list_task_items(
task_id,
scan_status,
page,
page_size,
preprocess_status=preprocess_status,
match_status=match_status,
dedupe_status=dedupe_status,
organize_status=organize_status,
active_only=active_only
)
@app.get('/api/tasks/{task_id}/logs', response_model=TaskLogsPageResponse)
def get_task_logs(
task_id: str,
page: int = Query(default=1, ge=1),
page_size: int = Query(default=50, ge=1, le=200)
):
task_store.get_task(task_id)
return task_store.list_task_logs(task_id, page, page_size)
@app.get('/api/library/summary', response_model=LibrarySummaryPayload)
def get_library_summary():
config = store.get_config()
return library_service.get_summary(config.get('output') or '')
@app.get('/api/library/tracks', response_model=LibraryTracksPageResponse)
def get_library_tracks(
q: str | None = None,
artist: str | None = None,
album: str | None = None,
format: str | None = None,
has_provenance: bool | None = None,
page: int = Query(default=1, ge=1),
page_size: int = Query(default=50, ge=1, le=200),
sort_by: str = Query(default='organized_at'),
sort_order: str = Query(default='desc')
):
config = store.get_config()
return library_service.get_tracks_page(
config.get('output') or '',
q=q,
artist=artist,
album=album,
format=format,
has_provenance=has_provenance,
page=page,
page_size=page_size,
sort_by=sort_by,
sort_order=sort_order
)
@app.post('/api/library/tracks/{track_id}/move-to-exception', response_model=LibraryMoveToExceptionResponse)
def move_library_track_to_exception(track_id: str):
config = store.get_config()
try:
return library_service.move_track_to_exception(config, track_id)
except TaskConflictError as error:
return JSONResponse(
status_code=409,
content={
'detail': 'Task already running',
'task_id': error.active_task_id
}
)
@app.get('/api/exceptions/summary', response_model=ExceptionSummaryPayload)
def get_exception_summary():
return exception_service.get_summary()
@app.get('/api/exceptions/items', response_model=ExceptionListResponse)
def get_exception_items(
type: str = Query(default='all'),
resolution_status: str = Query(default='open'),
page: int = Query(default=1, ge=1),
page_size: int = Query(default=50, ge=1, le=200)
):
return exception_service.get_items(type, page, page_size, resolution_status)
@app.get('/api/exceptions/items/{exception_id}', response_model=ExceptionDetailPayload)
def get_exception_item(exception_id: int):
return exception_service.get_item(exception_id)
@app.get('/api/exceptions/items/{exception_id}/audio')
def get_exception_item_audio(exception_id: int, request: Request):
audio_path = exception_service.resolve_audio_path(exception_id)
file_size = audio_path.stat().st_size
content_type = mimetypes.guess_type(audio_path.name)[0] or 'application/octet-stream'
range_header = request.headers.get('range')
if not range_header:
return StreamingResponse(
_iter_file_range(audio_path, 0, file_size - 1),
media_type=content_type,
headers={
'Accept-Ranges': 'bytes',
'Content-Length': str(file_size)
}
)
start, end = _parse_range_header(range_header, file_size)
content_length = end - start + 1
return StreamingResponse(
_iter_file_range(audio_path, start, end),
status_code=206,
media_type=content_type,
headers={
'Accept-Ranges': 'bytes',
'Content-Length': str(content_length),
'Content-Range': f'bytes {start}-{end}/{file_size}'
}
)
@app.post('/api/exceptions/actions/preview', response_model=RepairPreviewResponse)
def preview_exception_action(payload: RepairPreviewRequest):
config_snapshot = store.get_config()
return repair_service.preview(payload.model_dump(), config_snapshot)
@app.post('/api/exceptions/actions/execute', response_model=RepairTaskRunResponse, status_code=202)
def execute_exception_action(payload: RepairExecuteRequest):
config_snapshot = store.get_config()
try:
task = repair_service.execute(payload.model_dump(), config_snapshot)
except TaskConflictError as error:
return JSONResponse(
status_code=409,
content={'detail': 'Repair task already running', 'task_id': error.active_task_id}
)
threading.Thread(
target=repair_runner.start_task,
args=(task['task_id'], config_snapshot),
daemon=True
).start()
return {
'repair_task_id': task['task_id'],
'status': task['status'],
'current_stage': task['current_stage'],
'stage_states': task['stage_states'],
'started_at': task['started_at']
}
@app.get('/api/repair-tasks/{task_id}', response_model=TaskDetailResponse)
def get_repair_task(task_id: str):
task = task_store.get_task(task_id)
if task.get('task_type') != 'repair':
raise TaskNotFoundError(task_id)
return {'task': task}
@app.get('/api/repair-tasks/{task_id}/logs', response_model=TaskLogsPageResponse)
def get_repair_task_logs(
task_id: str,
page: int = Query(default=1, ge=1),
page_size: int = Query(default=50, ge=1, le=200)
):
task = task_store.get_task(task_id)
if task.get('task_type') != 'repair':
raise TaskNotFoundError(task_id)
return task_store.list_task_logs(task_id, page, page_size)
@app.websocket('/api/tasks/{task_id}/stream')
async def stream_task(task_id: str, websocket: WebSocket):
try:
snapshot = task_store.get_task_snapshot(task_id)
except TaskNotFoundError:
await websocket.accept()
await websocket.close(code=4404)
return
await task_stream.connect(task_id, websocket)
try:
await websocket.send_json(
{
'type': 'task.snapshot',
'task_id': task_id,
'stage': snapshot['task']['current_stage'],
'timestamp': snapshot['task']['updated_at'],
'data': snapshot
}
)
while True:
await websocket.receive_text()
except WebSocketDisconnect:
task_stream.disconnect(task_id, websocket)
except Exception:
task_stream.disconnect(task_id, websocket)
raise
@app.websocket('/api/repair-tasks/{task_id}/stream')
async def stream_repair_task(task_id: str, websocket: WebSocket):
await stream_task(task_id, websocket)
@app.exception_handler(ValueError)
def value_error_handler(_, exc: ValueError):
return JSONResponse(status_code=400, content={'detail': str(exc)})
@app.exception_handler(TaskNotFoundError)
def task_not_found_error_handler(_, exc: TaskNotFoundError):
return JSONResponse(status_code=404, content={'detail': f'Task not found: {exc}'})
@app.exception_handler(ExceptionItemNotFoundError)
def exception_item_not_found_error_handler(_, exc: ExceptionItemNotFoundError):
return JSONResponse(status_code=404, content={'detail': f'Exception item not found: {exc}'})
@app.exception_handler(LibraryTrackNotFoundError)
def library_track_not_found_error_handler(_, exc: LibraryTrackNotFoundError):
return JSONResponse(status_code=404, content={'detail': f'Library track not found: {exc}'})
@app.exception_handler(FileNotFoundError)
def file_not_found_error_handler(_, exc: FileNotFoundError):
return JSONResponse(status_code=404, content={'detail': str(exc)})
def _parse_range_header(header_value: str, file_size: int) -> tuple[int, int]:
if not header_value.startswith('bytes='):
raise ValueError('Invalid Range header')
range_value = header_value[6:].strip()
start_text, _, end_text = range_value.partition('-')
if not start_text and not end_text:
raise ValueError('Invalid Range header')
if start_text:
start = int(start_text)
end = int(end_text) if end_text else file_size - 1
else:
suffix_length = int(end_text)
if suffix_length <= 0:
raise ValueError('Invalid Range header')
start = max(file_size - suffix_length, 0)
end = file_size - 1
if start < 0 or end < start or start >= file_size:
raise ValueError('Invalid Range header')
return start, min(end, file_size - 1)
def _iter_file_range(file_path: Path, start: int, end: int, chunk_size: int = 64 * 1024):
with file_path.open('rb') as file_handle:
file_handle.seek(start)
remaining = end - start + 1
while remaining > 0:
chunk = file_handle.read(min(chunk_size, remaining))
if not chunk:
break
remaining -= len(chunk)
yield chunk