import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import spaces # Model configuration MODEL_PATH = "ibm-granite/granite-4.0-h-small" # Load tokenizer (doesn't need GPU) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # Load model and move to GPU model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=torch.float16, low_cpu_mem_usage=True ) model.to('cuda') model.eval() @spaces.GPU(duration=60) def generate_response(message, history): """Generate response using IBM Granite model with ZeroGPU with streaming.""" # Format the conversation history chat = [] # Add conversation history for user_msg, assistant_msg in history: chat.append({"role": "user", "content": user_msg}) if assistant_msg: chat.append({"role": "assistant", "content": assistant_msg}) # Add current message chat.append({"role": "user", "content": message}) # Apply chat template formatted_chat = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True ) # Tokenize the text input_tokens = tokenizer( formatted_chat, return_tensors="pt", truncation=True, max_length=2048 ).to('cuda') # Setup for streaming generation from transformers import TextIteratorStreamer from threading import Thread streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) # Generation kwargs generation_kwargs = dict( **input_tokens, max_new_tokens=512, temperature=0.7, top_p=0.95, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, streamer=streamer ) # Start generation in a separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Stream the response response = "" for new_text in streamer: response += new_text yield response thread.join() # Create the Gradio interface with gr.Blocks(title="IBM Granite Chat", theme=gr.themes.Soft()) as demo: gr.HTML( """
Chat with IBM Granite 4.0-h Small model powered by ZeroGPU
This application uses the IBM Granite 4.0-h Small model for generating responses.
Responses are generated using AI and should be verified for accuracy.