File size: 3,453 Bytes
72fb101
4e4df51
 
5b79e88
4e4df51
5b79e88
 
c4454c9
4e4df51
 
 
72fb101
 
 
 
 
 
3b91aca
4e4df51
 
72fb101
 
be53140
4e4df51
9febe95
4e4df51
 
 
 
 
72fb101
 
be53140
72fb101
 
 
 
 
 
 
 
 
435f95e
72fb101
4e4df51
 
435f95e
4e4df51
230aa4d
435f95e
 
 
 
 
230aa4d
435f95e
230aa4d
435f95e
 
 
4e4df51
435f95e
 
 
 
 
72fb101
 
5ee5175
4e4df51
b2c3e0e
4e4df51
 
72fb101
435f95e
4e4df51
 
 
 
 
 
 
72fb101
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import random
import gradio as gr
from datasets import load_dataset
import os

auth_token = os.environ.get("auth_token")
whoops = load_dataset("nlphuji/whoops", use_auth_token=auth_token)['test']
print(f"Loaded WMTIS, first example:")
print(whoops[0])
dataset_size = len(whoops)
print(f"all dataset size: {dataset_size}")

IMAGE = 'image'
IMAGE_DESIGNER = 'image_designer'
DESIGNER_EXPLANATION = 'designer_explanation'
CROWD_CAPTIONS = 'crowd_captions'
CROWD_EXPLANATIONS = 'crowd_explanations'
CROWD_UNDERSPECIFIED_CAPTIONS = 'crowd_underspecified_captions'
SELECTED_CAPTION = 'selected_caption'
COMMONSENSE_CATEGORY = 'commonsense_category'
QA = 'question_answering_pairs'
IMAGE_ID = 'image_id'
left_side_columns = [IMAGE]
right_side_columns = [x for x in whoops.features.keys() if x not in left_side_columns and x not in [QA]]
enumerate_cols = [CROWD_CAPTIONS, CROWD_EXPLANATIONS, CROWD_UNDERSPECIFIED_CAPTIONS]
emoji_to_label = {IMAGE_DESIGNER: '🎨, πŸ§‘β€πŸŽ¨, πŸ’»', DESIGNER_EXPLANATION: 'πŸ’‘, πŸ€”, πŸ§‘β€πŸŽ¨',
                  CROWD_CAPTIONS: 'πŸ‘₯, πŸ’¬, πŸ“', CROWD_EXPLANATIONS: 'πŸ‘₯, πŸ’‘, πŸ€”', CROWD_UNDERSPECIFIED_CAPTIONS: 'πŸ‘₯, πŸ’¬, πŸ‘Ž',
                  QA: '❓, πŸ€”, πŸ’‘', IMAGE_ID: 'πŸ”, πŸ“„, πŸ’Ύ', COMMONSENSE_CATEGORY: 'πŸ€”, πŸ“š, πŸ’‘', SELECTED_CAPTION: 'πŸ“, πŸ‘Œ, πŸ’¬'}
target_size = (1024, 1024)

def get_instance_values(example):
    values = []
    for k in left_side_columns + right_side_columns:
        if k in enumerate_cols:
            value = list_to_string(example[k])
        elif k == QA:
            qa_list = [f"Q: {x[0]} A: {x[1]}" for x in example[k]]
            value = list_to_string(qa_list)
        else:
            value = example[k]
        values.append(value)
    return values


def list_to_string(lst):
    return '\n'.join(['{}. {}'.format(i + 1, item) for i, item in enumerate(lst)])

def plot_image(index):
    example = whoops_sample[index]
    instance_values = get_instance_values(example)
    assert len(left_side_columns) == len(
        instance_values[:len(left_side_columns)])  # excluding the image & designer
    for key, value in zip(left_side_columns, instance_values[:len(left_side_columns)]):
        if key == IMAGE:
            img = whoops_sample[index]["image"]
            img_resized = img.resize(target_size)
            gr.Image(value=img_resized, label=whoops_sample[index]['commonsense_category'])
        else:
            label = key.capitalize().replace("_", " ")
            gr.Textbox(value=value, label=f"{label} {emoji_to_label[key]}")
    with gr.Accordion("Click for details", open=False):
        assert len(right_side_columns) == len(
            instance_values[len(left_side_columns):])  # excluding the image & designer
        for key, value in zip(right_side_columns, instance_values[len(left_side_columns):]):
            label = key.capitalize().replace("_", " ")
            gr.Textbox(value=value, label=f"{label} {emoji_to_label[key]}")


columns_number = 3
# rows_number = int(dataset_size / columns_number)
rows_number = 25
whoops_sample = whoops.shuffle().select(range(0, columns_number * rows_number))
index = 0

with gr.Blocks() as demo:
    gr.Markdown(f"# WHOOPS! Dataset Explorer")
    for row_num in range(0, rows_number):
        with gr.Row():
            for col_num in range(0, columns_number):
                with gr.Column():
                    plot_image(index)
                    index += 1
demo.launch()