import os import yaml import gradio as gr from sentence_transformers import SentenceTransformer, util import torch import shutil import tempfile import re import random import pandas as pd # ----- 파일 경로 상수 ----- GLOSSARY_FILE = "glossary.md" INFO_FILE = "info.md" PERSONA_FILE = "persona.yaml" CHITCHAT_FILE = "chitchat.yaml" KEYWORD_MAP_FILE = "keyword_map.yaml" CEO_VIDEO_FILE = "ceo_video.mp4" # ----- 유틸 함수 ----- def load_yaml(file_path, default_data=None): try: with open(file_path, "r", encoding="utf-8") as f: return yaml.safe_load(f) except Exception: return default_data if default_data is not None else [] def parse_knowledge_base(file_path): faqs = [] if not os.path.exists(file_path): return [] with open(file_path, encoding="utf-8") as f: content = f.read() blocks = re.findall(r"Q:\s*(.*?)\nA:\s*(.*?)(?=(\n{2,}Q:|\Z))", content, re.DOTALL) for q, a, _ in blocks: faqs.append({"question": q.strip(), "answer": a.strip()}) return faqs # ----- 데이터 로딩 ----- persona = load_yaml(PERSONA_FILE, {}) chitchat_map = load_yaml(CHITCHAT_FILE, []) keyword_map = load_yaml(KEYWORD_MAP_FILE, []) glossary_base = parse_knowledge_base(GLOSSARY_FILE) info_base = parse_knowledge_base(INFO_FILE) glossary_questions = [item['question'] for item in glossary_base] glossary_answers = [item['answer'] for item in glossary_base] info_questions = [item['question'] for item in info_base] info_answers = [item['answer'] for item in info_base] glossary_keywords = ["Balance Block", "U-Clamp", "Punch", "Bush", "메일먼"] info_keywords = ["복지", "연봉", "조직문화", "52시간", "주력제품"] # ----- 페르소나 스타일 적용 함수 ----- def apply_persona_style(text, fallback=False): style = persona.get("style", {}) if fallback: return style.get("unknown_answer", "죄송합니다. 정확한 답변이 어렵습니다.") intro = random.choice(style.get("responses", ["저희 아진산업은 "])) closing = random.choice(style.get("closings", [""])) return f"{intro}{text}{closing}" # ----- 모델 관리 및 로딩 ----- model_cache = {} def get_model(name): if name not in model_cache: model_cache[name] = SentenceTransformer(name) return model_cache[name] default_model_name = "sentence-transformers/LaBSE" model = get_model(default_model_name) glossary_embeddings = model.encode(glossary_questions, convert_to_tensor=True) if glossary_questions else None info_embeddings = model.encode(info_questions, convert_to_tensor=True) if info_questions else None # ----- 챗봇 응답 ----- def best_faq_answer_base(user_question, kb_questions, kb_answers, kb_embeddings): if not user_question.strip() or not kb_questions: return "" q_emb = model.encode([user_question.strip()], convert_to_tensor=True) scores = util.cos_sim(q_emb, kb_embeddings)[0] best_idx = int(torch.argmax(scores)) best_score = float(scores[best_idx]) if best_score < 0.2: return apply_persona_style("", fallback=True) return apply_persona_style(kb_answers[best_idx]) def find_chitchat(user_question): uq = user_question.lower() for chat in chitchat_map: if any(kw in uq for kw in chat.get('keywords', [])): return chat['answer'] return None def get_temp_video_copy(): temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) shutil.copyfile(CEO_VIDEO_FILE, temp_file.name) return temp_file.name def best_faq_answer_with_type(user_question, kb_type): user_question = user_question.strip() if not user_question: return "무엇이 궁금하신지 말씀해 주세요." chit = find_chitchat(user_question) if chit: return apply_persona_style(chit) if kb_type == "용어": return best_faq_answer_base(user_question, glossary_questions, glossary_answers, glossary_embeddings) elif kb_type == "정보": return best_faq_answer_base(user_question, info_questions, info_answers, info_embeddings) return "검색 유형을 선택해 주세요." def chat_interface(message, history, kb_type, model_name): if not message.strip(): return history, "" if chit := find_chitchat(message): resp = apply_persona_style(chit) else: model = get_model(model_name) if kb_type == "용어": kb_qs, kb_as = glossary_questions, glossary_answers else: kb_qs, kb_as = info_questions, info_answers emb = model.encode(kb_qs, convert_to_tensor=True) q_emb = model.encode([message.strip()], convert_to_tensor=True) scores = util.cos_sim(q_emb, emb)[0] best_idx = int(torch.argmax(scores)) best_score = float(scores[best_idx]) if best_score < 0.2: resp = apply_persona_style("", fallback=True) else: resp = apply_persona_style(kb_as[best_idx]) history = history or [] history.append([message, resp]) temp_video_path = get_temp_video_copy() return history, "", gr.Video(value=temp_video_path, autoplay=True, interactive=False) # ----- 퀴즈 기능 (변경 없음) ----- def generate_quiz_set(kb_type): base = glossary_base if kb_type == "용어" else info_base if len(base) < 5: return [] return random.sample(base, 5) def clean_question_text(raw_question): cleaned = re.sub(r"^Q:\s*", "", raw_question) cleaned = re.sub(r"#.*", "", cleaned) return cleaned.strip() def get_question_display(quiz_set, current_index): question = clean_question_text(quiz_set[current_index]['question']) correct = quiz_set[current_index]['answer'] distractors = random.sample([item['answer'] for item in quiz_set if item['answer'] != correct], k=3) options = random.sample([correct] + distractors, k=4) return f"{current_index+1}번 문제: {question}", options, correct def check_quiz_answer(user_answer, correct_answer, score, current_index, quiz_set): result = "✅ 정답입니다!" if user_answer == correct_answer else f"❌ 오답입니다. 정답은: {correct_answer}" new_score = score + 1 if user_answer == correct_answer else score if current_index + 1 >= len(quiz_set): result_msg = f"퀴즈 종료! {new_score}/{len(quiz_set)} 맞추셨습니다. " score_msg = { 5: "당신의 점수는 제네시스 G90입니다. 완벽합니다!", 4: "당신의 점수는 그랜저입니다. 거의 완벽해요.", 3: "당신의 점수는 쏘나타입니다. 괜찮은 편이에요.", 2: "당신의 점수는 아반떼입니다. 조금만 더 노력해보세요.", 1: "당신의 점수는 캐스퍼입니다. 다시 도전해보시기 바랍니다.", 0: "당신의 점수는 캐스퍼입니다. 다시 도전해보시기 바랍니다.", } result_msg += score_msg.get(new_score, "") final_image = f"img/score_{new_score}.jpg" return ( gr.update(visible=False), result_msg, 0, 0, [], "", [], "", gr.update(value=final_image, visible=True) ) else: next_q, next_opts, next_ans = get_question_display(quiz_set, current_index + 1) return ( gr.update(visible=True, choices=next_opts, value=None), result, new_score, current_index + 1, quiz_set, next_q, next_opts, next_ans, gr.update(visible=False) ) # ----- 모델 비교 ----- def compare_models(kb_type, selected_models): if kb_type == "용어": qs, ans = glossary_questions, glossary_answers else: qs, ans = info_questions, info_answers qs_clean = [re.sub(r"#.*", "", q).strip() for q in qs] records = [] total = len(qs) for m in selected_models: model = get_model(m) emb = model.encode(qs, convert_to_tensor=True) test_emb = model.encode(qs_clean, convert_to_tensor=True) sims = util.cos_sim(test_emb, emb) top1 = torch.argmax(sims, dim=1).tolist() top3 = torch.topk(sims, k=3, dim=1).indices.tolist() c1 = c3 = 0 for i in range(total): if ans[top1[i]] == ans[i]: c1 += 1 if ans[i] in {ans[idx] for idx in top3[i]}: c3 += 1 records.append({ "모델": m, "Top-1 맞은 수": c1, "Top-1 정확도": f"{c1}/{total} ({c1/total:.2%})", "Top-3 맞은 수": c3, "Top-3 정확도": f"{c3}/{total} ({c3/total:.2%})", }) return pd.DataFrame(records) # ----- Gradio UI ----- model_choices = [ "sentence-transformers/LaBSE", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", "sentence-transformers/bert-base-nli-mean-tokens", "sentence-transformers/distiluse-base-multilingual-cased-v2", "bert-base-uncased", "distilbert-base-multilingual-cased" ] with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Tab("💬 챗봇"): with gr.Row(): with gr.Column(scale=1, min_width=400): video_player = gr.Video(value=CEO_VIDEO_FILE, autoplay=False, interactive=False, height=540) type_radio = gr.Radio(choices=["용어", "정보"], value="정보", label="검색 유형") model_dropdown = gr.Dropdown(choices=model_choices, value=model_choices[0], label="모델 선택") example_dropdown = gr.Dropdown(choices=info_keywords, label="추천 키워드", interactive=True) with gr.Column(scale=2): chatbot = gr.Chatbot(height=540, value=[["", persona.get("greeting", "무엇이든 물어보세요.")]]) with gr.Row(): msg_input = gr.Textbox(placeholder="무엇이든 물어보세요.", lines=3, show_label=False) send_btn = gr.Button("전송") example_dropdown.change(lambda x: x, inputs=example_dropdown, outputs=msg_input) type_radio.change(lambda x: gr.Dropdown(choices=glossary_keywords if x == "용어" else info_keywords), inputs=type_radio, outputs=example_dropdown) send_btn.click(chat_interface, inputs=[msg_input, chatbot, type_radio, model_dropdown], outputs=[chatbot, msg_input, video_player]) msg_input.submit(chat_interface, inputs=[msg_input, chatbot, type_radio, model_dropdown], outputs=[chatbot, msg_input, video_player]) with gr.Tab("🎯 퀴즈"): quiz_type = gr.Radio(["용어", "정보"], value="용어", label="퀴즈 유형 선택") start_btn = gr.Button("퀴즈 5문제 시작") question_display = gr.Textbox(label="문제", interactive=False) answer_select = gr.Radio(choices=[], label="보기 선택", visible=False) submit_btn = gr.Button("정답 제출", visible=False) result_display = gr.Textbox(label="결과", interactive=False) result_image = gr.Image(visible=False, type="filepath", show_label=False, elem_id="result-image", height=600) quiz_state = gr.State([]) quiz_index = gr.State(0) quiz_score = gr.State(0) correct_answer = gr.State("") def start_quiz(kb_type): quiz_set = generate_quiz_set(kb_type) q_text, options, correct = get_question_display(quiz_set, 0) return (quiz_set, 0, 0, correct, q_text, gr.update(visible=True, choices=options, value=None), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)) start_btn.click(start_quiz, [quiz_type], [quiz_state, quiz_index, quiz_score, correct_answer, question_display, answer_select, submit_btn, result_image]) submit_btn.click(check_quiz_answer, inputs=[answer_select, correct_answer, quiz_score, quiz_index, quiz_state], outputs=[answer_select, result_display, quiz_score, quiz_index, quiz_state, question_display, answer_select, correct_answer, result_image]) with gr.Tab("🛠 모델 비교"): cmp_type = gr.Radio(["용어", "정보"], value="용어", label="평가할 KB") cmp_models = gr.CheckboxGroup(model_choices, value=[default_model_name], label="비교할 모델들") run_cmp = gr.Button("비교 실행") cmp_table = gr.DataFrame(interactive=False) run_cmp.click(compare_models, inputs=[cmp_type, cmp_models], outputs=[cmp_table]) if __name__ == "__main__": demo.launch()