maliahson commited on
Commit
71f163e
·
verified ·
1 Parent(s): 557422f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import pipeline
3
+ import librosa
4
+ from datetime import datetime
5
+ from deep_translator import GoogleTranslator
6
+ from typing import Dict, Union
7
+ from gliner import GLiNER
8
+ import gradio as gr
9
+
10
+ # Load transcription models
11
+ whisper_pipeline_agri = pipeline("automatic-speech-recognition", model="maliahson/whisper-agri")
12
+ device = 0 if torch.cuda.is_available() else "cpu"
13
+
14
+ # Initialize GLiNER for information extraction
15
+ gliner_model = GLiNER.from_pretrained("xomad/gliner-model-merge-large-v1.0").to("cpu")
16
+
17
+ def merge_entities(entities):
18
+ if not entities:
19
+ return []
20
+ merged = []
21
+ current = entities[0]
22
+ for next_entity in entities[1:]:
23
+ if next_entity['entity'] == current['entity'] and (next_entity['start'] == current['end'] + 1 or next_entity['start'] == current['end']):
24
+ current['word'] += ' ' + next_entity['word']
25
+ current['end'] = next_entity['end']
26
+ else:
27
+ merged.append(current)
28
+ current = next_entity
29
+ merged.append(current)
30
+ return merged
31
+
32
+ def transcribe_audio(audio_path):
33
+ """
34
+ Transcribe a local audio file using the Whisper pipeline, log timing, and save transcription to a file.
35
+ """
36
+ try:
37
+ # Log start time
38
+ start_time = datetime.now()
39
+
40
+ # Ensure audio is mono and resampled to 16kHz
41
+ audio, sr = librosa.load(audio_path, sr=16000, mono=True)
42
+
43
+ # Perform transcription
44
+ transcription = whisper_pipeline_agri(audio, batch_size=8)["text"]
45
+
46
+ # Log end time
47
+ end_time = datetime.now()
48
+
49
+ return transcription
50
+
51
+ except Exception as e:
52
+ return f"Error processing audio: {e}"
53
+
54
+ def translate_text_to_english(text):
55
+ """
56
+ Translate text into English using GoogleTranslator.
57
+ """
58
+ try:
59
+ # Perform translation
60
+ translated_text = GoogleTranslator(source='auto', target='en').translate(text)
61
+ return translated_text
62
+ except Exception as e:
63
+ return f"Error during translation: {e}"
64
+
65
+ def extract_information(prompt: str, text: str, threshold: float, nested_ner: bool) -> Dict[str, Union[str, int, float]]:
66
+ """
67
+ Extract entities from the English text using GLiNER model.
68
+ """
69
+ try:
70
+ text = prompt + "\n" + text
71
+ entities = [
72
+ {
73
+ "entity": entity["label"],
74
+ "word": entity["text"],
75
+ "start": entity["start"],
76
+ "end": entity["end"],
77
+ "score": 0,
78
+ }
79
+ for entity in gliner_model.predict_entities(
80
+ text, ["match"], flat_ner=not nested_ner, threshold=threshold
81
+ )
82
+ ]
83
+ merged_entities = merge_entities(entities)
84
+ return {"text": text, "entities": merged_entities}
85
+ except Exception as e:
86
+ return {"error": f"Information extraction failed: {e}"}
87
+
88
+ def pipeline_fn(audio, prompt, threshold, nested_ner):
89
+ """
90
+ Combine transcription, translation, and information extraction in a single pipeline.
91
+ """
92
+ transcription = transcribe_audio(audio)
93
+ if "Error" in transcription:
94
+ return transcription, "", "", {}
95
+
96
+ translated_text = translate_text_to_english(transcription)
97
+ if "Error" in translated_text:
98
+ return transcription, translated_text, "", {}
99
+
100
+ info_extraction = extract_information(prompt, translated_text, threshold, nested_ner)
101
+ return transcription, translated_text, info_extraction
102
+
103
+ # Gradio Interface
104
+ with gr.Blocks(title="Audio Processing and Information Extraction") as interface:
105
+ gr.Markdown("## Audio Transcription, Translation, and Information Extraction")
106
+
107
+ with gr.Row():
108
+ # Fixed: removed 'source' argument from gr.Audio
109
+ audio_input = gr.Audio(type="filepath", label="Upload Audio File")
110
+ prompt_input = gr.Textbox(label="Prompt for Information Extraction", placeholder="Enter your prompt here")
111
+
112
+ with gr.Row():
113
+ threshold_slider = gr.Slider(0, 1, value=0.3, step=0.01, label="NER Threshold")
114
+ nested_ner_checkbox = gr.Checkbox(label="Enable Nested NER")
115
+
116
+ with gr.Row():
117
+ transcription_output = gr.Textbox(label="Transcription (Urdu)", interactive=False) # Corrected to interactive=False
118
+ translation_output = gr.Textbox(label="Translation (English)", interactive=False) # Corrected to interactive=False
119
+
120
+ with gr.Row():
121
+ extraction_output = gr.HighlightedText(label="Extracted Information")
122
+
123
+ process_button = gr.Button("Process Audio")
124
+
125
+ process_button.click(
126
+ fn=pipeline_fn,
127
+ inputs=[audio_input, prompt_input, threshold_slider, nested_ner_checkbox],
128
+ outputs=[transcription_output, translation_output, extraction_output],
129
+ )
130
+
131
+ if __name__ == "__main__":
132
+ interface.launch()