|
import os |
|
import time |
|
import gc |
|
import threading |
|
from itertools import islice |
|
from datetime import datetime |
|
import gradio as gr |
|
import torch |
|
from transformers import pipeline, TextIteratorStreamer |
|
from duckduckgo_search import DDGS |
|
import spaces |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cancel_event = threading.Event() |
|
|
|
|
|
|
|
|
|
MODELS = { |
|
|
|
"Qwen3-8B": {"repo_id": "Qwen/Qwen3-8B", "description": "Qwen3-8B"}, |
|
"Qwen3-4B": {"repo_id": "Qwen/Qwen3-4B", "description": "Qwen3-4B"}, |
|
"Qwen3-1.7B": {"repo_id": "Qwen/Qwen3-1,7B", "description": "Qwen3-1.7B"}, |
|
"Qwen3-0.6B": {"repo_id": "Qwen/Qwen3-0.6B", "description": "Qwen3-0.6B"}, |
|
"Gemma-3-4B-IT": {"repo_id": "unsloth/gemma-3-4b-it", "description": "Gemma-3-4B-IT"}, |
|
"SmolLM2-135M-Instruct-TaiwanChat": {"repo_id": "Luigi/SmolLM2-135M-Instruct-TaiwanChat", "description": "SmolLM2‑135M Instruct fine-tuned on TaiwanChat"}, |
|
"SmolLM2-135M-Instruct": {"repo_id": "HuggingFaceTB/SmolLM2-135M-Instruct", "description": "Original SmolLM2‑135M Instruct"}, |
|
"SmolLM2-360M-Instruct-TaiwanChat": {"repo_id": "Luigi/SmolLM2-360M-Instruct-TaiwanChat", "description": "SmolLM2‑360M Instruct fine-tuned on TaiwanChat"}, |
|
"Llama-3.2-Taiwan-3B-Instruct": {"repo_id": "lianghsun/Llama-3.2-Taiwan-3B-Instruct", "description": "Llama-3.2-Taiwan-3B-Instruct"}, |
|
"MiniCPM3-4B": {"repo_id": "openbmb/MiniCPM3-4B", "description": "MiniCPM3-4B"}, |
|
"Qwen2.5-3B-Instruct": {"repo_id": "Qwen/Qwen2.5-3B-Instruct", "description": "Qwen2.5-3B-Instruct"}, |
|
"Qwen2.5-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-7B-Instruct", "description": "Qwen2.5-7B-Instruct"}, |
|
"Phi-4-mini-Instruct": {"repo_id": "unsloth/Phi-4-mini-instruct", "description": "Phi-4-mini-Instruct"}, |
|
"Meta-Llama-3.1-8B-Instruct": {"repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct", "description": "Meta-Llama-3.1-8B-Instruct"}, |
|
"DeepSeek-R1-Distill-Llama-8B": {"repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B", "description": "DeepSeek-R1-Distill-Llama-8B"}, |
|
"Mistral-7B-Instruct-v0.3": {"repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3", "description": "Mistral-7B-Instruct-v0.3"}, |
|
"Qwen2.5-Coder-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct", "description": "Qwen2.5-Coder-7B-Instruct"}, |
|
} |
|
|
|
|
|
PIPELINES = {} |
|
|
|
def load_pipeline(model_name): |
|
""" |
|
Load and cache a transformers pipeline for text generation. |
|
Tries bfloat16, falls back to float16 or float32 if unsupported. |
|
""" |
|
global PIPELINES |
|
if model_name in PIPELINES: |
|
return PIPELINES[model_name] |
|
repo = MODELS[model_name]["repo_id"] |
|
for dtype in (torch.bfloat16, torch.float16, torch.float32): |
|
try: |
|
pipe = pipeline( |
|
task="text-generation", |
|
model=repo, |
|
tokenizer=repo, |
|
trust_remote_code=True, |
|
torch_dtype=dtype, |
|
device_map="auto" |
|
) |
|
PIPELINES[model_name] = pipe |
|
return pipe |
|
except Exception: |
|
continue |
|
|
|
pipe = pipeline( |
|
task="text-generation", |
|
model=repo, |
|
tokenizer=repo, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
PIPELINES[model_name] = pipe |
|
return pipe |
|
|
|
|
|
def retrieve_context(query, max_results=6, max_chars=600): |
|
""" |
|
Retrieve search snippets from DuckDuckGo (runs in background). |
|
Returns a list of result strings. |
|
""" |
|
try: |
|
with DDGS() as ddgs: |
|
return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}" |
|
for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))] |
|
except Exception: |
|
return [] |
|
|
|
|
|
def format_conversation(history, system_prompt): |
|
""" |
|
Flatten chat history and system prompt into a single string. |
|
""" |
|
prompt = system_prompt.strip() + "\n" |
|
for msg in history: |
|
if msg['role'] == 'user': |
|
prompt += "User: " + msg['content'].strip() + "\n" |
|
elif msg['role'] == 'assistant': |
|
prompt += "Assistant: " + msg['content'].strip() + "\n" |
|
else: |
|
prompt += msg['content'].strip() + "\n" |
|
if not prompt.strip().endswith("Assistant:"): |
|
prompt += "Assistant: " |
|
return prompt |
|
|
|
@spaces.GPU(duration=60) |
|
def chat_response(user_msg, chat_history, system_prompt, |
|
enable_search, max_results, max_chars, |
|
model_name, max_tokens, temperature, |
|
top_k, top_p, repeat_penalty): |
|
""" |
|
Generates streaming chat responses, optionally with background web search. |
|
""" |
|
cancel_event.clear() |
|
history = list(chat_history or []) |
|
history.append({'role': 'user', 'content': user_msg}) |
|
|
|
|
|
debug = '' |
|
search_results = [] |
|
if enable_search: |
|
debug = 'Search task started.' |
|
thread_search = threading.Thread( |
|
target=lambda: search_results.extend( |
|
retrieve_context(user_msg, int(max_results), int(max_chars)) |
|
) |
|
) |
|
thread_search.daemon = True |
|
thread_search.start() |
|
else: |
|
debug = 'Web search disabled.' |
|
|
|
|
|
history.append({'role': 'assistant', 'content': ''}) |
|
|
|
try: |
|
|
|
|
|
if search_results: |
|
enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results) |
|
else: |
|
enriched = system_prompt |
|
|
|
|
|
if enable_search: |
|
thread_search.join(timeout=1.0) |
|
if search_results: |
|
debug = "### Search results merged into prompt\n\n" + "\n".join( |
|
f"- {r}" for r in search_results |
|
) |
|
else: |
|
debug = "*No web search results found.*" |
|
|
|
|
|
if search_results: |
|
enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results) |
|
else: |
|
enriched = system_prompt |
|
|
|
prompt = format_conversation(history, enriched) |
|
|
|
pipe = load_pipeline(model_name) |
|
streamer = TextIteratorStreamer(pipe.tokenizer, |
|
skip_prompt=True, |
|
skip_special_tokens=True) |
|
gen_thread = threading.Thread( |
|
target=pipe, |
|
args=(prompt,), |
|
kwargs={ |
|
'max_new_tokens': max_tokens, |
|
'temperature': temperature, |
|
'top_k': top_k, |
|
'top_p': top_p, |
|
'repetition_penalty': repeat_penalty, |
|
'streamer': streamer, |
|
'return_full_text': False |
|
} |
|
) |
|
gen_thread.start() |
|
|
|
assistant_text = '' |
|
for chunk in streamer: |
|
if cancel_event.is_set(): |
|
break |
|
assistant_text += chunk |
|
history[-1]['content'] = assistant_text |
|
|
|
yield history, debug |
|
gen_thread.join() |
|
except Exception as e: |
|
history[-1]['content'] = f"Error: {e}" |
|
yield history, debug |
|
finally: |
|
gc.collect() |
|
|
|
|
|
def cancel_generation(): |
|
cancel_event.set() |
|
return 'Generation cancelled.' |
|
|
|
|
|
def update_default_prompt(enable_search): |
|
today = datetime.now().strftime('%Y-%m-%d') |
|
return f"You are a helpful assistant. Today is {today}." |
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo: |
|
gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search") |
|
gr.Markdown("Interact with the model. Select parameters and chat below.") |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
model_dd = gr.Dropdown(label="Select Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0]) |
|
search_chk = gr.Checkbox(label="Enable Web Search", value=True) |
|
sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value)) |
|
gr.Markdown("### Generation Parameters") |
|
max_tok = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens") |
|
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") |
|
k = gr.Slider(1, 100, value=40, step=1, label="Top-K") |
|
p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") |
|
rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") |
|
gr.Markdown("### Web Search Settings") |
|
mr = gr.Number(value=6, precision=0, label="Max Results") |
|
mc = gr.Number(value=600, precision=0, label="Max Chars/Result") |
|
clr = gr.Button("Clear Chat") |
|
cnl = gr.Button("Cancel Generation") |
|
with gr.Column(scale=7): |
|
chat = gr.Chatbot(type="messages") |
|
txt = gr.Textbox(placeholder="Type your message and press Enter...") |
|
dbg = gr.Markdown() |
|
|
|
search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt) |
|
clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg]) |
|
cnl.click(fn=cancel_generation, outputs=dbg) |
|
txt.submit(fn=chat_response, |
|
inputs=[txt, chat, sys_prompt, search_chk, mr, mc, |
|
model_dd, max_tok, temp, k, p, rp], |
|
outputs=[chat, dbg]) |
|
demo.launch() |
|
|