File size: 1,902 Bytes
997bbfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b50d476
997bbfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import torch
import transformers

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


class ChatService:
    def __init__(self):
        pass

    @staticmethod
    def load_model(model_name=""):
        global tokenizer, pipeline

        print("Loading " + model_name + "...")

        # config
        gpu_count = torch.cuda.device_count()
        print('gpu_count', gpu_count)

        tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
        pipeline = transformers.pipeline(
            task="text-generation",
            model=model_name,
            torch_dtype=torch.float16,
            device_map="auto",
        )

    @staticmethod
    def generate_message(req):
        history = req["chat"]
        assistant_name = req["assistant_name"] + ": "
        system_message = req.get("system_message") if req.get("system_message") is not None else ""
        temperature = req.get("temperature") if req.get("temperature") is not None else 1
        top_p = req.get("top_p") if req.get("top_p") is not None else 1
        top_k = req.get("top_k") if req.get("top_k") is not None else 10
        max_length = req.get("max_length") if req.get("max_length") is not None else 1000
        ending_tag = "[/INST]"

        fulltext = "[INST] <<SYS>>" + system_message + "<</SYS>>" + "\n\n".join(
            history) + "\n\n" + assistant_name + ending_tag

        sequences = pipeline(
            fulltext,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            num_return_sequences=1,
            eos_token_id=tokenizer.eos_token_id,
            max_length=max_length,
        )

        response = sequences[0]['generated_text'].split(ending_tag)[1].split(assistant_name)

        response = response[1] if len(response) > 1 else response[0]

        response = response.strip()

        return response