馃 distributed transcription service
thistle.dunkirk.sh
1import os
2import json
3import tempfile
4import asyncio
5import sqlite3
6import time
7import uuid
8from faster_whisper import WhisperModel
9from fastapi import FastAPI, UploadFile, File
10from fastapi.responses import StreamingResponse
11from sse_starlette.sse import EventSourceResponse
12
13# --- 1. Load Model on Startup ---
14# This loads the model only once, not on every request
15print("--- Loading faster-whisper model... ---")
16model_size = "small"
17# You can change this to "cuda" and "float16" if you have a GPU
18model = WhisperModel(model_size, device="cpu", compute_type="int8")
19print(f"--- Model '{model_size}' loaded. Server is ready. ---")
20
21# --- 2. Setup Database for Job Tracking ---
22db_path = "./whisper.db" # Independent DB for Whisper server
23db = sqlite3.connect(db_path, check_same_thread=False)
24db.execute("""
25CREATE TABLE IF NOT EXISTS whisper_jobs (
26 id TEXT PRIMARY KEY,
27 status TEXT DEFAULT 'pending',
28 progress REAL DEFAULT 0,
29 transcript TEXT DEFAULT '',
30 error_message TEXT DEFAULT '',
31 created_at INTEGER,
32 updated_at INTEGER
33)
34""")
35db.commit()
36
37# --- 2. Create FastAPI App ---
38app = FastAPI(title="Whisper Transcription Server with Progress")
39
40
41# --- 3. Define the Transcription Function ---
42# Runs in background and updates DB
43def run_transcription(job_id: str, temp_file_path: str):
44 try:
45 # 1. Update to processing
46 db.execute("UPDATE whisper_jobs SET status = 'processing', updated_at = ? WHERE id = ?", (int(time.time()), job_id))
47 db.commit()
48
49 # 2. Get segments and total audio duration
50 segments, info = model.transcribe(
51 temp_file_path,
52 beam_size=5,
53 vad_filter=True
54 )
55
56 total_duration = round(info.duration, 2)
57 print(f"Job {job_id}: Total audio duration: {total_duration}s")
58 print(f"Job {job_id}: Detected language: {info.language}")
59
60 transcript = ""
61
62 # 3. Process each segment
63 for segment in segments:
64 progress_percent = (segment.end / total_duration) * 100
65 transcript += segment.text.strip() + " "
66
67 db.execute("""
68 UPDATE whisper_jobs SET progress = ?, transcript = ?, updated_at = ? WHERE id = ?
69 """, (round(progress_percent, 2), transcript.strip(), int(time.time()), job_id))
70 db.commit()
71
72 # 4. Complete
73 db.execute("UPDATE whisper_jobs SET status = 'completed', progress = 100, updated_at = ? WHERE id = ?", (int(time.time()), job_id))
74 db.commit()
75
76 except Exception as e:
77 db.execute("UPDATE whisper_jobs SET status = 'failed', error_message = ?, updated_at = ? WHERE id = ?", (str(e), int(time.time()), job_id))
78 db.commit()
79
80 finally:
81 # Clean up temp file
82 print(f"Job {job_id}: Cleaning up temp file: {temp_file_path}")
83 os.remove(temp_file_path)
84
85
86# --- 4. Define the FastAPI Endpoints ---
87@app.post("/transcribe")
88async def transcribe_endpoint(file: UploadFile = File(...)):
89 """
90 Accepts an audio file, starts transcription in background, returns job ID.
91 """
92
93 # Generate job ID
94 job_id = str(uuid.uuid4())
95
96 # Save the uploaded file to a temporary file
97 with tempfile.NamedTemporaryFile(delete=False, suffix=".tmp") as temp_file:
98 while content := await file.read(1024 * 1024):
99 temp_file.write(content)
100 temp_file_path = temp_file.name
101
102 print(f"Job {job_id}: File saved to temporary path: {temp_file_path}")
103
104 # Create job in DB
105 db.execute("INSERT INTO whisper_jobs (id, created_at, updated_at) VALUES (?, ?, ?)", (job_id, int(time.time()), int(time.time())))
106 db.commit()
107
108 # Start transcription in background
109 asyncio.create_task(asyncio.to_thread(run_transcription, job_id, temp_file_path))
110
111 return {"job_id": job_id}
112
113@app.get("/transcribe/{job_id}/stream")
114async def stream_transcription_status(job_id: str):
115 """
116 Stream the status and progress of a transcription job via SSE.
117 """
118 async def event_generator():
119 last_updated_at = None
120
121 while True:
122 row = db.execute("""
123 SELECT status, progress, transcript, error_message, updated_at
124 FROM whisper_jobs
125 WHERE id = ?
126 """, (job_id,)).fetchone()
127
128 if not row:
129 yield {
130 "event": "error",
131 "data": json.dumps({"error": "Job not found"})
132 }
133 return
134
135 status, progress, transcript, error_message, updated_at = row
136
137 # Only send if data changed
138 if updated_at != last_updated_at:
139 last_updated_at = updated_at
140
141 data = {
142 "status": status,
143 "progress": progress,
144 }
145
146 # Include transcript only if it changed (save bandwidth)
147 if transcript:
148 data["transcript"] = transcript
149
150 if error_message:
151 data["error_message"] = error_message
152
153 yield {
154 "event": "message",
155 "data": json.dumps(data)
156 }
157
158 # Close stream if job is complete or failed
159 if status in ('completed', 'failed'):
160 return
161
162 # Poll every 500ms
163 await asyncio.sleep(0.5)
164
165 return EventSourceResponse(event_generator())
166
167@app.get("/transcribe/{job_id}")
168def get_transcription_status(job_id: str):
169 """
170 Get the status and progress of a transcription job.
171 """
172 row = db.execute("SELECT status, progress, transcript, error_message FROM whisper_jobs WHERE id = ?", (job_id,)).fetchone()
173 if not row:
174 return {"error": "Job not found"}, 404
175
176 status, progress, transcript, error_message = row
177 return {
178 "status": status,
179 "progress": progress,
180 "transcript": transcript,
181 "error_message": error_message
182 }
183
184@app.get("/jobs")
185def list_jobs():
186 """
187 List all jobs with their current status. Used for recovery/sync.
188 """
189 rows = db.execute("""
190 SELECT id, status, progress, created_at, updated_at
191 FROM whisper_jobs
192 ORDER BY created_at DESC
193 """).fetchall()
194
195 jobs = []
196 for row in rows:
197 jobs.append({
198 "id": row[0],
199 "status": row[1],
200 "progress": row[2],
201 "created_at": row[3],
202 "updated_at": row[4]
203 })
204
205 return {"jobs": jobs}
206
207@app.delete("/transcribe/{job_id}")
208def delete_job(job_id: str):
209 """
210 Delete a job from the database. Used for cleanup.
211 """
212 result = db.execute("DELETE FROM whisper_jobs WHERE id = ?", (job_id,))
213 db.commit()
214
215 if result.rowcount == 0:
216 return {"error": "Job not found"}, 404
217
218 return {"success": True}
219
220
221if __name__ == "__main__":
222 import uvicorn
223 uvicorn.run(app, host="0.0.0.0", port=8000)