Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						016285f
	
1
								Parent(s):
							
							4be744b
								
Upload 13 files
Browse files- app.py +169 -57
- create_cache.py +83 -0
- model/__init__.py +6 -0
- model/model/caption_model.py +89 -0
- model/model/question_asking_model.py +83 -0
- model/model/question_generator.py +194 -0
- model/model/question_model_base.py +85 -0
- model/model/response_model.py +190 -0
- model/run_question_asking_model.py +186 -0
- model/utils.py +54 -0
- open_db.py +1 -6
- pilot-study.csv +161 -0
- response_db.py +1 -2
    	
        app.py
    CHANGED
    
    | @@ -1,80 +1,192 @@ | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
             
            from response_db import StResponseDb
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 3 | 
             
            db = StResponseDb()
         | 
| 4 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
            -
            def get_next_question(history):
         | 
| 7 | 
            -
                if len(history)==2:
         | 
| 8 | 
            -
                    question = "What is the man doing?"
         | 
| 9 | 
            -
                elif len(history)==4:
         | 
| 10 | 
            -
                    question = "How many apples are there?"
         | 
| 11 | 
            -
                else:
         | 
| 12 | 
            -
                    question = "What color is the cat?"
         | 
| 13 | 
            -
                return question
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            def ask_a_question(input, taskid, history=[]):
         | 
| 16 | 
            -
                history.append(input)
         | 
| 17 | 
            -
                db.add(int(a.value), taskid, len(history)//2-1, history[-2], history[-1])
         | 
| 18 | 
            -
                history.append(get_next_question(history))
         | 
| 19 | 
            -
                
         | 
| 20 | 
             
                # write some HTML
         | 
| 21 | 
             
                html = "<div class='chatbot'>"
         | 
| 22 | 
            -
                for m, msg in enumerate(history):
         | 
|  | |
| 23 | 
             
                    cls = "bot" if m%2 == 0 else "user"
         | 
| 24 | 
             
                    html += "<div class='msg {}'> {}</div>".format(cls, msg)
         | 
| 25 | 
             
                html += "</div>"
         | 
| 26 | 
            -
                return html, history
         | 
| 27 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 28 |  | 
| 29 | 
            -
            css = """
         | 
| 30 | 
            -
            .chatbox {display:flex;flex-direction:column}
         | 
| 31 | 
            -
            .msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
         | 
| 32 | 
            -
            .msg.user {background-color:cornflowerblue;color:white}
         | 
| 33 | 
            -
            .msg.bot {background-color:lightgray;align-self:self-end}
         | 
| 34 | 
            -
            .footer {display:none !important}
         | 
| 35 | 
            -
            """
         | 
| 36 |  | 
| 37 | 
             
            def set_images(taskid):
         | 
| 38 | 
            -
                 | 
| 39 | 
            -
                 | 
| 40 | 
            -
                 | 
| 41 | 
            -
             | 
| 42 | 
            -
                 | 
| 43 | 
            -
             | 
| 44 | 
            -
                 | 
| 45 | 
            -
                 | 
| 46 | 
            -
                 | 
| 47 | 
            -
                 | 
| 48 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 49 | 
             
                first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
         | 
| 50 | 
            -
                 | 
| 51 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 52 |  | 
| 53 | 
            -
            with gr.Blocks(css=css) as demo:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 54 |  | 
| 55 | 
             
                with gr.Column() as img_block:
         | 
| 56 | 
             
                    with gr.Row():
         | 
| 57 | 
            -
                        img1 = gr.Image()
         | 
| 58 | 
            -
                        img2 = gr.Image()
         | 
| 59 | 
            -
                        img3 = gr.Image()
         | 
| 60 | 
            -
                        img4 = gr.Image()
         | 
| 61 | 
            -
                        img5 = gr.Image()
         | 
| 62 | 
             
                    with gr.Row():
         | 
| 63 | 
            -
                        img6 = gr.Image()
         | 
| 64 | 
            -
                        img7 = gr.Image()
         | 
| 65 | 
            -
                        img8 = gr.Image()
         | 
| 66 | 
            -
                        img9 = gr.Image()
         | 
| 67 | 
            -
                        img10 = gr.Image()
         | 
| 68 | 
             
                conversation = gr.HTML()
         | 
| 69 | 
            -
                 | 
| 70 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 71 | 
             
                with gr.Column():
         | 
| 72 | 
             
                    with gr.Row():
         | 
| 73 | 
            -
                         | 
| 74 | 
            -
                         | 
| 75 | 
            -
             | 
| 76 | 
            -
                 | 
| 77 | 
            -
                submit =  | 
| 78 | 
            -
                 | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
|  | |
|  | 
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
             
            from response_db import StResponseDb
         | 
| 3 | 
            +
            from create_cache import Game_Cache
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            import pandas as pd
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import pickle
         | 
| 9 | 
            +
            import uuid
         | 
| 10 | 
            +
             | 
| 11 | 
             
            db = StResponseDb()
         | 
| 12 | 
            +
            css = """
         | 
| 13 | 
            +
            .chatbot {display:flex;flex-direction:column}
         | 
| 14 | 
            +
            .msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
         | 
| 15 | 
            +
            .msg.user {background-color:cornflowerblue;color:white;align-self:self-end}
         | 
| 16 | 
            +
            .msg.bot {background-color:lightgray}
         | 
| 17 | 
            +
            .na_button {background-color:red;color:red}
         | 
| 18 | 
            +
            """
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from model.run_question_asking_model import return_modules, return_modules_yn
         | 
| 21 | 
            +
            question_model, response_model_simul, _, caption_model = return_modules()
         | 
| 22 | 
            +
            question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn()
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            class Game_Session:
         | 
| 25 | 
            +
                def __init__(self, taskid, yn, hard_setting):
         | 
| 26 | 
            +
                    self.yn = yn
         | 
| 27 | 
            +
                    self.hard_setting = hard_setting
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    global question_model, response_model_simul, caption_model
         | 
| 30 | 
            +
                    global question_model_yn, response_model_simul_yn, caption_model_yn
         | 
| 31 | 
            +
                    self.question_model = question_model
         | 
| 32 | 
            +
                    self.response_model_simul = response_model_simul
         | 
| 33 | 
            +
                    self.caption_model = caption_model
         | 
| 34 | 
            +
                    self.question_model_yn = question_model_yn
         | 
| 35 | 
            +
                    self.response_model_simul_yn = response_model_simul_yn
         | 
| 36 | 
            +
                    self.caption_model_yn = caption_model_yn
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
         | 
| 39 | 
            +
                    self.image_files, self.image_np, self.p_y_x, self.p_r_qy, self.p_y_xqr = None, None, None, None, None
         | 
| 40 | 
            +
                    self.captions, self.questions, self.target_questions = None, None, None
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.history = []
         | 
| 43 | 
            +
                    self.game_id = str(uuid.uuid4())
         | 
| 44 | 
            +
                    self.set_curr_models()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def set_curr_models(self):
         | 
| 47 | 
            +
                    if self.yn:
         | 
| 48 | 
            +
                        self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn
         | 
| 49 | 
            +
                    else:
         | 
| 50 | 
            +
                        self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model, self.caption_model, self.response_model_simul
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def get_next_question(self):
         | 
| 53 | 
            +
                    return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def ask_a_question(input, taskid, gs):
         | 
| 57 | 
            +
                gs.history.append(input)
         | 
| 58 | 
            +
                gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions)
         | 
| 59 | 
            +
                gs.p_y_xqr = gs.p_y_x*gs.p_r_qy
         | 
| 60 | 
            +
                gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr)        
         | 
| 61 | 
            +
                gs.p_y_x = gs.p_y_xqr
         | 
| 62 | 
            +
                gs.questions.remove(gs.history[-2])
         | 
| 63 | 
            +
                db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1])
         | 
| 64 | 
            +
                gs.history.append(gs.get_next_question())
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                top_prob = torch.max(gs.p_y_x).item()
         | 
| 67 | 
            +
                top_pred = torch.argmax(gs.p_y_x).item()
         | 
| 68 | 
            +
                if top_prob > 0.8: 
         | 
| 69 | 
            +
                    gs.history = gs.history[:-1]
         | 
| 70 | 
            +
                    db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "")
         | 
| 71 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 72 | 
             
                # write some HTML
         | 
| 73 | 
             
                html = "<div class='chatbot'>"
         | 
| 74 | 
            +
                for m, msg in enumerate(gs.history):
         | 
| 75 | 
            +
                    if msg=="nothing": msg="n/a"
         | 
| 76 | 
             
                    cls = "bot" if m%2 == 0 else "user"
         | 
| 77 | 
             
                    html += "<div class='msg {}'> {}</div>".format(cls, msg)
         | 
| 78 | 
             
                html += "</div>"
         | 
|  | |
| 79 |  | 
| 80 | 
            +
                ### Game finished:
         | 
| 81 | 
            +
                if top_prob > 0.8:
         | 
| 82 | 
            +
                    html += f"<p>The model identified <b>Image {top_pred+1}</b> as the image. Please select a new task ID to continue.</p>"
         | 
| 83 | 
            +
                    return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False)
         | 
| 84 | 
            +
                else:
         | 
| 85 | 
            +
                    if not gs.yn:
         | 
| 86 | 
            +
                        return html, gs, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False)
         | 
| 87 | 
            +
                    else:
         | 
| 88 | 
            +
                        return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True)
         | 
| 89 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 90 |  | 
| 91 | 
             
            def set_images(taskid):
         | 
| 92 | 
            +
                pilot_study = pd.read_csv("pilot-study.csv")
         | 
| 93 | 
            +
                taskid_original = taskid
         | 
| 94 | 
            +
                taskid = pilot_study['mscoco-id'].tolist()[int(taskid)]
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                with open(f'cache/{int(taskid)}.p', 'rb') as fp:
         | 
| 97 | 
            +
                    game_cache = pickle.load(fp)
         | 
| 98 | 
            +
                gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting)
         | 
| 99 | 
            +
                id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}"
         | 
| 100 | 
            +
                id2 = f"./mscoco-images/val2014/{game_cache.image_files[1]}"
         | 
| 101 | 
            +
                id3 = f"./mscoco-images/val2014/{game_cache.image_files[2]}"
         | 
| 102 | 
            +
                id4 = f"./mscoco-images/val2014/{game_cache.image_files[3]}"
         | 
| 103 | 
            +
                id5 = f"./mscoco-images/val2014/{game_cache.image_files[4]}"
         | 
| 104 | 
            +
                id6 = f"./mscoco-images/val2014/{game_cache.image_files[5]}"
         | 
| 105 | 
            +
                id7 = f"./mscoco-images/val2014/{game_cache.image_files[6]}"
         | 
| 106 | 
            +
                id8 = f"./mscoco-images/val2014/{game_cache.image_files[7]}"
         | 
| 107 | 
            +
                id9 = f"./mscoco-images/val2014/{game_cache.image_files[8]}"
         | 
| 108 | 
            +
                id10 = f"./mscoco-images/val2014/{game_cache.image_files[9]}"    
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                gs.image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
         | 
| 111 | 
            +
                gs.image_files = [x[15:] for x in gs.image_files]
         | 
| 112 | 
            +
                gs.images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in gs.image_files]
         | 
| 113 | 
            +
                gs.images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in gs.images_np]
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                gs.p_y_x = (torch.ones(10)/10).to(gs.curr_question_model.device)
         | 
| 116 | 
            +
                gs.captions = gs.curr_caption_model.get_captions(gs.image_files)
         | 
| 117 | 
            +
                gs.questions, gs.target_questions = gs.curr_question_model.get_questions(gs.image_files, gs.captions, 0)
         | 
| 118 | 
            +
                gs.curr_question_model.reset_question_bank()
         | 
| 119 | 
            +
                gs.curr_question_model.question_bank = game_cache.question_dict
         | 
| 120 | 
            +
                first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul)
         | 
| 121 | 
             
                first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
         | 
| 122 | 
            +
                gs.history.append(first_question)
         | 
