praeclarumjj3 commited on
Commit
f1653dd
·
1 Parent(s): 20b4d0d

:zap: Fix version

Browse files
Files changed (2) hide show
  1. app.py +5 -6
  2. demo.py +0 -486
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- import spaces
5
  from ola_vlm.constants import DEFAULT_IMAGE_TOKEN
6
 
7
  from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
@@ -9,8 +8,7 @@ from ola_vlm.conversation import conv_templates, SeparatorStyle
9
  from ola_vlm.model.builder import load_pretrained_model
10
  from ola_vlm.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
11
 
12
- from diffusers import StableUnCLIPImg2ImgPipeline
13
- from diffusers import DPMSolverMultistepScheduler
14
  from transformers import OneFormerProcessor
15
  from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead
16
  from ola_vlm.ola_utils import visualize_oneformer_masks_on_image, oneformer_prepare_panoptic_instance_prediction
@@ -150,10 +148,9 @@ our_chatbot = None
150
 
151
  pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16")
152
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
153
- pipe = pipe.to("cuda")
154
 
155
  oneformer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")
156
- oneformer = OneFormerHead.from_pretrained("shi-labs/oneformer_coco_swin_large").to("cuda")
157
 
158
  gen_layer_indices = model.config.image_gen["img_layer_indices"].split("-")
159
  seg_layer_indices = model.config.image_seg["seg_layer_indices"].split("-")
@@ -181,6 +178,7 @@ def add_text(state, imagebox, textbox, image_process_mode):
181
  yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
182
 
183
  def get_gen_images(out):
 
184
  img_embeds = out.image_embs
185
  if len(img_embeds) == 0:
186
  return None
@@ -213,6 +211,7 @@ def get_depth_images(out, org_size):
213
  return grid_image
214
 
215
  def get_seg_images(out, image):
 
216
  seg_embs = out.seg_embs
217
 
218
  if len(seg_embs) == 0:
@@ -252,7 +251,7 @@ def regenerate(state, image_process_mode):
252
  # @spaces.GPU
253
  # def get_interm_outs(state):
254
 
255
-
256
  @spaces.GPU
257
  def generate(state, temperature, top_p, max_output_tokens, is_inter=False):
258
  if is_inter:
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
 
4
  from ola_vlm.constants import DEFAULT_IMAGE_TOKEN
5
 
6
  from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
 
8
  from ola_vlm.model.builder import load_pretrained_model
9
  from ola_vlm.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
10
 
11
+ from diffusers import StableUnCLIPImg2ImgPipeline, DPMSolverMultistepScheduler
 
12
  from transformers import OneFormerProcessor
13
  from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead
14
  from ola_vlm.ola_utils import visualize_oneformer_masks_on_image, oneformer_prepare_panoptic_instance_prediction
 
148
 
149
  pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16")
150
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
 
151
 
152
  oneformer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")
153
+ oneformer = OneFormerHead.from_pretrained("shi-labs/oneformer_coco_swin_large")
154
 
155
  gen_layer_indices = model.config.image_gen["img_layer_indices"].split("-")
156
  seg_layer_indices = model.config.image_seg["seg_layer_indices"].split("-")
 
178
  yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
179
 
180
  def get_gen_images(out):
181
+ pipe = pipe.to("cuda")
182
  img_embeds = out.image_embs
183
  if len(img_embeds) == 0:
184
  return None
 
211
  return grid_image
212
 
213
  def get_seg_images(out, image):
214
+ oneformer = oneformer.to("cuda")
215
  seg_embs = out.seg_embs
216
 
217
  if len(seg_embs) == 0:
 
251
  # @spaces.GPU
252
  # def get_interm_outs(state):
253
 
254
+ import spaces
255
  @spaces.GPU
256
  def generate(state, temperature, top_p, max_output_tokens, is_inter=False):
257
  if is_inter:
