Man-isH-07 commited on
Commit
8c22e42
·
1 Parent(s): 83557e0

Again At Normal

Browse files
Files changed (2) hide show
  1. app.py +54 -142
  2. style.css +0 -19
app.py CHANGED
@@ -24,10 +24,7 @@ if torch.cuda.is_available():
24
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
25
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
26
  tokenizer = AutoTokenizer.from_pretrained(model_id)
27
- # Set the pad token to avoid warnings
28
- if tokenizer.pad_token is None:
29
- tokenizer.pad_token = tokenizer.eos_token
30
- model.config.pad_token_id = tokenizer.pad_token_id
31
 
32
  @spaces.GPU
33
  def generate(
@@ -41,31 +38,15 @@ def generate(
41
  ) -> Iterator[str]:
42
  conversation = [*chat_history, {"role": "user", "content": message}]
43
 
44
- # Apply chat template
45
- inputs = tokenizer.apply_chat_template(conversation, return_tensors="pt", padding=True, return_attention_mask=True)
46
-
47
- # Check if inputs is a dictionary or a tensor
48
- if isinstance(inputs, dict):
49
- input_ids = inputs["input_ids"]
50
- attention_mask = inputs.get("attention_mask", None)
51
- else:
52
- input_ids = inputs
53
- attention_mask = (input_ids != tokenizer.pad_token_id).long() if tokenizer.pad_token_id is not None else None
54
-
55
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
56
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
57
- if attention_mask is not None:
58
- attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
59
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
60
-
61
  input_ids = input_ids.to(model.device)
62
- if attention_mask is not None:
63
- attention_mask = attention_mask.to(model.device)
64
 
65
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
66
  generate_kwargs = dict(
67
- input_ids=input_ids,
68
- attention_mask=attention_mask,
69
  streamer=streamer,
70
  max_new_tokens=max_new_tokens,
71
  do_sample=True,
@@ -78,132 +59,63 @@ def generate(
78
  t = Thread(target=model.generate, kwargs=generate_kwargs)
79
  t.start()
80
 
81
- # First, yield the user's message (which contains the prompts)
82
- yield message
83
-
84
- # Then, yield the model's response
85
  outputs = []
86
  for text in streamer:
87
  outputs.append(text)
88
  yield "".join(outputs)
89
 
90
- # Yield the final model output
91
- final_output = "".join(outputs)
92
- yield final_output
93
-
94
- # Updated JavaScript with debugging and robustness
95
- custom_js = """
96
- function splitPrompts() {
97
- console.log("Running splitPrompts function"); // Debug log
98
- const messages = document.querySelectorAll('.chatbot-message, .message, [class*="message"]');
99
- console.log("Found messages:", messages.length); // Debug log
100
- messages.forEach((message, index) => {
101
- const text = message.innerHTML;
102
- console.log("Message", index, "text:", text); // Debug log
103
-
104
- if (text.includes('Positive Prompt:') && text.includes('Negative Prompt:')) {
105
- console.log("Found Positive and Negative prompts in message", index); // Debug log
106
- const positiveMatch = text.match(/Positive Prompt:(.*?)(?=(Negative Prompt:|$))/s);
107
- const negativeMatch = text.match(/Negative Prompt:(.*)/s);
108
 
109
- if (positiveMatch && negativeMatch) {
110
- const positivePrompt = positiveMatch[1].trim();
111
- const negativePrompt = negativeMatch[1].trim();
112
- console.log("Positive Prompt:", positivePrompt); // Debug log
113
- console.log("Negative Prompt:", negativePrompt); // Debug log
114
-
115
- message.innerHTML = `
116
- <div class="positive-prompt"><strong>Positive Prompt:</strong><br>${positivePrompt}</div>
117
- <div class="negative-prompt"><strong>Negative Prompt:</strong><br>${negativePrompt}</div>
118
- `;
119
- } else {
120
- console.log("Failed to match prompts in message", index); // Debug log
121
- }
122
- }
123
- });
124
- }
125
-
126
- // Run the function when the DOM is fully loaded
127
- document.addEventListener('DOMContentLoaded', () => {
128
- console.log("DOM fully loaded, setting up MutationObserver"); // Debug log
129
- const observer = new MutationObserver((mutations) => {
130
- console.log("MutationObserver triggered", mutations); // Debug log
131
- splitPrompts();
132
- });
133
-
134
- const chatArea = document.querySelector('.gr-chatbot, [class*="chatbot"], [class*="chat"]');
135
- console.log("Chat area found:", chatArea); // Debug log
136
- if (chatArea) {
137
- observer.observe(chatArea, { childList: true, subtree: true });
138
- } else {
139
- console.log("Chat area not found, retrying in 1 second"); // Debug log
140
- setTimeout(() => {
141
- const retryChatArea = document.querySelector('.gr-chatbot, [class*="chatbot"], [class*="chat"]');
142
- if (retryChatArea) {
143
- observer.observe(retryChatArea, { childList: true, subtree: true });
144
- } else {
145
- console.log("Chat area still not found after retry"); // Debug log
146
- }
147
- }, 1000);
148
- }
149
-
150
- // Run initially
151
- splitPrompts();
152
- });
153
- """
154
-
155
- # Use gr.Blocks to allow custom JavaScript injection
156
- with gr.Blocks(css="style.css", js=custom_js) as demo:
157
- gr.Markdown(DESCRIPTION)
158
- chat_interface = gr.ChatInterface(
159
- fn=generate,
160
- additional_inputs=[
161
- gr.Slider(
162
- label="Max new tokens",
163
- minimum=1,
164
- maximum=MAX_MAX_NEW_TOKENS,
165
- step=1,
166
- value=DEFAULT_MAX_NEW_TOKENS,
167
- ),
168
- gr.Slider(
169
- label="Temperature",
170
- minimum=0.1,
171
- maximum=4.0,
172
- step=0.1,
173
- value=0.6,
174
- ),
175
- gr.Slider(
176
- label="Top-p (nucleus sampling)",
177
- minimum=0.05,
178
- maximum=1.0,
179
- step=0.05,
180
- value=0.9,
181
- ),
182
- gr.Slider(
183
- label="Top-k",
184
- minimum=1,
185
- maximum=1000,
186
- step=1,
187
- value=50,
188
- ),
189
- gr.Slider(
190
- label="Repetition penalty",
191
- minimum=1.0,
192
- maximum=2.0,
193
- step=0.05,
194
- value=1.2,
195
- ),
196
- ],
197
- stop_btn=None,
198
- examples=[
199
- ["Hello there! How are you doing?"],
200
- ["Can you explain briefly to me what is the Python programming language?"],
201
- ["Explain the plot of Cinderella in a sentence."],
202
- ["How many hours does it take a man to eat a Helicopter?"],
203
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
204
- ],
205
- type="messages",
206
- )
207
 
208
  if __name__ == "__main__":
209
  demo.queue(max_size=20).launch()
 
24
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
25
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
26
  tokenizer = AutoTokenizer.from_pretrained(model_id)
27
+
 
 
 
28
 
29
  @spaces.GPU
30
  def generate(
 
38
  ) -> Iterator[str]:
39
  conversation = [*chat_history, {"role": "user", "content": message}]
40
 
41
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
42
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
43
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
 
44
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
45
  input_ids = input_ids.to(model.device)
 
 
46
 
47
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
48
  generate_kwargs = dict(
49
+ {"input_ids": input_ids},
 
50
  streamer=streamer,
51
  max_new_tokens=max_new_tokens,
52
  do_sample=True,
 
59
  t = Thread(target=model.generate, kwargs=generate_kwargs)
60
  t.start()
61
 
 
 
 
 
62
  outputs = []
63
  for text in streamer:
64
  outputs.append(text)
65
  yield "".join(outputs)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ demo = gr.ChatInterface(
69
+ fn=generate,
70
+ additional_inputs=[
71
+ gr.Slider(
72
+ label="Max new tokens",
73
+ minimum=1,
74
+ maximum=MAX_MAX_NEW_TOKENS,
75
+ step=1,
76
+ value=DEFAULT_MAX_NEW_TOKENS,
77
+ ),
78
+ gr.Slider(
79
+ label="Temperature",
80
+ minimum=0.1,
81
+ maximum=4.0,
82
+ step=0.1,
83
+ value=0.6,
84
+ ),
85
+ gr.Slider(
86
+ label="Top-p (nucleus sampling)",
87
+ minimum=0.05,
88
+ maximum=1.0,
89
+ step=0.05,
90
+ value=0.9,
91
+ ),
92
+ gr.Slider(
93
+ label="Top-k",
94
+ minimum=1,
95
+ maximum=1000,
96
+ step=1,
97
+ value=50,
98
+ ),
99
+ gr.Slider(
100
+ label="Repetition penalty",
101
+ minimum=1.0,
102
+ maximum=2.0,
103
+ step=0.05,
104
+ value=1.2,
105
+ ),
106
+ ],
107
+ stop_btn=None,
108
+ examples=[
109
+ ["Hello there! How are you doing?"],
110
+ ["Can you explain briefly to me what is the Python programming language?"],
111
+ ["Explain the plot of Cinderella in a sentence."],
112
+ ["How many hours does it take a man to eat a Helicopter?"],
113
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
114
+ ],
115
+ type="messages",
116
+ description=DESCRIPTION,
117
+ css_paths="style.css",
118
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  if __name__ == "__main__":
121
  demo.queue(max_size=20).launch()
style.css CHANGED
@@ -9,22 +9,3 @@ h1 {
9
  background: #1565c0;
10
  border-radius: 100vh;
11
  }
12
-
13
- /* Style for the positive prompt box */
14
- .positive-prompt {
15
- background-color: #2a2a2a; /* Dark background to match the theme */
16
- border: 1px solid #444; /* Subtle border */
17
- border-radius: 8px;
18
- padding: 15px;
19
- margin-bottom: 10px; /* Space between the two boxes */
20
- color: #ffffff; /* White text for readability */
21
- }
22
-
23
- /* Style for the negative prompt box */
24
- .negative-prompt {
25
- background-color: #2a2a2a;
26
- border: 1px solid #444;
27
- border-radius: 8px;
28
- padding: 15px;
29
- color: #ffffff;
30
- }
 
9
  background: #1565c0;
10
  border-radius: 100vh;
11
  }