Spaces:
Running
Running
import gradio as gr | |
import os, sys | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, pipeline | |
from transformers import LlamaTokenizer | |
import torch | |
import spaces | |
import psutil | |
# Define the model repository | |
REPO_NAME = 'schuler/experimental-JP47D56C' | |
# How to cache? | |
def load_model(repo_name): | |
# tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True) | |
tokenizer = LlamaTokenizer.from_pretrained(repo_name, trust_remote_code=True) | |
generator_conf = GenerationConfig.from_pretrained(repo_name) | |
model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="eager") | |
# model.to('cuda') | |
return tokenizer, generator_conf, model | |
# tokenizer, generator_conf, model, generator = False, False, False, False | |
# with gr.Blocks() as main_block: | |
tokenizer, generator_conf, model = load_model(REPO_NAME) | |
global_error = '' | |
try: | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
except Exception as e: | |
global_error = f"Failed to load model: {str(e)}" | |
def local_generate( | |
prompt, | |
generation_config, | |
max_new_tokens, | |
do_sample=True, | |
top_p=0.25, | |
repetition_penalty=1.2, | |
temperature=1.0 | |
): | |
response_output = generator( | |
prompt, | |
generation_config=generation_config, | |
max_new_tokens=max_new_tokens, | |
do_sample=do_sample, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
temperature=temperature | |
) | |
generated_text = response_output[0]['generated_text'] | |
# Extract the assistant's response | |
result = generated_text[len(prompt):] | |
return result | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
result = 'none' | |
try: | |
# Build the conversation prompt | |
prompt = '' | |
messages = [] | |
if (len(system_message)>0): | |
prompt = "<|assistant|>"+system_message+f"<|end|>\n" | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
for hmessage in messages: | |
role = "<|assistant|>" if hmessage['role'] == 'assistant' else "<|user|>" | |
prompt += f"{role}{hmessage['content']}<|end|>" | |
prompt += f"<|assistant|>" | |
tokens_cnt = 0 | |
tokens_inc = 64 | |
last_token_len = 1 | |
full_result = '' | |
while ( (tokens_cnt < max_tokens) and (last_token_len > 0) ): | |
# Generate the response | |
result = local_generate( | |
prompt, | |
generation_config=generator_conf, | |
max_new_tokens=tokens_inc, | |
do_sample=True, | |
top_p=top_p, | |
repetition_penalty=1.2, | |
temperature=temperature | |
) | |
full_result = full_result + result | |
prompt = prompt + result | |
tokens_cnt = tokens_cnt + tokens_inc | |
last_token_len = len(result) | |
yield full_result | |
except Exception as error: | |
exc_type, exc_obj, exc_tb = sys.exc_info() | |
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] | |
result = str(error) +':'+ exc_type +':'+ fname +':'+ exc_tb.tb_lineno | |
yield result | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
embed_params = sum(p.numel() for p in model.model.embed_tokens.parameters())*2 | |
non_embed_params = (trainable_params - embed_params) / 1e6 | |
cpu_usage = psutil.cpu_percent(interval=1) | |
status_text = \ | |
f"This chat uses the {REPO_NAME} model with {model.get_memory_footprint() / 1e6:.2f} MB memory footprint. " + \ | |
f"Current CPU usage is {cpu_usage:.2f}% . '" + \ | |
f"Total number of non embedding trainable parameters: {non_embed_params:.2f} million. " + \ | |
f"You may ask questions such as 'What is biology?' or 'What is the human body?'" | |
# """ | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="" + global_error, label="System message"), | |
gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.25, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
description=status_text | |
) | |
""" | |
with gr.Blocks() as demo: | |
# Display the status text at the top | |
gr.Markdown(status_text) | |
# Create the ChatInterface | |
chat = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="" + global_error, label="System message"), | |
gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.25, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
) | |
""" | |
if __name__ == "__main__": | |
demo.launch() | |