import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import re from resources import banner, error_html_response model_checkpoint = 'gastronomia-para-to2/gastronomia_para_to2' tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForCausalLM.from_pretrained(model_checkpoint) special_tokens = [ '', '', '', '', '', '', '', '', '', '', ''] def frame_html_response(html_response): return f"""""" def check_special_tokens_order(pre_output): return (pre_output.find('') < pre_output.find('') <= pre_output.rfind('') < pre_output.find('') < pre_output.find('') < pre_output.find('') <= pre_output.rfind('') < pre_output.find('') < pre_output.find('') < pre_output.find('') <= pre_output.rfind('') < pre_output.find('') < pre_output.find('') < pre_output.find('')) def make_html_response(title, ingredients, instructions): ingredients_html_list = '
  • ' + '
  • '.join(ingredients) + '
' instructions_html_list = '
  1. ' + '
  2. '.join(instructions) + '
' html_response = f'''

{title}

Ingredientes

{ingredients_html_list}

Instrucciones

{instructions_html_list} ''' return html_response def rerun_model_output(pre_output): if pre_output is None: return True elif not '' in pre_output: print(' not in pre_output') return True pre_output_trimmed = pre_output[:pre_output.find('')] if not all(special_token in pre_output_trimmed for special_token in special_tokens): print('Not all special tokens are in preoutput') return True elif not check_special_tokens_order(pre_output_trimmed): print('Special tokens are unordered in preoutput') return True elif len(pre_output_trimmed.split())<75: print('Length of the recipe is <75') return True else: return False def generate_output(tokenized_input): pre_output = None while rerun_model_output(pre_output): output = model.generate(**tokenized_input, max_length=600, do_sample=True, top_p=0.92, top_k=50, # no_repeat_ngram_size=2, num_return_sequences=3) pre_output = tokenizer.decode(output[0], skip_special_tokens=False) pre_output_trimmed = pre_output[:pre_output.find('')] return pre_output_trimmed def check_wrong_ingredients(ingredients): if ingredients is None: return True if any(ingredient.startswith('De') for ingredient in ingredients): print('At least one ingredient starts with De') return True def make_recipe(input_ingredients): input_ingredients = re.sub(' y ', ', ', input_ingredients) input = ' ' input += ' ' + ' '.join(input_ingredients.split(', ')) + ' ' input += ' ' tokenized_input = tokenizer(input, return_tensors='pt') output_ingredients = None i = 0 while check_wrong_ingredients(output_ingredients): if i == 4: return frame_html_response(error_html_response) pre_output_trimmed = generate_output(tokenized_input) output_ingredients = re.search(' (.*) ', pre_output_trimmed).group(1) output_ingredients = output_ingredients.split(' ') output_ingredients = list(set([output_ingredient.strip() for output_ingredient in output_ingredients])) output_ingredients = [output_ing.capitalize() for output_ing in output_ingredients] i += 1 output_title = re.search(' (.*) ', pre_output_trimmed).group(1).strip().capitalize() output_instructions = re.search(' (.*) ', pre_output_trimmed).group(1) output_instructions = output_instructions.split(' ') html_response = make_html_response(output_title, output_ingredients, output_instructions) return frame_html_response(html_response) iface = gr.Interface( fn=make_recipe, inputs= [ gr.inputs.Textbox(lines=1, placeholder='ingrediente_1, ingrediente_2, ..., ingrediente_n', label='Dime con qué ingredientes quieres que cocinemos hoy y te sugeriremos una receta tan pronto como nuestros fogones estén libres'), ], outputs= [ gr.outputs.HTML(label="¡Esta es mi propuesta para ti! ¡Buen provecho!") ], examples= [ ['salmón, zumo de naranja, aceite de oliva, sal, pimienta'], ['harina, azúcar, huevos, chocolate, levadura Royal'] ], description=banner) iface.launch(enable_queue=True)