hyeongnym commited on
Commit
4f7e18c
ยท
verified ยท
1 Parent(s): f7c2d07

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +343 -0
app.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import spaces
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
+ import os, gc, logging
6
+ from threading import Thread
7
+ import random
8
+ from datasets import load_dataset
9
+ import numpy as np
10
+ from sklearn.feature_extraction.text import TfidfVectorizer
11
+ import pandas as pd
12
+ from typing import List, Tuple, Iterator
13
+ import json
14
+ from datetime import datetime
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ from functools import lru_cache
17
+ import pyarrow.parquet as pq
18
+ import pypdf
19
+ from pdfminer.high_level import extract_text
20
+ from pdfminer.layout import LAParams
21
+ from tabulate import tabulate
22
+ from pydantic import BaseModel
23
+ import unittest
24
+
25
+ # ๋กœ๊น… ์„ค์ •
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
29
+ handlers=[
30
+ logging.FileHandler('app.log'),
31
+ logging.StreamHandler()
32
+ ]
33
+ )
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # ์„ค์ • ํด๋ž˜์Šค
37
+ class Config:
38
+ def __init__(self):
39
+ self.MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
40
+ self.MAX_HISTORY = 10
41
+ self.MAX_TOKENS = 4096
42
+ self.DEFAULT_TEMPERATURE = 0.8
43
+ self.HF_TOKEN = os.environ.get("HF_TOKEN", None)
44
+ self.MODELS = os.environ.get("MODELS")
45
+
46
+ config = Config()
47
+
48
+ # ์ปค์Šคํ…€ ์˜ˆ์™ธ ํด๋ž˜์Šค
49
+ class FileProcessingError(Exception):
50
+ pass
51
+
52
+ # ์‘๋‹ต ๋ชจ๋ธ
53
+ class ChatResponse(BaseModel):
54
+ message: str
55
+ status: str
56
+ timestamp: datetime
57
+
58
+ # ํŒŒ์ผ ์ฒ˜๋ฆฌ ํด๋ž˜์Šค
59
+ class FileProcessor:
60
+ @staticmethod
61
+ def process_pdf(file_path):
62
+ try:
63
+ with ThreadPoolExecutor() as executor:
64
+ pdf_reader = pypdf.PdfReader(file_path)
65
+ text = extract_text(
66
+ file_path,
67
+ laparams=LAParams(
68
+ line_margin=0.5,
69
+ word_margin=0.1,
70
+ char_margin=2.0,
71
+ all_texts=True
72
+ )
73
+ )
74
+ return text
75
+ except Exception as e:
76
+ raise FileProcessingError(f"PDF processing error: {str(e)}")
77
+
78
+ @staticmethod
79
+ def process_csv(file_path):
80
+ try:
81
+ encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
82
+ for encoding in encodings:
83
+ try:
84
+ return pd.read_csv(file_path, encoding=encoding)
85
+ except UnicodeDecodeError:
86
+ continue
87
+ raise FileProcessingError("Unable to read CSV with supported encodings")
88
+ except Exception as e:
89
+ raise FileProcessingError(f"CSV processing error: {str(e)}")
90
+
91
+ # ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
92
+ @torch.no_grad()
93
+ def clear_cuda_memory():
94
+ if torch.cuda.is_available():
95
+ torch.cuda.empty_cache()
96
+ gc.collect()
97
+
98
+ # ๋ชจ๋ธ ๋กœ๋“œ
99
+ @spaces.GPU
100
+ def load_model():
101
+ try:
102
+ model = AutoModelForCausalLM.from_pretrained(
103
+ config.MODEL_ID,
104
+ torch_dtype=torch.bfloat16,
105
+ device_map="auto",
106
+ )
107
+ return model
108
+ except Exception as e:
109
+ logger.error(f"Model loading error: {str(e)}")
110
+ raise
111
+
112
+ # ์ปจํ…์ŠคํŠธ ๊ฒ€์ƒ‰
113
+ @lru_cache(maxsize=100)
114
+ def find_relevant_context(query, top_k=3):
115
+ try:
116
+ query_vector = vectorizer.transform([query])
117
+ similarities = (query_vector * question_vectors.T).toarray()[0]
118
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
119
+
120
+ relevant_contexts = []
121
+ for idx in top_indices:
122
+ if similarities[idx] > 0:
123
+ relevant_contexts.append({
124
+ 'question': questions[idx],
125
+ 'answer': wiki_dataset['train']['answer'][idx],
126
+ 'similarity': similarities[idx]
127
+ })
128
+ return relevant_contexts
129
+ except Exception as e:
130
+ logger.error(f"Context search error: {str(e)}")
131
+ return []
132
+
133
+ # ์ŠคํŠธ๋ฆฌ๋ฐ ์ฑ„ํŒ…
134
+ @spaces.GPU
135
+ def stream_chat(message: str, history: list, uploaded_file, temperature: float,
136
+ max_new_tokens: int, top_p: float, top_k: int, penalty: float) -> Iterator[Tuple[str, list]]:
137
+ """
138
+ ์ŠคํŠธ๋ฆฌ๋ฐ ์ฑ„ํŒ… ์‘๋‹ต์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
139
+
140
+ Args:
141
+ message (str): ์‚ฌ์šฉ์ž ์ž…๋ ฅ ๋ฉ”์‹œ์ง€
142
+ history (list): ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ
143
+ uploaded_file: ์—…๋กœ๋“œ๋œ ํŒŒ์ผ
144
+ temperature (float): ์ƒ์„ฑ ์˜จ๋„
145
+ max_new_tokens (int): ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜
146
+ top_p (float): ์ƒ์œ„ p ์ƒ˜ํ”Œ๋ง
147
+ top_k (int): ์ƒ์œ„ k ์ƒ˜ํ”Œ๋ง
148
+ penalty (float): ๋ฐ˜๋ณต ํŽ˜๋„ํ‹ฐ
149
+
150
+ Returns:
151
+ Iterator[Tuple[str, list]]: ์ƒ์„ฑ๋œ ์‘๋‹ต๊ณผ ์—…๋ฐ์ดํŠธ๋œ ํžˆ์Šคํ† ๋ฆฌ
152
+ """
153
+ global model, current_file_context
154
+
155
+ try:
156
+ if model is None:
157
+ model = load_model()
158
+
159
+ logger.info(f'Processing message: {message}')
160
+ logger.debug(f'History length: {len(history)}')
161
+
162
+ # ํŒŒ์ผ ์ฒ˜๋ฆฌ
163
+ file_context = ""
164
+ if uploaded_file:
165
+ try:
166
+ file_ext = os.path.splitext(uploaded_file.name)[1].lower()
167
+ if file_ext == '.pdf':
168
+ content = FileProcessor.process_pdf(uploaded_file.name)
169
+ elif file_ext == '.csv':
170
+ content = FileProcessor.process_csv(uploaded_file.name)
171
+ else:
172
+ content = safe_file_read(uploaded_file.name)
173
+
174
+ file_context = analyze_file_content(content, file_ext)
175
+ current_file_context = file_context
176
+ except Exception as e:
177
+ logger.error(f"File processing error: {str(e)}")
178
+ file_context = f"\n\nโŒ File analysis error: {str(e)}"
179
+
180
+ # ์ปจํ…์ŠคํŠธ ๊ฒ€์ƒ‰ ๋ฐ ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
181
+ relevant_contexts = find_relevant_context(message)
182
+ wiki_context = "\n\n๊ด€๋ จ ์œ„ํ‚คํ”ผ๋””์•„ ์ •๋ณด:\n" + "\n".join([
183
+ f"Q: {ctx['question']}\nA: {ctx['answer']}\n์œ ์‚ฌ๋„: {ctx['similarity']:.3f}"
184
+ for ctx in relevant_contexts
185
+ ])
186
+
187
+ # ํ† ํฐํ™” ๋ฐ ์ƒ์„ฑ
188
+ conversation = [
189
+ {"role": "user" if i % 2 == 0 else "assistant", "content": msg}
190
+ for hist in history[-config.MAX_HISTORY:]
191
+ for i, msg in enumerate(hist)
192
+ ]
193
+
194
+ final_message = f"{file_context}{wiki_context}\nํ˜„์žฌ ์งˆ๋ฌธ: {message}"
195
+ conversation.append({"role": "user", "content": final_message})
196
+
197
+ inputs = tokenizer(
198
+ tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True),
199
+ return_tensors="pt"
200
+ ).to("cuda")
201
+
202
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
203
+
204
+ generate_kwargs = dict(
205
+ inputs,
206
+ streamer=streamer,
207
+ top_k=top_k,
208
+ top_p=top_p,
209
+ repetition_penalty=penalty,
210
+ max_new_tokens=min(max_new_tokens, 2048),
211
+ do_sample=True,
212
+ temperature=temperature,
213
+ eos_token_id=[255001],
214
+ )
215
+
216
+ clear_cuda_memory()
217
+
218
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
219
+ thread.start()
220
+
221
+ buffer = ""
222
+ for new_text in streamer:
223
+ buffer += new_text
224
+ yield "", history + [[message, buffer]]
225
+
226
+ clear_cuda_memory()
227
+
228
+ except Exception as e:
229
+ logger.error(f"Stream chat error: {str(e)}")
230
+ yield "", history + [[message, f"Error: {str(e)}"]]
231
+ clear_cuda_memory()
232
+
233
+ # UI ์ƒ์„ฑ
234
+ def create_demo():
235
+ with gr.Blocks(css=UPDATED_CSS) as demo:
236
+ # UI ์ปดํฌ๋„ŒํŠธ ๊ตฌ์„ฑ
237
+ with gr.Column(elem_classes="markdown-style"):
238
+ gr.Markdown("""
239
+ # ๐Ÿค– RAGOndevice
240
+ #### ๐Ÿ“Š RAG: Upload and Analyze Files (TXT, CSV, PDF, Parquet files)
241
+ Upload your files for data analysis and learning
242
+ """)
243
+
244
+ chatbot = gr.Chatbot(
245
+ value=[],
246
+ height=600,
247
+ label="GiniGEN AI Assistant",
248
+ elem_classes="chat-container"
249
+ )
250
+
251
+ # ์ž…๋ ฅ ์ปดํฌ๋„ŒํŠธ
252
+ with gr.Row(elem_classes="input-container"):
253
+ with gr.Column(scale=1, min_width=70):
254
+ file_upload = gr.File(
255
+ type="filepath",
256
+ elem_classes="file-upload-icon",
257
+ scale=1,
258
+ container=True,
259
+ interactive=True,
260
+ show_label=False
261
+ )
262
+
263
+ with gr.Column(scale=3):
264
+ msg = gr.Textbox(
265
+ show_label=False,
266
+ placeholder="Type your message here... ๐Ÿ’ญ",
267
+ container=False,
268
+ elem_classes="input-textbox",
269
+ scale=1
270
+ )
271
+
272
+ with gr.Column(scale=1, min_width=70):
273
+ send = gr.Button(
274
+ "Send",
275
+ elem_classes="send-button custom-button",
276
+ scale=1
277
+ )
278
+
279
+ with gr.Column(scale=1, min_width=70):
280
+ clear = gr.Button(
281
+ "Clear",
282
+ elem_classes="clear-button custom-button",
283
+ scale=1
284
+ )
285
+
286
+ # ๊ณ ๊ธ‰ ์„ค์ •
287
+ with gr.Accordion("๐ŸŽฎ Advanced Settings", open=False):
288
+ with gr.Row():
289
+ with gr.Column(scale=1):
290
+ temperature = gr.Slider(
291
+ minimum=0, maximum=1, step=0.1, value=config.DEFAULT_TEMPERATURE,
292
+ label="Creativity Level ๐ŸŽจ"
293
+ )
294
+ max_new_tokens = gr.Slider(
295
+ minimum=128, maximum=8000, step=1, value=4000,
296
+ label="Maximum Token Count ๐Ÿ“"
297
+ )
298
+ with gr.Column(scale=1):
299
+ top_p = gr.Slider(
300
+ minimum=0.0, maximum=1.0, step=0.1, value=0.8,
301
+ label="Diversity Control ๐ŸŽฏ"
302
+ )
303
+ top_k = gr.Slider(
304
+ minimum=1, maximum=20, step=1, value=20,
305
+ label="Selection Range ๐Ÿ“Š"
306
+ )
307
+ penalty = gr.Slider(
308
+ minimum=0.0, maximum=2.0, step=0.1, value=1.0,
309
+ label="Repetition Penalty ๐Ÿ”„"
310
+ )
311
+
312
+ # ์ด๋ฒคํŠธ ๋ฐ”์ธ๋”ฉ
313
+ msg.submit(stream_chat, [msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty], [msg, chatbot])
314
+ send.click(stream_chat, [msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty], [msg, chatbot])
315
+ clear.click(lambda: ([], None, ""), outputs=[chatbot, file_upload, msg])
316
+
317
+ return demo
318
+
319
+ # ๋ฉ”์ธ ์‹คํ–‰
320
+ if __name__ == "__main__":
321
+ # ์œ„ํ‚คํ”ผ๋””์•„ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
322
+ wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
323
+ logger.info("Wikipedia dataset loaded")
324
+
325
+ # TF-IDF ๋ฒกํ„ฐ๋ผ์ด์ € ์ดˆ๊ธฐํ™”
326
+ questions = wiki_dataset['train']['question'][:10000]
327
+ vectorizer = TfidfVectorizer(max_features=1000)
328
+ question_vectors = vectorizer.fit_transform(questions)
329
+ logger.info("TF-IDF vectorization completed")
330
+
331
+ # UI ์‹คํ–‰
332
+ demo = create_demo()
333
+ demo.launch()
334
+
335
+ # ํ…Œ์ŠคํŠธ ์ฝ”๋“œ
336
+ class TestChatBot(unittest.TestCase):
337
+ def test_file_processing(self):
338
+ # ํ…Œ์ŠคํŠธ ๊ตฌํ˜„
339
+ pass
340
+
341
+ def test_context_search(self):
342
+ # ํ…Œ์ŠคํŠธ ๊ตฌํ˜„
343
+ pass