Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import json | |
| from tqdm import tqdm | |
| import random | |
| import numpy as np | |
| from PIL import Image | |
| import webdataset as wds | |
| import torch | |
| from torchvision.datasets import ImageFolder | |
| import torchvision.transforms as transforms | |
| import openai | |
| from tenacity import ( | |
| retry, | |
| stop_after_attempt, | |
| wait_random_exponential, | |
| ) # for exponential backoff | |
| from minigpt4.common.config import Config | |
| from minigpt4.common.registry import registry | |
| from minigpt4.conversation.conversation import Chat | |
| openai.api_key = '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' | |
| def prepare_chatgpt_message(task_prompt, paragraph): | |
| messages = [{"role": "system", "content": task_prompt}, | |
| {"role": "user", "content": paragraph}] | |
| return messages | |
| def call_chatgpt(chatgpt_messages, max_tokens=200, model="gpt-3.5-turbo"): | |
| response = openai.ChatCompletion.create(model=model, messages=chatgpt_messages, temperature=0.7, max_tokens=max_tokens) | |
| reply = response['choices'][0]['message']['content'] | |
| total_tokens = response['usage']['total_tokens'] | |
| return reply, total_tokens | |
| def main(args): | |
| print('Initializing Chat') | |
| cfg = Config(args) | |
| model_config = cfg.model_cfg | |
| model_cls = registry.get_model_class(model_config.arch) | |
| model = model_cls.from_config(model_config).to('cuda:{}'.format(args.device)) | |
| ckpt_path = '/ibex/project/c2133/vicuna_ckpt_test/Vicuna_pretrain_stage2_cc/20230405233_3GPU40kSTEP_MAIN/checkpoint_3.pth' | |
| ckpt = torch.load(ckpt_path) | |
| msg = model.load_state_dict(ckpt['model'], strict=False) | |
| vis_processor_cfg = cfg.datasets_cfg.cc_combine.vis_processor.train | |
| vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) | |
| text_processor_cfg = cfg.datasets_cfg.laion.text_processor.train | |
| text_processor = registry.get_processor_class(text_processor_cfg.name).from_config(text_processor_cfg) | |
| chat = Chat(model, vis_processor, args.device) | |
| print('Initialization Finished') | |
| texts = {} | |
| negative_list = [] | |
| for i in tqdm(range(args.begin_id, args.end_id)): | |
| image = Image.open(os.path.join(args.save_dir, 'image/{}.jpg'.format(i))).convert('RGB') | |
| fix_prompt = \ | |
| "Fix the error in the given paragraph. " \ | |
| "Remove any repeating sentences, meanless characters, not English sentences, and so on." \ | |
| "Remove unnecessary repetition." \ | |
| "Rewrite any incomplete sentences." \ | |
| "Return directly the results WITHOUT explanation." \ | |
| "Return directly the input paragraph if it is already correct WITHOUT explanation." | |
| answers = [] | |
| answer_tokens = 0 | |
| chat.reset() | |
| chat.upload_img(image) | |
| chat.ask("Describe this image in detail. Give as many details as possible. Say everything you see.") | |
| answer, tokens = chat.answer() | |
| answers.append(answer) | |
| answer_tokens += tokens | |
| if len(answer_tokens) < 80: | |
| chat.ask("Continue") | |
| answer, answer_token = chat.answer() | |
| answers.append(answer) | |
| answer_tokens += tokens | |
| answer = ' '.join(answers) | |
| chatgpt_message = prepare_chatgpt_message(fix_prompt, answer) | |
| improved_answer, num_token = call_chatgpt(chatgpt_message) | |
| if 'already correct' in improved_answer: | |
| if 'repetition' in improved_answer: | |
| continue | |
| improved_answer = answer | |
| if 'incomplete' in improved_answer or len(improved_answer) < 50: | |
| negative_list.append(improved_answer) | |
| else: | |
| texts[i] = improved_answer | |
| with open(os.path.join(args.save_dir, "cap_{}_{}.json".format(args.begin_id, args.end_id)), "w") as outfile: | |
| # write the dictionary to the file in JSON format | |
| json.dump(texts, outfile) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Create Alignment") | |
| parser.add_argument("--cfg-path", default='train_config/minigpt4_stage2_align.yaml') | |
| parser.add_argument("--save-dir", default="/ibex/project/c2133/blip_dataset/image_alignment") | |
| parser.add_argument("--begin-id", type=int) | |
| parser.add_argument("--end-id", type=int) | |
| parser.add_argument("--device", type=int) | |
| parser.add_argument( | |
| "--options", | |
| nargs="+", | |
| help="override some settings in the used config, the key-value pair " | |
| "in xxx=yyy format will be merged into config file (deprecate), " | |
| "change to --cfg-options instead.", | |
| ) | |
| args = parser.parse_args() | |
| print("begin_id: ", args.begin_id) | |
| print("end_id: ", args.end_id) | |
| print("device:", args.device) | |
| main(args) | |