| 123 | 
            +
                html = f"<p>Current Task ID: <b>{int(taskid_original)}</b></p>"
         | 
| 124 | 
            +
                if not gs.yn:
         | 
| 125 | 
            +
                    return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False)
         | 
| 126 | 
            +
                else:
         | 
| 127 | 
            +
                    return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True)
         | 
| 128 | 
            +
             | 
| 129 |  | 
| 130 | 
            +
            with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
         | 
| 131 | 
            +
                gr.HTML("<h1>Image Q&A Guessing Game</h1>\
         | 
| 132 | 
            +
                <p style='font-size:120%;'>\
         | 
| 133 | 
            +
                Imagine you are playing 20-questions with an AI model.<br>\
         | 
| 134 | 
            +
                The AI model plays the role of the question asker. You play the role of the responder. <br>\
         | 
| 135 | 
            +
                There are 10 images. <b>Your image is Image 1</b>. The other images are distraction images.\
         | 
| 136 | 
            +
                The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.<br>\
         | 
| 137 | 
            +
                <span style='color: #0000ff'>The goal of the model is to accurately guess the correct image (i.e. <b><span style='color: #0000ff'>Image 1</span></b>) in as few turns as possible.<br>\
         | 
| 138 | 
            +
                Your goal is to help the model guess the image by answering as clearly and accurately as possible.</span><br><br>\
         | 
| 139 | 
            +
                <b>Guidelines:</b><br>\
         | 
| 140 | 
            +
                <ol style='font-size:120%;'>\
         | 
| 141 | 
            +
                    <li>It is best to keep your answers short (a single word or a short phrase). No need to answer in full sentences.</li>\
         | 
| 142 | 
            +
                    <li>If you feel that the question cannot be answered or does not apply to Image 1, please select N/A.</li>\
         | 
| 143 | 
            +
                </ol> \
         | 
| 144 | 
            +
                <br>\
         | 
| 145 | 
            +
                (Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)<br></p>\
         | 
| 146 | 
            +
                <br>\
         | 
| 147 | 
            +
                <h2>Please enter a TaskID to start</h2>")
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                with gr.Column():
         | 
| 150 | 
            +
                    with gr.Row():
         | 
| 151 | 
            +
                        taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", value=0)
         | 
| 152 | 
            +
                        start_button = gr.Button("Enter")
         | 
| 153 | 
            +
                    with gr.Row():
         | 
| 154 | 
            +
                        task_text = gr.HTML()
         | 
| 155 |  | 
| 156 | 
             
                with gr.Column() as img_block:
         | 
| 157 | 
             
                    with gr.Row():
         | 
| 158 | 
            +
                        img1 = gr.Image(label="Image 1", show_label=True)
         | 
| 159 | 
            +
                        img2 = gr.Image(label="Image 2", show_label=True)
         | 
| 160 | 
            +
                        img3 = gr.Image(label="Image 3", show_label=True)
         | 
| 161 | 
            +
                        img4 = gr.Image(label="Image 4", show_label=True)
         | 
| 162 | 
            +
                        img5 = gr.Image(label="Image 5", show_label=True)
         | 
| 163 | 
             
                    with gr.Row():
         | 
| 164 | 
            +
                        img6 = gr.Image(label="Image 6", show_label=True)
         | 
| 165 | 
            +
                        img7 = gr.Image(label="Image 7", show_label=True)
         | 
| 166 | 
            +
                        img8 = gr.Image(label="Image 8", show_label=True)
         | 
| 167 | 
            +
                        img9 = gr.Image(label="Image 9", show_label=True)
         | 
| 168 | 
            +
                        img10 = gr.Image(label="Image 10", show_label=True)
         | 
| 169 | 
             
                conversation = gr.HTML()
         | 
| 170 | 
            +
                game_session_state = gr.State()
         | 
| 171 |  | 
| 172 | 
            +
                answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False)
         | 
| 173 | 
            +
                null_answer = gr.Textbox("nothing", visible=False)
         | 
| 174 | 
            +
                yes_answer = gr.Textbox("yes", visible=False)
         | 
| 175 | 
            +
                no_answer = gr.Textbox("no", visible=False)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                with gr.Column():
         | 
| 178 | 
            +
                    with gr.Row():
         | 
| 179 | 
            +
                        yes_box = gr.Button("Yes", visible=False)
         | 
| 180 | 
            +
                        no_box = gr.Button("No", visible=False)
         | 
| 181 | 
             
                with gr.Column():
         | 
| 182 | 
             
                    with gr.Row():
         | 
| 183 | 
            +
                        na_box = gr.Button("N/A", visible=False, elem_classes="na_button")
         | 
| 184 | 
            +
                        submit = gr.Button("Submit", visible=False)
         | 
| 185 | 
            +
                ### Button click events
         | 
| 186 | 
            +
                start_button.click(fn=set_images, inputs=taskid, outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, task_text, answer, na_box, submit, taskid, start_button, yes_box, no_box])
         | 
| 187 | 
            +
                submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
         | 
| 188 | 
            +
                na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
         | 
| 189 | 
            +
                yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
         | 
| 190 | 
            +
                no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
         | 
| 191 | 
            +
             | 
| 192 | 
            +
            demo.launch()
         | 
    	
        create_cache.py
    ADDED
    
    | @@ -0,0 +1,83 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from PIL import Image
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import pickle
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            class Game_Cache:
         | 
| 8 | 
            +
                def __init__(self, question_dict, image_files, yn, hard_setting):
         | 
| 9 | 
            +
                    self.question_dict = question_dict
         | 
| 10 | 
            +
                    self.image_files = image_files
         | 
| 11 | 
            +
                    self.yn = yn
         | 
| 12 | 
            +
                    self.hard_setting = hard_setting
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            image_list = []
         | 
| 15 | 
            +
            with open('./mscoco/mscoco_images.txt', 'r') as f:
         | 
| 16 | 
            +
                for line in f.readlines():
         | 
| 17 | 
            +
                    image_list.append(line.strip())
         | 
| 18 | 
            +
            image_list_hard = []
         | 
| 19 | 
            +
            with open('./mscoco/mscoco_images_attribute_n=1.txt', 'r') as f:
         | 
| 20 | 
            +
                for line in f.readlines():
         | 
| 21 | 
            +
                    image_list_hard.append(line.strip())
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            yn_indices = list(range(40,80))+list(range(120,160))
         | 
| 24 | 
            +
            hard_setting_indices = list(range(80,160))
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            from model.run_question_asking_model import return_modules, return_modules_yn
         | 
| 28 | 
            +
            global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions 
         | 
| 29 | 
            +
            global question_model, response_model_simul, caption_model
         | 
| 30 | 
            +
            question_model, response_model_simul, _, caption_model = return_modules()
         | 
| 31 | 
            +
            global question_model_yn, response_model_simul_yn, caption_model_yn
         | 
| 32 | 
            +
            question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn()
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def create_cache(taskid):
         | 
| 35 | 
            +
                original_taskid = taskid
         | 
| 36 | 
            +
                global question_model, response_model_simul, caption_model
         | 
| 37 | 
            +
                global question_model_yn, response_model_simul_yn, caption_model_yn
         | 
| 38 | 
            +
                if taskid in yn_indices: 
         | 
| 39 | 
            +
                    yn = True
         | 
| 40 | 
            +
                    curr_question_model, curr_response_model_simul, curr_caption_model = question_model, response_model_simul, caption_model
         | 
| 41 | 
            +
                    taskid-=40
         | 
| 42 | 
            +
                else: 
         | 
| 43 | 
            +
                    yn = False
         | 
| 44 | 
            +
                    curr_question_model, curr_response_model_simul, curr_caption_model = question_model_yn, response_model_simul_yn, caption_model_yn
         | 
| 45 | 
            +
                if taskid in hard_setting_indices: 
         | 
| 46 | 
            +
                    hard_setting = True
         | 
| 47 | 
            +
                    image_list_curr = image_list_hard
         | 
| 48 | 
            +
                    taskid -= 80
         | 
| 49 | 
            +
                else: 
         | 
| 50 | 
            +
                    hard_setting = False    
         | 
| 51 | 
            +
                    image_list_curr = image_list
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
                id1 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+0]}"
         | 
| 54 | 
            +
                id2 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+1]}"
         | 
| 55 | 
            +
                id3 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+2]}"
         | 
| 56 | 
            +
                id4 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+3]}"
         | 
| 57 | 
            +
                id5 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+4]}"
         | 
| 58 | 
            +
                id6 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+5]}"
         | 
| 59 | 
            +
                id7 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+6]}"
         | 
| 60 | 
            +
                id8 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+7]}"
         | 
| 61 | 
            +
                id9 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+8]}"
         | 
| 62 | 
            +
                id10 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+9]}"
         | 
| 63 | 
            +
                image_names = []
         | 
| 64 | 
            +
                for i in range(10):
         | 
| 65 | 
            +
                    image_names.append(image_list_curr[int(taskid)*10+i])
         | 
| 66 | 
            +
                image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
         | 
| 67 | 
            +
                image_files = [x[15:] for x in image_files]
         | 
| 68 | 
            +
                images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files]
         | 
| 69 | 
            +
                images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np]
         | 
| 70 | 
            +
                p_y_x = (torch.ones(10)/10).to(curr_question_model.device)
         | 
| 71 | 
            +
                captions = curr_caption_model.get_captions(image_files)
         | 
| 72 | 
            +
                questions, target_questions = curr_question_model.get_questions(image_files, captions, 0)
         | 
| 73 | 
            +
                curr_question_model.reset_question_bank()
         | 
| 74 | 
            +
                first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting)
         | 
| 77 | 
            +
                with open(f'./cache{int(taskid)}.p', 'wb') as fp:
         | 
| 78 | 
            +
                    pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            if __name__=="__main__":
         | 
| 81 | 
            +
                for i in range(160):
         | 
| 82 | 
            +
                    create_cache(i)
         | 
| 83 | 
            +
             | 
    	
        model/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import model.run_question_asking_model
         | 
| 2 | 
            +
            import model.model.caption_model
         | 
| 3 | 
            +
            import model.model.question_asking_model
         | 
| 4 | 
            +
            import model.model.question_generator
         | 
| 5 | 
            +
            import model.model.question_model_base
         | 
| 6 | 
            +
            import model.model.response_model
         | 
    	
        model/model/caption_model.py
    ADDED
    
    | @@ -0,0 +1,89 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import pandas as pd
         | 
| 2 | 
            +
            from pycocotools.coco import COCO
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            def get_caption_model(args, question_model):
         | 
| 5 | 
            +
                if args.caption_strategy=="simple":
         | 
| 6 | 
            +
                    return CaptionModelSimple(question_model)
         | 
| 7 | 
            +
                elif args.caption_strategy=="granular":
         | 
| 8 | 
            +
                    return CaptionModelGranular()
         | 
| 9 | 
            +
                elif args.caption_strategy=="gtruth":
         | 
| 10 | 
            +
                    return CaptionModelCOCO()
         | 
| 11 | 
            +
                else:
         | 
| 12 | 
            +
                    raise ValueError(f"{args.caption_strategy} is not a valid caption strategy.")
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class CaptionModel():
         | 
| 16 | 
            +
                # Class for the other CaptionModels to inherit from
         | 
| 17 | 
            +
                def __init__(self):
         | 
| 18 | 
            +
                    pass
         | 
| 19 | 
            +
                
         | 
| 20 | 
            +
                def get_captions(self, images, **kwargs):
         | 
| 21 | 
            +
                    raise NotImplemented
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            class CaptionModelCOCO():
         | 
| 24 | 
            +
                # Ground truth annotations from COCO dataset
         | 
| 25 | 
            +
                def __init__(self):
         | 
| 26 | 
            +
                    dataDir='./mscoco'
         | 
| 27 | 
            +
                    val_file = '{}/annotations/captions_val2014.json'.format(dataDir)
         | 
| 28 | 
            +
                    self.coco_caps_val = COCO(val_file)
         | 
| 29 | 
            +
                    val_file = '{}/annotations/instances_val2014.json'.format(dataDir)
         | 
