Spaces:
Sleeping
Sleeping
File size: 6,573 Bytes
a6a5a69 a02cabe a6a5a69 a02cabe 01ddf54 a02cabe 01ddf54 a02cabe a6a5a69 a02cabe a6a5a69 a02cabe 5cf7c39 a02cabe 5cf7c39 a02cabe a6a5a69 a02cabe 5cf7c39 a02cabe 5cf7c39 a6a5a69 a02cabe 01ddf54 a02cabe 01ddf54 a02cabe 01ddf54 a02cabe 01ddf54 a02cabe 01ddf54 a02cabe 01ddf54 a6a5a69 a02cabe a6a5a69 a02cabe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch # Needed for model operations, especially on GPU
import os
# --- Model Loading ---
# Define the model ID
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
tokenizer = None
model = None
# Use device_map="auto" to automatically handle placing the model on GPU/CPU
# Use torch_dtype=torch.bfloat16 or torch.float16 for reduced memory usage on compatible GPUs
try:
print(f"Loading tokenizer for {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
print("Tokenizer loaded.")
print(f"Loading model {model_id}...")
# Adjust torch_dtype based on your GPU capability and memory (float16 or bfloat16 are common for speed/memory)
# If no GPU is available, remove device_map="auto" and the torch_dtype argument, or set device_map="cpu"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", # Automatically select device (GPU or CPU)
torch_dtype=torch.bfloat16 # Use bfloat16 for better performance/memory on compatible GPUs
# If you have less VRAM, try torch.float16, or remove this line for float32 (uses more VRAM)
)
print("Model loaded successfully!")
# Optional: Check if the tokenizer has a chat template (DeepSeek/Qwen should)
if not hasattr(tokenizer, 'apply_chat_template'):
print(f"Warning: Tokenizer for {model_id} does not have a chat template. Model might not be optimized for chat.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
tokenizer = None # Ensure both are None if loading fails
model = None
# --- Inference Function for Gradio ---
def chat_with_model(user_input_string):
if model is None or tokenizer is None:
# Return error message if model loading failed
return "Model or tokenizer failed to load. Please check App Space logs."
# --- 1. Format the input into the chat structure ---
# For a single-turn chat from user input, the messages list is simple
messages = [
{"role": "user", "content": user_input_string},
# Add previous turns here for multi-turn chat (more complex)
]
# --- 2. Apply the chat template ---
# The tokenizer converts the messages list into a single string formatted
# according to the model's specific chat requirements (e.g., adding <|im_start|>user tokens)
# add_generation_prompt=True tells the model it should generate the assistant's response next
try:
chat_input_string = tokenizer.apply_chat_template(
messages,
tokenize=False, # Return a string, not token IDs yet
add_generation_prompt=True
)
print(f"Formatted chat input: {chat_input_string[:200]}...") # Log the formatted input
except Exception as e:
print(f"Error applying chat template: {e}")
return f"Error formatting input: {e}"
# --- 3. Tokenize the formatted input ---
try:
input_ids = tokenizer(chat_input_string, return_tensors="pt").input_ids
# Move input tensors to the same device as the model (e.g., GPU)
if model.device.type != 'cpu':
input_ids = input_ids.to(model.device)
print(f"Input token IDs shape: {input_ids.shape}")
except Exception as e:
print(f"Error tokenizing input: {e}")
return f"Error tokenizing input: {e}"
# --- 4. Generate response ---
try:
print("Starting text generation...")
# Use model.generate() for text generation
# max_new_tokens limits the length of the generated response
# Add other generation parameters (temperature, top_p, etc.) for more control
with torch.no_grad(): # Inference doesn't need gradient calculation, saves memory
outputs = model.generate(
input_ids,
max_new_tokens=512, # Limit the response length
temperature=0.7, # Control creativity (adjust as needed)
do_sample=True, # Enable sampling (recommended for chat)
top_p=0.95, # Top-p sampling
# Add other parameters like num_return_sequences if you want multiple responses
)
print("Text generation complete.")
# --- 5. Decode the output ---
# The generated output contains the original input tokens + the new tokens generated by the model.
# Decode only the new tokens that the model generated.
generated_tokens = outputs[0, input_ids.shape[-1]:]
assistant_response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Clean up potential leading/trailing whitespace
assistant_response = assistant_response.strip()
print(f"Generated response: {assistant_response[:200]}...") # Log the generated response
return assistant_response
except Exception as e:
print(f"Error during text generation: {e}")
return f"An error occurred during generation: {e}"
# --- Gradio Interface Definition ---
# Only create the interface if the model and tokenizer loaded successfully
if model is not None and tokenizer is not None:
print("Creating Gradio interface...")
interface = gr.Interface(
fn=chat_with_model,
inputs=gr.Textbox(label="Digite sua mensagem (Chat em Português do Brasil)", lines=5),
outputs=gr.Textbox(label="Resposta do Modelo", lines=10),
title="DeepSeek-R1-Distill-Qwen-7B Chat PT-BR Demo",
description="Converse com o modelo DeepSeek-R1-Distill-Qwen-7B, versão destilada.",
allow_flagging="never" # Disable flagging for a simple demo
)
print("Gradio interface created.")
else:
# Create a simple interface indicating an error if model loading failed
print("Model/Tokenizer failed to load, creating error interface.")
interface = gr.Interface(
fn=lambda x: "O modelo ou tokenizer falhou ao carregar. Verifique os logs do App Space para mais detalhes.",
inputs=gr.Textbox(label="Status da Aplicação"),
outputs=gr.Textbox(),
title="Erro na Aplicação",
description="Falha ao carregar o modelo Transformers. Consulte os logs para diagnóstico."
)
# --- Launch the Gradio App ---
# This part is necessary for the App Space to run your Gradio app
if __name__ == "__main__":
print("Launching Gradio interface...")
# App Spaces automatically set server_name and server_port
interface.launch()
print("Gradio launch initiated.") |