File size: 3,686 Bytes
486275c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import json, subprocess

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

import config as conf
import time, json

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

class Mallam():
    # check llm health    
    @classmethod
    async def health_check_llm(cls):
        try:
            print("HarvestBuddy-API is running!")
        except subprocess.CalledProcessError as e:
            # raise Exception(f"Failed to ping LLM sever. Error:\n{e.output}")
            # print(f"Failed to ping {hostname}. Error:\n{e.output}")
            print(f"Connection to LLM sever failed. Error:\n{e.output}")


    async def load_model_token():
        # Measure time for loading model and tokenizer
        start_time = time.time()
        print(conf.LLM['MODEL'])

        tokenizer = AutoTokenizer.from_pretrained(conf.LLM['MODEL'])
        model = AutoModelForCausalLM.from_pretrained(
            conf.LLM['MODEL'],
        )
        model = model.to(device)

        end_time = time.time()
        print(f"Model and tokenizer loaded in {end_time - start_time:.2f} seconds.")

        return tokenizer, model

    
    async def parse_chat(messages, function_call=None):
        # Measure time for prompt parsing
        start_time = time.time()
        user_query = messages[-1]['content']

        users, assistants = [], []
        for q in messages[:-1]:
            if q['role'] == 'user':
                users.append(q['content'])
            elif q['role'] == 'assistant':
                assistants.append(q['content'])

        texts = ['<s>']
        
        if function_call:
            fs = []
            for f in function_call:
                f = json.dumps(f, indent=4)
                fs.append(f)
            fs = '\n\n'.join(fs)
            texts.append(f'\n[FUNCTIONCALL]\n{fs}\n')
            
        for u, a in zip(users, assistants):
            texts.append(f'[INST] {u.strip()} [/INST] {a.strip()}</s>')

        texts.append(f'[INST] {user_query.strip()} [/INST]')
        prompt = ''.join(texts).strip()

        end_time = time.time()
        print(f"Prompt parsed in {end_time - start_time:.2f} seconds.")

        return prompt


    async def tokenize(tokenizer, prompt):
        # Measure time for tokenization
        start_time = time.time()

        inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)

        end_time = time.time()
        print(f"Tokenization completed in {end_time - start_time:.2f} seconds.")

        return inputs


    async def gpt(inputs, model):
        # Measure time for model generation
        start_time = time.time()
        
        generate_kwargs = dict(
            inputs,
            max_new_tokens=256,  # ~120 words (short but informative)
            top_p=0.85, # Prioritize most probable words
            top_k=50, # Keep reasonable diversity
            temperature=0.65, # Less randomness for factual answers
            do_sample=True, # Slight variety in responses
            num_beams=3, # Explores multiple sentence possibilities, chooses best overall sequence
            repetition_penalty=1.0,  # Reduce repeated phrases
            # stream=True
        )
        r = model.generate(**generate_kwargs)

        end_time = time.time()
        print(f"Model generation completed in {end_time - start_time:.2f} seconds.")

        return r


    async def decode(tokenizer, r):
        # Measure time for decoding
        start_time = time.time()

        res = tokenizer.decode(r[0])

        end_time = time.time()
        print(f"Decoding completed in {end_time - start_time:.2f} seconds.")

        return res