Spaces:
Build error
Build error
import os | |
import pickle | |
import gradio as gr | |
import torch | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from mmgpt.models.builder import create_model_and_transforms | |
TEMPLATE = "Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
response_split = "### Response:" | |
class Inferencer: | |
def __init__(self, finetune_path, llama_path, open_flamingo_path): | |
print("inferencer initialization begun") | |
ckpt = torch.load(finetune_path, map_location="cpu") | |
print("ckpt: ", ckpt) | |
if "model_state_dict" in ckpt: | |
state_dict = ckpt["model_state_dict"] | |
# remove the "module." prefix | |
state_dict = { | |
k[7:]: v | |
for k, v in state_dict.items() if k.startswith("module.") | |
} | |
else: | |
state_dict = ckpt | |
print("state_dict has been set") | |
tuning_config = ckpt.get("tuning_config") | |
if tuning_config is None: | |
print("tuning_config not found in checkpoint") | |
else: | |
print("tuning_config found in checkpoint: ", tuning_config) | |
model, image_processor, tokenizer = create_model_and_transforms( | |
model_name="open_flamingo", | |
clip_vision_encoder_path="ViT-L-14", | |
clip_vision_encoder_pretrained="openai", | |
lang_encoder_path=llama_path, | |
tokenizer_path=llama_path, | |
pretrained_model_path=open_flamingo_path, | |
tuning_config=tuning_config, | |
) | |
model.load_state_dict(state_dict, strict=False) | |
model.half() | |
model = model.to("cuda") | |
model.eval() | |
tokenizer.padding_side = "left" | |
tokenizer.add_eos_token = False | |
self.model = model | |
self.image_processor = image_processor | |
self.tokenizer = tokenizer | |
print("finished inferencer initialization") | |
def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature, | |
top_k, top_p, do_sample): | |
print("inferecer called") | |
if len(imgpaths) > 1: | |
raise gr.Error( | |
"Current only support one image, please clear gallery and upload one image" | |
) | |
lang_x = self.tokenizer([prompt], return_tensors="pt") | |
print("tokenized") | |
if len(imgpaths) == 0 or imgpaths is None: | |
print("imgpath len is 0 or None") | |
for layer in self.model.lang_encoder._get_decoder_layers(): | |
layer.condition_only_lang_x(True) | |
output_ids = self.model.lang_encoder.generate( | |
input_ids=lang_x["input_ids"].cuda(), | |
attention_mask=lang_x["attention_mask"].cuda(), | |
max_new_tokens=max_new_token, | |
num_beams=num_beams, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
do_sample=do_sample, | |
)[0] | |
for layer in self.model.lang_encoder._get_decoder_layers(): | |
layer.condition_only_lang_x(False) | |
else: | |
print("imgpath is valid") | |
images = (Image.open(fp) for fp in imgpaths) | |
print("images retrieved") | |
vision_x = [self.image_processor(im).unsqueeze(0) for im in images] | |
vision_x = torch.cat(vision_x, dim=0) | |
vision_x = vision_x.unsqueeze(1).unsqueeze(0).half() | |
print("vision_x retrieved") | |
torch.cuda.empty_cache() | |
print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
print(f"Available GPU memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") | |
output_ids = self.model.generate( | |
vision_x=vision_x.cuda(), | |
lang_x=lang_x["input_ids"].cuda(), | |
attention_mask=lang_x["attention_mask"].cuda(), | |
max_new_tokens=max_new_token, | |
num_beams=num_beams, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
do_sample=do_sample, | |
)[0] | |
print("output_ids retrieved") | |
generated_text = self.tokenizer.decode( | |
output_ids, skip_special_tokens=True) | |
print("text generated:", generated_text) | |
result = generated_text.split(response_split)[-1].strip() | |
print("result: ", result) | |
return result | |
def save(self, file_path): | |
print("Saving model components...") | |
data = { | |
"model_state_dict": self.model.state_dict(), | |
"tokenizer": self.tokenizer, | |
"image_processor": self.image_processor, | |
} | |
with open(file_path, "wb") as f: | |
pickle.dump(data, f) | |
print(f"Model components saved to {file_path}") | |
class PromptGenerator: | |
def __init__( | |
self, | |
prompt_template=TEMPLATE, | |
ai_prefix="Response", | |
user_prefix="Instruction", | |
sep: str = "\n\n### ", | |
buffer_size=0, | |
): | |
self.all_history = [("user", "Welcome to the chatbot!")] | |
self.ai_prefix = ai_prefix | |
self.user_prefix = user_prefix | |
self.buffer_size = buffer_size | |
self.prompt_template = prompt_template | |
self.sep = sep | |
def add_message(self, role, message): | |
self.all_history.append([role, message]) | |
def get_images(self): | |
img_list = list() | |
if self.buffer_size > 0: | |
all_history = self.all_history[-2 * (self.buffer_size + 1):] | |
elif self.buffer_size == 0: | |
all_history = self.all_history[-2:] | |
else: | |
all_history = self.all_history[:] | |
for his in all_history: | |
if type(his[-1]) == tuple: | |
img_list.append(his[-1][-1]) | |
return img_list | |
def get_prompt(self): | |
format_dict = dict() | |
if "{user_prefix}" in self.prompt_template: | |
format_dict["user_prefix"] = self.user_prefix | |
if "{ai_prefix}" in self.prompt_template: | |
format_dict["ai_prefix"] = self.ai_prefix | |
prompt_template = self.prompt_template.format(**format_dict) | |
ret = prompt_template | |
if self.buffer_size > 0: | |
all_history = self.all_history[-2 * (self.buffer_size + 1):] | |
elif self.buffer_size == 0: | |
all_history = self.all_history[-2:] | |
else: | |
all_history = self.all_history[:] | |
context = [] | |
have_image = False | |
for role, message in all_history[::-1]: | |
if message: | |
if type(message) is tuple and message[ | |
1] is not None and not have_image: | |
message, _ = message | |
context.append(self.sep + "Image:\n<image>" + self.sep + | |
role + ":\n" + message) | |
else: | |
context.append(self.sep + role + ":\n" + message) | |
else: | |
context.append(self.sep + role + ":\n") | |
ret += "".join(context[::-1]) | |
return ret | |
def to_gradio_chatbot(prompt_generator): | |
ret = [] | |
for i, (role, msg) in enumerate(prompt_generator.all_history): | |
if i % 2 == 0: | |
if type(msg) is tuple: | |
import base64 | |
from io import BytesIO | |
msg, image = msg | |
if type(image) is str: | |
from PIL import Image | |
image = Image.open(image) | |
max_hw, min_hw = max(image.size), min(image.size) | |
aspect_ratio = max_hw / min_hw | |
max_len, min_len = 800, 400 | |
shortest_edge = int( | |
min(max_len / aspect_ratio, min_len, min_hw)) | |
longest_edge = int(shortest_edge * aspect_ratio) | |
H, W = image.size | |
if H > W: | |
H, W = longest_edge, shortest_edge | |
else: | |
H, W = shortest_edge, longest_edge | |
image = image.resize((H, W)) | |
# image = image.resize((224, 224)) | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
img_b64_str = base64.b64encode(buffered.getvalue()).decode() | |
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />' | |
msg = msg + img_str | |
ret.append([msg, None]) | |
else: | |
ret[-1][-1] = msg | |
return ret | |
def bot( | |
text, | |
image, | |
state, | |
prompt, | |
ai_prefix, | |
user_prefix, | |
seperator, | |
history_buffer, | |
max_new_token, | |
num_beams, | |
temperature, | |
top_k, | |
top_p, | |
do_sample, | |
): | |
state.prompt_template = prompt | |
state.ai_prefix = ai_prefix | |
state.user_prefix = user_prefix | |
state.sep = seperator | |
state.buffer_size = history_buffer | |
if image: | |
print(image) | |
print(text) | |
state.add_message(user_prefix, (text, image)) | |
print("added message") | |
else: | |
state.add_message(user_prefix, text) | |
state.add_message(ai_prefix, None) | |
print("added ai_prefix message") | |
inputs = state.get_prompt() | |
print("retrived inputs") | |
image_paths = state.get_images()[-1:] | |
print("retrieved image_paths") | |
inference_results = inferencer(inputs, image_paths, max_new_token, | |
num_beams, temperature, top_k, top_p, | |
do_sample) | |
print(inference_results) | |
state.all_history[-1][-1] = inference_results | |
memory_allocated = str(round(torch.cuda.memory_allocated() / 1024**3, | |
2)) + 'GB' | |
return state, to_gradio_chatbot(state), "", None, inputs, memory_allocated | |
def clear(state): | |
state.all_history = [] | |
return state, to_gradio_chatbot(state), "", None, "" | |
# title_markdown = (""" | |
# # 🤖 Multi-modal GPT | |
# [[Project]](https://github.com/open-mmlab/Multimodal-GPT.git)""") | |
def build_conversation_demo(): | |
with gr.Blocks(title="Multi-modal GPT") as demo: | |
gr.Markdown(title_markdown) | |
state = gr.State(PromptGenerator()) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
memory_allocated = gr.Textbox( | |
value=init_memory, label="Memory") | |
imagebox = gr.Image(type="filepath") | |
# TODO config parameters | |
with gr.Accordion( | |
"Parameters", | |
open=True, | |
): | |
max_new_token_bar = gr.Slider( | |
0, 1024, 512, label="max_new_token", step=1) | |
num_beams_bar = gr.Slider( | |
0.0, 10, 3, label="num_beams", step=1) | |
temperature_bar = gr.Slider( | |
0.0, 1.0, 1.0, label="temperature", step=0.01) | |
topk_bar = gr.Slider(0, 100, 20, label="top_k", step=1) | |
topp_bar = gr.Slider(0, 1.0, 1.0, label="top_p", step=0.01) | |
do_sample = gr.Checkbox(True, label="do_sample") | |
with gr.Accordion( | |
"Prompt", | |
open=False, | |
): | |
with gr.Row(): | |
ai_prefix = gr.Text("Response", label="AI Prefix") | |
user_prefix = gr.Text( | |
"Instruction", label="User Prefix") | |
seperator = gr.Text("\n\n### ", label="Seperator") | |
history_buffer = gr.Slider( | |
-1, 10, -1, label="History buffer", step=1) | |
prompt = gr.Text(TEMPLATE, label="Prompt") | |
model_inputs = gr.Textbox(label="Actual inputs for Model") | |
with gr.Column(scale=6): | |
with gr.Row(): | |
with gr.Column(): | |
chatbot = gr.Chatbot(elem_id="chatbot", height=750) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
textbox = gr.Textbox( | |
show_label=False, | |
placeholder="Enter text and press ENTER", | |
container=False) | |
submit_btn = gr.Button(value="Submit") | |
clear_btn = gr.Button(value="🗑️ Clear history") | |
cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
gr.Examples( | |
examples=[ | |
[ | |
f"{cur_dir}/docs/images/demo_image.jpg", | |
"What is in this image?" | |
], | |
], | |
inputs=[imagebox, textbox], | |
) | |
textbox.submit( | |
bot, | |
[ | |
textbox, | |
imagebox, | |
state, | |
prompt, | |
ai_prefix, | |
user_prefix, | |
seperator, | |
history_buffer, | |
max_new_token_bar, | |
num_beams_bar, | |
temperature_bar, | |
topk_bar, | |
topp_bar, | |
do_sample, | |
], | |
[ | |
state, chatbot, textbox, imagebox, model_inputs, | |
memory_allocated | |
], | |
) | |
submit_btn.click( | |
bot, | |
[ | |
textbox, | |
imagebox, | |
state, | |
prompt, | |
ai_prefix, | |
user_prefix, | |
seperator, | |
history_buffer, | |
max_new_token_bar, | |
num_beams_bar, | |
temperature_bar, | |
topk_bar, | |
topp_bar, | |
do_sample, | |
], | |
[ | |
state, chatbot, textbox, imagebox, model_inputs, | |
memory_allocated | |
], | |
) | |
clear_btn.click(clear, [state], | |
[state, chatbot, textbox, imagebox, model_inputs]) | |
return demo | |
if __name__ == "__main__": | |
llama_path = "checkpoints/llama-7b_hf" | |
open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt" | |
finetune_path = "checkpoints/mmgpt-lora-v0-release.pt" | |
inferencer = Inferencer( | |
llama_path=llama_path, | |
open_flamingo_path=open_flamingo_path, | |
finetune_path=finetune_path) | |
init_memory = str(round(torch.cuda.memory_allocated() / 1024**3, 2)) + 'GB' | |
inferencer.save("inferencer.pkl") | |
demo = build_conversation_demo() | |
demo.queue() | |
IP = "0.0.0.0" | |
PORT = 8997 | |
demo.launch(server_name=IP, server_port=PORT, share=True) | |