greg0rs commited on
Commit
2db3ee9
·
verified ·
1 Parent(s): e320304

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -3
app.py CHANGED
@@ -1,7 +1,76 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
1
+ import os
2
+ import io
3
+ import subprocess
4
+
5
+ from fastapi import FastAPI, UploadFile, File
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ import torchaudio
8
+ import torch
9
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
10
+
11
+ # Use writable cache paths
12
+ os.environ['TRANSFORMERS_CACHE'] = '/app/cache'
13
+ os.environ['HF_HOME'] = '/app/cache'
14
+ os.environ['TORCH_HOME'] = '/app/cache'
15
 
16
  app = FastAPI()
17
 
18
+ # CORS config
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["http://localhost:8080"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ # Load model + processor
28
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
29
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
30
+
31
+ def convert_webm_to_wav(webm_bytes: bytes) -> io.BytesIO:
32
+ process = subprocess.run(
33
+ ["ffmpeg", "-i", "pipe:0", "-f", "wav", "pipe:1"],
34
+ input=webm_bytes,
35
+ stdout=subprocess.PIPE,
36
+ stderr=subprocess.PIPE # Capture stderr now
37
+ )
38
+
39
+ if process.returncode != 0:
40
+ print("❌ ffmpeg error:", process.stderr.decode())
41
+ raise RuntimeError("ffmpeg conversion failed")
42
+
43
+ return io.BytesIO(process.stdout)
44
+
45
+
46
+ @app.post("/api/transcribe")
47
+ async def transcribe(audio: UploadFile = File(...)):
48
+ # Read uploaded file
49
+ contents = await audio.read()
50
+
51
+ # Convert webm to wav in-memory
52
+ wav_io = convert_webm_to_wav(contents)
53
+
54
+ # Load into torch tensor
55
+ waveform, sample_rate = torchaudio.load(wav_io)
56
+
57
+ # Resample if needed
58
+ if sample_rate != 16000:
59
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
60
+ sample_rate = 16000
61
+
62
+ # Run through model
63
+ input_values = processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values
64
+ with torch.no_grad():
65
+ logits = model(input_values).logits
66
+ predicted_ids = torch.argmax(logits, dim=-1)
67
+
68
+ # Decode to text
69
+ transcription = processor.decode(predicted_ids[0])
70
+
71
+ return {"phonemes": transcription}
72
+
73
  @app.get("/")
74
+ def root():
75
+ return {"message": "Backend is running"}
76
+