demo.py DELETED
@@ -1,486 +0,0 @@
1
- import gradio as gr
2
- import os
3
- import torch
4
- import numpy as np
5
-
6
- from ola_vlm.constants import DEFAULT_IMAGE_TOKEN
7
-
8
- from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9
- from ola_vlm.conversation import conv_templates, SeparatorStyle
10
- from ola_vlm.model.builder import load_pretrained_model
11
- from ola_vlm.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
12
-
13
- from diffusers import StableUnCLIPImg2ImgPipeline
14
- from diffusers import DPMSolverMultistepScheduler
15
- from transformers import OneFormerProcessor
16
- from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead
17
- from ola_vlm.ola_utils import visualize_oneformer_masks_on_image, oneformer_prepare_panoptic_instance_prediction
18
- import matplotlib
19
- from PIL import Image, ImageDraw, ImageFont
20
- import argparse
21
- import math
22
-
23
- from transformers import TextIteratorStreamer
24
- from threading import Thread
25
-
26
- def make_grid(pil_images, layer_indices=None):
27
- new_images = []
28
- new_captions = []
29
-
30
- # Resize images and prepare captions
31
- for i, pil_image in enumerate(pil_images):
32
- pil_image = pil_image.resize((256, 256))
33
- new_images.append(pil_image)
34
- if layer_indices is not None:
35
- new_captions.append(f"Layer: {layer_indices[i]}")
36
- else:
37
- new_captions.append(f"Layer: {i+1}")
38
-
39
- images = new_images
40
- captions = new_captions
41
-
42
- width, height = images[0].size
43
- font_size = 18
44
-
45
- # Calculate the number of rows and columns for the grid
46
- images_per_row = min(len(images), 4) # Max 4 images per row
47
- row_count = math.ceil(len(images) / images_per_row)
48
- total_width = width * images_per_row
49
- total_height = height * row_count
50
-
51
- # Create a new blank image
52
- new_image = Image.new("RGB", (total_width, total_height), "white")
53
- draw = ImageDraw.Draw(new_image)
54
-
55
- # Load a default font
56
- try:
57
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
58
- except:
59
- font = ImageFont.load_default()
60
-
61
- # Place images and captions in the grid
62
- for i, (image, caption) in enumerate(zip(images, captions)):
63
- row = i // images_per_row
64
- col = i % images_per_row
65
- x_offset = col * width
66
- y_offset = row * height
67
-
68
- # Paste the image
69
- new_image.paste(image, (x_offset, y_offset))
70
-
71
- # Calculate text and background positions
72
- text_width, text_height = draw.textsize(caption, font=font)
73
- text_position = (x_offset + 10, y_offset + height - text_height - 10)
74
- background_position = (
75
- text_position[0] - 5,
76
- text_position[1] - 5,
77
- text_position[0] + text_width + 5,
78
- text_position[1] + text_height + 5,
79
- )
80
-
81
- # Draw background rectangle and text
82
- draw.rectangle(background_position, fill="white", outline="black")
83
- draw.text(text_position, caption, fill="black", font=font)
84
-
85
- return new_image
86
-
87
- def reload_from_ckpt(model_path, model, cache_dir=None):
88
- import os
89
- from safetensors import safe_open
90
- from huggingface_hub import hf_hub_download, list_repo_files
91
-
92
- state_dict = {}
93
-
94
- # Check if the path is a local directory or HF Hub model
95
- if os.path.isdir(model_path):
96
- # Local directory: Load safetensors files
97
- safetensors_paths = [os.path.join(model_path, f) for f in os.listdir(model_path) if f.endswith('.safetensors')]
98
- else:
99
- # HF Hub: Get list of safetensors files and download them
100
- repo_files = list_repo_files(model_path)
101
- safetensors_paths = [
102
- hf_hub_download(model_path, file_name, cache_dir=cache_dir)
103
- for file_name in repo_files if file_name.endswith('.safetensors')
104
- ]
105
-
106
- # Load safetensors files into the state_dict
107
- for path in safetensors_paths:
108
- with safe_open(path, framework="pt", device="cpu") as f:
109
- for key in f.keys():
110
- state_dict[key] = f.get_tensor(key)
111
-
112
- # Load the state dict into the model
113
- model.load_state_dict(state_dict, strict=False)
114
- return model
115
-
116
- # os.environ['GRADIO_TEMP_DIR'] = './gradio_tmp'
117
- no_change_btn = gr.Button()
118
- enable_btn = gr.Button(interactive=True)
119
- disable_btn = gr.Button(interactive=False)
120
-
121
- argparser = argparse.ArgumentParser()
122
- argparser.add_argument("--server_name", default="0.0.0.0", type=str)
123
- argparser.add_argument("--port", default="6324", type=str)
124
- argparser.add_argument("--model-path", default="shi-labs/pretrain_dsg_OLA-VLM-CLIP-ViT-Llama3-8b", type=str)
125
- argparser.add_argument("--model-base", type=str, default=None)
126
- argparser.add_argument("--num-gpus", type=int, default=1)
127
- argparser.add_argument("--conv-mode", type=str, default="llava_llama_3")
128
- argparser.add_argument("--temperature", type=float, default=0.2)
129
- argparser.add_argument("--max-new-tokens", type=int, default=512)
130
- argparser.add_argument("--num_frames", type=int, default=16)
131
- argparser.add_argument("--load-8bit", action="store_true")
132
- argparser.add_argument("--load-4bit", action="store_true")
133
- argparser.add_argument("--debug", action="store_true")
134
-
135
- args = argparser.parse_args()
136
- model_path = args.model_path
137
- conv_mode = args.conv_mode
138
- filt_invalid="cut"
139
- model_name = get_model_name_from_path(args.model_path)
140
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
141
- model = reload_from_ckpt("shi-labs/OLA-VLM-CLIP-ViT-Llama3-8b", model)
142
- our_chatbot = None
143
-
144
- pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16")
145
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
146
- pipe = pipe.to("cuda")
147
-
148
- oneformer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")
149
- oneformer = OneFormerHead.from_pretrained("shi-labs/oneformer_coco_swin_large").to("cuda")
150
-
151
- gen_layer_indices = model.config.image_gen["img_layer_indices"].split("-")
152
- seg_layer_indices = model.config.image_seg["seg_layer_indices"].split("-")
153
- depth_layer_indices = model.config.image_depth["depth_layer_indices"].split("-")
154
-
155
-
156
- def clear_history():
157
- state =conv_templates[conv_mode].copy()
158
- return (state, state.to_gradio_chatbot(), "", None, None, None, None) + (disable_btn,) * 5
159
-
160
- def add_text(state, imagebox, textbox, image_process_mode):
161
- if state is None:
162
- state = conv_templates[conv_mode].copy()
163
-
164
- if imagebox is not None:
165
- textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox
166
- image = Image.open(imagebox).convert('RGB')
167
-
168
- if imagebox is not None:
169
- textbox = (textbox, image, image_process_mode)
170
-
171
- state.append_message(state.roles[0], textbox)
172
- state.append_message(state.roles[1], None)
173
-
174
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
175
-
176
- def get_gen_images(out):
177
- img_embeds = out.image_embs
178
- if len(img_embeds) == 0:
179
- return None
180
- images = []
181
- for img_embed in img_embeds:
182
- gen_image = pipe(image_embeds=img_embed.squeeze(1),
183
- num_inference_steps=25,
184
- ).images[0]
185
- images.append(gen_image)
186
- grid_image = make_grid(images, gen_layer_indices)
187
- return grid_image
188
-
189
- def get_depth_images(out, org_size):
190
- depth_preds = out.depth_preds
191
-
192
- if len(depth_preds) == 0:
193
- return None
194
- depths = []
195
-
196
- for i, depth_pred in enumerate(depth_preds):
197
- depth = (depth_pred - depth_pred.min()) / (depth_pred.max() - depth_pred.min()) * 255.0
198
- depth = depth.squeeze(0).cpu().numpy()
199
- depth = depth.astype(np.uint8)
200
- cmap = matplotlib.colormaps.get_cmap('Spectral_r')
201
- depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
202
- depth = Image.fromarray(depth)
203
- depth = depth.resize(org_size)
204
- depths.append(depth)
205
- grid_image = make_grid(depths, depth_layer_indices)
206
- return grid_image
207
-
208
- def get_seg_images(out, image):
209
- seg_embs = out.seg_embs
210
-
211
- if len(seg_embs) == 0:
212
- return None
213
-
214
- seg_preds = []
215
- inputs = oneformer_processor(image, ["semantic"], return_tensors="pt")
216
- inputs["pixel_values"] = inputs["pixel_values"].to(out.logits.device, out.logits.dtype)
217
- inputs["task_inputs"] = inputs["task_inputs"].to(out.logits.device, out.logits.dtype)
218
- backbone_features = oneformer.get_backbone_feats(**inputs)
219
- for i, seg_emb in enumerate(seg_embs):
220
- pred = oneformer.get_masks(**inputs, backbone_last_feature=seg_emb.float(), all_backbone_features=backbone_features)
221
- pred = oneformer_processor.post_process_panoptic_segmentation(
222
- pred, target_sizes=[image.size[::-1]]
223
- )[0]
224
- pred_msk, pred_cls = oneformer_prepare_panoptic_instance_prediction(**pred, oneformer=oneformer)
225
- pred = visualize_oneformer_masks_on_image(image, pred_msk, pred_cls)
226
- seg_preds.append(pred)
227
- grid_image = make_grid(seg_preds, seg_layer_indices)
228
- return grid_image
229
-
230
- def delete_text(state, image_process_mode):
231
- state.messages[-1][-1] = None
232
- prev_human_msg = state.messages[-2]
233
- if type(prev_human_msg[1]) in (tuple, list):
234
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
235
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
236
-
237
- def regenerate(state, image_process_mode):
238
- state.messages[-1][-1] = None
239
- prev_human_msg = state.messages[-2]
240
- if type(prev_human_msg[1]) in (tuple, list):
241
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
242
- state.skip_next = False
243
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
244
-
245
- def get_interm_outs(state):
246
- prompt = state.get_prompt()
247
- images = state.get_images(return_pil=True)
248
- #prompt, image_args = process_image(prompt, images)
249
-
250
- if images is not None and len(images) > 0:
251
- if len(images) > 0:
252
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
253
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
254
-
255
- #images = [load_image_from_base64(image) for image in images]
256
- image_sizes = [image.size for image in images]
257
- inp_images = process_images(images, image_processor, model.config)
258
-
259
- if type(inp_images) is list:
260
- inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
261
- else:
262
- inp_images = inp_images.to(model.device, dtype=torch.float16)
263
- else:
264
- inp_images = None
265
- image_sizes = None
266
- image_args = {"images": inp_images, "image_sizes": image_sizes}
267
- else:
268
- inp_images = None
269
- image_args = {}
270
-
271
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
272
-
273
- interm_outs = model.get_visual_interpretations(
274
- input_ids,
275
- **image_args
276
- )
277
-
278
- depth_outs = get_depth_images(interm_outs, image_sizes[0])
279
- seg_outs = get_seg_images(interm_outs, images[0])
280
- gen_outs = get_gen_images(interm_outs)
281
-
282
- return depth_outs, seg_outs, gen_outs
283
-
284
- # @spaces.GPU
285
- def generate(state, temperature, top_p, max_output_tokens):
286
- prompt = state.get_prompt()
287
- images = state.get_images(return_pil=True)
288
- #prompt, image_args = process_image(prompt, images)
289
-
290
- ori_prompt = prompt
291
- num_image_tokens = 0
292
-
293
- if images is not None and len(images) > 0:
294
- if len(images) > 0:
295
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
296
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
297
-
298
- #images = [load_image_from_base64(image) for image in images]
299
- image_sizes = [image.size for image in images]
300
- images = process_images(images, image_processor, model.config)
301
-
302
- if type(images) is list:
303
- images = [image.to(model.device, dtype=torch.float16) for image in images]
304
- else:
305
- images = images.to(model.device, dtype=torch.float16)
306
- else:
307
- images = None
308
- image_sizes = None
309
- image_args = {"images": images, "image_sizes": image_sizes}
310
- else:
311
- images = None
312
- image_args = {}
313
-
314
- max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
315
- max_new_tokens = max_output_tokens
316
- do_sample = True if temperature > 0.001 else False
317
- stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
318
-
319
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
320
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
321
-
322
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
323
-
324
- if max_new_tokens < 1:
325
- return
326
-
327
- thread = Thread(target=model.generate, kwargs=dict(
328
- inputs=input_ids,
329
- do_sample=do_sample,
330
- temperature=temperature,
331
- top_p=top_p,
332
- max_new_tokens=max_new_tokens,
333
- streamer=streamer,
334
- use_cache=True,
335
- pad_token_id=tokenizer.eos_token_id,
336
- **image_args
337
- ))
338
- thread.start()
339
- generated_text = ''
340
- for new_text in streamer:
341
- generated_text += new_text
342
- if generated_text.endswith(stop_str):
343
- generated_text = generated_text[:-len(stop_str)]
344
- state.messages[-1][-1] = generated_text
345
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
346
-
347
- yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
348
-
349
- torch.cuda.empty_cache()
350
-
351
- txt = gr.Textbox(
352
- scale=4,
353
- show_label=False,
354
- placeholder="Enter text and press enter.",
355
- container=False,
356
- )
357
-
358
-
359
- title = "<h1 style='margin-bottom: -10px; text-align: center'>OLA-VLM: Optimizing Language Model Representations for Enhanced Visual Quality and Alignment</h1>"
360
- description = "<p style='font-size: 16px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://praeclarumjj3.github.io/' style='text-decoration:none' target='_blank'>Jitesh Jain</a> &nbsp;&nbsp <a href='https://zyang-ur.github.io/' style='text-decoration:none' target='_blank'>Zhengyuan Yang</a> &nbsp;&nbsp <a href='https://www.humphreyshi.com/home' style='text-decoration:none' target='_blank'>Humphrey Shi<sup>*</sup></a> &nbsp;&nbsp <a href='https://www.humphreyshi.com/home' style='text-decoration:none' target='_blank'>Jianfeng Gao<sup>*</sup></a> &nbsp;&nbsp <a href='https://jwyang.github.io/' style='text-decoration:none' target='_blank'>Jianwei Yang<sup>*</sup></a></p>" \
361
- + "<p style='font-size: 12px; margin: 5px; font-weight: w300; text-align: center'><sup>*</sup>Equal Advising</p>" \
362
- + "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://praeclarumjj3.github.io/ola_vlm/' target='_blank'>Project Page</a> | <a href='https://youtu.be/' target='_blank'>Video</a> | <a href='https://arxiv.org/abs/' target='_blank'>ArXiv</a> | <a href='https://github.com/SHI-Labs/OLA-VLM' target='_blank'>Github</a></p>"
363
-
364
- tos_markdown = ("""
365
- ### Terms of use
366
- By using this service, users are required to agree to the following terms:
367
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
368
- """)
369
-
370
-
371
- learn_more_markdown = ("""
372
- ### License
373
- The service is a research preview intended for non-commercial use only, subject to the [License](https://huggingface.co/lmsys/vicuna-7b-v1.5) of Vicuna-v1.5, [License](https://github.com/haotian-liu/LLaVA/blob/main/LICENSE) of LLaVA, [Terms of Use](https://cocodataset.org/#termsofuse) of the COCO dataset, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
374
- """)
375
-
376
- block_css = """
377
- #buttons button {
378
- min-width: min(120px,100%);
379
- }
380
- """
381
-
382
-
383
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
384
- with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as demo:
385
- state = gr.State()
386
-
387
- gr.Markdown(title)
388
- gr.Markdown(description)
389
-
390
- with gr.Row():
391
- with gr.Column(scale=4):
392
- imagebox = gr.Image(label="Input Image", type="filepath")
393
- image_process_mode = gr.Radio(
394
- ["Crop", "Resize", "Pad", "Default"],
395
- value="Default",
396
- label="Preprocess for non-square image", visible=False)
397
-
398
- # with gr.Accordion("Parameters", open=False) as parameter_row:
399
- with gr.Row():
400
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
401
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
402
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
403
-
404
- with gr.Column(scale=8):
405
- chatbot = gr.Chatbot(
406
- elem_id="chatbot",
407
- label="OLA-VLM",
408
- height=300,
409
- layout="panel",
410
- )
411
- textbox.render()
412
- with gr.Row(elem_id="buttons") as button_row:
413
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False, visible=False)
414
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False, visible=False)
415
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False, visible=False)
416
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
417
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
418
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
419
- submit_btn = gr.Button(value="Send", variant="primary")
420
-
421
- with gr.Accordion("Representations from selected layers of the LLM (expects only a single image input)", open=False) as interm_out:
422
- inter_vis_btn = gr.Button(value="✨ Visualize")
423
- with gr.Row():
424
- depth_box = gr.Image(label="depth", type="pil", visible=True)
425
- seg_box = gr.Image(label="seg", type="pil", visible=True)
426
- gen_box = gr.Image(label="gen", type="pil", visible=True)
427
-
428
- gr.Examples(examples=[
429
- [f"assets/cars.jpg", "Which car is in front: the blue or the brown one?"],
430
- [f"assets/pb.jpg", "Where is the bulding located with respect to the man?"],
431
- ], inputs=[imagebox, textbox], cache_examples=False)
432
-
433
- # gr.Markdown(tos_markdown)
434
- # gr.Markdown(learn_more_markdown)
435
- # url_params = gr.JSON(visible=False)
436
-
437
- # Register listeners
438
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
439
-
440
- inter_vis_btn.click(
441
- get_interm_outs,
442
- [state],
443
- [depth_box, seg_box, gen_box],
444
- )
445
-
446
- clear_btn.click(
447
- clear_history,
448
- None,
449
- [state, chatbot, textbox, imagebox, depth_box, gen_box, seg_box] + btn_list,
450
- queue=False
451
- )
452
-
453
- regenerate_btn.click(
454
- delete_text,
455
- [state, image_process_mode],
456
- [state, chatbot, textbox, imagebox] + btn_list,
457
- ).then(
458
- generate,
459
- [state, temperature, top_p, max_output_tokens],
460
- [state, chatbot, textbox, imagebox] + btn_list,
461
- )
462
- textbox.submit(
463
- add_text,
464
- [state, imagebox, textbox, image_process_mode],
465
- [state, chatbot, textbox, imagebox] + btn_list,
466
- ).then(
467
- generate,
468
- [state, temperature, top_p, max_output_tokens],
469
- [state, chatbot, textbox, imagebox] + btn_list,
470
- )
471
-
472
- submit_btn.click(
473
- add_text,
474
- [state, imagebox, textbox, image_process_mode],
475
- [state, chatbot, textbox, imagebox] + btn_list,
476
- ).then(
477
- generate,
478
- [state, temperature, top_p, max_output_tokens],
479
- [state, chatbot, textbox, imagebox] + btn_list,
480
- )
481
-
482
- demo.queue(
483
- status_update_rate=10,
484
- api_open=False
485
- ).launch(share=True)
486
- demo.queue()