harvestbuddy-api / model.py
sohok's picture
Init commit from local
486275c
import json, subprocess
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import config as conf
import time, json
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
class Mallam():
# check llm health
@classmethod
async def health_check_llm(cls):
try:
print("HarvestBuddy-API is running!")
except subprocess.CalledProcessError as e:
# raise Exception(f"Failed to ping LLM sever. Error:\n{e.output}")
# print(f"Failed to ping {hostname}. Error:\n{e.output}")
print(f"Connection to LLM sever failed. Error:\n{e.output}")
async def load_model_token():
# Measure time for loading model and tokenizer
start_time = time.time()
print(conf.LLM['MODEL'])
tokenizer = AutoTokenizer.from_pretrained(conf.LLM['MODEL'])
model = AutoModelForCausalLM.from_pretrained(
conf.LLM['MODEL'],
)
model = model.to(device)
end_time = time.time()
print(f"Model and tokenizer loaded in {end_time - start_time:.2f} seconds.")
return tokenizer, model
async def parse_chat(messages, function_call=None):
# Measure time for prompt parsing
start_time = time.time()
user_query = messages[-1]['content']
users, assistants = [], []
for q in messages[:-1]:
if q['role'] == 'user':
users.append(q['content'])
elif q['role'] == 'assistant':
assistants.append(q['content'])
texts = ['<s>']
if function_call:
fs = []
for f in function_call:
f = json.dumps(f, indent=4)
fs.append(f)
fs = '\n\n'.join(fs)
texts.append(f'\n[FUNCTIONCALL]\n{fs}\n')
for u, a in zip(users, assistants):
texts.append(f'[INST] {u.strip()} [/INST] {a.strip()}</s>')
texts.append(f'[INST] {user_query.strip()} [/INST]')
prompt = ''.join(texts).strip()
end_time = time.time()
print(f"Prompt parsed in {end_time - start_time:.2f} seconds.")
return prompt
async def tokenize(tokenizer, prompt):
# Measure time for tokenization
start_time = time.time()
inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)
end_time = time.time()
print(f"Tokenization completed in {end_time - start_time:.2f} seconds.")
return inputs
async def gpt(inputs, model):
# Measure time for model generation
start_time = time.time()
generate_kwargs = dict(
inputs,
max_new_tokens=256, # ~120 words (short but informative)
top_p=0.85, # Prioritize most probable words
top_k=50, # Keep reasonable diversity
temperature=0.65, # Less randomness for factual answers
do_sample=True, # Slight variety in responses
num_beams=3, # Explores multiple sentence possibilities, chooses best overall sequence
repetition_penalty=1.0, # Reduce repeated phrases
# stream=True
)
r = model.generate(**generate_kwargs)
end_time = time.time()
print(f"Model generation completed in {end_time - start_time:.2f} seconds.")
return r
async def decode(tokenizer, r):
# Measure time for decoding
start_time = time.time()
res = tokenizer.decode(r[0])
end_time = time.time()
print(f"Decoding completed in {end_time - start_time:.2f} seconds.")
return res