| 30 | 
            +
                    self.coco_anns_val = COCO(val_file)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def get_captions(self, images, return_all=False):
         | 
| 33 | 
            +
                    captions = []
         | 
| 34 | 
            +
                    for i, image in enumerate(images):
         | 
| 35 | 
            +
                        image_id = int(image.split('_')[-1].split('.')[0].lstrip("0"))
         | 
| 36 | 
            +
                        annIds = self.coco_caps_val.getAnnIds(imgIds=image_id)
         | 
| 37 | 
            +
                        anns_val = self.coco_caps_val.loadAnns(annIds)
         | 
| 38 | 
            +
                        # annIds = self.coco_caps_train.getAnnIds(imgIds=image_id)
         | 
| 39 | 
            +
                        # anns_train = self.coco_caps_train.loadAnns(annIds)
         | 
| 40 | 
            +
                        # anns = anns_val + anns_train
         | 
| 41 | 
            +
                        anns = anns_val
         | 
| 42 | 
            +
                        anns = [d['caption'] for d in anns]
         | 
| 43 | 
            +
                        if return_all: captions.append(anns)
         | 
| 44 | 
            +
                        else: captions.append(anns[0])
         | 
| 45 | 
            +
                    return captions
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def get_subjects(self, images):
         | 
| 48 | 
            +
                    subjects = []
         | 
| 49 | 
            +
                    for i, image in enumerate(images):
         | 
| 50 | 
            +
                        image_id = int(image.split('_')[-1].split('.')[0].lstrip("0"))
         | 
| 51 | 
            +
                        annIds = self.coco_anns_val.getAnnIds(imgIds=image_id)
         | 
| 52 | 
            +
                        anns_val = self.coco_anns_val.loadAnns(annIds)
         | 
| 53 | 
            +
                        cats_val = list(set([d['category_id'] for d in anns_val]))
         | 
| 54 | 
            +
                        annIds = self.coco_caps_train.getAnnIds(imgIds=image_id)
         | 
| 55 | 
            +
                        anns_train = self.coco_caps_train.loadAnns(annIds)
         | 
| 56 | 
            +
                        cats_train = list(set([d['category_id'] for d in anns_train]))
         | 
| 57 | 
            +
                        cats = self.coco_anns_val.loadCats(ids=cats_val+cats_train)
         | 
| 58 | 
            +
                        cats1, cats2 = [d['supercategory'] for d in cats], [d['name'] for d in cats]
         | 
| 59 | 
            +
                        cats = list(set(cats1+cats2))
         | 
| 60 | 
            +
                        subjects.append(cats)
         | 
| 61 | 
            +
                    return subjects
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            class CaptionModelSimple():
         | 
| 65 | 
            +
                def __init__(self, qa_model):
         | 
| 66 | 
            +
                    self.qa_model = qa_model            
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def get_captions(self, images):
         | 
| 69 | 
            +
                    captions = []
         | 
| 70 | 
            +
                    for i, image in enumerate(images):
         | 
| 71 | 
            +
                        caption = self.qa_model.generate_description(image, images)
         | 
| 72 | 
            +
                        captions.append(caption)
         | 
| 73 | 
            +
                    return captions
         | 
| 74 | 
            +
                    
         | 
| 75 | 
            +
            class CaptionModelGranular():
         | 
| 76 | 
            +
                def __init__(self):
         | 
| 77 | 
            +
                    df_train = pd.read_json("captions/coco_train_captions.jsonl", lines=True)
         | 
| 78 | 
            +
                    df_val = pd.read_json("captions/coco_val_captions.jsonl", lines=True)
         | 
| 79 | 
            +
                    self.caption_dict = {}
         | 
| 80 | 
            +
                    for i in range(len(df_train)):
         | 
| 81 | 
            +
                        self.caption_dict[str(df_train.image_id[i])] = df_train.caption[i]
         | 
| 82 | 
            +
                    for i in range(len(df_val)):
         | 
| 83 | 
            +
                        self.caption_dict[str(df_val.image_id[i])] = df_val.caption[i]
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                def get_captions(self, images):
         | 
| 86 | 
            +
                    captions = []
         | 
| 87 | 
            +
                    for i, image in enumerate(images):
         | 
| 88 | 
            +
                        captions.append(self.caption_dict[image.split('.')[0].split('_')[-1].lstrip('0')])
         | 
| 89 | 
            +
                    return captions
         | 
    	
        model/model/question_asking_model.py
    ADDED
    
    | @@ -0,0 +1,83 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            import pandas as pd
         | 
| 3 | 
            +
            from model.model.question_model_base import QuestionAskingModel
         | 
| 4 | 
            +
            import openai
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            def get_question_model(args):
         | 
| 7 | 
            +
                if args.question_strategy=="rule":
         | 
| 8 | 
            +
                    return QuestionAskingModelSimple(args)
         | 
| 9 | 
            +
                elif args.question_strategy=="gpt3":
         | 
| 10 | 
            +
                    return QuestionAskingModelGPT3(args)
         | 
| 11 | 
            +
                else:
         | 
| 12 | 
            +
                    raise ValueError(f"{args.question_strategy} is not a valid question strategy.")
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class QuestionAskingModelSimple(QuestionAskingModel):
         | 
| 16 | 
            +
                def __init__(self, args):
         | 
| 17 | 
            +
                    super(QuestionAskingModelSimple, self).__init__(args)
         | 
| 18 | 
            +
                
         | 
| 19 | 
            +
                def get_questions(self, images, captions, target_idx=0):
         | 
| 20 | 
            +
                    questions = []
         | 
| 21 | 
            +
                    for i, (image, caption) in enumerate(zip(images, captions)):
         | 
| 22 | 
            +
                        image_questions = self.question_generator.generate_is_there_question(caption)
         | 
| 23 | 
            +
                        if i == target_idx: target_questions = image_questions
         | 
| 24 | 
            +
                        questions += image_questions
         | 
| 25 | 
            +
                    questions = list(set(questions))
         | 
| 26 | 
            +
                    # random.shuffle(questions)
         | 
| 27 | 
            +
                    return questions, target_questions
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class QuestionAskingModelGPT3(QuestionAskingModel):
         | 
| 31 | 
            +
                def __init__(self, args):
         | 
| 32 | 
            +
                    super(QuestionAskingModelGPT3, self).__init__(args)
         | 
| 33 | 
            +
                    self.gpt3_path = f"data/{args.gpt3_save_name}.csv"
         | 
| 34 | 
            +
                    try: self.gpt3_captions = pd.read_csv(self.gpt3_path)      # cache locally to save GPT3 compute
         | 
| 35 | 
            +
                    except: self.gpt3_captions = pd.DataFrame({"caption":[], "question":[]})
         | 
| 36 | 
            +
                
         | 
| 37 | 
            +
                def generate_gpt3_questions(self, caption):
         | 
| 38 | 
            +
                    # c1="Two people sitting on grass. The man in a blue shirt is playing a guitar and is on the left. The woman on the right is eating a sandwich."
         | 
| 39 | 
            +
                    # q1="What are the two people doing? How many people are there? What color is the man's shirt? Where is the man? Where is the woman? What is the man doing? What is the woman doing? Who is playing the guitar? Who is eating a sandwich?"
         | 
| 40 | 
            +
                    # c2="There is a table in the middle of the room. On the left there is a bowl of red apples. To the right of the bowl, there is a glass of juice, as well as a bowl of salad. Beside the table there is a bookshelf with ten books of various colors."
         | 
| 41 | 
            +
                    # q2="What is beside the table? What is on the left of the table? What color are the apples? How many bowls are there? What is inside the bowl? What is inside the glass? How many books are there? What color are the books?"
         | 
| 42 | 
            +
                    c1="A living room with a couch, coffee table and two large windows with white curtains."
         | 
| 43 | 
            +
                    q1="What color is the couch? How many windows are there? How many tables are there? What color is the table? What color are the curtains? What is next to the table? What is next to the couch?"
         | 
| 44 | 
            +
                    c2="A large, shiny, stainless, side by side refrigerator in a kitchen."
         | 
| 45 | 
            +
                    q2="Where is the refrigerator? What color is the refrigerator?"
         | 
| 46 | 
            +
                    c3="A stop sign with a skeleton painted on it, next to a car."
         | 
| 47 | 
            +
                    q3="What color is the sign? What color is the car? What is next to the sign? What is next to the car? What is on the sign? Where is the car?"
         | 
| 48 | 
            +
                    c4="A man brushing his teeth with a toothbrush"
         | 
| 49 | 
            +
                    q4="What is the man doing? Where is the man? What color is the toothbrush? How many people are there?"
         | 
| 50 | 
            +
                    prompt=f"Generate questions for the following caption:\nCaption: {c1}\nQuestions: {q1}\n"
         | 
| 51 | 
            +
                    prompt+=f"Generate questions for the following caption:\nCaption: {c2}\nQuestions: {q2}\n"
         | 
| 52 | 
            +
                    prompt+=f"Generate questions for the following caption:\nCaption: {c3}\nQuestions: {q3}\n"
         | 
| 53 | 
            +
                    prompt+=f"Generate questions for the following caption:\nCaption: {c4}\nQuestions: {q4}\n"
         | 
| 54 | 
            +
                    prompt+=f"Generate questions for the following caption:\nCaption: {caption}\nQuestions:"
         | 
| 55 | 
            +
                    response = openai.Completion.create(
         | 
| 56 | 
            +
                        model="text-davinci-003",
         | 
| 57 | 
            +
                        prompt=prompt,
         | 
| 58 | 
            +
                        temperature=0,
         | 
| 59 | 
            +
                        max_tokens=1024,
         | 
| 60 | 
            +
                        top_p=1,
         | 
| 61 | 
            +
                        frequency_penalty=0,
         | 
| 62 | 
            +
                        presence_penalty=0
         | 
| 63 | 
            +
                    )
         | 
| 64 | 
            +
                    questions = response["choices"][0]["text"]
         | 
| 65 | 
            +
                    questions = [i.strip() for i in questions.split('?') if len(i.strip())>1]
         | 
| 66 | 
            +
                    questions = [i+"?" for i in questions]
         | 
| 67 | 
            +
                    return questions
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def get_questions(self, images, captions, target_idx=0):
         | 
| 70 | 
            +
                    questions = []
         | 
| 71 | 
            +
                    for i, (image, caption) in enumerate(zip(images, captions)):
         | 
| 72 | 
            +
                        if caption in self.gpt3_captions.caption.tolist():
         | 
| 73 | 
            +
                            image_questions = self.gpt3_captions[self.gpt3_captions.caption==caption].question.tolist()
         | 
| 74 | 
            +
                        else:
         | 
| 75 | 
            +
                            image_questions = self.generate_gpt3_questions(caption)
         | 
| 76 | 
            +
                            image_df = pd.DataFrame({'caption':[caption for _ in image_questions], 'question':image_questions})
         | 
| 77 | 
            +
                            self.gpt3_captions = pd.concat([self.gpt3_captions, image_df])
         | 
| 78 | 
            +
                            self.gpt3_captions.to_csv(self.gpt3_path, index=False)
         | 
| 79 | 
            +
                        if i == target_idx: target_questions = image_questions
         | 
| 80 | 
            +
                        questions += image_questions
         | 
| 81 | 
            +
                    questions = list(set(questions))
         | 
| 82 | 
            +
                    # random.shuffle(questions)
         | 
| 83 | 
            +
                    return questions, target_questions
         | 
    	
        model/model/question_generator.py
    ADDED
    
    | @@ -0,0 +1,194 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import spacy
         | 
| 2 | 
            +
            import nltk.tree
         | 
| 3 | 
            +
            import collections
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import benepar
         | 
| 6 | 
            +
            ### Run this the first time if benepar_en3 is not yet downloaded
         | 
| 7 | 
            +
            # benepar.download('benepar_en3')
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            def find_word(tree, kind, parents=None):
         | 
| 10 | 
            +
                if parents is None:
         | 
| 11 | 
            +
                    parents = []
         | 
