greg0rs commited on
Commit
4adedcb
Β·
verified Β·
1 Parent(s): 11a1daf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -13
app.py CHANGED
@@ -6,16 +6,15 @@ from fastapi import FastAPI, UploadFile, File
6
  from fastapi.middleware.cors import CORSMiddleware
7
  import torchaudio
8
  import torch
9
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
10
 
11
  # Use writable cache paths
12
  os.environ['HF_HOME'] = '/app/cache'
13
  os.environ['TORCH_HOME'] = '/app/cache'
14
 
15
- # FastAPI app setup
16
  app = FastAPI()
17
 
18
- # CORS: allow frontend from localhost
19
  app.add_middleware(
20
  CORSMiddleware,
21
  allow_origins=["http://localhost:8080"],
@@ -24,11 +23,17 @@ app.add_middleware(
24
  allow_headers=["*"],
25
  )
26
 
27
- # Load phoneme model + processor
28
  try:
29
- processor = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
30
- model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
31
- print("βœ… Model loaded successfully.")
 
 
 
 
 
 
32
  except Exception as e:
33
  print("❌ Model load error:", str(e))
34
  raise
@@ -57,18 +62,33 @@ async def transcribe(audio: UploadFile = File(...)):
57
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
58
  sample_rate = 16000
59
 
60
- input_values = processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values
 
61
  with torch.no_grad():
62
- logits = model(input_values).logits
63
- predicted_ids = torch.argmax(logits, dim=-1)
64
- phonemes = processor.decode(predicted_ids[0])
65
 
66
- return {"phonemes": phonemes}
 
 
 
 
 
 
 
 
 
 
67
 
68
  except Exception as e:
69
  print("❌ Transcription error:", str(e))
70
- return {"phonemes": "[Error: " + str(e) + "]"}
 
 
 
71
 
72
  @app.get("/")
73
  def root():
74
  return {"message": "Backend is running"}
 
 
6
  from fastapi.middleware.cors import CORSMiddleware
7
  import torchaudio
8
  import torch
9
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
10
 
11
  # Use writable cache paths
12
  os.environ['HF_HOME'] = '/app/cache'
13
  os.environ['TORCH_HOME'] = '/app/cache'
14
 
15
+ # FastAPI setup
16
  app = FastAPI()
17
 
 
18
  app.add_middleware(
19
  CORSMiddleware,
20
  allow_origins=["http://localhost:8080"],
 
23
  allow_headers=["*"],
24
  )
25
 
26
+ # Load models
27
  try:
28
+ # Phoneme model
29
+ phoneme_processor = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
30
+ phoneme_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
31
+
32
+ # Speech-to-text model
33
+ stt_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
34
+ stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
35
+
36
+ print("βœ… Models loaded successfully.")
37
  except Exception as e:
38
  print("❌ Model load error:", str(e))
39
  raise
 
62
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
63
  sample_rate = 16000
64
 
65
+ # Run phoneme model
66
+ phoneme_inputs = phoneme_processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values
67
  with torch.no_grad():
68
+ phoneme_logits = phoneme_model(phoneme_inputs).logits
69
+ phoneme_ids = torch.argmax(phoneme_logits, dim=-1)
70
+ phonemes = phoneme_processor.decode(phoneme_ids[0])
71
 
72
+ # Run speech-to-text model
73
+ stt_inputs = stt_processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values
74
+ with torch.no_grad():
75
+ stt_logits = stt_model(stt_inputs).logits
76
+ stt_ids = torch.argmax(stt_logits, dim=-1)
77
+ transcript = stt_processor.decode(stt_ids[0])
78
+
79
+ return {
80
+ "phonemes": phonemes,
81
+ "transcript": transcript
82
+ }
83
 
84
  except Exception as e:
85
  print("❌ Transcription error:", str(e))
86
+ return {
87
+ "phonemes": "[Error]",
88
+ "transcript": "[Error: " + str(e) + "]"
89
+ }
90
 
91
  @app.get("/")
92
  def root():
93
  return {"message": "Backend is running"}
94
+