akki8602 commited on
Commit
d35e7f4
·
1 Parent(s): 9c73e47

Add application file

Browse files
Files changed (1) hide show
  1. app.py +42 -8
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
-
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
 
6
 
7
  from mmgpt.models.builder import create_model_and_transforms
8
 
@@ -13,7 +14,9 @@ response_split = "### Response:"
13
  class Inferencer:
14
 
15
  def __init__(self, finetune_path, llama_path, open_flamingo_path):
 
16
  ckpt = torch.load(finetune_path, map_location="cpu")
 
17
  if "model_state_dict" in ckpt:
18
  state_dict = ckpt["model_state_dict"]
19
  # remove the "module." prefix
@@ -23,6 +26,7 @@ class Inferencer:
23
  }
24
  else:
25
  state_dict = ckpt
 
26
  tuning_config = ckpt.get("tuning_config")
27
  if tuning_config is None:
28
  print("tuning_config not found in checkpoint")
@@ -46,15 +50,19 @@ class Inferencer:
46
  self.model = model
47
  self.image_processor = image_processor
48
  self.tokenizer = tokenizer
 
49
 
50
  def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature,
51
  top_k, top_p, do_sample):
 
52
  if len(imgpaths) > 1:
53
  raise gr.Error(
54
  "Current only support one image, please clear gallery and upload one image"
55
  )
56
  lang_x = self.tokenizer([prompt], return_tensors="pt")
 
57
  if len(imgpaths) == 0 or imgpaths is None:
 
58
  for layer in self.model.lang_encoder._get_decoder_layers():
59
  layer.condition_only_lang_x(True)
60
  output_ids = self.model.lang_encoder.generate(
@@ -70,10 +78,16 @@ class Inferencer:
70
  for layer in self.model.lang_encoder._get_decoder_layers():
71
  layer.condition_only_lang_x(False)
72
  else:
 
73
  images = (Image.open(fp) for fp in imgpaths)
 
74
  vision_x = [self.image_processor(im).unsqueeze(0) for im in images]
75
  vision_x = torch.cat(vision_x, dim=0)
76
  vision_x = vision_x.unsqueeze(1).unsqueeze(0).half()
 
 
 
 
77
 
78
  output_ids = self.model.generate(
79
  vision_x=vision_x.cuda(),
@@ -86,12 +100,24 @@ class Inferencer:
86
  top_p=top_p,
87
  do_sample=do_sample,
88
  )[0]
 
89
  generated_text = self.tokenizer.decode(
90
  output_ids, skip_special_tokens=True)
91
- # print(generated_text)
92
  result = generated_text.split(response_split)[-1].strip()
 
93
  return result
94
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  class PromptGenerator:
97
 
@@ -103,7 +129,7 @@ class PromptGenerator:
103
  sep: str = "\n\n### ",
104
  buffer_size=0,
105
  ):
106
- self.all_history = list()
107
  self.ai_prefix = ai_prefix
108
  self.user_prefix = user_prefix
109
  self.buffer_size = buffer_size
@@ -217,16 +243,23 @@ def bot(
217
  state.sep = seperator
218
  state.buffer_size = history_buffer
219
  if image:
 
 
220
  state.add_message(user_prefix, (text, image))
 
221
  else:
222
  state.add_message(user_prefix, text)
223
  state.add_message(ai_prefix, None)
 
224
  inputs = state.get_prompt()
 
225
  image_paths = state.get_images()[-1:]
 
226
 
227
  inference_results = inferencer(inputs, image_paths, max_new_token,
228
  num_beams, temperature, top_k, top_p,
229
  do_sample)
 
230
  state.all_history[-1][-1] = inference_results
231
  memory_allocated = str(round(torch.cuda.memory_allocated() / 1024**3,
232
  2)) + 'GB'
@@ -284,14 +317,13 @@ def build_conversation_demo():
284
  with gr.Column(scale=6):
285
  with gr.Row():
286
  with gr.Column():
287
- chatbot = gr.Chatbot(elem_id="chatbot").style(
288
- height=750)
289
  with gr.Row():
290
  with gr.Column(scale=8):
291
  textbox = gr.Textbox(
292
  show_label=False,
293
  placeholder="Enter text and press ENTER",
294
- ).style(container=False)
295
  submit_btn = gr.Button(value="Submit")
296
  clear_btn = gr.Button(value="🗑️ Clear history")
297
  cur_dir = os.path.dirname(os.path.abspath(__file__))
@@ -354,7 +386,6 @@ def build_conversation_demo():
354
  [state, chatbot, textbox, imagebox, model_inputs])
355
  return demo
356
 
357
-
358
  if __name__ == "__main__":
359
  llama_path = "checkpoints/llama-7b_hf"
360
  open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt"
@@ -365,8 +396,11 @@ if __name__ == "__main__":
365
  open_flamingo_path=open_flamingo_path,
366
  finetune_path=finetune_path)
367
  init_memory = str(round(torch.cuda.memory_allocated() / 1024**3, 2)) + 'GB'
 
 
 
368
  demo = build_conversation_demo()
369
- demo.queue(concurrency_count=3)
370
  IP = "0.0.0.0"
371
  PORT = 8997
372
  demo.launch(server_name=IP, server_port=PORT, share=True)
 
1
  import os
2
+ import pickle
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
+ import matplotlib.pyplot as plt
7
 
8
  from mmgpt.models.builder import create_model_and_transforms
9
 
 
14
  class Inferencer:
15
 