| 12 | 
            +
                if not isinstance(tree, nltk.tree.Tree):
         | 
| 13 | 
            +
                    return None, None
         | 
| 14 | 
            +
                if tree.label() == kind:
         | 
| 15 | 
            +
                    return tree[0], parents
         | 
| 16 | 
            +
                parents.append(tree)
         | 
| 17 | 
            +
                for st in tree:
         | 
| 18 | 
            +
                    n, p = find_word(st, kind, parents)
         | 
| 19 | 
            +
                    if n is not None:
         | 
| 20 | 
            +
                        return n, p
         | 
| 21 | 
            +
                parents.pop()
         | 
| 22 | 
            +
                return None, None
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            def find_subtrees(tree, kind, blocking_kinds=()):
         | 
| 25 | 
            +
                result = []
         | 
| 26 | 
            +
                if not isinstance(tree, nltk.tree.Tree):
         | 
| 27 | 
            +
                    return result
         | 
| 28 | 
            +
                if tree.label() == kind:
         | 
| 29 | 
            +
                    result.append(tree)
         | 
| 30 | 
            +
                if tree.label() not in blocking_kinds:
         | 
| 31 | 
            +
                    for st in tree:
         | 
| 32 | 
            +
                        result.extend(find_subtrees(st, kind))
         | 
| 33 | 
            +
                return result
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            def tree_to_str(tree, transform=lambda w: w):
         | 
| 36 | 
            +
                l = []
         | 
| 37 | 
            +
                def list_words(tree):
         | 
| 38 | 
            +
                    if isinstance(tree, str):
         | 
| 39 | 
            +
                        l.append(transform(tree))
         | 
| 40 | 
            +
                    else:
         | 
| 41 | 
            +
                        for st in tree:
         | 
| 42 | 
            +
                            list_words(st)
         | 
| 43 | 
            +
                list_words(tree)
         | 
| 44 | 
            +
                if l[-1] == '.':
         | 
| 45 | 
            +
                    l = l[:-1]
         | 
| 46 | 
            +
                return ' '.join(l)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            def tree_to_nouns(tree, transform=lambda w: w):
         | 
| 49 | 
            +
                l = []
         | 
| 50 | 
            +
                def list_words(tree, noun=False):
         | 
| 51 | 
            +
                    if isinstance(tree, str):
         | 
| 52 | 
            +
                        if noun == True:
         | 
| 53 | 
            +
                            l.append(transform(tree))
         | 
| 54 | 
            +
                    else:
         | 
| 55 | 
            +
                        for st in tree:
         | 
| 56 | 
            +
                            if not isinstance(st, str):
         | 
| 57 | 
            +
                                if st.label() == 'NN':
         | 
| 58 | 
            +
                                    noun = True
         | 
| 59 | 
            +
                                else:
         | 
| 60 | 
            +
                                    noun = False
         | 
| 61 | 
            +
                            list_words(st, noun)
         | 
| 62 | 
            +
                list_words(tree)
         | 
| 63 | 
            +
                if l[-1] == '.':
         | 
| 64 | 
            +
                    l = l[:-1]
         | 
| 65 | 
            +
                return l
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            def make_determinate(w):
         | 
| 68 | 
            +
                if w.lower() in ('a', 'an'):
         | 
| 69 | 
            +
                    return 'the'
         | 
| 70 | 
            +
                return w
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            def make_indeterminate(w):
         | 
| 73 | 
            +
                if w.lower() in ('the', 'his', 'her', 'their', 'its'):
         | 
| 74 | 
            +
                    return 'a'
         | 
| 75 | 
            +
                return w
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            def pluralize(singular, plural, number):
         | 
| 78 | 
            +
                if number <= 1:
         | 
| 79 | 
            +
                    return singular
         | 
| 80 | 
            +
                return plural
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            def count_labels(tree):
         | 
| 83 | 
            +
                counts = collections.defaultdict(int)
         | 
| 84 | 
            +
                def update_counts(node):
         | 
| 85 | 
            +
                    counts[node.label()] += 1
         | 
| 86 | 
            +
                    for child in node:
         | 
| 87 | 
            +
                        if isinstance(child, nltk.tree.Tree):
         | 
| 88 | 
            +
                            update_counts(child)
         | 
| 89 | 
            +
                update_counts(tree)
         | 
| 90 | 
            +
                return counts
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            def get_number(tree):
         | 
| 93 | 
            +
                if not isinstance(tree, nltk.tree.Tree):
         | 
| 94 | 
            +
                    return 0
         | 
| 95 | 
            +
                if tree.label() == 'NN':
         | 
| 96 | 
            +
                    return 1
         | 
| 97 | 
            +
                if tree.label() == 'NNS':
         | 
| 98 | 
            +
                    return 2
         | 
| 99 | 
            +
                first_noun_number = None
         | 
| 100 | 
            +
                n_np_children = 0
         | 
| 101 | 
            +
                for subtree in tree:
         | 
| 102 | 
            +
                    label = subtree.label() if isinstance(subtree, nltk.tree.Tree) else None
         | 
| 103 | 
            +
                    if label == 'NP':
         | 
| 104 | 
            +
                        n_np_children += 1
         | 
| 105 | 
            +
                    if label in ('NP', 'NN', 'NNS') and first_noun_number is None:
         | 
| 106 | 
            +
                        first_noun_number = get_number(subtree)
         | 
| 107 | 
            +
                if tree.label() == 'NP' and n_np_children > 1:
         | 
| 108 | 
            +
                    return 2
         | 
| 109 | 
            +
                return first_noun_number or 0
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            def is_present_continuous(verb):
         | 
| 112 | 
            +
                return verb.endswith('ing')
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            class QuestionGenerator:
         | 
| 115 | 
            +
                def __init__(self):
         | 
| 116 | 
            +
                    self.parser = benepar.Parser("benepar_en3")
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def generate_what_question(self, s):
         | 
| 119 | 
            +
                    tree = self.parser.parse(s)[0]
         | 
| 120 | 
            +
                    questions = []
         | 
| 121 | 
            +
                    try:
         | 
| 122 | 
            +
                        if len(tree) >= 2 and tree[0].label() == 'NP' and tree[1].label() == 'VP':
         | 
| 123 | 
            +
                            np = tree[0]
         | 
| 124 | 
            +
                            verb = None
         | 
| 125 | 
            +
                            vp = tree[1]
         | 
| 126 | 
            +
                            vnp = None
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                            while True:
         | 
| 129 | 
            +
                                verb, verb_parents = find_word(vp, 'VBG')
         | 
| 130 | 
            +
                                if verb is None:
         | 
| 131 | 
            +
                                    break
         | 
| 132 | 
            +
                                if is_present_continuous(verb):
         | 
| 133 | 
            +
                                    if len(verb_parents[-1]) > 1:
         | 
| 134 | 
            +
                                        vnp = verb_parents[-1][1]
         | 
| 135 | 
            +
                                    break
         | 
| 136 | 
            +
                                else:
         | 
| 137 | 
            +
                                    vp = verb_parents[-1][1]
         | 
| 138 | 
            +
                            to_be = pluralize('is', 'are', get_number(np))
         | 
| 139 | 
            +
                            if vnp is not None and vnp.label() == 'NP':
         | 
| 140 | 
            +
                                questions.append(('What {} {} {}?'
         | 
| 141 | 
            +
                                                 .format(to_be, 
         | 
| 142 | 
            +
                                                         tree_to_str(np, make_determinate).lower(), 
         | 
| 143 | 
            +
                                                         verb),
         | 
| 144 | 
            +
                                                 tree_to_nouns(vnp)[-1]))
         | 
| 145 | 
            +
                    except Exception as e:
         | 
| 146 | 
            +
                        print(e)
         | 
| 147 | 
            +
                    return questions
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def generate_is_there_question(self, s):
         | 
| 150 | 
            +
                    tree = self.parser.parse(s)
         | 
| 151 | 
            +
                    questions = []
         | 
| 152 | 
            +
                    nps = find_subtrees(tree, 'NP', ('PP',))
         | 
| 153 | 
            +
                    for np in nps:
         | 
| 154 | 
            +
                        only_child_label = len(np) == 1 and next(iter(np)).label()
         | 
| 155 | 
            +
                        if only_child_label in ('PRP', 'EX'):
         | 
| 156 | 
            +
                            continue
         | 
| 157 | 
            +
                        try:
         | 
| 158 | 
            +
                            to_be = pluralize('Is', 'Are', get_number(np))
         | 
| 159 | 
            +
                            questions.append('{} there {}?'
         | 
| 160 | 
            +
                                             .format(to_be, 
         | 
| 161 | 
            +
                                                     tree_to_str(np, make_indeterminate).lower()))
         | 
| 162 | 
            +
                        except Exception as e:
         | 
| 163 | 
            +
                            print(e)
         | 
| 164 | 
            +
                    return questions
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                def generate_is_there_question_v2(self, s):
         | 
| 167 | 
            +
                    tree = self.parser.parse(s)
         | 
| 168 | 
            +
                    questions = []
         | 
| 169 | 
            +
                    nps = find_subtrees(tree, 'NP', ('PP',))
         | 
| 170 | 
            +
                    for np in nps:
         | 
| 171 | 
            +
                        only_child_label = len(np) == 1 and next(iter(np)).label()
         | 
| 172 | 
            +
                        if only_child_label in ('PRP', 'EX'):
         | 
| 173 | 
            +
                            continue
         | 
| 174 | 
            +
                        try:
         | 
| 175 | 
            +
                            to_be = pluralize('Is', 'Are', get_number(np))
         | 
| 176 | 
            +
                            questions.append('{} there {}?'
         | 
| 177 | 
            +
                                             .format(to_be, 
         | 
| 178 | 
            +
                                                     tree_to_str(np, make_indeterminate).lower()))
         | 
| 179 | 
            +
                        except Exception as e:
         | 
| 180 | 
            +
                            print(e)
         | 
| 181 | 
            +
                    if len(questions)==0:
         | 
| 182 | 
            +
                        nps = find_subtrees(tree, 'NNS', ('PP',))
         | 
| 183 | 
            +
                        for np in nps:
         | 
| 184 | 
            +
                            only_child_label = len(np) == 1
         | 
| 185 | 
            +
                            if only_child_label in ('PRP', 'EX'):
         | 
| 186 | 
            +
                                continue
         | 
| 187 | 
            +
                            try:
         | 
| 188 | 
            +
                                to_be = pluralize('Is', 'Are', get_number(np))
         | 
| 189 | 
            +
                                questions.append('{} there {}?'
         | 
| 190 | 
            +
                                                .format(to_be, 
         | 
| 191 | 
            +
                                                        tree_to_str(np, make_indeterminate).lower()))
         | 
| 192 | 
            +
                            except Exception as e:
         | 
| 193 | 
            +
                                print(e)
         | 
| 194 | 
            +
                    return questions
         | 
    	
        model/model/question_model_base.py
    ADDED
    
    | @@ -0,0 +1,85 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import scipy
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import operator
         | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
            from model.model.question_generator import QuestionGenerator
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class QuestionAskingModel():
         | 
| 10 | 
            +
                def __init__(self, args):
         | 
| 11 | 
            +
                    self.device = args.device
         | 
| 12 | 
            +
                    self.include_what = args.include_what
         | 
| 13 | 
            +
                    self.max_length = 128
         | 
| 14 | 
            +
                    self.eps = 1e-25
         | 
| 15 | 
            +
                    self.multiplier_mode = args.multiplier_mode
         | 
| 16 | 
            +
                    self.num_images = args.num_images
         | 
| 17 | 
            +
                    
         | 
| 18 | 
            +
                    # Initialize question generation model
         | 
| 19 | 
            +
                    class Namespace:
         | 
| 20 | 
            +
                        def __init__(self, **kwargs):
         | 
| 21 | 
            +
                            self.__dict__.update(kwargs)
         | 
| 22 | 
            +
                    
         | 
| 23 | 
            +
                    self.question_generator = QuestionGenerator()
         | 
| 24 | 
            +
                    self.reset_question_bank()
         | 
