import gradio as gr
import os
from interface_utils import *
maxim = 'quality'
submaxims = ["The response is factual and supported by adequate evidence whenever possible."]
checkbox_choices = [
["Yes", "No", "NA"]
]
conversation_data_sliced = load_from_jsonl('./data/conversations_unlabeled_sliced.jsonl')
max_conversation_length = max([len(conversation['transcript']) for conversation in conversation_data_sliced])
conversation = get_conversation(conversation_data_sliced)
def save_labels(conv_id, slice_idx, skipped, submaxim_0=None):
data = {
'conv_id': conv_id,
'slice_idx': int(slice_idx),
'maxim': maxim,
'skipped': skipped,
'submaxim_0': submaxim_0,
}
os.makedirs("./labels", exist_ok=True)
with open(f"./labels/{maxim}_human_labels_{conv_id}_{slice_idx}.json", 'w') as f:
json.dump(data, f, indent=4)
def update_interface(new_conversation):
new_conv_id = new_conversation['conv_id']
new_slice_idx = new_conversation['slice_idx']
new_transcript = new_conversation['transcript']
markdown_blocks = [None] * max_conversation_length
for i in range(max_conversation_length):
if i < len(new_transcript) and new_transcript[i]['speaker'] != '':
markdown_blocks[i] = gr.Markdown(
f""" **{new_transcript[i]['speaker']}**: {new_transcript[i]['response']}""",
visible=True)
else:
markdown_blocks[i] = gr.Markdown("", visible=False)
new_last_response = gr.Text(value=get_last_response(new_transcript),
label="",
lines=1,
container=False,
interactive=False,
autoscroll=True,
visible=True)
new_radio_0_base = gr.Radio(label=submaxims[0],
choices=checkbox_choices[0],
value=None,
visible=True)
conv_len = gr.Number(value=len(new_transcript), visible=False)
return [new_conv_id] + [new_slice_idx] + list(markdown_blocks) + [new_last_response] + [new_radio_0_base] + [conv_len]
def submit(*args):
conv_id = args[0]
slice_idx = args[1]
submaxim_0 = args[-2]
save_labels(conv_id, slice_idx, skipped=False, submaxim_0=submaxim_0)
new_conversation = get_conversation(conversation_data_sliced)
return update_interface(new_conversation)
def skip(*args):
conv_id = args[0]
slice_idx = args[1]
save_labels(conv_id, slice_idx, skipped=True)
new_conversation = get_conversation(conversation_data_sliced)
return update_interface(new_conversation, slice_idx)
with gr.Blocks(theme=gr.themes.Default()) as interface:
conv_id = conversation['conv_id']
slice_idx = conversation['slice_idx']
transcript = conversation['transcript']
conv_len = gr.Number(value=len(transcript), visible=False)
markdown_blocks = [None] * max_conversation_length
with gr.Column(scale=1, min_width=600):
with gr.Group():
gr.Markdown(""" **Conversational context** """,
visible=True)
for i in range(max_conversation_length):
if i < len(transcript):
markdown_blocks[i] = gr.Markdown(
f""" **{transcript[i]['speaker']}**: {transcript[i]['response']}""")
else:
markdown_blocks[i] = gr.Markdown("")
if i >= conv_len.value:
markdown_blocks[i].visible = False
with gr.Row():
with gr.Group(elem_classes="bottom-aligned-group"):
speaker_adapted = gr.Markdown(
f""" **Response to label** """,
visible=True)
last_response = gr.Textbox(value=get_last_response(transcript),
label="",
lines=1,
container=False,
interactive=False,
autoscroll=True,
visible=True)
radio_submaxim_0_base = gr.Radio(label=submaxims[0],
choices=checkbox_choices[0],
value=None,
visible=True)
submit_button = gr.Button("Submit")
skip_button = gr.Button("Skip")
conv_id_element = gr.Text(value=conv_id, visible=False)
slice_idx_element = gr.Text(value=slice_idx, visible=False)
input_list = [conv_id_element] + \
[slice_idx_element] + \
markdown_blocks + \
[last_response] + \
[radio_submaxim_0_base] + \
[conv_len]
submit_button.click(
fn=submit,
inputs=input_list,
outputs=[conv_id_element,
slice_idx_element,
*markdown_blocks,
last_response,
radio_submaxim_0_base,
conv_len]
)
skip_button.click(
fn=skip,
inputs=input_list,
outputs=[conv_id_element,
slice_idx_element,
*markdown_blocks,
last_response,
radio_submaxim_0_base,
conv_len]
)
css = """
#textbox_id textarea {
background-color: white;
}
.bottom-aligned-group {
display: flex;
flex-direction: column;
justify-content: flex-end;
height: 100%;
}
"""
interface.css = css
interface.launch()