16
  def __init__(self, finetune_path, llama_path, open_flamingo_path):
17
+ print("inferencer initialization begun")
18
  ckpt = torch.load(finetune_path, map_location="cpu")
19
+ print("ckpt: ", ckpt)
20
  if "model_state_dict" in ckpt:
21
  state_dict = ckpt["model_state_dict"]
22
  # remove the "module." prefix
 
26
  }
27
  else:
28
  state_dict = ckpt
29
+ print("state_dict has been set")
30
  tuning_config = ckpt.get("tuning_config")
31
  if tuning_config is None:
32
  print("tuning_config not found in checkpoint")
 
50
  self.model = model
51
  self.image_processor = image_processor
52
  self.tokenizer = tokenizer
53
+ print("finished inferencer initialization")
54
 
55
  def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature,
56
  top_k, top_p, do_sample):
57
+ print("inferecer called")
58
  if len(imgpaths) > 1:
59
  raise gr.Error(
60
  "Current only support one image, please clear gallery and upload one image"
61
  )
62
  lang_x = self.tokenizer([prompt], return_tensors="pt")
63
+ print("tokenized")
64
  if len(imgpaths) == 0 or imgpaths is None:
65
+ print("imgpath len is 0 or None")
66
  for layer in self.model.lang_encoder._get_decoder_layers():
67
  layer.condition_only_lang_x(True)
68
  output_ids = self.model.lang_encoder.generate(
 
78
  for layer in self.model.lang_encoder._get_decoder_layers():
79
  layer.condition_only_lang_x(False)
80
  else:
81
+ print("imgpath is valid")
82
  images = (Image.open(fp) for fp in imgpaths)
83
+ print("images retrieved")
84
  vision_x = [self.image_processor(im).unsqueeze(0) for im in images]
85
  vision_x = torch.cat(vision_x, dim=0)
86
  vision_x = vision_x.unsqueeze(1).unsqueeze(0).half()
87
+ print("vision_x retrieved")
88
+ torch.cuda.empty_cache()
89
+ print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
90
+ print(f"Available GPU memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
91
 
92
  output_ids = self.model.generate(
93
  vision_x=vision_x.cuda(),
 
100
  top_p=top_p,
101
  do_sample=do_sample,
102
  )[0]
103
+ print("output_ids retrieved")
104
  generated_text = self.tokenizer.decode(
105
  output_ids, skip_special_tokens=True)
106
+ print("text generated:", generated_text)
107
  result = generated_text.split(response_split)[-1].strip()
108
+ print("result: ", result)
109
  return result
110
 
111
+ def save(self, file_path):
112
+ print("Saving model components...")
113
+ data = {
114
+ "model_state_dict": self.model.state_dict(),
115
+ "tokenizer": self.tokenizer,
116
+ "image_processor": self.image_processor,
117
+ }
118
+ with open(file_path, "wb") as f:
119
+ pickle.dump(data, f)
120
+ print(f"Model components saved to {file_path}")
121
 
122
  class PromptGenerator:
123
 
 
129
  sep: str = "\n\n### ",
130
  buffer_size=0,
131
  ):
132
+ self.all_history = [("user", "Welcome to the chatbot!")]
133
  self.ai_prefix = ai_prefix
134
  self.user_prefix = user_prefix
135
  self.buffer_size = buffer_size
 
243
  state.sep = seperator
244
  state.buffer_size = history_buffer
245
  if image:
246
+ print(image)
247
+ print(text)
248
  state.add_message(user_prefix, (text, image))
249
+ print("added message")
250
  else:
251
  state.add_message(user_prefix, text)
252
  state.add_message(ai_prefix, None)
253
+ print("added ai_prefix message")
254
  inputs = state.get_prompt()
255
+ print("retrived inputs")
256
  image_paths = state.get_images()[-1:]
257
+ print("retrieved image_paths")
258
 
259
  inference_results = inferencer(inputs, image_paths, max_new_token,
260
  num_beams, temperature, top_k, top_p,
261
  do_sample)
262
+ print(inference_results)
263
  state.all_history[-1][-1] = inference_results
264
  memory_allocated = str(round(torch.cuda.memory_allocated() / 1024**3,
265
  2)) + 'GB'
 
317
  with gr.Column(scale=6):
318
  with gr.Row():
319
  with gr.Column():
320
+ chatbot = gr.Chatbot(elem_id="chatbot", height=750)
 
321
  with gr.Row():
322
  with gr.Column(scale=8):
323
  textbox = gr.Textbox(
324
  show_label=False,
325
  placeholder="Enter text and press ENTER",
326
+ container=False)
327
  submit_btn = gr.Button(value="Submit")
328
  clear_btn = gr.Button(value="🗑️ Clear history")
329
  cur_dir = os.path.dirname(os.path.abspath(__file__))
 
386
  [state, chatbot, textbox, imagebox, model_inputs])
387
  return demo
388
 
 
389
  if __name__ == "__main__":
390
  llama_path = "checkpoints/llama-7b_hf"
391
  open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt"
 
396
  open_flamingo_path=open_flamingo_path,
397
  finetune_path=finetune_path)
398
  init_memory = str(round(torch.cuda.memory_allocated() / 1024**3, 2)) + 'GB'
399
+
400
+ inferencer.save("inferencer.pkl")
401
+
402
  demo = build_conversation_demo()
403
+ demo.queue()
404
  IP = "0.0.0.0"
405
  PORT = 8997
406
  demo.launch(server_name=IP, server_port=PORT, share=True)