| 25 | 
            +
                            
         | 
| 26 | 
            +
                def get_questions(self, images, captions, target_idx=0):
         | 
| 27 | 
            +
                    raise NotImplemented
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def get_negative_information_gain(self, p_y_x, question, image_set, caption_set, response_model):
         | 
| 30 | 
            +
                    if question in self.question_bank:
         | 
| 31 | 
            +
                        p_r_qy = self.question_bank[question]
         | 
| 32 | 
            +
                    else:
         | 
| 33 | 
            +
                        is_a_questions = self.question_generator.generate_is_there_question_v2(question)
         | 
| 34 | 
            +
                        is_a_multiplier = []
         | 
| 35 | 
            +
                        if self.multiplier_mode != "none":
         | 
| 36 | 
            +
                            for is_a_q in is_a_questions:
         | 
| 37 | 
            +
                                # print(f"IsA Question: {is_a_q}")
         | 
| 38 | 
            +
                                p_r_qy = response_model.get_p_r_qy(None, is_a_q, image_set, caption_set, is_a=True)
         | 
| 39 | 
            +
                                p_r_qy = p_r_qy.detach().cpu().numpy()
         | 
| 40 | 
            +
                                is_a_multiplier.append(p_r_qy)
         | 
| 41 | 
            +
                            if len(is_a_multiplier)==0: is_a_multiplier.append([0 for _ in range(self.num_images)])
         | 
| 42 | 
            +
                            is_a_multiplier = torch.tensor(scipy.stats.mstats.gmean(is_a_multiplier, axis=0)).to("cuda")
         | 
| 43 | 
            +
                            if self.multiplier_mode=="hard":
         | 
| 44 | 
            +
                                for i in range(is_a_multiplier.shape[0]):
         | 
| 45 | 
            +
                                    if is_a_multiplier[i]<0.5: is_a_multiplier[i]=1e-6
         | 
| 46 | 
            +
                                    else: is_a_multiplier[i]=0.9
         | 
| 47 | 
            +
                            elif self.multiplier_mode=="soft":
         | 
| 48 | 
            +
                                pass
         | 
| 49 | 
            +
                        elif self.multiplier_mode == "none":
         | 
| 50 | 
            +
                            is_a_multiplier = torch.tensor([1 for _ in range(self.num_images)]).to("cuda")
         | 
| 51 | 
            +
                        p_r_qy = response_model.get_p_r_qy(None, question, image_set, caption_set)
         | 
| 52 | 
            +
                        p_r_qy = torch.stack([is_a_multiplier*p_r_qy[r] for r in range(len(p_r_qy))])
         | 
| 53 | 
            +
                        for i in range(self.num_images):
         | 
| 54 | 
            +
                            if is_a_multiplier[i] < 0.5:
         | 
| 55 | 
            +
                                p_r_qy[i] = 1-is_a_multiplier
         | 
| 56 | 
            +
                            else:
         | 
| 57 | 
            +
                                p_r_qy[i] *= is_a_multiplier
         | 
| 58 | 
            +
                        if self.multiplier_mode=="none":
         | 
| 59 | 
            +
                            p_r_qy = response_model.get_p_r_qy(None, question, image_set, caption_set)
         | 
| 60 | 
            +
                        if not self.include_what: 
         | 
| 61 | 
            +
                            p_r_qy = [p_r_qy, 1-p_r_qy]
         | 
| 62 | 
            +
                        self.question_bank[question] = p_r_qy
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    p_y_xqr = torch.stack([p_y_x*p_r_qy[r] for r in range(len(p_r_qy))])
         | 
| 65 | 
            +
                    p_y_xqr = [p_y_xqr[r]/torch.sum(p_y_xqr[r]) if torch.sum(p_y_xqr[r]) != 0 \
         | 
| 66 | 
            +
                                else [0]*len(p_y_xqr[r]) for r in range(len(p_y_xqr))]
         | 
| 67 | 
            +
                    return torch.sum(torch.stack([p_r_qy[r]*p_y_x*torch.log2(1/(p_y_xqr[r]+self.eps)) for r in range(len(p_r_qy))]))
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def get_question_ranking(self, p_y_x, question_set, image_set, caption_set, response_model):
         | 
| 70 | 
            +
                    H_y_rxq = [0]*len(question_set)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    for i, question in enumerate(question_set):
         | 
| 73 | 
            +
                        H_y_rxq[i] = self.get_negative_information_gain(p_y_x, question, image_set, caption_set, response_model)
         | 
| 74 | 
            +
                       
         | 
| 75 | 
            +
                    IG = - torch.stack(H_y_rxq).unsqueeze(1)
         | 
| 76 | 
            +
                    ranked_questions = sorted(zip(list(IG.data.cpu().numpy()), question_set),
         | 
| 77 | 
            +
                                              key = operator.itemgetter(0))[::-1]
         | 
| 78 | 
            +
                    return ranked_questions
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def select_best_question(self, p_y_x, question_set, image_set, caption_set, response_model):
         | 
| 81 | 
            +
                    ranked_questions = self.get_question_ranking(p_y_x, question_set, image_set, caption_set, response_model)
         | 
| 82 | 
            +
                    return ranked_questions[0][1]
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
                def reset_question_bank(self):
         | 
| 85 | 
            +
                    self.question_bank = {}
         | 
    	
        model/model/response_model.py
    ADDED
    
    | @@ -0,0 +1,190 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            from transformers import AutoProcessor, AutoTokenizer, AutoModelForQuestionAnswering, pipeline
         | 
| 4 | 
            +
            from transformers import ViltProcessor, ViltForQuestionAnswering
         | 
| 5 | 
            +
            from transformers import BlipProcessor, BlipForQuestionAnswering
         | 
| 6 | 
            +
            from sentence_transformers import SentenceTransformer
         | 
| 7 | 
            +
            import openai
         | 
| 8 | 
            +
            from PIL import Image
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            def get_response_model(args, response_type):
         | 
| 11 | 
            +
                if response_type=="QA":
         | 
| 12 | 
            +
                    return ResponseModelQA(args.device, args.include_what)
         | 
| 13 | 
            +
                elif response_type=="VQA1":
         | 
| 14 | 
            +
                    return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="vilt1")
         | 
| 15 | 
            +
                elif response_type=="VQA2":
         | 
| 16 | 
            +
                    return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="vilt2")
         | 
| 17 | 
            +
                elif response_type=="VQA3":
         | 
| 18 | 
            +
                    return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="blip")
         | 
| 19 | 
            +
                elif response_type=="VQA4":
         | 
| 20 | 
            +
                    return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="git")
         | 
| 21 | 
            +
                else:
         | 
| 22 | 
            +
                    raise ValueError(f"{response_type} is not a valid response type.")
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class ResponseModel(nn.Module):
         | 
| 26 | 
            +
                # Class for the other ResponseModels to inherit from
         | 
| 27 | 
            +
                def __init__(self, device, include_what):
         | 
| 28 | 
            +
                    super(ResponseModel, self).__init__()
         | 
| 29 | 
            +
                    self.device = device
         | 
| 30 | 
            +
                    self.include_what = include_what
         | 
| 31 | 
            +
                    self.model = None
         | 
| 32 | 
            +
                
         | 
| 33 | 
            +
                def get_response(self, question, image, caption, target_questions, **kwargs):
         | 
| 34 | 
            +
                    raise NotImplemented
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def get_p_r_qy(self, response, question, images, captions, **kwargs):
         | 
| 37 | 
            +
                    raise NotImplemented
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            class ResponseModelQA(ResponseModel):
         | 
| 40 | 
            +
                def __init__(self, device, include_what):
         | 
| 41 | 
            +
                    super(ResponseModelQA, self).__init__(device, include_what)
         | 
| 42 | 
            +
                    if not self.include_what:
         | 
| 43 | 
            +
                        tokenizer = AutoTokenizer.from_pretrained("AmazonScience/qanlu")
         | 
| 44 | 
            +
                        model = AutoModelForQuestionAnswering.from_pretrained("AmazonScience/qanlu")
         | 
| 45 | 
            +
                        self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=0)   # remove device=0 for cpu
         | 
| 46 | 
            +
                    elif self.include_what: 
         | 
| 47 | 
            +
                        tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
         | 
| 48 | 
            +
                        model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
         | 
| 49 | 
            +
                        self.model_wh = pipeline('question-answering', model=model, tokenizer=tokenizer, device=0)   # remove device=0 for cpu
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def get_response(self, question, image, caption, target_questions, **kwargs):
         | 
| 52 | 
            +
                    if self.include_what:
         | 
| 53 | 
            +
                        answer = self.model({'context':caption, 'question':question})
         | 
| 54 | 
            +
                        return answer['answer'].split(' ')[-1]
         | 
| 55 | 
            +
                    else:
         | 
| 56 | 
            +
                        answer = self.model({'context':f"Yes. No. {caption}", 'question':question})
         | 
| 57 | 
            +
                        response, score = answer['answer'], answer['score']
         | 
| 58 | 
            +
                        if score>0.5:
         | 
| 59 | 
            +
                            response = response.lower().replace('.','')
         | 
| 60 | 
            +
                            if "yes" in response.split() and "no" not in response.split():
         | 
| 61 | 
            +
                                response = 'yes'
         | 
| 62 | 
            +
                            elif "no" in response.split() and "yes" not in response.split():
         | 
| 63 | 
            +
                                response = 'no'
         | 
| 64 | 
            +
                            else:
         | 
| 65 | 
            +
                                response = 'yes' if question in target_questions else 'no'
         | 
| 66 | 
            +
                        else:
         | 
| 67 | 
            +
                            response = 'yes' if question in target_questions else 'no'
         | 
| 68 | 
            +
                        return response
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def get_p_r_qy(self, response, question, images, captions, **kwargs):
         | 
| 71 | 
            +
                    if self.include_what:
         | 
| 72 | 
            +
                        raise NotImplementedError
         | 
| 73 | 
            +
                    else:
         | 
| 74 | 
            +
                        p_r_qy = torch.zeros(len(captions))
         | 
| 75 | 
            +
                        qa_input = {'context':[f"Yes. No. {c}" for c in captions], 'question':[question for _ in captions]}
         | 
| 76 | 
            +
                        answers = self.model(qa_input)
         | 
| 77 | 
            +
                        for idx, answer in enumerate(answers):
         | 
| 78 | 
            +
                            curr_ans, score = answer['answer'], answer['score']
         | 
| 79 | 
            +
                            if curr_ans.strip() in ["Yes.", "No."]:
         | 
| 80 | 
            +
                                if response==None:
         | 
| 81 | 
            +
                                    if curr_ans.strip()=="No.": p_r_qy[idx] = 1-score
         | 
| 82 | 
            +
                                    if curr_ans.strip()=="Yes.": p_r_qy[idx] = score
         | 
| 83 | 
            +
                                elif curr_ans.strip().lower().replace('.','')==response: p_r_qy[idx]=score
         | 
| 84 | 
            +
                                else: p_r_qy[idx]=1-score
         | 
| 85 | 
            +
                            else:
         | 
| 86 | 
            +
                                p_r_qy[idx]=0.5
         | 
| 87 | 
            +
                        return p_r_qy.to(self.device)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            class ResponseModelVQA(ResponseModel):
         | 
| 90 | 
            +
                def __init__(self, device, include_what, question_generator, vqa_type):
         | 
| 91 | 
            +
                    super(ResponseModelVQA, self).__init__(device, include_what)
         | 
| 92 | 
            +
                    self.vqa_type = vqa_type
         | 
| 93 | 
            +
                    self.question_generator = question_generator
         | 
| 94 | 
            +
                    self.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
         | 
| 95 | 
            +
                    if vqa_type=="vilt1":
         | 
| 96 | 
            +
                        self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
         | 
| 97 | 
            +
                        self.model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)
         | 
| 98 | 
            +
                        self.vocab = list(self.model.config.label2id.keys())
         | 
| 99 | 
            +
                    elif vqa_type=="vilt2":
         | 
| 100 | 
            +
                        self.processor = AutoProcessor.from_pretrained("tufa15nik/vilt-finetuned-vqasi")
         | 
