File size: 3,608 Bytes
a36df50
 
da304d0
44f51cc
 
 
1762ffe
 
 
 
 
 
 
 
 
 
 
44f51cc
 
 
 
 
 
 
 
 
a36df50
44f51cc
 
 
 
 
 
 
 
 
1232794
44f51cc
a36df50
44f51cc
 
 
 
1232794
44f51cc
bfab15d
1232794
44f51cc
23cc1b1
 
 
44f51cc
1232794
44f51cc
bfab15d
1232794
44f51cc
 
 
1232794
44f51cc
4ac2c78
44f51cc
 
 
 
 
 
a36df50
44f51cc
 
 
 
 
 
 
a36df50
44f51cc
 
a36df50
44f51cc
 
a36df50
44f51cc
 
a36df50
 
44f51cc
1232794
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from pydub import AudioSegment
import re

# ์—…๋กœ๋“œํ•œ ๋ชจ๋ธ ๋กœ๋“œ
repo_name = "ireneminhee/speech-to-depression"
model = WhisperForConditionalGeneration.from_pretrained(repo_name)
processor = WhisperProcessor.from_pretrained(repo_name)

# ์Œ์„ฑ์„ ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
def transcribe(audio):
    inputs = processor(audio, return_tensors="pt", sampling_rate=16000)
    generated_ids = model.generate(inputs.input_features)
    transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return transcription

# ์šฐ์šธ์ฆ ์˜ˆ์ธก ๋ชจ๋ธ ๋กœ๋“œ
def load_model_from_safetensors(model_name, safetensors_path):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, config=model_name)
    state_dict = torch.load(safetensors_path)  # safetensors๋ฅผ ๋ชจ๋ธ๋กœ ๋กœ๋“œ
    model.load_state_dict(state_dict)
    model.eval()
    return model, tokenizer

# ์˜ˆ์ธก ํ•จ์ˆ˜
def predict_depression(sentences, model, tokenizer):
    results = []
    for sentence in sentences:
        inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            prediction = torch.argmax(logits, dim=-1).item()
            results.append((sentence, prediction))
    return results

# ์ „์ฒด ํ”„๋กœ์„ธ์Šค๋ฅผ ์‹คํ–‰ํ•˜๋Š” ํ•จ์ˆ˜
def process_audio_and_predict(audio):
    # 1. Whisper ๋ชจ๋ธ๋กœ ์Œ์„ฑ์„ ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜
    text = transcribe_audio(audio)

    # 2. ํ…์ŠคํŠธ๋ฅผ ๋ฌธ์žฅ ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„๊ธฐ
    #sentences = split_sentences_using_gpt(text)

    # 3. ๋ชจ๋ธ ๋กœ๋“œ (๋ชจ๋ธ ๊ฒฝ๋กœ์— ๋งž๊ฒŒ ์ˆ˜์ •)
    # ๋ชจ๋ธ๊ณผ tokenizer ๊ฒฝ๋กœ (์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์„ค์ •)
    safetensors_path = "./model/model.safetensors"  # SafeTensors ๋ชจ๋ธ ํŒŒ์ผ ๊ฒฝ๋กœ
    model_name = "klue/bert-base"  # ๋ชจ๋ธ ์ด๋ฆ„ ๋˜๋Š” Hugging Face ๊ฒฝ๋กœ
    model, tokenizer = load_model_from_safetensors(model_name, safetensors_path)

    # 4. ๋ฌธ์žฅ๋ณ„๋กœ ์šฐ์šธ ์ฆ์ƒ ์˜ˆ์ธก
    results = predict_depression(text, model, tokenizer)

    # 5. ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜
    df_result = pd.DataFrame(results, columns=["Sentence", "Depression_Prediction"])
    average_probability = df_result["Depression_Prediction"].mean()

    return f"Average Depression Probability: {average_probability:.2f}"

# Gradio ์ธํ„ฐํŽ˜์ด์Šค๋กœ ์—ฐ๊ฒฐํ•  ํ•จ์ˆ˜
def gradio_process_audio(audio_data):
    # ์‚ฌ์šฉ์ž๊ฐ€ ๋งˆ์ดํฌ๋กœ ์ž…๋ ฅํ•œ ์Œ์„ฑ์„ ์ž„์‹œ ํŒŒ์ผ๋กœ ์ €์žฅ
    temp_audio_path = "temp_audio.wav"
    with open(temp_audio_path, "wb") as f:
        f.write(audio_data)

    # ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ ๋ฐ ์˜ˆ์ธก
    average_probability, df_result = process_audio_and_detect_depression(temp_audio_path, safetensors_path, model_name)

    # ๊ฒฐ๊ณผ ์ถœ๋ ฅ
    return f"Average Depression Probability: {average_probability:.2f}", df_result

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
interface = gr.Interface(
    fn=gradio_process_audio,           # Gradio์—์„œ ํ˜ธ์ถœํ•  ํ•จ์ˆ˜
    inputs=gr.Audio(type="numpy"),  # ์‚ฌ์šฉ์ž ์Œ์„ฑ ์ž…๋ ฅ (๋งˆ์ดํฌ)
    outputs=[
        gr.Textbox(label="Depression Probability"),  # ํ‰๊ท  ํ™•๋ฅ 
        gr.Dataframe(label="Sentence-wise Analysis")  # ์ƒ์„ธ ๋ถ„์„ ๊ฒฐ๊ณผ
    ],
    title="Depression Detection from Audio",
    description="Record your voice, and the model will analyze the text for depression likelihood."
)

# Gradio ์‹คํ–‰
interface.launch(share=True)