Spaces:
Running
Running
File size: 5,672 Bytes
1beac26 39fd400 1beac26 39fd400 394a0e3 33b4451 52931fc 39fd400 f629b14 8fd7246 f629b14 8fd7246 eead688 8fd7246 f629b14 52931fc 39fd400 f629b14 39fd400 52931fc 39c3467 52931fc 39fd400 52931fc 39c3467 52931fc 39c3467 f629b14 39fd400 f629b14 39c3467 9d3ef25 f629b14 39c3467 f629b14 39c3467 f629b14 52931fc 39fd400 9d3ef25 550cb22 9d3ef25 1beac26 9d3ef25 39fd400 432c4f6 39fd400 de27701 432c4f6 9d3ef25 39fd400 4cc9747 39fd400 acbb905 d525201 39fd400 9d3ef25 52931fc 39fd400 33b4451 39fd400 52931fc 432c4f6 e950877 f629b14 39fd400 33b4451 394a0e3 39fd400 f629b14 394a0e3 e950877 39fd400 52931fc 4cc9747 52931fc 39fd400 432c4f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
import random
import streamlit as st
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList
from unsloth import FastLanguageModel, is_bfloat16_supported
from utils import SpecificStringStoppingCriteria
from cot import EIGHT_SHOT_PROMPT, FOUR_SHOT_PROMPT
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generation_util = [
"Q:",
"</s>",
"<|im_end|>"
]
# GPT-2 and Mistral model registry
gpt_models = {
"GPT-2 Small BL": "openai-community/gpt2",
"GPT-2 Small CPT+CL+IFT": "jonathantiedchen/GPT2-Small-CPT-CL-IFT"
}
mistral_models = {
"Mistral 7B BL": "unsloth/mistral-7b-bnb-4bit",
"Mistral 7B CPT+CL": "jonathantiedchen/Mistral-7B-CPT-CL",
"Mistral 7B CPT+IFT": "jonathantiedchen/MistralMath-CPT-IFT"
}
all_models = gpt_models | mistral_models
### Load GSM8K once
@st.cache_resource
def load_gsm8k_dataset():
return load_dataset("openai/gsm8k", "main")["test"]
### Load Mistral
@st.cache_resource
def load_mistral(mistral_path, _models):
try:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=mistral_path,
max_seq_length=2048,
dtype=torch.bfloat16 if is_bfloat16_supported() else torch.float16,
load_in_4bit=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
FastLanguageModel.for_inference(model)
_models[mistral_path] = {"tokenizer": tokenizer, "model": model}
except Exception as e:
st.sidebar.error(f"β οΈ Failed to load Mistral model with Unsloth: {e}")
return _models
### Load GPT-2
@st.cache_resource
def load_gpts(path, _models):
try:
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path).to(device)
model.eval()
_models[path] = {"tokenizer": tokenizer, "model": model}
except Exception as e:
st.sidebar.error(f"β οΈ Failed to load GPT model: {e}")
return _models
# Load models
st.title("π§ Math LLM Demo")
models = {}
with st.sidebar:
with st.spinner("π₯ Load all Models. That might take a while."):
for model_path in mistral_models.values():
models = load_mistral(model_path, models)
for model_path in gpt_models.values():
models = load_gpts(model_path, models)
st.write("β
Successfully loaded all models.")
# Load GSM8K dataset and allow selection
st.sidebar.write("π₯ Load GSM8K")
gsm8k_data = load_gsm8k_dataset()
st.sidebar.write("π GSM8K loaded:", len(gsm8k_data), "samples")
# Check for random question index in query params
random_index = st.query_params.get("question_index")
if random_index is not None:
try:
default_index = int(random_index)
except (ValueError, TypeError):
default_index = 0
else:
default_index = 0
question_index = st.selectbox("π’ Select GSM8K question index", range(len(gsm8k_data)), index=default_index)
if st.button("π² Pick Random Question"):
new_random_index = random.randint(0, len(gsm8k_data) - 1)
st.query_params.update(question_index=new_random_index)
st.rerun() # Force app to rerun to update the selectbox
default_prompt = "Jasper has 5 apples and eats 2 of them. How many apples does he have left?"
selected_question = gsm8k_data[question_index]["question"] if question_index is not None else default_prompt
correct_answer = gsm8k_data[question_index]["answer"]
# Prompt options
st.write('##')
use_cot = st.toggle("Use Few-Shot Prompt")
model_choice = st.selectbox("Choose a model:", list(all_models.keys()))
model_path = all_models[model_choice]
tokenizer = models[model_path]["tokenizer"]
model = models[model_path]["model"]
# Prompt input
prompt = st.text_area("Enter your math prompt:", selected_question)
# Generation
if st.button("Generate Response", key="manual"):
# Check if the current prompt is from GSM8K dataset
is_gsm8k_question = prompt == selected_question
with st.sidebar:
with st.spinner("π Generating..."):
if use_cot:
if 'mistral' in model_choice.lower():
prompt_template = EIGHT_SHOT_PROMPT
else:
prompt_template = FOUR_SHOT_PROMPT
input_text = prompt_template.format(question=prompt)
else:
input_text = prompt
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
stop_criteria = SpecificStringStoppingCriteria(tokenizer, generation_util, len(input_text))
stopping_criteria_list = StoppingCriteriaList([stop_criteria])
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=512,
temperature=1,
pad_token_id=tokenizer.eos_token_id,
stopping_criteria=stopping_criteria_list
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
response_only = generated_text[len(input_text):].strip() if generated_text.startswith(input_text) else generated_text.strip()
with st.expander("π Prompt"):
st.subheader("π Prompt")
st.write(input_text)
st.subheader("π§ Model Output")
st.success(response_only)
# Only show correct answer if using actual GSM8K question
if is_gsm8k_question:
st.subheader("β
Correct Answer (GSM8K)")
st.info(correct_answer) |