rgenius commited on
Commit
c7a1dbc
1 Parent(s): 0ea462d

Init commit

Browse files
Files changed (2) hide show
  1. app.py +356 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from uuid import uuid4
4
+ from huggingface_hub import snapshot_download
5
+ from langchain.document_loaders import (
6
+ CSVLoader,
7
+ EverNoteLoader,
8
+ PDFMinerLoader,
9
+ TextLoader,
10
+ UnstructuredEmailLoader,
11
+ UnstructuredEPubLoader,
12
+ UnstructuredHTMLLoader,
13
+ UnstructuredMarkdownLoader,
14
+ UnstructuredODTLoader,
15
+ UnstructuredPowerPointLoader,
16
+ UnstructuredWordDocumentLoader,
17
+ )
18
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
19
+ from langchain.vectorstores import Chroma
20
+ from langchain.embeddings import HuggingFaceEmbeddings
21
+ from langchain.docstore.document import Document
22
+ from chromadb.config import Settings
23
+ from llama_cpp import Llama
24
+
25
+
26
+ SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
27
+ SYSTEM_TOKEN = 1788
28
+ USER_TOKEN = 1404
29
+ BOT_TOKEN = 9225
30
+ LINEBREAK_TOKEN = 13
31
+
32
+ ROLE_TOKENS = {
33
+ "user": USER_TOKEN,
34
+ "bot": BOT_TOKEN,
35
+ "system": SYSTEM_TOKEN
36
+ }
37
+
38
+ LOADER_MAPPING = {
39
+ ".csv": (CSVLoader, {}),
40
+ ".doc": (UnstructuredWordDocumentLoader, {}),
41
+ ".docx": (UnstructuredWordDocumentLoader, {}),
42
+ ".enex": (EverNoteLoader, {}),
43
+ ".epub": (UnstructuredEPubLoader, {}),
44
+ ".html": (UnstructuredHTMLLoader, {}),
45
+ ".md": (UnstructuredMarkdownLoader, {}),
46
+ ".odt": (UnstructuredODTLoader, {}),
47
+ ".pdf": (PDFMinerLoader, {}),
48
+ ".ppt": (UnstructuredPowerPointLoader, {}),
49
+ ".pptx": (UnstructuredPowerPointLoader, {}),
50
+ ".txt": (TextLoader, {"encoding": "utf8"}),
51
+ }
52
+
53
+
54
+ repo_name = "IlyaGusev/saiga_13b_lora_llamacpp"
55
+ model_name = "ggml-model-q4_1.bin"
56
+ embedder_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
57
+
58
+ snapshot_download(repo_id=repo_name, local_dir=".", allow_patterns=model_name)
59
+
60
+ model = Llama(
61
+ model_path=model_name,
62
+ n_ctx=2000,
63
+ n_parts=1,
64
+ )
65
+
66
+ max_new_tokens = 1500
67
+ embeddings = HuggingFaceEmbeddings(model_name=embedder_name)
68
+
69
+ def get_uuid():
70
+ return str(uuid4())
71
+
72
+
73
+ def load_single_document(file_path: str) -> Document:
74
+ ext = "." + file_path.rsplit(".", 1)[-1]
75
+ assert ext in LOADER_MAPPING
76
+ loader_class, loader_args = LOADER_MAPPING[ext]
77
+ loader = loader_class(file_path, **loader_args)
78
+ return loader.load()[0]
79
+
80
+
81
+ def get_message_tokens(model, role, content):
82
+ message_tokens = model.tokenize(content.encode("utf-8"))
83
+ message_tokens.insert(1, ROLE_TOKENS[role])
84
+ message_tokens.insert(2, LINEBREAK_TOKEN)
85
+ message_tokens.append(model.token_eos())
86
+ return message_tokens
87
+
88
+
89
+ def get_system_tokens(model):
90
+ system_message = {"role": "system", "content": SYSTEM_PROMPT}
91
+ return get_message_tokens(model, **system_message)
92
+
93
+
94
+ def upload_files(files, file_paths):
95
+ file_paths = [f.name for f in files]
96
+ return file_paths
97
+
98
+
99
+ def process_text(text):
100
+ lines = text.split("\n")
101
+ lines = [line for line in lines if len(line.strip()) > 2]
102
+ text = "\n".join(lines).strip()
103
+ if len(text) < 10:
104
+ return None
105
+ return text
106
+
107
+
108
+ def build_index(file_paths, db, chunk_size, chunk_overlap, file_warning):
109
+ documents = [load_single_document(path) for path in file_paths]
110
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
111
+ documents = text_splitter.split_documents(documents)
112
+ fixed_documents = []
113
+ for doc in documents:
114
+ doc.page_content = process_text(doc.page_content)
115
+ if not doc.page_content:
116
+ continue
117
+ fixed_documents.append(doc)
118
+
119
+ db = Chroma.from_documents(
120
+ fixed_documents,
121
+ embeddings,
122
+ client_settings=Settings(
123
+ anonymized_telemetry=False
124
+ )
125
+ )
126
+ file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы."
127
+ return db, file_warning
128
+
129
+
130
+ def user(message, history, system_prompt):
131
+ new_history = history + [[message, None]]
132
+ return "", new_history
133
+
134
+
135
+ def retrieve(history, db, retrieved_docs, k_documents):
136
+ context = ""
137
+ if db:
138
+ last_user_message = history[-1][0]
139
+ retriever = db.as_retriever(search_kwargs={"k": k_documents})
140
+ docs = retriever.get_relevant_documents(last_user_message)
141
+ retrieved_docs = "\n\n".join([doc.page_content for doc in docs])
142
+ return retrieved_docs
143
+
144
+
145
+ def bot(history, system_prompt, conversation_id, retrieved_docs, top_p, top_k, temp):
146
+ if not history:
147
+ return
148
+
149
+ tokens = get_system_tokens(model)[:]
150
+ tokens.append(LINEBREAK_TOKEN)
151
+
152
+ for user_message, bot_message in history[:-1]:
153
+ message_tokens = get_message_tokens(model=model, role="user", content=user_message)
154
+ tokens.extend(message_tokens)
155
+ if bot_message:
156
+ message_tokens = get_message_tokens(model=model, role="bot", content=bot_message)
157
+ tokens.extend(message_tokens)
158
+
159
+ last_user_message = history[-1][0]
160
+ if retrieved_docs:
161
+ last_user_message = f"Контекст: {retrieved_docs}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}"
162
+ message_tokens = get_message_tokens(model=model, role="user", content=last_user_message)
163
+ tokens.extend(message_tokens)
164
+
165
+ role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
166
+ tokens.extend(role_tokens)
167
+ generator = model.generate(
168
+ tokens,
169
+ top_k=top_k,
170
+ top_p=top_p,
171
+ temp=temp
172
+ )
173
+
174
+ partial_text = ""
175
+ for i, token in enumerate(generator):
176
+ if token == model.token_eos() or (max_new_tokens is not None and i >= max_new_tokens):
177
+ break
178
+ partial_text += model.detokenize([token]).decode("utf-8", "ignore")
179
+ history[-1][1] = partial_text
180
+ yield history
181
+
182
+
183
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
184
+ db = gr.State(None)
185
+ conversation_id = gr.State(get_uuid)
186
+ favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">'
187
+ gr.Markdown(
188
+ f"""<h1><center>{favicon}Saiga 13B llama.cpp: retrieval QA</center></h1>
189
+ """
190
+ )
191
+
192
+ with gr.Row():
193
+ with gr.Column(scale=5):
194
+ file_output = gr.File(file_count="multiple", label="Загрузка файлов")
195
+ file_paths = gr.State([])
196
+ file_warning = gr.Markdown(f"Фрагменты ещё не загружены!")
197
+
198
+ with gr.Column(min_width=200, scale=3):
199
+ with gr.Tab(label="Параметры нарезки"):
200
+ chunk_size = gr.Slider(
201
+ minimum=50,
202
+ maximum=2000,
203
+ value=250,
204
+ step=50,
205
+ interactive=True,
206
+ label="Размер фрагментов",
207
+ )
208
+ chunk_overlap = gr.Slider(
209
+ minimum=0,
210
+ maximum=500,
211
+ value=30,
212
+ step=10,
213
+ interactive=True,
214
+ label="Пересечение"
215
+ )
216
+
217
+
218
+ with gr.Row():
219
+ k_documents = gr.Slider(
220
+ minimum=1,
221
+ maximum=10,
222
+ value=2,
223
+ step=1,
224
+ interactive=True,
225
+ label="Кол-во фрагментов для контекста"
226
+ )
227
+ with gr.Row():
228
+ retrieved_docs = gr.Textbox(
229
+ lines=6,
230
+ label="Извлеченные фрагменты",
231
+ placeholder="Появятся после задавания вопросов",
232
+ interactive=False
233
+ )
234
+ with gr.Row():
235
+ with gr.Column(scale=5):
236
+ system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT, interactive=False)
237
+ chatbot = gr.Chatbot(label="Диалог").style(height=400)
238
+ with gr.Column(min_width=80, scale=1):
239
+ with gr.Tab(label="Параметры генерации"):
240
+ top_p = gr.Slider(
241
+ minimum=0.0,
242
+ maximum=1.0,
243
+ value=0.9,
244
+ step=0.05,
245
+ interactive=True,
246
+ label="Top-p",
247
+ )
248
+ top_k = gr.Slider(
249
+ minimum=10,
250
+ maximum=100,
251
+ value=30,
252
+ step=5,
253
+ interactive=True,
254
+ label="Top-k",
255
+ )
256
+ temp = gr.Slider(
257
+ minimum=0.0,
258
+ maximum=2.0,
259
+ value=0.1,
260
+ step=0.1,
261
+ interactive=True,
262
+ label="Temp"
263
+ )
264
+
265
+ with gr.Row():
266
+ with gr.Column():
267
+ msg = gr.Textbox(
268
+ label="Отправить сообщение",
269
+ placeholder="Отправить сообщение",
270
+ show_label=False,
271
+ ).style(container=False)
272
+ with gr.Column():
273
+ with gr.Row():
274
+ submit = gr.Button("Отправить")
275
+ stop = gr.Button("Остановить")
276
+ clear = gr.Button("Очистить")
277
+
278
+ # Upload files
279
+ upload_event = file_output.change(
280
+ fn=upload_files,
281
+ inputs=[file_output, file_paths],
282
+ outputs=[file_paths],
283
+ queue=True,
284
+ ).success(
285
+ fn=build_index,
286
+ inputs=[file_paths, db, chunk_size, chunk_overlap, file_warning],
287
+ outputs=[db, file_warning],
288
+ queue=True
289
+ )
290
+
291
+ # Pressing Enter
292
+ submit_event = msg.submit(
293
+ fn=user,
294
+ inputs=[msg, chatbot, system_prompt],
295
+ outputs=[msg, chatbot],
296
+ queue=False,
297
+ ).success(
298
+ fn=retrieve,
299
+ inputs=[chatbot, db, retrieved_docs, k_documents],
300
+ outputs=[retrieved_docs],
301
+ queue=True,
302
+ ).success(
303
+ fn=bot,
304
+ inputs=[
305
+ chatbot,
306
+ system_prompt,
307
+ conversation_id,
308
+ retrieved_docs,
309
+ top_p,
310
+ top_k,
311
+ temp
312
+ ],
313
+ outputs=chatbot,
314
+ queue=True,
315
+ )
316
+
317
+ # Pressing the button
318
+ submit_click_event = submit.click(
319
+ fn=user,
320
+ inputs=[msg, chatbot, system_prompt],
321
+ outputs=[msg, chatbot],
322
+ queue=False,
323
+ ).success(
324
+ fn=retrieve,
325
+ inputs=[chatbot, db, retrieved_docs, k_documents],
326
+ outputs=[retrieved_docs],
327
+ queue=True,
328
+ ).success(
329
+ fn=bot,
330
+ inputs=[
331
+ chatbot,
332
+ system_prompt,
333
+ conversation_id,
334
+ retrieved_docs,
335
+ top_p,
336
+ top_k,
337
+ temp
338
+ ],
339
+ outputs=chatbot,
340
+ queue=True,
341
+ )
342
+
343
+ # Stop generation
344
+ stop.click(
345
+ fn=None,
346
+ inputs=None,
347
+ outputs=None,
348
+ cancels=[submit_event, submit_click_event],
349
+ queue=False,
350
+ )
351
+
352
+ # Clear history
353
+ clear.click(lambda: None, None, chatbot, queue=False)
354
+
355
+ demo.queue(max_size=128, concurrency_count=1)
356
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ llama-cpp-python==0.1.53
2
+ langchain==0.0.174
3
+ huggingface-hub==0.14.1
4
+ chromadb==0.3.23
5
+ pdfminer.six==20221105
6
+ unstructured==0.6.10
7
+ gradio==3.32.0
8
+ tabulate