MultimodalGPT / app.py
mallepally's picture
Add application file
8cef389
import os
import pickle
import gradio as gr
import torch
from PIL import Image
import matplotlib.pyplot as plt
from mmgpt.models.builder import create_model_and_transforms
TEMPLATE = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
response_split = "### Response:"
class Inferencer:
def __init__(self, finetune_path, llama_path, open_flamingo_path):
print("inferencer initialization begun")
ckpt = torch.load(finetune_path, map_location="cpu")
print("ckpt: ", ckpt)
if "model_state_dict" in ckpt:
state_dict = ckpt["model_state_dict"]
# remove the "module." prefix
state_dict = {
k[7:]: v
for k, v in state_dict.items() if k.startswith("module.")
}
else:
state_dict = ckpt
print("state_dict has been set")
tuning_config = ckpt.get("tuning_config")
if tuning_config is None:
print("tuning_config not found in checkpoint")
else:
print("tuning_config found in checkpoint: ", tuning_config)
model, image_processor, tokenizer = create_model_and_transforms(
model_name="open_flamingo",
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path=llama_path,
tokenizer_path=llama_path,
pretrained_model_path=open_flamingo_path,
tuning_config=tuning_config,
)
model.load_state_dict(state_dict, strict=False)
model.half()
model = model.to("cuda")
model.eval()
tokenizer.padding_side = "left"
tokenizer.add_eos_token = False
self.model = model
self.image_processor = image_processor
self.tokenizer = tokenizer
print("finished inferencer initialization")
def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature,
top_k, top_p, do_sample):
print("inferecer called")
if len(imgpaths) > 1:
raise gr.Error(
"Current only support one image, please clear gallery and upload one image"
)
lang_x = self.tokenizer([prompt], return_tensors="pt")
print("tokenized")
if len(imgpaths) == 0 or imgpaths is None:
print("imgpath len is 0 or None")
for layer in self.model.lang_encoder._get_decoder_layers():
layer.condition_only_lang_x(True)
output_ids = self.model.lang_encoder.generate(
input_ids=lang_x["input_ids"].cuda(),
attention_mask=lang_x["attention_mask"].cuda(),
max_new_tokens=max_new_token,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
)[0]
for layer in self.model.lang_encoder._get_decoder_layers():
layer.condition_only_lang_x(False)
else:
print("imgpath is valid")
images = (Image.open(fp) for fp in imgpaths)
print("images retrieved")
vision_x = [self.image_processor(im).unsqueeze(0) for im in images]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0).half()
print("vision_x retrieved")
torch.cuda.empty_cache()
print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Available GPU memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
output_ids = self.model.generate(
vision_x=vision_x.cuda(),
lang_x=lang_x["input_ids"].cuda(),
attention_mask=lang_x["attention_mask"].cuda(),
max_new_tokens=max_new_token,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
)[0]
print("output_ids retrieved")
generated_text = self.tokenizer.decode(
output_ids, skip_special_tokens=True)
print("text generated:", generated_text)
result = generated_text.split(response_split)[-1].strip()
print("result: ", result)
return result
def save(self, file_path):
print("Saving model components...")
data = {
"model_state_dict": self.model.state_dict(),
"tokenizer": self.tokenizer,
"image_processor": self.image_processor,
}
with open(file_path, "wb") as f:
pickle.dump(data, f)
print(f"Model components saved to {file_path}")
class PromptGenerator:
def __init__(
self,
prompt_template=TEMPLATE,
ai_prefix="Response",
user_prefix="Instruction",
sep: str = "\n\n### ",
buffer_size=0,
):
self.all_history = [("user", "Welcome to the chatbot!")]
self.ai_prefix = ai_prefix
self.user_prefix = user_prefix
self.buffer_size = buffer_size
self.prompt_template = prompt_template
self.sep = sep
def add_message(self, role, message):
self.all_history.append([role, message])
def get_images(self):
img_list = list()
if self.buffer_size > 0:
all_history = self.all_history[-2 * (self.buffer_size + 1):]
elif self.buffer_size == 0:
all_history = self.all_history[-2:]
else:
all_history = self.all_history[:]
for his in all_history:
if type(his[-1]) == tuple:
img_list.append(his[-1][-1])
return img_list
def get_prompt(self):
format_dict = dict()
if "{user_prefix}" in self.prompt_template:
format_dict["user_prefix"] = self.user_prefix
if "{ai_prefix}" in self.prompt_template:
format_dict["ai_prefix"] = self.ai_prefix
prompt_template = self.prompt_template.format(**format_dict)
ret = prompt_template
if self.buffer_size > 0:
all_history = self.all_history[-2 * (self.buffer_size + 1):]
elif self.buffer_size == 0:
all_history = self.all_history[-2:]
else:
all_history = self.all_history[:]
context = []
have_image = False
for role, message in all_history[::-1]:
if message:
if type(message) is tuple and message[
1] is not None and not have_image:
message, _ = message
context.append(self.sep + "Image:\n<image>" + self.sep +
role + ":\n" + message)
else:
context.append(self.sep + role + ":\n" + message)
else:
context.append(self.sep + role + ":\n")
ret += "".join(context[::-1])
return ret
def to_gradio_chatbot(prompt_generator):
ret = []
for i, (role, msg) in enumerate(prompt_generator.all_history):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image = msg
if type(image) is str:
from PIL import Image
image = Image.open(image)
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(
min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
H, W = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((H, W))
# image = image.resize((224, 224))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = msg + img_str
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def bot(
text,
image,
state,
prompt,
ai_prefix,
user_prefix,
seperator,
history_buffer,
max_new_token,
num_beams,
temperature,
top_k,
top_p,
do_sample,
):
state.prompt_template = prompt
state.ai_prefix = ai_prefix
state.user_prefix = user_prefix
state.sep = seperator
state.buffer_size = history_buffer
if image:
print(image)
print(text)
state.add_message(user_prefix, (text, image))
print("added message")
else:
state.add_message(user_prefix, text)
state.add_message(ai_prefix, None)
print("added ai_prefix message")
inputs = state.get_prompt()
print("retrived inputs")
image_paths = state.get_images()[-1:]
print("retrieved image_paths")
inference_results = inferencer(inputs, image_paths, max_new_token,
num_beams, temperature, top_k, top_p,
do_sample)
print(inference_results)
state.all_history[-1][-1] = inference_results
memory_allocated = str(round(torch.cuda.memory_allocated() / 1024**3,
2)) + 'GB'
return state, to_gradio_chatbot(state), "", None, inputs, memory_allocated
def clear(state):
state.all_history = []
return state, to_gradio_chatbot(state), "", None, ""
# title_markdown = ("""
# # 🤖 Multi-modal GPT
# [[Project]](https://github.com/open-mmlab/Multimodal-GPT.git)""")
def build_conversation_demo():
with gr.Blocks(title="Multi-modal GPT") as demo:
gr.Markdown(title_markdown)
state = gr.State(PromptGenerator())
with gr.Row():
with gr.Column(scale=3):
memory_allocated = gr.Textbox(
value=init_memory, label="Memory")
imagebox = gr.Image(type="filepath")
# TODO config parameters
with gr.Accordion(
"Parameters",
open=True,
):
max_new_token_bar = gr.Slider(
0, 1024, 512, label="max_new_token", step=1)
num_beams_bar = gr.Slider(
0.0, 10, 3, label="num_beams", step=1)
temperature_bar = gr.Slider(
0.0, 1.0, 1.0, label="temperature", step=0.01)
topk_bar = gr.Slider(0, 100, 20, label="top_k", step=1)
topp_bar = gr.Slider(0, 1.0, 1.0, label="top_p", step=0.01)
do_sample = gr.Checkbox(True, label="do_sample")
with gr.Accordion(
"Prompt",
open=False,
):
with gr.Row():
ai_prefix = gr.Text("Response", label="AI Prefix")
user_prefix = gr.Text(
"Instruction", label="User Prefix")
seperator = gr.Text("\n\n### ", label="Seperator")
history_buffer = gr.Slider(
-1, 10, -1, label="History buffer", step=1)
prompt = gr.Text(TEMPLATE, label="Prompt")
model_inputs = gr.Textbox(label="Actual inputs for Model")
with gr.Column(scale=6):
with gr.Row():
with gr.Column():
chatbot = gr.Chatbot(elem_id="chatbot", height=750)
with gr.Row():
with gr.Column(scale=8):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
container=False)
submit_btn = gr.Button(value="Submit")
clear_btn = gr.Button(value="🗑️ Clear history")
cur_dir = os.path.dirname(os.path.abspath(__file__))
gr.Examples(
examples=[
[
f"{cur_dir}/docs/images/demo_image.jpg",
"What is in this image?"
],
],
inputs=[imagebox, textbox],
)
textbox.submit(
bot,
[
textbox,
imagebox,
state,
prompt,
ai_prefix,
user_prefix,
seperator,
history_buffer,
max_new_token_bar,
num_beams_bar,
temperature_bar,
topk_bar,
topp_bar,
do_sample,
],
[
state, chatbot, textbox, imagebox, model_inputs,
memory_allocated
],
)
submit_btn.click(
bot,
[
textbox,
imagebox,
state,
prompt,
ai_prefix,
user_prefix,
seperator,
history_buffer,
max_new_token_bar,
num_beams_bar,
temperature_bar,
topk_bar,
topp_bar,
do_sample,
],
[
state, chatbot, textbox, imagebox, model_inputs,
memory_allocated
],
)
clear_btn.click(clear, [state],
[state, chatbot, textbox, imagebox, model_inputs])
return demo
if __name__ == "__main__":
llama_path = "checkpoints/llama-7b_hf"
open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt"
finetune_path = "checkpoints/mmgpt-lora-v0-release.pt"
inferencer = Inferencer(
llama_path=llama_path,
open_flamingo_path=open_flamingo_path,
finetune_path=finetune_path)
init_memory = str(round(torch.cuda.memory_allocated() / 1024**3, 2)) + 'GB'
inferencer.save("inferencer.pkl")
demo = build_conversation_demo()
demo.queue()
IP = "0.0.0.0"
PORT = 8997
demo.launch(server_name=IP, server_port=PORT, share=True)