| 101 | 
            +
                        self.model = ViltForQuestionAnswering.from_pretrained("tufa15nik/vilt-finetuned-vqasi").to("cuda")
         | 
| 102 | 
            +
                        self.vocab = list(self.model.config.label2id.keys())
         | 
| 103 | 
            +
                    elif vqa_type=="blip":
         | 
| 104 | 
            +
                        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
         | 
| 105 | 
            +
                        self.model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
         | 
| 106 | 
            +
                    elif vqa_type=="git":
         | 
| 107 | 
            +
                        pass
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        raise ValueError(f"{vqa_type} is not a valid vqa_type.")
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
                def get_response(self, question, image, caption, target_questions, is_a=False):
         | 
| 113 | 
            +
                    encoding = self.processor(image, question, return_tensors="pt").to(self.device)
         | 
| 114 | 
            +
                    if is_a==False:
         | 
| 115 | 
            +
                        is_a_questions = self.question_generator.generate_is_there_question_v2(question)
         | 
| 116 | 
            +
                        is_a_responses = []
         | 
| 117 | 
            +
                        if question in ["What is in the photo?", "What is in the picture?", "What is in the background?"]:
         | 
| 118 | 
            +
                            is_a_questions = []
         | 
| 119 | 
            +
                        for q in is_a_questions:
         | 
| 120 | 
            +
                            is_a_responses.append(self.get_response(q, image, caption, target_questions, is_a=True))
         | 
| 121 | 
            +
                        no_cnt = sum([i.lower()=="no" for i in is_a_responses])
         | 
| 122 | 
            +
                        if len(is_a_responses)>0 and no_cnt/len(is_a_responses)>=0.5:
         | 
| 123 | 
            +
                            if question[:8]=="How many": return "0"
         | 
| 124 | 
            +
                            else: return "nothing"
         | 
| 125 | 
            +
                    if self.vqa_type in ["vilt1", "vilt2"]:    
         | 
| 126 | 
            +
                        outputs = self.model(**encoding)
         | 
| 127 | 
            +
                        logits = torch.nn.functional.softmax(outputs.logits, dim=1)
         | 
| 128 | 
            +
                        idx = logits.argmax(-1).item()
         | 
| 129 | 
            +
                        response = self.model.config.id2label[idx]
         | 
| 130 | 
            +
                        response = response.lower().replace('.','').strip()
         | 
| 131 | 
            +
                    elif self.vqa_type == "blip":
         | 
| 132 | 
            +
                        outputs = self.model.generate(**encoding)
         | 
| 133 | 
            +
                        response = self.processor.decode(outputs[0], skip_special_tokens=True)
         | 
| 134 | 
            +
                    return response 
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def get_p_r_qy(self, response, question, images, captions, is_a=False):
         | 
| 137 | 
            +
                    p_r_qy = torch.zeros(len(captions))
         | 
| 138 | 
            +
                    logits_arr = []
         | 
| 139 | 
            +
                    for i, image in enumerate(images):
         | 
| 140 | 
            +
                        with torch.no_grad():
         | 
| 141 | 
            +
                            if len(question) > 150: question=""     # ignore question if too long
         | 
| 142 | 
            +
                            encoding = self.processor(image, question, return_tensors="pt").to(self.device)
         | 
| 143 | 
            +
                            outputs = self.model(**encoding)
         | 
| 144 | 
            +
                        logits = torch.nn.functional.softmax(outputs.logits, dim=1)
         | 
| 145 | 
            +
                        idx = logits.argmax(-1).item()
         | 
| 146 | 
            +
                        curr_response = self.model.config.id2label[idx]
         | 
| 147 | 
            +
                        curr_response = curr_response.lower().replace('.','').strip()
         | 
| 148 | 
            +
                        if self.include_what==False or is_a==True:
         | 
| 149 | 
            +
                            if response==None:
         | 
| 150 | 
            +
                                if curr_response=="yes": p_r_qy[i] = logits[0][3].item()
         | 
| 151 | 
            +
                                elif curr_response=="no": p_r_qy[i] = 1-logits[0][9].item()
         | 
| 152 | 
            +
                                else: p_r_qy[i] = 0.5
         | 
| 153 | 
            +
                            elif curr_response==response: p_r_qy[i] = logits[0][idx].item()
         | 
| 154 | 
            +
                            else: p_r_qy[i] = 1-logits[0][idx].item()
         | 
| 155 | 
            +
                        else:
         | 
| 156 | 
            +
                            logits_arr.append(logits)
         | 
| 157 | 
            +
                    if self.include_what==False or is_a==True: 
         | 
| 158 | 
            +
                        return p_r_qy.to(self.device)
         | 
| 159 | 
            +
                    else:
         | 
| 160 | 
            +
                        logits = torch.concat(logits_arr)
         | 
| 161 | 
            +
                        if response==None:
         | 
| 162 | 
            +
                            top_answers = logits.argmax(1)
         | 
| 163 | 
            +
                            p_r_qy = logits[:,top_answers]
         | 
| 164 | 
            +
                        else:
         | 
| 165 | 
            +
                            response_idx = self.get_response_idx(response)
         | 
| 166 | 
            +
                            p_r_qy = logits[:,response_idx]
         | 
| 167 | 
            +
                        
         | 
| 168 | 
            +
                        # check if this 
         | 
| 169 | 
            +
                        # consider rerunning also without the geometric mean
         | 
| 170 | 
            +
                        if response=="nothing":
         | 
| 171 | 
            +
                            is_a_questions = self.question_generator.generate_is_there_question_v2(question)
         | 
| 172 | 
            +
                            for idx, (caption, image) in enumerate(zip(captions, images)):
         | 
| 173 | 
            +
                                current_responses = []
         | 
| 174 | 
            +
                                for is_a_q in is_a_questions:
         | 
| 175 | 
            +
                                    current_responses.append(self.get_response(is_a_q, image, caption, None, is_a=True))
         | 
| 176 | 
            +
                                no_cnt = sum([i.lower()=="no" for i in current_responses])
         | 
| 177 | 
            +
                                if len(current_responses)>0 and no_cnt/len(current_responses)>=0.5:
         | 
| 178 | 
            +
                                    p_r_qy[idx] = 1.0
         | 
| 179 | 
            +
                        return p_r_qy.to(self.device)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def get_response_idx(self, response):
         | 
| 182 | 
            +
                    if response in self.model.config.label2id:
         | 
| 183 | 
            +
                        return self.model.config.label2id[response]
         | 
| 184 | 
            +
                    else:
         | 
| 185 | 
            +
                        embs = self.sentence_transformer.encode(self.vocab, convert_to_tensor=True)
         | 
| 186 | 
            +
                        emb_response = self.sentence_transformer.encode([response], convert_to_tensor=True)
         | 
| 187 | 
            +
                        dists = torch.nn.CosineSimilarity(-1)(emb_response, embs)
         | 
| 188 | 
            +
                        best_response_idx = torch.argmax(dists)
         | 
| 189 | 
            +
                        best_response = self.vocab[best_response_idx]
         | 
| 190 | 
            +
                        return self.model.config.label2id[best_response]
         | 
    	
        model/run_question_asking_model.py
    ADDED
    
    | @@ -0,0 +1,186 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from model.model.question_asking_model import get_question_model
         | 
| 2 | 
            +
            from model.model.caption_model import get_caption_model
         | 
| 3 | 
            +
            from model.model.response_model import get_response_model
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from torch.utils.data import Dataset, DataLoader
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import argparse
         | 
| 10 | 
            +
            import random
         | 
| 11 | 
            +
            from tqdm.auto import tqdm
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import pandas as pd
         | 
| 14 | 
            +
            import logging
         | 
| 15 | 
            +
            from model.utils import logging_handler, image_saver, assert_checks
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            random.seed(123)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            parser = argparse.ArgumentParser()
         | 
| 20 | 
            +
            parser.add_argument('--device', type=str, default='cuda')
         | 
| 21 | 
            +
            parser.add_argument('--include_what', action='store_true')
         | 
| 22 | 
            +
            parser.add_argument('--target_idx', type=int, default=0)
         | 
| 23 | 
            +
            parser.add_argument('--max_num_questions', type=int, default=25)
         | 
| 24 | 
            +
            parser.add_argument('--num_images', type=int, default=10)
         | 
| 25 | 
            +
            parser.add_argument('--beam', type=int, default=1)
         | 
| 26 | 
            +
            parser.add_argument('--num_samples', type=int, default=100)
         | 
| 27 | 
            +
            parser.add_argument('--threshold', type=float, default=0.9)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            parser.add_argument('--caption_strategy', type=str, default='simple', choices=['simple', 'granular', 'gtruth'])
         | 
| 30 | 
            +
            parser.add_argument('--sample_strategy', type=str, default='random', choices=['random', 'attribute', 'clip'])
         | 
| 31 | 
            +
            parser.add_argument('--attribute_n', type=int, default=1)               # Number of attributes to split
         | 
| 32 | 
            +
            parser.add_argument('--response_type_simul', type=str, default='VQA1', choices=['simple', 'QA', 'VQA1', 'VQA2', 'VQA3', 'VQA4'])
         | 
| 33 | 
            +
            parser.add_argument('--response_type_gtruth', type=str, default='VQA2', choices=['simple', 'QA', 'VQA1', 'VQA2', 'VQA3', 'VQA4'])
         | 
| 34 | 
            +
            parser.add_argument('--question_strategy', type=str, default='gpt3', choices=['rule', 'gpt3'])
         | 
| 35 | 
            +
            parser.add_argument('--multiplier_mode', type=str, default='soft', choices=['soft', 'hard', 'none'])
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            parser.add_argument('--gpt3_save_name', type=str, default='questions_gpt3')
         | 
| 38 | 
            +
            parser.add_argument('--save_name', type=str, default=None)
         | 
| 39 | 
            +
            parser.add_argument('--verbose', action='store_true')
         | 
| 40 | 
            +
            args = parser.parse_args()
         | 
| 41 | 
            +
            args.question_strategy='gpt3'
         | 
| 42 | 
            +
            args.include_what=True
         | 
| 43 | 
            +
            args.response_type_simul='VQA1'
         | 
| 44 | 
            +
            args.response_type_gtruth='VQA3'
         | 
| 45 | 
            +
            args.multiplier_mode='soft'
         | 
| 46 | 
            +
            args.sample_strategy='attribute'
         | 
| 47 | 
            +
            args.attribute_n=1
         | 
| 48 | 
            +
            args.caption_strategy='gtruth'
         | 
| 49 | 
            +
            assert_checks(args)
         | 
| 50 | 
            +
            if args.save_name is None: logger = logging_handler(args.verbose, args.save_name)
         | 
| 51 | 
            +
            args.load_response_model = True
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            print("1. Loading question model ...")
         | 
| 54 | 
            +
            question_model = get_question_model(args)
         | 
| 55 | 
            +
            args.question_generator = question_model.question_generator
         | 
| 56 | 
            +
            print("2. Loading response model simul ...")
         | 
| 57 | 
            +
            response_model_simul = get_response_model(args, args.response_type_simul)
         | 
| 58 | 
            +
            response_model_simul.to(args.device)
         | 
| 59 | 
            +
            print("3. Loading response model gtruth ...")
         | 
| 60 | 
            +
            response_model_gtruth = get_response_model(args, args.response_type_gtruth)
         | 
| 61 | 
            +
            response_model_gtruth.to(args.device)
         | 
| 62 | 
            +
            print("4. Loading caption model ...")
         | 
| 63 | 
            +
            caption_model = get_caption_model(args, question_model)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def return_modules():
         | 
| 68 | 
            +
                return question_model, response_model_simul, response_model_gtruth, caption_model 
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            args.question_strategy='rule'
         | 
| 73 | 
            +
            args.include_what=False
         | 
| 74 | 
            +
            args.response_type_simul='VQA1'
         | 
| 75 | 
            +
            args.response_type_gtruth='VQA3'
         | 
| 76 | 
            +
            args.multiplier_mode='none'
         | 
