Spaces:
Runtime error
Runtime error
KonradSzafer
commited on
Commit
•
19fc5ee
1
Parent(s):
34ac5d3
llama3 chatQA
Browse files- .github/workflows/tests-integration.yml +1 -0
- config/prompt_templates/llama3_chat.txt +7 -0
- config/prompt_templates/phi3.txt +1 -0
- qa_engine/config.py +1 -1
- qa_engine/logger.py +2 -0
- qa_engine/qa_engine.py +54 -127
- requirements.txt +0 -1
.github/workflows/tests-integration.yml
CHANGED
@@ -30,6 +30,7 @@ jobs:
|
|
30 |
|
31 |
- name: Install dependencies
|
32 |
run: |
|
|
|
33 |
pip install --no-cache-dir -r requirements.txt
|
34 |
cp config/.env.example config/.env
|
35 |
- name: Run unit tests
|
|
|
30 |
|
31 |
- name: Install dependencies
|
32 |
run: |
|
33 |
+
pip3 install --upgrade pip
|
34 |
pip install --no-cache-dir -r requirements.txt
|
35 |
cp config/.env.example config/.env
|
36 |
- name: Run unit tests
|
config/prompt_templates/llama3_chat.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.
|
2 |
+
|
3 |
+
{context}
|
4 |
+
|
5 |
+
User: {question} Please give a full and complete answer for the question.
|
6 |
+
|
7 |
+
Assistant:
|
config/prompt_templates/phi3.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
<|user|>\n Try to be as factual as possible. Answer a question from a given context. CONTEXT: {context} \n QUESTION: {question} <|end|>\n<|assistant|>
|
qa_engine/config.py
CHANGED
@@ -11,7 +11,7 @@ def get_env(env_name: str, default: Any = None, warn: bool = True) -> str:
|
|
11 |
if default is not None:
|
12 |
if warn:
|
13 |
logger.warning(
|
14 |
-
f'Environment variable {env_name} not found.' \
|
15 |
f'Using the default value: {default}.'
|
16 |
)
|
17 |
return default
|
|
|
11 |
if default is not None:
|
12 |
if warn:
|
13 |
logger.warning(
|
14 |
+
f'Environment variable {env_name} not found. ' \
|
15 |
f'Using the default value: {default}.'
|
16 |
)
|
17 |
return default
|
qa_engine/logger.py
CHANGED
@@ -2,6 +2,8 @@ import logging
|
|
2 |
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
|
|
|
|
5 |
|
6 |
def setup_logger() -> None:
|
7 |
"""
|
|
|
2 |
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
5 |
+
logging.getLogger('discord').setLevel(logging.ERROR)
|
6 |
+
logging.getLogger('discord.gateway').setLevel(logging.ERROR)
|
7 |
|
8 |
def setup_logger() -> None:
|
9 |
"""
|
qa_engine/qa_engine.py
CHANGED
@@ -1,19 +1,11 @@
|
|
1 |
-
import os
|
2 |
import re
|
3 |
-
import json
|
4 |
-
import requests
|
5 |
-
import subprocess
|
6 |
from typing import Mapping, Optional, Any
|
7 |
|
8 |
import torch
|
9 |
import transformers
|
10 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
from huggingface_hub import snapshot_download
|
12 |
-
from
|
13 |
-
from langchain import PromptTemplate, HuggingFaceHub, LLMChain
|
14 |
-
from langchain.llms import HuggingFacePipeline
|
15 |
-
from langchain.llms.base import LLM
|
16 |
-
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings, HuggingFaceInstructEmbeddings
|
17 |
from langchain.vectorstores import FAISS
|
18 |
from sentence_transformers import CrossEncoder
|
19 |
|
@@ -22,41 +14,7 @@ from qa_engine.response import Response
|
|
22 |
from qa_engine.mocks import MockLocalBinaryModel
|
23 |
|
24 |
|
25 |
-
class
|
26 |
-
model_id: str = None
|
27 |
-
model_path: str = None
|
28 |
-
llm: None = None
|
29 |
-
|
30 |
-
def __init__(self, config: Config):
|
31 |
-
super().__init__()
|
32 |
-
# pip install llama_cpp_python==0.1.39
|
33 |
-
from llama_cpp import Llama
|
34 |
-
|
35 |
-
self.model_id = config.question_answering_model_id
|
36 |
-
self.model_path = f'qa_engine/{self.model_id}'
|
37 |
-
if not os.path.exists(self.model_path):
|
38 |
-
raise ValueError(f'{self.model_path} does not exist')
|
39 |
-
self.llm = Llama(model_path=self.model_path, n_ctx=4096)
|
40 |
-
|
41 |
-
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
42 |
-
output = self.llm(
|
43 |
-
prompt,
|
44 |
-
max_tokens=1024,
|
45 |
-
stop=['Q:'],
|
46 |
-
echo=False
|
47 |
-
)
|
48 |
-
return output['choices'][0]['text']
|
49 |
-
|
50 |
-
@property
|
51 |
-
def _identifying_params(self) -> Mapping[str, Any]:
|
52 |
-
return {'name_of_model': self.model_id}
|
53 |
-
|
54 |
-
@property
|
55 |
-
def _llm_type(self) -> str:
|
56 |
-
return self.model_id
|
57 |
-
|
58 |
-
|
59 |
-
class TransformersPipelineModel(LLM):
|
60 |
model_id: str = None
|
61 |
min_new_tokens: int = None
|
62 |
max_new_tokens: int = None
|
@@ -64,7 +22,8 @@ class TransformersPipelineModel(LLM):
|
|
64 |
top_k: int = None
|
65 |
top_p: float = None
|
66 |
do_sample: bool = None
|
67 |
-
|
|
|
68 |
|
69 |
def __init__(self, config: Config):
|
70 |
super().__init__()
|
@@ -76,35 +35,32 @@ class TransformersPipelineModel(LLM):
|
|
76 |
self.top_p = config.top_p
|
77 |
self.do_sample = config.do_sample
|
78 |
|
79 |
-
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
80 |
-
model = AutoModelForCausalLM.from_pretrained(
|
81 |
self.model_id,
|
82 |
-
torch_dtype=torch.
|
83 |
-
|
84 |
-
load_in_8bit=False,
|
85 |
-
device_map='auto',
|
86 |
-
resume_download=True,
|
87 |
)
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
tokenizer
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
96 |
min_new_tokens=self.min_new_tokens,
|
97 |
max_new_tokens=self.max_new_tokens,
|
98 |
-
|
99 |
-
top_k=self.top_k,
|
100 |
-
top_p=self.top_p,
|
101 |
-
do_sample=self.do_sample,
|
102 |
)
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
output_text = output_text.replace(prompt+'\n', '')
|
107 |
-
return output_text
|
108 |
|
109 |
@property
|
110 |
def _identifying_params(self) -> Mapping[str, Any]:
|
@@ -115,39 +71,6 @@ class TransformersPipelineModel(LLM):
|
|
115 |
return self.model_id
|
116 |
|
117 |
|
118 |
-
class APIServedModel(LLM):
|
119 |
-
model_url: str = None
|
120 |
-
debug: bool = None
|
121 |
-
|
122 |
-
def __init__(self, model_url: str, debug: bool = False):
|
123 |
-
super().__init__()
|
124 |
-
if model_url[-1] == '/':
|
125 |
-
raise ValueError('URL should not end with a slash - "/"')
|
126 |
-
self.model_url = model_url
|
127 |
-
self.debug = debug
|
128 |
-
|
129 |
-
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
130 |
-
prompt_encoded = quote(prompt, safe='')
|
131 |
-
url = f'{self.model_url}/?prompt={prompt_encoded}'
|
132 |
-
if self.debug:
|
133 |
-
logger.info(f'URL: {url}')
|
134 |
-
try:
|
135 |
-
response = requests.get(url, timeout=1200, verify=False)
|
136 |
-
response.raise_for_status()
|
137 |
-
return json.loads(response.content)['output_text']
|
138 |
-
except Exception as err:
|
139 |
-
logger.error(f'Error: {err}')
|
140 |
-
return f'Error: {err}'
|
141 |
-
|
142 |
-
@property
|
143 |
-
def _identifying_params(self) -> Mapping[str, Any]:
|
144 |
-
return {'name_of_model': f'model url: {self.model_url}'}
|
145 |
-
|
146 |
-
@property
|
147 |
-
def _llm_type(self) -> str:
|
148 |
-
return 'api_model'
|
149 |
-
|
150 |
-
|
151 |
class QAEngine():
|
152 |
"""
|
153 |
QAEngine class, used for generating answers to questions.
|
@@ -163,16 +86,10 @@ class QAEngine():
|
|
163 |
self.num_relevant_docs=config.num_relevant_docs
|
164 |
self.add_sources_to_response=config.add_sources_to_response
|
165 |
self.use_messages_for_context=config.use_messages_in_context
|
166 |
-
self.debug=config.debug
|
167 |
-
|
168 |
self.first_stage_docs: int = 50
|
169 |
|
170 |
-
prompt = PromptTemplate(
|
171 |
-
template=self.prompt_template,
|
172 |
-
input_variables=['question', 'context']
|
173 |
-
)
|
174 |
self.llm_model = self._get_model()
|
175 |
-
self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
|
176 |
|
177 |
if self.use_docs_for_context:
|
178 |
logger.info(f'Downloading {self.index_repo_id}')
|
@@ -196,29 +113,39 @@ class QAEngine():
|
|
196 |
|
197 |
|
198 |
def _get_model(self):
|
199 |
-
if
|
200 |
-
logger.
|
201 |
-
return LocalBinaryModel(self.config)
|
202 |
-
elif 'api_models/' in self.question_answering_model_id:
|
203 |
-
logger.info('using api served model')
|
204 |
-
return APIServedModel(
|
205 |
-
model_url=self.question_answering_model_id.replace('api_models/', ''),
|
206 |
-
debug=self.debug
|
207 |
-
)
|
208 |
-
elif self.question_answering_model_id == 'mock':
|
209 |
-
logger.info('using mock model')
|
210 |
return MockLocalBinaryModel()
|
211 |
else:
|
212 |
logger.info('using transformers pipeline model')
|
213 |
-
return
|
214 |
-
|
215 |
-
|
216 |
@staticmethod
|
217 |
-
def
|
218 |
if '?' not in question:
|
219 |
question += '?'
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
|
|
|
|
222 |
|
223 |
@staticmethod
|
224 |
def _postprocess_answer(answer: str) -> str:
|
@@ -289,8 +216,8 @@ class QAEngine():
|
|
289 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
290 |
|
291 |
logger.info('Running LLM chain')
|
292 |
-
|
293 |
-
answer = self.
|
294 |
answer_postprocessed = QAEngine._postprocess_answer(answer)
|
295 |
response.set_answer(answer_postprocessed)
|
296 |
logger.info('Received answer')
|
|
|
|
|
1 |
import re
|
|
|
|
|
|
|
2 |
from typing import Mapping, Optional, Any
|
3 |
|
4 |
import torch
|
5 |
import transformers
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
from huggingface_hub import snapshot_download
|
8 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|
|
|
|
|
|
|
|
9 |
from langchain.vectorstores import FAISS
|
10 |
from sentence_transformers import CrossEncoder
|
11 |
|
|
|
14 |
from qa_engine.mocks import MockLocalBinaryModel
|
15 |
|
16 |
|
17 |
+
class HuggingFaceModel:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
model_id: str = None
|
19 |
min_new_tokens: int = None
|
20 |
max_new_tokens: int = None
|
|
|
22 |
top_k: int = None
|
23 |
top_p: float = None
|
24 |
do_sample: bool = None
|
25 |
+
tokenizer: transformers.PreTrainedTokenizer = None
|
26 |
+
model: transformers.PreTrainedModel = None
|
27 |
|
28 |
def __init__(self, config: Config):
|
29 |
super().__init__()
|
|
|
35 |
self.top_p = config.top_p
|
36 |
self.do_sample = config.do_sample
|
37 |
|
38 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
39 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
40 |
self.model_id,
|
41 |
+
torch_dtype=torch.float16,
|
42 |
+
device_map="auto"
|
|
|
|
|
|
|
43 |
)
|
44 |
+
|
45 |
+
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
46 |
+
tokenized_prompt = self.tokenizer(
|
47 |
+
self.tokenizer.bos_token + prompt,
|
48 |
+
return_tensors="pt"
|
49 |
+
).to(self.model.device)
|
50 |
+
terminators = [
|
51 |
+
self.tokenizer.eos_token_id,
|
52 |
+
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
53 |
+
]
|
54 |
+
outputs = self.model.generate(
|
55 |
+
input_ids=tokenized_prompt.input_ids,
|
56 |
+
attention_mask=tokenized_prompt.attention_mask,
|
57 |
min_new_tokens=self.min_new_tokens,
|
58 |
max_new_tokens=self.max_new_tokens,
|
59 |
+
eos_token_id=terminators
|
|
|
|
|
|
|
60 |
)
|
61 |
+
response = outputs[0][tokenized_prompt.input_ids.shape[-1]:]
|
62 |
+
decoded_response = self.tokenizer.decode(response, skip_special_tokens=True)
|
63 |
+
return decoded_response
|
|
|
|
|
64 |
|
65 |
@property
|
66 |
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
71 |
return self.model_id
|
72 |
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
class QAEngine():
|
75 |
"""
|
76 |
QAEngine class, used for generating answers to questions.
|
|
|
86 |
self.num_relevant_docs=config.num_relevant_docs
|
87 |
self.add_sources_to_response=config.add_sources_to_response
|
88 |
self.use_messages_for_context=config.use_messages_in_context
|
89 |
+
self.debug=config.debug
|
|
|
90 |
self.first_stage_docs: int = 50
|
91 |
|
|
|
|
|
|
|
|
|
92 |
self.llm_model = self._get_model()
|
|
|
93 |
|
94 |
if self.use_docs_for_context:
|
95 |
logger.info(f'Downloading {self.index_repo_id}')
|
|
|
113 |
|
114 |
|
115 |
def _get_model(self):
|
116 |
+
if self.question_answering_model_id == 'mock':
|
117 |
+
logger.warn('using mock model')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
return MockLocalBinaryModel()
|
119 |
else:
|
120 |
logger.info('using transformers pipeline model')
|
121 |
+
return HuggingFaceModel(self.config)
|
122 |
+
|
|
|
123 |
@staticmethod
|
124 |
+
def _preprocess_input(question: str, context: str) -> str:
|
125 |
if '?' not in question:
|
126 |
question += '?'
|
127 |
+
|
128 |
+
# llama3 chatQA specific
|
129 |
+
messages = [
|
130 |
+
{"role": "user", "content": question}
|
131 |
+
]
|
132 |
+
|
133 |
+
system = "System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context."
|
134 |
+
instruction = "Please give a full and complete answer for the question."
|
135 |
+
|
136 |
+
for item in messages:
|
137 |
+
if item['role'] == "user":
|
138 |
+
## only apply this instruction for the first user turn
|
139 |
+
item['content'] = instruction + " " + item['content']
|
140 |
+
break
|
141 |
+
|
142 |
+
conversation = '\n\n'.join([
|
143 |
+
"User: " + item["content"] if item["role"] == "user" else
|
144 |
+
"Assistant: " + item["content"] for item in messages
|
145 |
+
]) + "\n\nAssistant:"
|
146 |
|
147 |
+
inputs = system + "\n\n" + context + "\n\n" + conversation
|
148 |
+
return inputs
|
149 |
|
150 |
@staticmethod
|
151 |
def _postprocess_answer(answer: str) -> str:
|
|
|
216 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
217 |
|
218 |
logger.info('Running LLM chain')
|
219 |
+
inputs = QAEngine._preprocess_input(question, context)
|
220 |
+
answer = self.llm_model._call(inputs)
|
221 |
answer_postprocessed = QAEngine._postprocess_answer(answer)
|
222 |
response.set_answer(answer_postprocessed)
|
223 |
logger.info('Received answer')
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
torch
|
2 |
-
torchvision
|
3 |
transformers
|
4 |
accelerate
|
5 |
einops
|
|
|
1 |
torch
|
|
|
2 |
transformers
|
3 |
accelerate
|
4 |
einops
|