import os import pathlib import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM from unsloth import FastLanguageModel, is_bfloat16_supported import importlib import random from datasets import load_dataset device = torch.device("cuda" if torch.cuda.is_available() else "cpu") st.title("🧠 Math LLM Demo") st.text(f"Using device: {device}") # === MODEL SELECTION === MODEL_OPTIONS = { "Vanilla GPT-2": "openai-community/gpt2", "GPT2-Small-CPT-CL-IFT": "jonathantiedchen/GPT2-Small-CPT-CL-IFT", "Mistral 7B+CPT+CL+IFT": "jonathantiedchen/MistralMath-CPT-IFT" } @st.cache_resource def load_models(): models = {} for name, path in MODEL_OPTIONS.items(): if "mistral" in name.lower(): try: model, tokenizer = FastLanguageModel.from_pretrained( model_name=path, max_seq_length=2048, dtype=torch.bfloat16 if is_bfloat16_supported() else torch.float16, load_in_4bit=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token FastLanguageModel.for_inference(model) except Exception as e: st.sidebar.error(f"⚠️ Failed to load Mistral model with Unsloth: {e}") continue else: tokenizer = AutoTokenizer.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained(path).to(device) model.eval() models[name] = {"tokenizer": tokenizer, "model": model} return models st.sidebar.write("📥 Load Models.") models = load_models() st.sidebar.write(f"✅ Successfully loaded models:{models}") model_choice = st.selectbox("Choose a model:", list(MODEL_OPTIONS.keys())) tokenizer = models[model_choice]["tokenizer"] model = models[model_choice]["model"] # === LOAD DATA === @st.cache_resource def load_gsm8k_dataset(): return load_dataset("openai/gsm8k", "main")["test"] st.sidebar.write("📥 Load GSM8K") gsm8k_data = load_gsm8k_dataset() st.sidebar.write("📊 GSM8K loaded:", len(gsm8k_data), "samples") # === TABS === tab1, tab2 = st.tabs(["🔓 Manual Prompting", "📊 GSM8K Evaluation"]) # === MANUAL GENERATION TAB === with tab1: prompt = st.text_area("Enter your math prompt:", "Jasper has 5 apples and eats 2 of them. How many apples does he have left?") if st.button("Generate Response", key="manual"): with st.sidebar.spinner("🔄 Generating..."): inputs = tokenizer(prompt, return_tensors="pt").to(model.device) output = model.generate( **inputs, max_new_tokens=100, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) response_only = generated_text[len(prompt):].strip() st.subheader("🔎 Prompt") st.code(prompt) st.subheader("🧠 Model Output") st.code(generated_text) st.subheader("✂️ Response Only") st.success(response_only) # === GSM8K TAB === with tab2: st.markdown("A random question from GSM8K will be shown. Click below to test the model.") if st.button("Run GSM8K Sample"): try: with st.sidebar.spinner("🔄 Generating..."): sample = random.choice(gsm8k_data) question = sample["question"] gold_answer = sample["answer"] inputs = tokenizer(question, return_tensors="pt").to(model.device) st.markdown(f"Create Output") output = model.generate( **inputs, max_new_tokens=150, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) response_only = generated_text[len(question):].strip() st.subheader("📌 GSM8K Question") st.markdown(question) st.subheader("🔍 Model Output") st.markdown(generated_text) st.subheader("✂️ Response Only") st.success(response_only) st.subheader("✅ Gold Answer") st.info(gold_answer) except Exception as e: st.error(f"Error: {e}")