jonathantiedchen commited on
Commit
562f5ef
Β·
verified Β·
1 Parent(s): dc78b8b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import streamlit as st
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from unsloth import FastLanguageModel, is_bfloat16_supported
7
+ import importlib
8
+ import random
9
+ from datasets import load_dataset
10
+
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ st.title("🧠 Math LLM Demo")
15
+ st.text(f"Using device: {device}")
16
+
17
+ # === MODEL SELECTION ===
18
+ MODEL_OPTIONS = {
19
+ "Vanilla GPT-2": "openai-community/gpt2",
20
+ "GPT2-Small-CPT-CL-IFT": "jonathantiedchen/GPT2-Small-CPT-CL-IFT",
21
+ "Mistral 7B+CPT+CL+IFT": "jonathantiedchen/MistralMath-CPT-IFT"
22
+ }
23
+
24
+ @st.cache_resource
25
+ def load_models():
26
+ models = {}
27
+ for name, path in MODEL_OPTIONS.items():
28
+ if "mistral" in name.lower():
29
+ try:
30
+ model, tokenizer = FastLanguageModel.from_pretrained(
31
+ model_name=path,
32
+ max_seq_length=2048,
33
+ dtype=torch.bfloat16 if is_bfloat16_supported() else torch.float16,
34
+ load_in_4bit=True
35
+ )
36
+
37
+ if tokenizer.pad_token is None:
38
+ tokenizer.pad_token = tokenizer.eos_token
39
+ FastLanguageModel.for_inference(model)
40
+
41
+ except Exception as e:
42
+ st.error(f"⚠️ Failed to load Mistral model with Unsloth: {e}")
43
+ continue
44
+ else:
45
+ tokenizer = AutoTokenizer.from_pretrained(path)
46
+ model = AutoModelForCausalLM.from_pretrained(path).to(device)
47
+ model.eval()
48
+
49
+ models[name] = {"tokenizer": tokenizer, "model": model}
50
+ return models
51
+
52
+
53
+ models = load_models()
54
+
55
+ model_choice = st.selectbox("Choose a model:", list(MODEL_OPTIONS.keys()))
56
+ tokenizer = models[model_choice]["tokenizer"]
57
+ model = models[model_choice]["model"]
58
+
59
+ # === LOAD DATA ===
60
+ @st.cache_resource
61
+ def load_gsm8k_dataset():
62
+ return load_dataset("openai/gsm8k", "main")["test"]
63
+
64
+ gsm8k_data = load_gsm8k_dataset()
65
+ st.write("πŸ“Š GSM8K loaded:", len(gsm8k_data), "samples")
66
+
67
+ # === TABS ===
68
+ tab1, tab2 = st.tabs(["πŸ”“ Manual Prompting", "πŸ“Š GSM8K Evaluation"])
69
+
70
+ # === MANUAL GENERATION TAB ===
71
+ with tab1:
72
+ prompt = st.text_area("Enter your math prompt:", "Jasper has 5 apples and eats 2 of them. How many apples does he have left?")
73
+ if st.button("Generate Response", key="manual"):
74
+ with st.spinner("Generating..."):
75
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
76
+ output = model.generate(
77
+ **inputs,
78
+ max_new_tokens=100,
79
+ temperature=0.7,
80
+ do_sample=True,
81
+ pad_token_id=tokenizer.eos_token_id,
82
+ eos_token_id=tokenizer.eos_token_id,
83
+ )
84
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
85
+
86
+ response_only = generated_text[len(prompt):].strip()
87
+
88
+ st.subheader("πŸ”Ž Prompt")
89
+ st.code(prompt)
90
+ st.subheader("🧠 Model Output")
91
+ st.code(generated_text)
92
+ st.subheader("βœ‚οΈ Response Only")
93
+ st.success(response_only)
94
+
95
+ # === GSM8K TAB ===
96
+ with tab2:
97
+ st.markdown("A random question from GSM8K will be shown. Click below to test the model.")
98
+
99
+ if st.button("Run GSM8K Sample"):
100
+ try:
101
+ sample = random.choice(gsm8k_data)
102
+ question = sample["question"]
103
+ gold_answer = sample["answer"]
104
+
105
+ inputs = tokenizer(question, return_tensors="pt").to(model.device)
106
+
107
+ st.markdown(f"Create Output")
108
+ output = model.generate(
109
+ **inputs,
110
+ max_new_tokens=150,
111
+ temperature=0.7,
112
+ do_sample=True,
113
+ pad_token_id=tokenizer.eos_token_id,
114
+ eos_token_id=tokenizer.eos_token_id,
115
+ )
116
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
117
+ response_only = generated_text[len(question):].strip()
118
+
119
+ st.subheader("πŸ“Œ GSM8K Question")
120
+ st.markdown(question)
121
+
122
+ st.subheader("πŸ” Model Output")
123
+ st.markdown(generated_text)
124
+
125
+ st.subheader("βœ‚οΈ Response Only")
126
+ st.success(response_only)
127
+
128
+ st.subheader("βœ… Gold Answer")
129
+ st.info(gold_answer)
130
+
131
+ except Exception as e:
132
+ st.error(f"Error: {e}")