my-deep-world / app.py
alesb2010
Update space
a02cabe
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch # Needed for model operations, especially on GPU
import os
# --- Model Loading ---
# Define the model ID
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
tokenizer = None
model = None
# Use device_map="auto" to automatically handle placing the model on GPU/CPU
# Use torch_dtype=torch.bfloat16 or torch.float16 for reduced memory usage on compatible GPUs
try:
print(f"Loading tokenizer for {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
print("Tokenizer loaded.")
print(f"Loading model {model_id}...")
# Adjust torch_dtype based on your GPU capability and memory (float16 or bfloat16 are common for speed/memory)
# If no GPU is available, remove device_map="auto" and the torch_dtype argument, or set device_map="cpu"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", # Automatically select device (GPU or CPU)
torch_dtype=torch.bfloat16 # Use bfloat16 for better performance/memory on compatible GPUs
# If you have less VRAM, try torch.float16, or remove this line for float32 (uses more VRAM)
)
print("Model loaded successfully!")
# Optional: Check if the tokenizer has a chat template (DeepSeek/Qwen should)
if not hasattr(tokenizer, 'apply_chat_template'):
print(f"Warning: Tokenizer for {model_id} does not have a chat template. Model might not be optimized for chat.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
tokenizer = None # Ensure both are None if loading fails
model = None
# --- Inference Function for Gradio ---
def chat_with_model(user_input_string):
if model is None or tokenizer is None:
# Return error message if model loading failed
return "Model or tokenizer failed to load. Please check App Space logs."
# --- 1. Format the input into the chat structure ---
# For a single-turn chat from user input, the messages list is simple
messages = [
{"role": "user", "content": user_input_string},
# Add previous turns here for multi-turn chat (more complex)
]
# --- 2. Apply the chat template ---
# The tokenizer converts the messages list into a single string formatted
# according to the model's specific chat requirements (e.g., adding <|im_start|>user tokens)
# add_generation_prompt=True tells the model it should generate the assistant's response next
try:
chat_input_string = tokenizer.apply_chat_template(
messages,
tokenize=False, # Return a string, not token IDs yet
add_generation_prompt=True
)
print(f"Formatted chat input: {chat_input_string[:200]}...") # Log the formatted input
except Exception as e:
print(f"Error applying chat template: {e}")
return f"Error formatting input: {e}"
# --- 3. Tokenize the formatted input ---
try:
input_ids = tokenizer(chat_input_string, return_tensors="pt").input_ids
# Move input tensors to the same device as the model (e.g., GPU)
if model.device.type != 'cpu':
input_ids = input_ids.to(model.device)
print(f"Input token IDs shape: {input_ids.shape}")
except Exception as e:
print(f"Error tokenizing input: {e}")
return f"Error tokenizing input: {e}"
# --- 4. Generate response ---
try:
print("Starting text generation...")
# Use model.generate() for text generation
# max_new_tokens limits the length of the generated response
# Add other generation parameters (temperature, top_p, etc.) for more control
with torch.no_grad(): # Inference doesn't need gradient calculation, saves memory
outputs = model.generate(
input_ids,
max_new_tokens=512, # Limit the response length
temperature=0.7, # Control creativity (adjust as needed)
do_sample=True, # Enable sampling (recommended for chat)
top_p=0.95, # Top-p sampling
# Add other parameters like num_return_sequences if you want multiple responses
)
print("Text generation complete.")
# --- 5. Decode the output ---
# The generated output contains the original input tokens + the new tokens generated by the model.
# Decode only the new tokens that the model generated.
generated_tokens = outputs[0, input_ids.shape[-1]:]
assistant_response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Clean up potential leading/trailing whitespace
assistant_response = assistant_response.strip()
print(f"Generated response: {assistant_response[:200]}...") # Log the generated response
return assistant_response
except Exception as e:
print(f"Error during text generation: {e}")
return f"An error occurred during generation: {e}"
# --- Gradio Interface Definition ---
# Only create the interface if the model and tokenizer loaded successfully
if model is not None and tokenizer is not None:
print("Creating Gradio interface...")
interface = gr.Interface(
fn=chat_with_model,
inputs=gr.Textbox(label="Digite sua mensagem (Chat em Português do Brasil)", lines=5),
outputs=gr.Textbox(label="Resposta do Modelo", lines=10),
title="DeepSeek-R1-Distill-Qwen-7B Chat PT-BR Demo",
description="Converse com o modelo DeepSeek-R1-Distill-Qwen-7B, versão destilada.",
allow_flagging="never" # Disable flagging for a simple demo
)
print("Gradio interface created.")
else:
# Create a simple interface indicating an error if model loading failed
print("Model/Tokenizer failed to load, creating error interface.")
interface = gr.Interface(
fn=lambda x: "O modelo ou tokenizer falhou ao carregar. Verifique os logs do App Space para mais detalhes.",
inputs=gr.Textbox(label="Status da Aplicação"),
outputs=gr.Textbox(),
title="Erro na Aplicação",
description="Falha ao carregar o modelo Transformers. Consulte os logs para diagnóstico."
)
# --- Launch the Gradio App ---
# This part is necessary for the App Space to run your Gradio app
if __name__ == "__main__":
print("Launching Gradio interface...")
# App Spaces automatically set server_name and server_port
interface.launch()
print("Gradio launch initiated.")