import streamlit as st
import time
import requests

import os
import json
import glob
import re
import random
import difflib

from random import randrange

enable_summary_button = False

prefix_lst = [
  "pgj_d_4096", 
  "pgj_d_2048", 
  "pgj_d_1024_v2", 
  "pgj_d_1024_layer_14", 
  "pgj_d_1024_layer_7", 
  "pgj_d_1024_layer_2", 
  "pgj_d_1024_layer_1" ]

model_names = {
  prefix_lst[0]: 'PatentGPT-J-6B',
  prefix_lst[1]: 'PatentGPT-J-1.6B',
  prefix_lst[2]: 'PatentGPT-J-456M',
  prefix_lst[3]: 'PatentGPT-J-279M',
  prefix_lst[4]: 'PatentGPT-J-191M',
  prefix_lst[5]: 'PatentGPT-J-128M',
  prefix_lst[6]: 'PatentGPT-J-115M',}

# experiment 3
folder = os.path.join('experiments', 'non_patent')
id_to_scroll = 1  # which of the above to scroll through
first_claim_only = True 

#experiment 2
# folder = os.path.join('experiments', 'ipg20220104_500')
# #folder = "device_serve_results"
# id_to_scroll = 1  # which of the above to scroll through
# first_claim_only = False

# prefix_lst = ["my_gptj_6b_tpu_size_8", "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1"]
# #, "pgj_large", "pgj_medium", "pgj_small", ]
# # "pgj_d_1024_layer_14"

# experiment 1
# folder = os.path.join('experiments', 'ipg22_500')
# # (previous) folder = "eval_ipg22_500"
# id_to_scroll = 1  # which of the above to scroll through
# first_claim_only = True 

ignore_outscope = True  # ignore pick > 10

def handle_char_return(text):
  if text == '(none)':  # unicorn text
    text == ''

  return text

def calc_details(base_fn):
  full_fn = os.path.join(folder, base_fn)
  if os.path.exists(full_fn) == False:
    return None, -1, -1, None, None, None, None, None

  with open(full_fn) as f:
    result = json.loads(f.read())
    print("Loaded: %s" % full_fn)

  lst = result['output']
  recv = result['recv']
  sum_pick = 0
  sum_prob = 0 
  sum_outscope_count = 0
  sum_outscope_len = 0
  sum_hit_1 = 0
  sum_top_10_len = 0
  full_text = ''

  token_count = 0
  for i, tk in enumerate(lst[:-1]):
    token_text = handle_char_return(tk['actual_next_token_text'])
    next_top_seq = int(tk['actual_next_token_top_seq'])
    next_top_prob = float(tk['actual_next_token_top_prob']) 

    full_text += token_text
    if next_top_seq == 0:
      sum_hit_1 += 1   # press "tab" for the top pick

    if ignore_outscope and next_top_seq>=10: 
      sum_outscope_count += 1
      sum_outscope_len += len(token_text)  # use length as keystrokes 
    else:
      sum_pick += min(next_top_seq+1, len(token_text))
      #sum_pick += (next_top_seq+1) # press "down" & "tab"
      sum_prob += next_top_prob
      sum_top_10_len += len(token_text)

    token_count += 1

  if ignore_outscope: 
    if token_count == 0: # unlikely
      avg_pick = 0
      avg_prob = 0
    else:
      avg_pick = float(sum_pick) / token_count
      avg_prob = float(sum_prob) / token_count
  else:
    avg_pick = float(sum_pick) / token_count
    avg_prob = float(sum_prob) / token_count  

  return result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text

def show_avg(base_fn, model_name, patent_claim_num, show_pick=False): 
  result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)  

  if result is None:
    return None

  lst = result['output']
  result = ''
  sum_all = {}
  colors = [
    ['00ff00', '000000', '1'], 
    ['008800', 'ffffff', '2-10'], 
    ['ff0000', 'ffffff', 'out of top 10'], 
  ]

  for i, tk in enumerate(lst):
    if i == len(lst)-1:
      break

    token_text = handle_char_return(tk['actual_next_token_text'])
    if token_text == '<|end_of_claim|>': 
      break

    if token_text == '(none)': # for unicorn text
      break      

    pick = int(tk['actual_next_token_top_seq'])
    prob = float(tk['actual_next_token_top_prob'])

    for j, item in enumerate(colors):
      sum_all[item[2]] = 0

    if pick == 0:
      bg_color = colors[0][0]
      fg_color = colors[0][1]
      tag = colors[0][2]
      sum_all[tag] += 1
    elif pick >= 1 and pick < 10:
      bg_color = colors[1][0]
      fg_color = colors[1][1]
      tag = colors[1][2]
      sum_all[tag] += 1
    else: # pick >= 10
      #elif pick >= 10 and pick < 100:
      bg_color = colors[2][0]
      fg_color = colors[2][1]
      tag = colors[2][2]
      sum_all[tag] += 1
  
    if show_pick:
      pick = '[%s]' % pick
    else:
      pick = ''

    result += "<span style=background-color:#%s;color:#%s;border-radius:5px;>%s%s</span>  " % (bg_color, fg_color, token_text, pick)

  color_msg = ''
  for i, v in enumerate(colors):
    color_msg += "<span style=background-color:#%s;color:#%s;border-radius:5px;>&nbsp;%s&nbsp;</span> " % (v[0], v[1], v[2]) 


  # sum_pick as top 1~10
  keys_with_auto = (sum_pick+sum_outscope_len)
  keys_without_auto = len(full_text)
  saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100

  s = 'model: %s\n' \
    'Autocomplete Effectiveness: %.1f%% (keystrokes saved)\n' \
    'Total keystrokes: %s (with autocomplete), %s (without autocomplete)\n' \
    'Keystroke distribution: rank 1~10: %s (rank 1: %s), out of top 10: %s' % (model_name, saved_ratio, keys_with_auto, keys_without_auto,  sum_pick, sum_hit_1, sum_outscope_len)
  st.text(s)
  st.markdown(color_msg, unsafe_allow_html=True)
  st.markdown(result, unsafe_allow_html=True)
  sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['out of top 10']]

  return sum_lst

