import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import spaces # Model configuration MODEL_PATH = "ibm-granite/granite-4.0-h-small" # Load tokenizer (doesn't need GPU) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # Load model and move to GPU model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=torch.float16, low_cpu_mem_usage=True ) model.to('cuda') model.eval() @spaces.GPU(duration=60) def generate_response(message, history): """Generate response using IBM Granite model with ZeroGPU with streaming.""" # Format the conversation history chat = [] # Add conversation history for user_msg, assistant_msg in history: chat.append({"role": "user", "content": user_msg}) if assistant_msg: chat.append({"role": "assistant", "content": assistant_msg}) # Add current message chat.append({"role": "user", "content": message}) # Apply chat template formatted_chat = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True ) # Tokenize the text input_tokens = tokenizer( formatted_chat, return_tensors="pt", truncation=True, max_length=2048 ).to('cuda') # Setup for streaming generation from transformers import TextIteratorStreamer from threading import Thread streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) # Generation kwargs generation_kwargs = dict( **input_tokens, max_new_tokens=512, temperature=0.7, top_p=0.95, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, streamer=streamer ) # Start generation in a separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Stream the response response = "" for new_text in streamer: response += new_text yield response thread.join() # Create the Gradio interface with gr.Blocks(title="IBM Granite Chat", theme=gr.themes.Soft()) as demo: gr.HTML( """

🪨 IBM Granite 4.0 Chat

Chat with IBM Granite 4.0-h Small model powered by ZeroGPU

Built with anycoder

""" ) chatbot = gr.Chatbot( height=500, bubble_full_width=False, show_copy_button=True, layout="panel" ) with gr.Row(): msg = gr.Textbox( label="Your Message", placeholder="Type your message here and press Enter...", lines=2, scale=9, autofocus=True ) submit_btn = gr.Button("Send", variant="primary", scale=1) with gr.Row(): clear_btn = gr.ClearButton([msg, chatbot], value="🗑️ Clear Chat") with gr.Accordion("Advanced Settings", open=False): gr.Markdown(""" ### Model Information - **Model**: IBM Granite 4.0-h Small - **Parameters**: Optimized for efficient inference - **Powered by**: Hugging Face ZeroGPU ### Tips for Better Responses: - Be specific and clear in your questions - Provide context when needed - The model excels at various tasks including coding, analysis, and general conversation """) # Example prompts gr.Examples( examples=[ "Explain quantum computing in simple terms", "Write a Python function to calculate factorial", "What are the main differences between machine learning and deep learning?", "Help me debug this code: def add(a, b) return a + b", "Create a healthy meal plan for a week", "Explain the concept of blockchain technology", ], inputs=msg, label="Example Prompts" ) # Event handlers def user_submit(message, history): if not message.strip(): return "", history return "", history + [[message, None]] def bot_response(history): if not history or history[-1][1] is not None: yield history return user_message = history[-1][0] history[-1][1] = "" for partial_response in generate_response(user_message, history[:-1]): history[-1][1] = partial_response yield history # Connect events msg.submit(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( bot_response, chatbot, chatbot ) submit_btn.click(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( bot_response, chatbot, chatbot ) # Add footer gr.HTML( """

This application uses the IBM Granite 4.0-h Small model for generating responses.
Responses are generated using AI and should be verified for accuracy.

""" ) # Launch the application if __name__ == "__main__": demo.queue() demo.launch( show_api=False, share=False )