| 77 | 
            +
            args.sample_strategy='attribute'
         | 
| 78 | 
            +
            args.attribute_n=1
         | 
| 79 | 
            +
            args.caption_strategy='gtruth'
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            print("1. Loading question model ...")
         | 
| 82 | 
            +
            question_model_yn = get_question_model(args)
         | 
| 83 | 
            +
            args.question_generator_yn = question_model_yn.question_generator
         | 
| 84 | 
            +
            print("2. Loading response model simul ...")
         | 
| 85 | 
            +
            response_model_simul_yn = get_response_model(args, args.response_type_simul)
         | 
| 86 | 
            +
            response_model_simul_yn.to(args.device)
         | 
| 87 | 
            +
            print("3. Loading response model gtruth ...")
         | 
| 88 | 
            +
            response_model_gtruth_yn = get_response_model(args, args.response_type_gtruth)
         | 
| 89 | 
            +
            response_model_gtruth_yn.to(args.device)
         | 
| 90 | 
            +
            print("4. Loading caption model ...")
         | 
| 91 | 
            +
            caption_model_yn = get_caption_model(args, question_model_yn)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def return_modules_yn():
         | 
| 95 | 
            +
                return question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn 
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            # args.question_strategy='gpt3'
         | 
| 100 | 
            +
            # args.include_what=True
         | 
| 101 | 
            +
            # args.response_type_simul='VQA1'
         | 
| 102 | 
            +
            # args.response_type_gtruth='VQA3'
         | 
| 103 | 
            +
            # args.multiplier_mode='none'
         | 
| 104 | 
            +
            # args.sample_strategy='attribute'
         | 
| 105 | 
            +
            # args.attribute_n=1
         | 
| 106 | 
            +
            # args.caption_strategy='gtruth'
         | 
| 107 | 
            +
            # assert_checks(args)
         | 
| 108 | 
            +
            # if args.save_name is None: logger = logging_handler(args.verbose, args.save_name)
         | 
| 109 | 
            +
            # args.load_response_model = True
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            # print("1. Loading question model ...")
         | 
| 112 | 
            +
            # question_model = get_question_model(args)
         | 
| 113 | 
            +
            # args.question_generator = question_model.question_generator
         | 
| 114 | 
            +
            # print("2. Loading response model simul ...")
         | 
| 115 | 
            +
            # response_model_simul = get_response_model(args, args.response_type_simul)
         | 
| 116 | 
            +
            # response_model_simul.to(args.device)
         | 
| 117 | 
            +
            # print("3. Loading response model gtruth ...")
         | 
| 118 | 
            +
            # response_model_gtruth = get_response_model(args, args.response_type_gtruth)
         | 
| 119 | 
            +
            # response_model_gtruth.to(args.device)
         | 
| 120 | 
            +
            # print("4. Loading caption model ...")
         | 
| 121 | 
            +
            # caption_model = get_caption_model(args, question_model)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
            # # dataloader = DataLoader(dataset=ReferenceGameData(split='test', 
         | 
| 124 | 
            +
            # #                                                   num_images=args.num_images, 
         | 
| 125 | 
            +
            # #                                                   num_samples=args.num_samples,
         | 
| 126 | 
            +
            # #                                                   sample_strategy=args.sample_strategy,
         | 
| 127 | 
            +
            # #                                                   attribute_n=args.attribute_n))
         | 
| 128 | 
            +
             | 
| 129 | 
            +
            # def return_modules():
         | 
| 130 | 
            +
            #     return question_model, response_model_simul, response_model_gtruth, caption_model 
         | 
| 131 | 
            +
            # # game_lens, game_preds = [], []
         | 
| 132 | 
            +
            # for t, batch in enumerate(tqdm(dataloader)):
         | 
| 133 | 
            +
            #     image_files = [image[0] for image in batch['images'][:args.num_images]]
         | 
| 134 | 
            +
            #     image_files = [str(i).split('/')[1] for i in image_files]
         | 
| 135 | 
            +
            #     with open("mscoco_images_attribute_n=1.txt", 'a') as f:
         | 
| 136 | 
            +
            #         for i in image_files:
         | 
| 137 | 
            +
            #             f.write(str(i)+"\n")
         | 
| 138 | 
            +
                # images = [np.asarray(Image.open(f"./../../../data/ms-coco/images/{i}")) for i in image_files]
         | 
| 139 | 
            +
                # images = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images]
         | 
| 140 | 
            +
            #     p_y_x = (torch.ones(args.num_images)/args.num_images).to(question_model.device)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            #     if args.save_name is not None: 
         | 
| 143 | 
            +
            #         logger = logging_handler(args.verbose, args.save_name, t)
         | 
| 144 | 
            +
            #         image_saver(images, args.save_name, t)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            #     captions = caption_model.get_captions(image_files)
         | 
| 147 | 
            +
            #     questions, target_questions = question_model.get_questions(image_files, captions, args.target_idx)
         | 
| 148 | 
            +
                
         | 
| 149 | 
            +
            #     question_model.reset_question_bank()
         | 
| 150 | 
            +
            #     logger.info(questions)
         | 
| 151 | 
            +
            #     for idx, c in enumerate(captions): logger.info(f"Image_{idx}: {c}")
         | 
| 152 | 
            +
                
         | 
| 153 | 
            +
            #     num_questions_original = len(questions)
         | 
| 154 | 
            +
            #     for j in range(min(args.max_num_questions, num_questions_original)):
         | 
| 155 | 
            +
            #         # Select best question
         | 
| 156 | 
            +
            #         question = question_model.select_best_question(p_y_x, questions, images, captions, response_model_simul)
         | 
| 157 | 
            +
            #         logger.info(f"Question: {question}")
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            #         # Ask the question and get the model's response
         | 
| 160 | 
            +
            #         response = response_model_gtruth.get_response(question, images[args.target_idx], captions[args.target_idx], target_questions, is_a=1-args.include_what)
         | 
| 161 | 
            +
            #         logger.info(f"Response: {response}")
         | 
| 162 | 
            +
             | 
| 163 | 
            +
            #         # Update the probabilities
         | 
| 164 | 
            +
            #         p_r_qy = response_model_simul.get_p_r_qy(response, question, images, captions)
         | 
| 165 | 
            +
            #         logger.info(f"P(r|q,y):\n{np.around(p_r_qy.cpu().detach().numpy(), 3)}")
         | 
| 166 | 
            +
            #         p_y_xqr = p_y_x*p_r_qy
         | 
| 167 | 
            +
            #         p_y_xqr = p_y_xqr/torch.sum(p_y_xqr)if torch.sum(p_y_xqr) != 0 else torch.zeros_like(p_y_xqr)        
         | 
| 168 | 
            +
            #         p_y_x = p_y_xqr
         | 
| 169 | 
            +
            #         logger.info(f"Updated distribution:\n{np.around(p_y_x.cpu().detach().numpy(), 3)}\n")
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            #         # Don't repeat the same question again in the future
         | 
| 172 | 
            +
            #         questions.remove(question)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
            #         # Terminate if probability exceeds threshold or if out of questions to ask
         | 
| 175 | 
            +
            #         top_prob = torch.max(p_y_x).item()
         | 
| 176 | 
            +
            #         if top_prob >= args.threshold or j==min(args.max_num_questions, num_questions_original)-1:
         | 
| 177 | 
            +
            #             game_preds.append(torch.argmax(p_y_x).item())
         | 
| 178 | 
            +
            #             game_lens.append(j+1)
         | 
| 179 | 
            +
            #             logger.info(f"pred:{game_preds[-1]} game_len:{game_lens[-1]}")
         | 
| 180 | 
            +
            #             break
         | 
| 181 | 
            +
             | 
| 182 | 
            +
            # logger = logging_handler(args.verbose, args.save_name, "final_results")
         | 
| 183 | 
            +
            # logger.info(f"Game lenths:\n{game_lens}")
         | 
| 184 | 
            +
            # logger.info(sum(game_lens)/len(game_lens))
         | 
| 185 | 
            +
            # logger.info(f"Predictions:\n{game_preds}")
         | 
| 186 | 
            +
            # logger.info(f"Accuracy:\n{sum([i==args.target_idx for i in game_preds])/len(game_preds)}")
         | 
    	
        model/utils.py
    ADDED
    
    | @@ -0,0 +1,54 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
            import matplotlib.pyplot as plt
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            import nltk
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def logging_handler(verbose, save_name, idx=0):
         | 
| 8 | 
            +
                logger = logging.getLogger(str(idx))
         | 
| 9 | 
            +
                logger.setLevel(logging.INFO)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                stream_logger = logging.StreamHandler()
         | 
| 12 | 
            +
                stream_logger.setFormatter(logging.Formatter("%(message)s"))
         | 
| 13 | 
            +
                logger.addHandler(stream_logger)
         | 
| 14 | 
            +
                
         | 
| 15 | 
            +
                if save_name is not None:
         | 
| 16 | 
            +
                    savepath = f"results/{save_name}"
         | 
| 17 | 
            +
                    if not os.path.exists(savepath):
         | 
| 18 | 
            +
                        os.makedirs(savepath)
         | 
| 19 | 
            +
                    file_logger = logging.FileHandler(f"{savepath}/{idx}.log")
         | 
| 20 | 
            +
                    file_logger.setFormatter(logging.Formatter("%(message)s"))
         | 
| 21 | 
            +
                    logger.addHandler(file_logger)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                return logger
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def image_saver(images, save_name, idx=0, interactive=True):
         | 
| 27 | 
            +
                fig, a =  plt.subplots(2,5) 
         | 
| 28 | 
            +
                fig.set_size_inches(30, 15)
         | 
| 29 | 
            +
                for i in range(10):
         | 
