ceadibc commited on
Commit
d90bfe0
·
verified ·
1 Parent(s): 3ca5ab1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +291 -61
app.py CHANGED
@@ -1,64 +1,294 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
1
+ import os
2
+ from dotenv import find_dotenv, load_dotenv
3
+ import streamlit as st
4
+ from groq import Groq
5
+ import base64
6
+
7
+ # Load environment variables
8
+ load_dotenv(find_dotenv())
9
+
10
+ # Function to encode the image to a base64 string
11
+ def encode_image(uploaded_file):
12
+ """
13
+ Encodes an uploaded image file into a base64 string.
14
+ Args:
15
+ uploaded_file: The file-like object uploaded via Streamlit.
16
+ Returns:
17
+ str: The base64 encoded string of the image.
18
+ """
19
+ return base64.b64encode(uploaded_file.read()).decode('utf-8')
20
+
21
+ # Initialize the Groq client using the API key from the environment variables
22
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
23
+
24
+ # Set up Streamlit page configuration
25
+ st.set_page_config(
26
+ page_icon="📃",
27
+ layout="wide",
28
+ page_title="Groq & LLaMA3x Chat Bot"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
 
31
+ # App Title
32
+ st.title("Groq Chat with LLaMA3x")
33
+
34
+ # Cache the model fetching function to improve performance
35
+ @st.cache_data
36
+ def fetch_available_models():
37
+ """
38
+ Fetches the available models from the Groq API.
39
+ Returns a list of models or an empty list if there's an error.
40
+ """
41
+ try:
42
+ models_response = client.models.list()
43
+ return models_response.data
44
+ except Exception as e:
45
+ st.error(f"Error fetching models: {e}")
46
+ return []
47
+
48
+ # Load available models and filter them
49
+ available_models = fetch_available_models()
50
+ filtered_models = [
51
+ model for model in available_models if 'llama' in model.id
52
+ ]
53
+
54
+ # Prepare a dictionary of model metadata
55
+ models = {
56
+ model.id: {
57
+ "name": model.id,
58
+ "tokens": 4000,
59
+ "developer": model.owned_by,
60
+ }
61
+ for model in filtered_models
62
+ }
63
+
64
+ # Initialize session state variables
65
+ if "messages" not in st.session_state:
66
+ st.session_state.messages = []
67
+
68
+ if "selected_model" not in st.session_state:
69
+ st.session_state.selected_model = None
70
+
71
+ # Sidebar: Controls
72
+ with st.sidebar:
73
+
74
+ # Powered by Groq logo
75
+ st.markdown(
76
+ """
77
+ <a href="https://groq.com" target="_blank" rel="noopener noreferrer">
78
+ <img
79
+ src="https://groq.com/wp-content/uploads/2024/03/PBG-mark1-color.svg"
80
+ alt="Powered by Groq for fast inference."
81
+ width="100%"
82
+ />
83
+ </a>
84
+ """,
85
+ unsafe_allow_html=True
86
+ )
87
+ st.markdown("---")
88
+
89
+ # Define a function to clear messages when the model changes
90
+ def reset_chat_on_model_change():
91
+ st.session_state.messages = []
92
+ st.session_state.image_used = False
93
+ uploaded_file = None
94
+ base64_image = None
95
+
96
+ # Model selection dropdown
97
+ if models:
98
+ model_option = st.selectbox(
99
+ "Choose a model:",
100
+ options=list(models.keys()),
101
+ format_func=lambda x: f"{models[x]['name']} ({models[x]['developer']})",
102
+ on_change=reset_chat_on_model_change, # Reset chat when model changes
103
+ )
104
+ else:
105
+ st.warning("No available models to select.")
106
+ model_option = None
107
+
108
+ # Token limit slider
109
+ if models:
110
+ max_tokens_range = models[model_option]["tokens"]
111
+ max_tokens = st.slider(
112
+ "Max Tokens:",
113
+ min_value=200,
114
+ max_value=max_tokens_range,
115
+ value=max(100, int(max_tokens_range * 0.5)),
116
+ step=256,
117
+ help=f"Adjust the maximum number of tokens for the response. Maximum for the selected model: {max_tokens_range}"
118
+ )
119
+ else:
120
+ max_tokens = 200
121
+
122
+ # Additional options
123
+ stream_mode = st.checkbox("Enable Streaming", value=True)
124
+
125
+ # Button to clear the chat
126
+ if st.button("Clear Chat"):
127
+ st.session_state.messages = []
128
+ st.session_state.image_used = False
129
+
130
+ # Initialize session state for tracking uploaded image usage
131
+ if "image_used" not in st.session_state:
132
+ st.session_state.image_used = False # Flag to track image usage
133
+
134
+ # Check if the selected model supports vision
135
+ base64_image = None
136
+ uploaded_file = None
137
+ if model_option and "vision" in model_option.lower():
138
+ st.markdown(
139
+ "### Upload an Image"
140
+ "\n\n*One per conversation*"
141
+ )
142
+
143
+ # File uploader for images (only if image hasn't been used yet)
144
+ if not st.session_state.image_used:
145
+ uploaded_file = st.file_uploader(
146
+ "Upload an image for the model to process:",
147
+ type=["png", "jpg", "jpeg"],
148
+ help="Upload an image if the model supports vision tasks.",
149
+ accept_multiple_files=False
150
+ )
151
+ if uploaded_file:
152
+ base64_image = encode_image(uploaded_file)
153
+ st.image(uploaded_file, caption="Uploaded Image")
154
+ else:
155
+ base64_image = None
156
+
157
+
158
+ st.markdown("### Usage Summary")
159
+ usage_box = st.empty()
160
+
161
+ # Disclaimer
162
+ st.markdown(
163
+ """
164
+ -----
165
+ ⚠️ **Important:**
166
+ *The responses provided by this application are generated automatically using an AI model.
167
+ Users are responsible for verifying the accuracy of the information before relying on it.
168
+ Always cross-check facts and data for critical decisions.*
169
+ """
170
+ )
171
+
172
+ # Main Chat Interface
173
+ st.markdown("### Chat Interface")
174
+
175
+ # Display the chat history
176
+ for message in st.session_state.messages:
177
+ avatar = "🔋" if message["role"] == "assistant" else "🧑‍💻"
178
+ with st.chat_message(message["role"], avatar=avatar):
179
+ # Check if the content is a list (text and image combined)
180
+ if isinstance(message["content"], list):
181
+ for item in message["content"]:
182
+ if item["type"] == "text":
183
+ st.markdown(item["text"])
184
+ elif item["type"] == "image_url":
185
+ # Handle base64-encoded image URLs
186
+ if item["image_url"]["url"].startswith("data:image"):
187
+ st.image(item["image_url"]["url"], caption="Uploaded Image")
188
+ st.session_state.image_used = True
189
+ else:
190
+ st.warning("Invalid image format or unsupported URL.")
191
+ else:
192
+ # For plain text content
193
+ st.markdown(message["content"])
194
+
195
+
196
+ # Capture user input
197
+ if user_input:=st.chat_input("Enter your message here..."):
198
+ # Append the user input to the session state
199
+ # including the image if uploaded
200
+ if base64_image and not st.session_state.image_used:
201
+ # Append the user message with the image to session state
202
+ st.session_state.messages.append(
203
+ {
204
+ "role": "user",
205
+ "content": [
206
+ {"type": "text", "text": user_input},
207
+ {
208
+ "type": "image_url",
209
+ "image_url": {
210
+ "url": f"data:image/jpeg;base64,{base64_image}",
211
+ },
212
+ },
213
+ ],
214
+ }
215
+ )
216
+ st.session_state.image_used = True
217
+ else:
218
+ st.session_state.messages.append({"role": "user", "content": user_input})
219
+
220
+ # Display the uploaded image and user query in the chat
221
+ with st.chat_message("user", avatar="🧑‍💻"):
222
+ # Display the user input
223
+ st.markdown(user_input)
224
+
225
+ # Display the uploaded image only if it's included in the current message
226
+ if base64_image and st.session_state.image_used:
227
+ st.image(uploaded_file, caption="Uploaded Image")
228
+ base64_image = None
229
+
230
+ # Generate a response using the selected model
231
+ try:
232
+ full_response = ""
233
+ usage_summary = ""
234
+
235
+ if stream_mode:
236
+ # Generate a response with streaming enabled
237
+ chat_completion = client.chat.completions.create(
238
+ model=model_option,
239
+ messages=[
240
+ {"role": m["role"], "content": m["content"]}
241
+ for m in st.session_state.messages
242
+ ],
243
+ max_tokens=max_tokens,
244
+ stream=True
245
+ )
246
+
247
+ with st.chat_message("assistant", avatar="🔋"):
248
+ response_placeholder = st.empty()
249
+
250
+ for chunk in chat_completion:
251
+ if chunk.choices[0].delta.content:
252
+ full_response += chunk.choices[0].delta.content
253
+ response_placeholder.markdown(full_response)
254
+ else:
255
+ # Generate a response without streaming
256
+ chat_completion = client.chat.completions.create(
257
+ model=model_option,
258
+ messages=[
259
+ {"role": m["role"], "content": m["content"]}
260
+ for m in st.session_state.messages
261
+ ],
262
+ max_tokens=max_tokens,
263
+ stream=False
264
+ )
265
+
266
+ response = chat_completion.choices[0].message.content
267
+ usage_data = chat_completion.usage
268
+
269
+ with st.chat_message("assistant", avatar="🔋"):
270
+ st.markdown(response)
271
+ full_response = response
272
+
273
+ if usage_data:
274
+ usage_summary = (
275
+ f"**Token Usage:**\n"
276
+ f"- Prompt Tokens: {usage_data.prompt_tokens}\n"
277
+ f"- Response Tokens: {usage_data.completion_tokens}\n"
278
+ f"- Total Tokens: {usage_data.total_tokens}\n\n"
279
+ f"**Timings:**\n"
280
+ f"- Prompt Time: {round(usage_data.prompt_time,5)} secs\n"
281
+ f"- Response Time: {round(usage_data.completion_time,5)} secs\n"
282
+ f"- Total Time: {round(usage_data.total_time,5)} secs"
283
+ )
284
+
285
+ if usage_summary:
286
+ usage_box.markdown(usage_summary)
287
+
288
+ # Append the assistant's response to the session state
289
+ st.session_state.messages.append(
290
+ {"role": "assistant", "content": full_response}
291
+ )
292
 
293
+ except Exception as e:
294
+ st.error(f"Error generating the response: {e}")