Spaces:
Paused
Paused
import os | |
import torch | |
import transformers | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
class ChatService: | |
def __init__(self): | |
pass | |
def load_model(model_name=""): | |
global tokenizer, pipeline | |
print("Loading " + model_name + "...") | |
# config | |
gpu_count = torch.cuda.device_count() | |
print('gpu_count', gpu_count) | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | |
pipeline = transformers.pipeline( | |
task="text-generation", | |
model=model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
def generate_message(req): | |
history = req["chat"] | |
assistant_name = req["assistant_name"] + ": " | |
system_message = req.get("system_message") if req.get("system_message") is not None else "" | |
temperature = req.get("temperature") if req.get("temperature") is not None else 1 | |
top_p = req.get("top_p") if req.get("top_p") is not None else 1 | |
top_k = req.get("top_k") if req.get("top_k") is not None else 10 | |
max_length = req.get("max_length") if req.get("max_length") is not None else 1000 | |
ending_tag = "[/INST]" | |
fulltext = "[INST] <<SYS>>" + system_message + "<</SYS>>" + "\n\n".join( | |
history) + "\n\n" + assistant_name + ending_tag | |
sequences = pipeline( | |
fulltext, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
num_return_sequences=1, | |
eos_token_id=tokenizer.eos_token_id, | |
max_length=max_length, | |
) | |
response = sequences[0]['generated_text'].split(ending_tag)[1].split(assistant_name) | |
response = response[1] if len(response) > 1 else response[0] | |
response = response.strip() | |
return response | |