馃 distributed transcription service thistle.dunkirk.sh
1import os 2import json 3import tempfile 4import asyncio 5import sqlite3 6import time 7import uuid 8from faster_whisper import WhisperModel 9from fastapi import FastAPI, UploadFile, File 10from fastapi.responses import StreamingResponse 11from sse_starlette.sse import EventSourceResponse 12 13# --- 1. Load Model on Startup --- 14# This loads the model only once, not on every request 15print("--- Loading faster-whisper model... ---") 16model_size = "small" 17# You can change this to "cuda" and "float16" if you have a GPU 18model = WhisperModel(model_size, device="cpu", compute_type="int8") 19print(f"--- Model '{model_size}' loaded. Server is ready. ---") 20 21# --- 2. Setup Database for Job Tracking --- 22db_path = "./whisper.db" # Independent DB for Whisper server 23db = sqlite3.connect(db_path, check_same_thread=False) 24db.execute(""" 25CREATE TABLE IF NOT EXISTS whisper_jobs ( 26 id TEXT PRIMARY KEY, 27 status TEXT DEFAULT 'pending', 28 progress REAL DEFAULT 0, 29 transcript TEXT DEFAULT '', 30 error_message TEXT DEFAULT '', 31 created_at INTEGER, 32 updated_at INTEGER 33) 34""") 35db.commit() 36 37# --- 2. Create FastAPI App --- 38app = FastAPI(title="Whisper Transcription Server with Progress") 39 40 41# --- 3. Define the Transcription Function --- 42# Runs in background and updates DB 43def run_transcription(job_id: str, temp_file_path: str): 44 try: 45 # 1. Update to processing 46 db.execute("UPDATE whisper_jobs SET status = 'processing', updated_at = ? WHERE id = ?", (int(time.time()), job_id)) 47 db.commit() 48 49 # 2. Get segments and total audio duration 50 segments, info = model.transcribe( 51 temp_file_path, 52 beam_size=5, 53 vad_filter=True 54 ) 55 56 total_duration = round(info.duration, 2) 57 print(f"Job {job_id}: Total audio duration: {total_duration}s") 58 print(f"Job {job_id}: Detected language: {info.language}") 59 60 transcript = "" 61 62 # 3. Process each segment 63 for segment in segments: 64 progress_percent = (segment.end / total_duration) * 100 65 transcript += segment.text.strip() + " " 66 67 db.execute(""" 68 UPDATE whisper_jobs SET progress = ?, transcript = ?, updated_at = ? WHERE id = ? 69 """, (round(progress_percent, 2), transcript.strip(), int(time.time()), job_id)) 70 db.commit() 71 72 # 4. Complete 73 db.execute("UPDATE whisper_jobs SET status = 'completed', progress = 100, updated_at = ? WHERE id = ?", (int(time.time()), job_id)) 74 db.commit() 75 76 except Exception as e: 77 db.execute("UPDATE whisper_jobs SET status = 'failed', error_message = ?, updated_at = ? WHERE id = ?", (str(e), int(time.time()), job_id)) 78 db.commit() 79 80 finally: 81 # Clean up temp file 82 print(f"Job {job_id}: Cleaning up temp file: {temp_file_path}") 83 os.remove(temp_file_path) 84 85 86# --- 4. Define the FastAPI Endpoints --- 87@app.post("/transcribe") 88async def transcribe_endpoint(file: UploadFile = File(...)): 89 """ 90 Accepts an audio file, starts transcription in background, returns job ID. 91 """ 92 93 # Generate job ID 94 job_id = str(uuid.uuid4()) 95 96 # Save the uploaded file to a temporary file 97 with tempfile.NamedTemporaryFile(delete=False, suffix=".tmp") as temp_file: 98 while content := await file.read(1024 * 1024): 99 temp_file.write(content) 100 temp_file_path = temp_file.name 101 102 print(f"Job {job_id}: File saved to temporary path: {temp_file_path}") 103 104 # Create job in DB 105 db.execute("INSERT INTO whisper_jobs (id, created_at, updated_at) VALUES (?, ?, ?)", (job_id, int(time.time()), int(time.time()))) 106 db.commit() 107 108 # Start transcription in background 109 asyncio.create_task(asyncio.to_thread(run_transcription, job_id, temp_file_path)) 110 111 return {"job_id": job_id} 112 113@app.get("/transcribe/{job_id}/stream") 114async def stream_transcription_status(job_id: str): 115 """ 116 Stream the status and progress of a transcription job via SSE. 117 """ 118 async def event_generator(): 119 last_updated_at = None 120 121 while True: 122 row = db.execute(""" 123 SELECT status, progress, transcript, error_message, updated_at 124 FROM whisper_jobs 125 WHERE id = ? 126 """, (job_id,)).fetchone() 127 128 if not row: 129 yield { 130 "event": "error", 131 "data": json.dumps({"error": "Job not found"}) 132 } 133 return 134 135 status, progress, transcript, error_message, updated_at = row 136 137 # Only send if data changed 138 if updated_at != last_updated_at: 139 last_updated_at = updated_at 140 141 data = { 142 "status": status, 143 "progress": progress, 144 } 145 146 # Include transcript only if it changed (save bandwidth) 147 if transcript: 148 data["transcript"] = transcript 149 150 if error_message: 151 data["error_message"] = error_message 152 153 yield { 154 "event": "message", 155 "data": json.dumps(data) 156 } 157 158 # Close stream if job is complete or failed 159 if status in ('completed', 'failed'): 160 return 161 162 # Poll every 500ms 163 await asyncio.sleep(0.5) 164 165 return EventSourceResponse(event_generator()) 166 167@app.get("/transcribe/{job_id}") 168def get_transcription_status(job_id: str): 169 """ 170 Get the status and progress of a transcription job. 171 """ 172 row = db.execute("SELECT status, progress, transcript, error_message FROM whisper_jobs WHERE id = ?", (job_id,)).fetchone() 173 if not row: 174 return {"error": "Job not found"}, 404 175 176 status, progress, transcript, error_message = row 177 return { 178 "status": status, 179 "progress": progress, 180 "transcript": transcript, 181 "error_message": error_message 182 } 183 184@app.get("/jobs") 185def list_jobs(): 186 """ 187 List all jobs with their current status. Used for recovery/sync. 188 """ 189 rows = db.execute(""" 190 SELECT id, status, progress, created_at, updated_at 191 FROM whisper_jobs 192 ORDER BY created_at DESC 193 """).fetchall() 194 195 jobs = [] 196 for row in rows: 197 jobs.append({ 198 "id": row[0], 199 "status": row[1], 200 "progress": row[2], 201 "created_at": row[3], 202 "updated_at": row[4] 203 }) 204 205 return {"jobs": jobs} 206 207@app.delete("/transcribe/{job_id}") 208def delete_job(job_id: str): 209 """ 210 Delete a job from the database. Used for cleanup. 211 """ 212 result = db.execute("DELETE FROM whisper_jobs WHERE id = ?", (job_id,)) 213 db.commit() 214 215 if result.rowcount == 0: 216 return {"error": "Job not found"}, 404 217 218 return {"success": True} 219 220 221if __name__ == "__main__": 222 import uvicorn 223 uvicorn.run(app, host="0.0.0.0", port=8000)