import gradio as gr
import datetime
import json
import requests
from constants import *
def process(query_type, index_desc, **kwargs):
timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
index = INDEX_BY_DESC[index_desc]
data = {
'source': 'hf' if not DEBUG else 'hf-dev',
'timestamp': timestamp,
'query_type': query_type,
'index': index,
}
data.update(kwargs)
print(json.dumps(data))
if API_URL is None:
raise ValueError(f'API_URL envvar is not set!')
try:
response = requests.post(API_URL, json=data, timeout=30)
except requests.exceptions.Timeout:
raise ValueError('Web request timed out. Please try again later.')
except requests.exceptions.RequestException as e:
raise ValueError(f'Web request error: {e}')
if response.status_code == 200:
result = response.json()
else:
raise ValueError(f'HTTP error {response.status_code}: {response.json()}')
if DEBUG:
print(result)
return result
def creativity(index_desc, query):
result = process('creativity', index_desc, query=query)
latency = '' if 'latency' not in result else f'{result["latency"]:.3f}'
if 'error' in result:
ci = result['error']
htmls = [''] * (NGRAM_LEN_MAX - NGRAM_LEN_MIN + 1)
return tuple([latency, ci] + htmls)
rs = result['rs']
tokens = result['tokens']
highlighteds_by_n = {}
uniqueness_by_n = {}
for n in range(NGRAM_LEN_MIN, NGRAM_LEN_MAX + 1):
highlighteds = [False] * len(tokens)
last_r = 0
for l, r in enumerate(rs):
if r - l < n:
continue
for i in range(max(last_r, l), r):
highlighteds[i] = True
last_r = r
uniqueness = sum([1 for h in highlighteds if not h]) / len(highlighteds)
highlighteds_by_n[n] = highlighteds
uniqueness_by_n[n] = uniqueness
ci = sum(uniqueness_by_n.values()) / len(uniqueness_by_n)
ci = f'{ci:.2%}'
htmls = []
for n in range(NGRAM_LEN_MIN, NGRAM_LEN_MAX + 1):
html = ''
highlighteds = highlighteds_by_n[n]
line_len = 0
for i, (token, highlighted) in enumerate(zip(tokens, highlighteds)):
if line_len >= MAX_DISP_CHARS_PER_LINE and token.startswith('▁'):
html += '
'
line_len = 0
color = '(255, 128, 128, 0.5)'
if token == '<0x0A>':
disp_token = '\\n'
is_linebreak = True
else:
disp_token = token.replace('▁', ' ')
is_linebreak = False
if highlighted:
html += f'{disp_token}'
else:
html += disp_token
if is_linebreak:
html += '
'
line_len = 0
else:
line_len += len(token)
html = '
' + html.strip(' ') + '
Compute the Creativity Index of a piece of text.
The computed Creativity Index is based on verbatim match and is supported by infini-gram.
''' ) with gr.Row(): with gr.Column(scale=1, min_width=240): index_desc = gr.Radio(choices=INDEX_DESCS, label='Corpus', value=INDEX_DESCS[0]) with gr.Column(scale=3): creativity_query = gr.Textbox(placeholder='Enter a piece of text here', label='Input', interactive=True, lines=10) with gr.Row(): creativity_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True) creativity_submit = gr.Button(value='Submit', variant='primary', visible=True) creativity_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1) with gr.Column(scale=4): creativity_ci = gr.Label(value='', label='Creativity Index') creativity_htmls = [] for n in range(NGRAM_LEN_MIN, NGRAM_LEN_MAX + 1): with gr.Tab(label=f'n={n}'): creativity_htmls.append(gr.HTML(value='', label=f'n={n}')) creativity_clear.add([creativity_query, creativity_latency, creativity_ci] + creativity_htmls) creativity_submit.click(creativity, inputs=[index_desc, creativity_query], outputs=[creativity_latency, creativity_ci] + creativity_htmls, api_name=False) demo.queue( default_concurrency_limit=DEFAULT_CONCURRENCY_LIMIT, max_size=MAX_SIZE, api_open=False, ).launch( max_threads=MAX_THREADS, debug=DEBUG, show_api=False, )