def show_overall_summary(prefix_lst, select_lst):  
  for prefix in prefix_lst:
    acc_token_count = 0
    acc_sum_pick = 0
    acc_sum_prob = 0 
    acc_sum_outscope_count = 0 
    acc_sum_outscope_len = 0 
    acc_sum_hit_1 = 0
    acc_sum_top_10_len = 0 
    acc_full_text_len = 0

    pre_full_text = ''
    for i, num in enumerate(select_lst):    
      base_fn = '%s_%s_forward.json' % (prefix, num)
      result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)  

      acc_token_count += token_count
      acc_sum_pick += sum_pick
      acc_sum_prob += sum_prob
      acc_sum_outscope_count += sum_outscope_count
      acc_sum_outscope_len += sum_outscope_len
      acc_sum_hit_1 += sum_hit_1
      acc_sum_top_10_len += sum_top_10_len
      acc_full_text_len += len(full_text)

    if acc_token_count > 0:
      # acc_sum_pick --> top 1~10
      keys_with_auto = acc_sum_pick + acc_sum_outscope_len
      keys_without_auto = acc_full_text_len
      saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100

      st.text('[ %s ]\n' \
        'Autocomplete Effectiveness: %.1f%% (ratio of saving keystroke)\n' \
        '(sum) keys_with_auto: %s, top_10_keys: %s, out_of_scope: %s, sum_hit_1: %s\n' \
        'keys_without_auto: %s, top_10_len: %s, prob: %.2f' % (
        model_names[prefix], saved_ratio, 
        '{:,}'.format(keys_with_auto),
        '{:,}'.format(acc_sum_pick), 
        '{:,}'.format(acc_sum_outscope_len),
        '{:,}'.format(acc_sum_hit_1), 
        '{:,}'.format(keys_without_auto),
        '{:,}'.format(acc_sum_top_10_len), 
        acc_sum_prob, 
        ))

      st.text('%s & %.1f\\%% & %s & %s & %s & %s & %s \\\\' % (model_names[prefix], saved_ratio, '{:,}'.format(keys_with_auto), '{:,}'.format(acc_sum_pick), '{:,}'.format(acc_sum_outscope_len), '{:,}'.format(acc_sum_hit_1), '{:,}'.format(keys_without_auto)))

      # st.text('* acc_token_count =%s --> (avg) hits: %.2f, keys: %.2f, prob: %.2f, outscope: %.2f' % (
      #     acc_token_count, 
      #     float(acc_sum_hit_1)/acc_token_count,
      #     float(acc_sum_pick)/acc_token_count, 
      #     float(acc_sum_prob)/acc_token_count, 
      #     float(acc_sum_outscope_count)/acc_token_count))

def main():
  st.set_page_config(  # Alternate names: setup_page, page, layout
    layout="wide",  # Can be "centered" or "wide". In the future also "dashboard", etc.
    initial_sidebar_state="auto",  # Can be "auto", "expanded", "collapsed"
    page_title="Patent-GPT-J demo",  # String or None. Strings get appended with "• Streamlit".
    page_icon=None,  # String, anything supported by st.image, or None.
  )
  st.subheader("PatentGPT-J Demo 3 (Autocomplete Effectiveness)")
  st.text("Data coverage: unicorn text")

  num_set = set()
  fn_lst = glob.glob(os.path.join(folder, '*'))
  for i, fn in enumerate(fn_lst):
    for prefix in prefix_lst:    
      v = re.search('(.*?)%s\_(\d+\_\d+)\_(.*?)' % prefix, fn)
      if v is None:
        v = re.search('(.*?)%s\_(\w+\_\d+)\_(.*?)' % prefix, fn)
        if v is None:
          continue

      v = v.group(2)
      if first_claim_only:
        if v.endswith('_1'):
          num_set.add(v)
      else: 
        num_set.add(v)

  num_lst = list(num_set)
  num_lst.sort()

  select_lst = []
  for i, num in enumerate(num_lst):
    all_existed = True
    for prefix in prefix_lst:
      fn = os.path.join(folder, '%s_%s_forward.json' % (prefix, num))
      if os.path.exists(fn) == False:
        all_existed = False
        break
    if all_existed: 
      select_lst.append(num)
  select_lst.sort()

  show_patent_lst = [ s.replace('_', ' (claim ') + ')' for s in select_lst]
  pick = random.randrange(len(select_lst))
  num = select_lst[pick]

  #st.text('debug 1')

  avgs = []
  for prefix in prefix_lst:
    base_fn = '%s_%s_forward.json' % (prefix, num)
    one_avg = show_avg(base_fn, model_names[prefix], num)
    if one_avg is not None:
      avgs.append(one_avg) 

if __name__ == "__main__":
  main()