Arsh-llm-demo / app.py
arshiaafshani's picture
Update app.py
9a1f1ed verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from gradio_client import utils as client_utils
# پچ کردن تابع get_type
def patched_get_type(schema):
if isinstance(schema, bool):
return "any" if schema else "never"
if "const" in schema:
return repr(schema["const"])
if "type" in schema:
return schema["type"]
return "any"
client_utils.get_type = patched_get_type
# چاپ نسخه‌های پکیج‌ها برای عیب‌یابی
print(f"Gradio version: {gr.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"Torch version: {torch.__version__}")
print(f"Huggingface_hub version: {huggingface_hub.__version__}")
# Load model and tokenizer with force_download
model_name = "arshiaafshani/Arsh-llm"
tokenizer = AutoTokenizer.from_pretrained(model_name, force_download=True, resume_download=False)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, force_download=True, resume_download=False)
# تنظیم توکن‌های خاص
tokenizer.bos_token = "<sos>"
tokenizer.eos_token = "<|endoftext|>"
# Create pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
def respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
chat_history = chat_history or []
messages = [{"role": "system", "content": system_message}] + \
[{"role": "user", "content": msg} for msg, _ in chat_history] + \
[{"role": "user", "content": message}, {"role": "assistant", "content": ""}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
output = pipe(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repeat_penalty,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = output[0]['generated_text'][len(prompt):].strip()
chat_history.append((message, response))
return chat_history
with gr.Blocks() as demo:
gr.Markdown("# Arsh-LLM Demo")
with gr.Row():
with gr.Column():
system_msg = gr.Textbox("You are Arsh, a helpful assistant by Arshia Afshani. You should answer the user carefully.",
label="System Message")
max_tokens = gr.Slider(1, 4096, value=2048, step=1, label="Max Tokens")
temperature = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
top_k = gr.Slider(0, 100, value=40, step=1, label="Top-k")
repeat_penalty = gr.Slider(0.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(label="Your Message")
clear = gr.Button("Clear")
def submit_message(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
chat_history = chat_history or []
response = respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty)
return "", response
msg.submit(
submit_message,
[msg, chatbot, system_msg, max_tokens, temperature, top_p, top_k, repeat_penalty],
[msg, chatbot]
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)