| 30 | 
            +
                    a[i//5][i%5].imshow(images[i])
         | 
| 31 | 
            +
                    a[i//5][i%5].axis('off')
         | 
| 32 | 
            +
                    a[i//5][i%5].set_aspect('equal')
         | 
| 33 | 
            +
                plt.tight_layout()
         | 
| 34 | 
            +
                plt.subplots_adjust(wspace=0, hspace=0)
         | 
| 35 | 
            +
                if not interactive:
         | 
| 36 | 
            +
                    plt.savefig(f"results/{save_name}/{idx}.png")
         | 
| 37 | 
            +
                else:
         | 
| 38 | 
            +
                    plt.savefig(f"{save_name}.png")
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            def assert_checks(args):
         | 
| 41 | 
            +
                if args.question_strategy=="gpt3":
         | 
| 42 | 
            +
                    assert args.include_what
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            def extract_nouns(sents):
         | 
| 45 | 
            +
                noun_list = []
         | 
| 46 | 
            +
                for idx, s in enumerate(sents):
         | 
| 47 | 
            +
                    curr = []
         | 
| 48 | 
            +
                    sent = (nltk.pos_tag(s.split()))
         | 
| 49 | 
            +
                    for word in sent:
         | 
| 50 | 
            +
                        if word[1] not in ["NN", "NNS"]: continue
         | 
| 51 | 
            +
                        currword = word[0].replace('.','')
         | 
| 52 | 
            +
                        curr.append(currword.lower())
         | 
| 53 | 
            +
                    noun_list.append(curr)
         | 
| 54 | 
            +
                return noun_list
         | 
    	
        open_db.py
    CHANGED
    
    | @@ -4,9 +4,4 @@ import pandas as pd | |
| 4 | 
             
            db = sqlite3.connect("response.db")
         | 
| 5 | 
             
            df = pd.read_sql('SELECT * from responses', db)
         | 
| 6 | 
             
            print(df)
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            #      conn = sqlite3.connect(db)
         | 
| 9 | 
            -
            #      c = conn.cursor()
         | 
| 10 | 
            -
            #      c.execute("SELECT name FROM sqlite_master WHERE type='table';")
         | 
| 11 | 
            -
            #      for table in c.fetchall()
         | 
| 12 | 
            -
            #          yield list(c.execute('SELECT * from ?;', (table[0],)))
         | 
|  | |
| 4 | 
             
            db = sqlite3.connect("response.db")
         | 
| 5 | 
             
            df = pd.read_sql('SELECT * from responses', db)
         | 
| 6 | 
             
            print(df)
         | 
| 7 | 
            +
            df.to_csv("responses.csv", index=False)
         | 
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        pilot-study.csv
    ADDED
    
    | @@ -0,0 +1,161 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            taskID,mscoco-id,task-type
         | 
| 2 | 
            +
            0,35,wh- standard
         | 
| 3 | 
            +
            1,102,wh- hard (n=1)
         | 
| 4 | 
            +
            2,25,wh- standard
         | 
| 5 | 
            +
            3,133,yes/no hard (n=1)
         | 
| 6 | 
            +
            4,80,wh- hard (n=1)
         | 
| 7 | 
            +
            5,15,wh- standard
         | 
| 8 | 
            +
            6,92,wh- hard (n=1)
         | 
| 9 | 
            +
            7,108,wh- hard (n=1)
         | 
| 10 | 
            +
            8,60,yes/no standard
         | 
| 11 | 
            +
            9,57,yes/no standard
         | 
| 12 | 
            +
            10,125,yes/no hard (n=1)
         | 
| 13 | 
            +
            11,56,yes/no standard
         | 
| 14 | 
            +
            12,137,yes/no hard (n=1)
         | 
| 15 | 
            +
            13,40,yes/no standard
         | 
| 16 | 
            +
            14,134,yes/no hard (n=1)
         | 
| 17 | 
            +
            15,130,yes/no hard (n=1)
         | 
| 18 | 
            +
            16,89,wh- hard (n=1)
         | 
| 19 | 
            +
            17,19,wh- standard
         | 
| 20 | 
            +
            18,58,yes/no standard
         | 
| 21 | 
            +
            19,81,wh- hard (n=1)
         | 
| 22 | 
            +
            20,5,wh- standard
         | 
| 23 | 
            +
            21,73,yes/no standard
         | 
| 24 | 
            +
            22,54,yes/no standard
         | 
| 25 | 
            +
            23,0,wh- standard
         | 
| 26 | 
            +
            24,14,wh- standard
         | 
| 27 | 
            +
            25,113,wh- hard (n=1)
         | 
| 28 | 
            +
            26,34,wh- standard
         | 
| 29 | 
            +
            27,159,yes/no hard (n=1)
         | 
| 30 | 
            +
            28,135,yes/no hard (n=1)
         | 
| 31 | 
            +
            29,2,wh- standard
         | 
| 32 | 
            +
            30,156,yes/no hard (n=1)
         | 
| 33 | 
            +
            31,30,wh- standard
         | 
| 34 | 
            +
            32,104,wh- hard (n=1)
         | 
| 35 | 
            +
            33,128,yes/no hard (n=1)
         | 
| 36 | 
            +
            34,18,wh- standard
         | 
| 37 | 
            +
            35,157,yes/no hard (n=1)
         | 
| 38 | 
            +
            36,1,wh- standard
         | 
| 39 | 
            +
            37,42,yes/no standard
         | 
| 40 | 
            +
            38,131,yes/no hard (n=1)
         | 
| 41 | 
            +
            39,115,wh- hard (n=1)
         | 
| 42 | 
            +
            40,120,yes/no hard (n=1)
         | 
| 43 | 
            +
            41,3,wh- standard
         | 
| 44 | 
            +
            42,63,yes/no standard
         | 
| 45 | 
            +
            43,65,yes/no standard
         | 
| 46 | 
            +
            44,103,wh- hard (n=1)
         | 
| 47 | 
            +
            45,124,yes/no hard (n=1)
         | 
| 48 | 
            +
            46,21,wh- standard
         | 
| 49 | 
            +
            47,72,yes/no standard
         | 
| 50 | 
            +
            48,62,yes/no standard
         | 
| 51 | 
            +
            49,47,yes/no standard
         | 
| 52 | 
            +
            50,78,yes/no standard
         | 
| 53 | 
            +
            51,109,wh- hard (n=1)
         | 
| 54 | 
            +
            52,136,yes/no hard (n=1)
         | 
| 55 | 
            +
            53,158,yes/no hard (n=1)
         | 
| 56 | 
            +
            54,61,yes/no standard
         | 
| 57 | 
            +
            55,27,wh- standard
         | 
| 58 | 
            +
            56,24,wh- standard
         | 
| 59 | 
            +
            57,123,yes/no hard (n=1)
         | 
| 60 | 
            +
            58,70,yes/no standard
         | 
| 61 | 
            +
            59,91,wh- hard (n=1)
         | 
| 62 | 
            +
            60,55,yes/no standard
         | 
| 63 | 
            +
            61,87,wh- hard (n=1)
         | 
| 64 | 
            +
            62,46,yes/no standard
         | 
| 65 | 
            +
            63,33,wh- standard
         | 
| 66 | 
            +
            64,16,wh- standard
         | 
| 67 | 
            +
            65,147,yes/no hard (n=1)
         | 
| 68 | 
            +
            66,85,wh- hard (n=1)
         | 
| 69 | 
            +
            67,59,yes/no standard
         | 
| 70 | 
            +
            68,99,wh- hard (n=1)
         | 
| 71 | 
            +
            69,117,wh- hard (n=1)
         | 
| 72 | 
            +
            70,9,wh- standard
         | 
| 73 | 
            +
            71,122,yes/no hard (n=1)
         | 
| 74 | 
            +
            72,53,yes/no standard
         | 
| 75 | 
            +
            73,22,wh- standard
         | 
| 76 | 
            +
            74,8,wh- standard
         | 
| 77 | 
            +
            75,29,wh- standard
         | 
| 78 | 
            +
            76,83,wh- hard (n=1)
         | 
| 79 | 
            +
            77,37,wh- standard
         | 
| 80 | 
            +
            78,66,yes/no standard
         | 
| 81 | 
            +
            79,41,yes/no standard
         | 
| 82 | 
            +
            80,94,wh- hard (n=1)
         | 
| 83 | 
            +
            81,98,wh- hard (n=1)
         | 
| 84 | 
            +
            82,110,wh- hard (n=1)
         | 
| 85 | 
            +
            83,77,yes/no standard
         | 
| 86 | 
            +
            84,151,yes/no hard (n=1)
         | 
| 87 | 
            +
            85,121,yes/no hard (n=1)
         | 
| 88 | 
            +
            86,6,wh- standard
         | 
| 89 | 
            +
            87,45,yes/no standard
         | 
| 90 | 
            +
            88,155,yes/no hard (n=1)
         | 
| 91 | 
            +
            89,88,wh- hard (n=1)
         | 
| 92 | 
            +
            90,96,wh- hard (n=1)
         | 
| 93 | 
            +
            91,75,yes/no standard
         | 
| 94 | 
            +
            92,112,wh- hard (n=1)
         | 
| 95 | 
            +
            93,49,yes/no standard
         | 
| 96 | 
            +
            94,152,yes/no hard (n=1)
         | 
| 97 | 
            +
            95,38,wh- standard
         | 
| 98 | 
            +
            96,7,wh- standard
         | 
| 99 | 
            +
            97,52,yes/no standard
         | 
| 100 | 
            +
            98,101,wh- hard (n=1)
         | 
| 101 | 
            +
            99,76,yes/no standard
         | 
| 102 | 
            +
            100,28,wh- standard
         | 
| 103 | 
            +
            101,114,wh- hard (n=1)
         | 
| 104 | 
            +
            102,139,yes/no hard (n=1)
         | 
| 105 | 
            +
            103,74,yes/no standard
         | 
| 106 | 
            +
            104,149,yes/no hard (n=1)
         | 
| 107 | 
            +
            105,84,wh- hard (n=1)
         | 
| 108 | 
            +
            106,79,yes/no standard
         | 
| 109 | 
            +
            107,127,yes/no hard (n=1)
         | 
| 110 | 
            +
            108,126,yes/no hard (n=1)
         | 
| 111 | 
            +
            109,116,wh- hard (n=1)
         | 
| 112 | 
            +
            110,71,yes/no standard
         | 
| 113 | 
            +
            111,67,yes/no standard
         | 
| 114 | 
            +
            112,10,wh- standard
         | 
| 115 | 
            +
            113,143,yes/no hard (n=1)
         | 
| 116 | 
            +
            114,132,yes/no hard (n=1)
         | 
| 117 | 
            +
            115,90,wh- hard (n=1)
         | 
| 118 | 
            +
            116,140,yes/no hard (n=1)
         | 
| 119 | 
            +
            117,144,yes/no hard (n=1)
         | 
| 120 | 
            +
            118,106,wh- hard (n=1)
         | 
| 121 | 
            +
            119,32,wh- standard
         | 
| 122 | 
            +
            120,154,yes/no hard (n=1)
         | 
| 123 | 
            +
            121,11,wh- standard
         | 
| 124 | 
            +
            122,17,wh- standard
         | 
| 125 | 
            +
            123,145,yes/no hard (n=1)
         | 
| 126 | 
            +
            124,118,wh- hard (n=1)
         | 
| 127 | 
            +
            125,48,yes/no standard
         | 
| 128 | 
            +
            126,148,yes/no hard (n=1)
         | 
| 129 | 
            +
            127,26,wh- standard
         | 
| 130 | 
            +
            128,51,yes/no standard
         | 
| 131 | 
            +
            129,13,wh- standard
         | 
| 132 | 
            +
            130,39,wh- standard
         | 
| 133 | 
            +
            131,153,yes/no hard (n=1)
         | 
| 134 | 
            +
            132,12,wh- standard
         | 
| 135 | 
            +
            133,93,wh- hard (n=1)
         | 
| 136 | 
            +
            134,107,wh- hard (n=1)
         | 
| 137 | 
            +
            135,86,wh- hard (n=1)
         | 
| 138 | 
            +
            136,31,wh- standard
         | 
| 139 | 
            +
            137,95,wh- hard (n=1)
         | 
| 140 | 
            +
            138,44,yes/no standard
         | 
| 141 | 
            +
            139,69,yes/no standard
         | 
| 142 | 
            +
            140,150,yes/no hard (n=1)
         | 
| 143 | 
            +
            141,4,wh- standard
         | 
| 144 | 
            +
            142,142,yes/no hard (n=1)
         | 
| 145 | 
            +
            143,43,yes/no standard
         | 
| 146 | 
            +
            144,50,yes/no standard
         | 
| 147 | 
            +
            145,100,wh- hard (n=1)
         | 
| 148 | 
            +
            146,129,yes/no hard (n=1)
         | 
| 149 | 
            +
            147,68,yes/no standard
         | 
| 150 | 
            +
            148,146,yes/no hard (n=1)
         | 
| 151 | 
            +
            149,64,yes/no standard
         | 
| 152 | 
            +
            150,23,wh- standard
         | 
| 153 | 
            +
            151,82,wh- hard (n=1)
         | 
| 154 | 
            +
            152,111,wh- hard (n=1)
         | 
| 155 | 
            +
            153,97,wh- hard (n=1)
         | 
| 156 | 
            +
            154,119,wh- hard (n=1)
         | 
| 157 | 
            +
            155,141,yes/no hard (n=1)
         | 
| 158 | 
            +
            156,20,wh- standard
         | 
| 159 | 
            +
            157,36,wh- standard
         | 
| 160 | 
            +
            158,138,yes/no hard (n=1)
         | 
| 161 | 
            +
            159,105,wh- hard (n=1)
         | 
    	
        response_db.py
    CHANGED
    
    | @@ -75,5 +75,4 @@ class StResponseDb(ResponseDb): | |
| 75 |  | 
| 76 | 
             
            if __name__ == "__main__":
         | 
| 77 | 
             
                db = ResponseDb()
         | 
| 78 | 
            -
                print(db.get_all())
         | 
| 79 | 
            -
             | 
|  | |
| 75 |  | 
| 76 | 
             
            if __name__ == "__main__":
         | 
| 77 | 
             
                db = ResponseDb()
         | 
| 78 | 
            +
                print(db.get_all())
         | 
|  |