File size: 167,048 Bytes
4b7f9b6
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"mount_file_id":"1InRgLMZPkNA6-Y62poIwyeNylGY8rYDk","authorship_tag":"ABX9TyMQONHdmulQ0Lve0gsrzSqR"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"7d50cc4d70a8453ea1023c081dd107b6":{"model_module":"@jupyter-widgets/controls","model_name":"VBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"VBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"VBoxView","box_style":"","children":["IPY_MODEL_d84e05020333431296fd65c7c3be9b2a","IPY_MODEL_b2f4e6f16f0a4d3b90717d0179ed642a","IPY_MODEL_3084ffb3b3664cfbb6b9931bea153f83","IPY_MODEL_68fd6d6b9e6c4adeb5431e29e735713d","IPY_MODEL_1d9521b2ef344ec4a8bf5a4d270d802b"],"layout":"IPY_MODEL_7a6f7ecb69c24b6583c3e55a8b22f4c8"}},"d84e05020333431296fd65c7c3be9b2a":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_5998e43917dc48618e326cfa5d148e79","placeholder":"​","style":"IPY_MODEL_b01977fca14b443d9f10b7aea7575c7a","value":"<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svg\nalt='Hugging Face'> <br> Copy a token from <a\nhref=\"https://huggingface.co/settings/tokens\" target=\"_blank\">your Hugging Face\ntokens page</a> and paste it below. <br> Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file. </center>"}},"b2f4e6f16f0a4d3b90717d0179ed642a":{"model_module":"@jupyter-widgets/controls","model_name":"PasswordModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"PasswordModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"PasswordView","continuous_update":true,"description":"Token:","description_tooltip":null,"disabled":false,"layout":"IPY_MODEL_e636d0177e1e4d0a8b65896085538ef0","placeholder":"​","style":"IPY_MODEL_471c56e7723e4bb892ac8bfe97523304","value":""}},"3084ffb3b3664cfbb6b9931bea153f83":{"model_module":"@jupyter-widgets/controls","model_name":"CheckboxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"CheckboxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"CheckboxView","description":"Add token as git credential?","description_tooltip":null,"disabled":false,"indent":true,"layout":"IPY_MODEL_6c5f07b8a0274ea48baf5825c496bfb1","style":"IPY_MODEL_219ec2d07f764f9fb137534c89682b7f","value":true}},"68fd6d6b9e6c4adeb5431e29e735713d":{"model_module":"@jupyter-widgets/controls","model_name":"ButtonModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ButtonModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ButtonView","button_style":"","description":"Login","disabled":false,"icon":"","layout":"IPY_MODEL_4abf542d1bd54db491bd53a4bd6ae39f","style":"IPY_MODEL_b8e6ba6cafd64e43a7c6dd40fc6cb377","tooltip":""}},"1d9521b2ef344ec4a8bf5a4d270d802b":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_a41054d6970c4fc2989e8942c1e8373f","placeholder":"​","style":"IPY_MODEL_621bc786ebaa46f3bb5e66312644a65c","value":"\n<b>Pro Tip:</b> If you don't already have one, you can create a dedicated\n'notebooks' token with 'write' access, that you can then easily reuse for all\nnotebooks. </center>"}},"7a6f7ecb69c24b6583c3e55a8b22f4c8":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":"center","align_self":null,"border":null,"bottom":null,"display":"flex","flex":null,"flex_flow":"column","grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"50%"}},"5998e43917dc48618e326cfa5d148e79":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"b01977fca14b443d9f10b7aea7575c7a":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"e636d0177e1e4d0a8b65896085538ef0":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"471c56e7723e4bb892ac8bfe97523304":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"6c5f07b8a0274ea48baf5825c496bfb1":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"219ec2d07f764f9fb137534c89682b7f":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"4abf542d1bd54db491bd53a4bd6ae39f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"b8e6ba6cafd64e43a7c6dd40fc6cb377":{"model_module":"@jupyter-widgets/controls","model_name":"ButtonStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ButtonStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","button_color":null,"font_weight":""}},"a41054d6970c4fc2989e8942c1e8373f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"621bc786ebaa46f3bb5e66312644a65c":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"markdown","source":["## Preparing the enviroment"],"metadata":{"id":"ifSzFnZl1M5V"}},{"cell_type":"code","execution_count":1,"metadata":{"id":"Eyy-XD2tfVzx","executionInfo":{"status":"ok","timestamp":1720530566765,"user_tz":240,"elapsed":16497,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"outputs":[],"source":["import torch\n","from torch import nn\n","from torch.nn import functional as F\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import time\n","import pandas as pd\n","from collections import OrderedDict\n","from torch.utils.data import DataLoader, Dataset\n","from torch.utils.tensorboard import SummaryWriter\n","from datetime import datetime"]},{"cell_type":"code","source":["# diretório do córpus\n","%cd /content/drive/MyDrive/Mestrado/NLP_2024/TP_LLM"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"W6fVNbBU2lzH","executionInfo":{"status":"ok","timestamp":1720530569323,"user_tz":240,"elapsed":892,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"10c239a1-5353-4e99-9669-6de81321be7e"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["/content/drive/MyDrive/Mestrado/NLP_2024/TP_LLM\n"]}]},{"cell_type":"code","source":["# model hyperparameters\n","config = {\n","    'd_model': 64,\n","    'n_heads': 8,\n","    'n_layers': 6,\n","    'context_window': 22,\n","    'epochs': 100,\n","    'log_interval': 10,\n","}"],"metadata":{"id":"jz-w5gh2bUZ0","executionInfo":{"status":"ok","timestamp":1720530570641,"user_tz":240,"elapsed":4,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":3,"outputs":[]},{"cell_type":"markdown","source":["## Carregamento e pré-processamento dos dados\n","\n","---\n","Nessa seção, ocorrem as seguites etapas:\n","\n","\n","1.   Teinamento do tokenizador no córpus. O tokenizador utilizado utilizado foi o `sentencepiece` com o algoritmo `Byte-pair encoding`.\n","2.   Todo o córpus é tokenizado, e o vocabulário é construído a partir do conjunto de tokens originado.\n","3. O dataset é montado a partir da tokenização do córpus.\n","4. Partição do dataset em treino, teste e validação.\n","5. Transformação dessas partições para o tipo `Dataloader` do Pytorch, a fim de facilitar o treinamento em batches.\n","\n","\n","\n"],"metadata":{"id":"OSAlYt2R2hm7"}},{"cell_type":"code","source":["# ! pip install sentencepiece\n","import sentencepiece as spm"],"metadata":{"id":"RzBkAXYlfDjS","executionInfo":{"status":"ok","timestamp":1720530626588,"user_tz":240,"elapsed":355,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":15,"outputs":[]},{"cell_type":"code","source":["# treinamento do tokenizador SentencePiece\n","spm.SentencePieceTrainer.Train('--input=./input.txt --model_prefix=tiny_shakespeare --vocab_size=8000 --model_type=bpe')"],"metadata":{"id":"5sx2WRFzrMRX"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# carregamento do tokenizador\n","tokenizer = spm.SentencePieceProcessor()\n","tokenizer.load('tiny_shakespeare.model')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"37PjjLlvr2xC","executionInfo":{"status":"ok","timestamp":1720530629498,"user_tz":240,"elapsed":491,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"a4ec245a-050c-47ea-89d1-735471b7417d"},"execution_count":16,"outputs":[{"output_type":"execute_result","data":{"text/plain":["True"]},"metadata":{},"execution_count":16}]},{"cell_type":"code","source":["# construção do vocabulário (criação do vocabulário)\n","lines = open('./input.txt', 'r').read()  # Lê o conteúdo do arquivo de entrada\n","\n","vocab = sorted(set(tokenizer.EncodeAsPieces(lines)))  # Codifica o texto e cria um conjunto único e ordenado de tokens\n","itos = {i:ch for i, ch in enumerate(vocab)}  # Dicionário que mapeia índices para tokens (índice -> caractere)\n","stoi = {ch:i for i, ch in enumerate(vocab)}  # Dicionário que mapeia tokens para índices (caractere -> índice)\n","print('tamanho do vocabulário:', len(vocab))  # Imprime o tamanho do vocabulário (número de tokens únicos)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Z54bTde_tG3H","executionInfo":{"status":"ok","timestamp":1720530635292,"user_tz":240,"elapsed":2894,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"318deb54-485c-49ab-b938-b068783ddf8a"},"execution_count":17,"outputs":[{"output_type":"stream","name":"stdout","text":["tamanho do vocabulário: 7268\n"]}]},{"cell_type":"code","source":["def encode(s):\n","    # Codifica uma string `s` em uma lista de índices inteiros.\n","    return [stoi[token] for token in tokenizer.EncodeAsPieces(s)]\n","\n","def decode(l):\n","    # Decodifica uma lista de índices inteiros `l` de volta para uma string.\n","    return ''.join([itos[i] for i in l])"],"metadata":{"id":"jbxWLO-pilFX","executionInfo":{"status":"ok","timestamp":1720530637982,"user_tz":240,"elapsed":345,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":18,"outputs":[]},{"cell_type":"code","source":["encode(\"hello\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"o8JcNdqZfDr7","executionInfo":{"status":"ok","timestamp":1720495947860,"user_tz":240,"elapsed":309,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"d21dbd92-4c9b-42b0-81dc-526f32fb089f"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["[4670, 1160]"]},"metadata":{},"execution_count":41}]},{"cell_type":"code","source":["config"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"v6W_nS-TaaBI","executionInfo":{"status":"ok","timestamp":1720494835300,"user_tz":240,"elapsed":419,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"727f1f58-935c-4acb-f7cb-dda402bc7731"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["{'d_model': 64,\n"," 'n_heads': 8,\n"," 'n_layers': 6,\n"," 'context_window': 22,\n"," 'epochs': 100,\n"," 'log_interval': 10,\n"," 'vocab_size': 7268,\n"," 'batch_size': 32}"]},"metadata":{},"execution_count":112}]},{"cell_type":"code","source":["# construção do dataset a partir do córpus tokenizado e update das configs com o tamanho do vocbulário\n","config.update({\n","   \"vocab_size\": len(vocab),\n","})\n","\n","dataset = torch.tensor(encode(lines), dtype=torch.int16)\n","dataset.shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9cWFh_-w4N4j","executionInfo":{"status":"ok","timestamp":1720530649649,"user_tz":240,"elapsed":333,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"28205fe6-19ae-4552-ba99-9bb4e8fc646d"},"execution_count":19,"outputs":[{"output_type":"execute_result","data":{"text/plain":["torch.Size([279996])"]},"metadata":{},"execution_count":19}]},{"cell_type":"code","source":["# Define a função para montar os splits de teste, treino e validação do dataset\n","def mount_dataset_splits(data, context_window, config=config):\n","\n","  # Calcula o tamanho de 80% do dataset para o treino\n","  train = data[:int(.8 * len(data))]\n","  # Calcula o tamanho de 10% do dataset para validação (entre 80% e 90%)\n","  val = data[int(.8 * len(data)): int(.9 * len(data))]\n","  # O restante do dataset (10%) fica para teste\n","  test = data[int(.9 * len(data)):]\n","\n","  # Dicionário para armazenar os splits\n","  data_splitted = {}\n","\n","  # Itera por cada split ('train', 'val', 'test')\n","  for split in ['train', 'val', 'test']:\n","    # Define os dados do batch de acordo com o split\n","    batch_data = train\n","    if split == 'val':\n","        batch_data = val  # Seleciona dados de validação\n","\n","    if split == 'test':\n","        batch_data = test  # Seleciona dados de teste\n","\n","    # Cria indices para janelas de contexto (tamanho do batch - janela - 1)\n","    ix = torch.arange(0, batch_data.size(0)-context_window-1, context_window)\n","\n","    # Cria o tensor 'x' com as janelas de contexto (entradas)\n","    x = torch.stack([batch_data[i:i+context_window] for i in ix]).long()\n","\n","    # Cria o tensor 'y' com os próximos valores (saídas)\n","    y = torch.stack([batch_data[i+1:i+context_window+1] for i in ix]).long()\n","\n","    # Adiciona o par (entradas, saídas) no dicionário para o split atual\n","    data_splitted[split] = (x, y)\n","\n","  # Retorna os splits de treino, validação e teste\n","  return data_splitted['train'], data_splitted['val'], data_splitted['test']\n"],"metadata":{"id":"VswFFIHD199G","executionInfo":{"status":"ok","timestamp":1720531305811,"user_tz":240,"elapsed":340,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":33,"outputs":[]},{"cell_type":"code","source":["class TorchDataset(Dataset):\n","    # Classe personalizada para encapsular um dataset do PyTorch.\n","    # Recebe um dataset existente e o encapsula para ser utilizado com o PyTorch.\n","\n","    def __init__(self, dataset):\n","        # Método construtor da classe.\n","        # dataset: O dataset existente que será encapsulado.\n","        self.dataset = dataset\n","\n","    def __len__(self):\n","        # Define o tamanho do dataset.\n","        # Retorna o comprimento da segunda lista dentro do dataset encapsulado, que geralmente representa os alvos.\n","        return len(self.dataset[1])  # Comprimento da lista de alvos\n","\n","    def __getitem__(self, idx):\n","        # Recupera um item específico do dataset.\n","        # idx: O índice do item a ser recuperado.\n","        # Retorna uma tupla contendo a entrada (primeira lista) e o alvo (segunda lista) correspondentes ao índice fornecido.\n","        input = self.dataset[0][idx]  # Entrada no índice especificado\n","        target = self.dataset[1][idx]  # Alvo no índice especificado\n","        return input, target\n","\n"],"metadata":{"id":"Cg1iyhi7PW8b","executionInfo":{"status":"ok","timestamp":1720531307332,"user_tz":240,"elapsed":4,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":34,"outputs":[]},{"cell_type":"code","source":["config.update({\n","    'batch_size': 32,\n","})"],"metadata":{"id":"kERNaCcUbtsM","executionInfo":{"status":"ok","timestamp":1720531308931,"user_tz":240,"elapsed":2,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":35,"outputs":[]},{"cell_type":"code","source":["# Divide o dataset em conjuntos de treino, validação e teste\n","train, test, val = mount_dataset_splits(dataset, config['context_window'])\n","\n","# Converte os conjuntos para o formato TorchDataset\n","train = TorchDataset(train)\n","test = TorchDataset(test)\n","val = TorchDataset(val)\n","\n","# Imprime o tamanho dos conjuntos de treino e validação\n","print('O conjunto de treino possui {} instâncias'.format(len(train)))\n","print('O conjunto de validação possui {} instâncias'.format(len(val)))\n","\n","# Carregadores de dados (dataloaders) para treino, validação e teste\n","train_dataloader = torch.utils.data.DataLoader(train, config['batch_size'], shuffle=True)\n","validation_loader = torch.utils.data.DataLoader(val, config['batch_size'], shuffle=False)\n","test_dataloader = torch.utils.data.DataLoader(test, config['batch_size'], shuffle=True)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"D_q5Je_fc0h_","executionInfo":{"status":"ok","timestamp":1720531310551,"user_tz":240,"elapsed":488,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"1bd814ce-35e1-478b-f290-6bb69aca0a97"},"execution_count":36,"outputs":[{"output_type":"stream","name":"stdout","text":["O conjunto de treino possui 10181 instâncias\n","O conjunto de validação possui 1272 instâncias\n"]}]},{"cell_type":"code","source":["train_features, train_labels = next(iter(train_dataloader))\n","print(f\"Feature batch shape: {train_features.size()}\")\n","print(f\"Labels batch shape: {train_labels.size()}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NMqX3YkIYDY5","executionInfo":{"status":"ok","timestamp":1720531313098,"user_tz":240,"elapsed":347,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"dd90cc18-a435-4dc7-c54c-dc413f666b98"},"execution_count":37,"outputs":[{"output_type":"stream","name":"stdout","text":["Feature batch shape: torch.Size([32, 22])\n","Labels batch shape: torch.Size([32, 22])\n"]}]},{"cell_type":"markdown","source":["## Defining the model"],"metadata":{"id":"aEZ52i4Z_VrZ"}},{"cell_type":"markdown","source":["### Componentes Llama\n","---\n","Implementação dos componentes principais da arquietetura do Llama:\n","\n","\n","* RMSNorm\n","* SwiGLU\n","* Rotary Embeddings\n","\n","\n"],"metadata":{"id":"p5kj5Yr1XJiJ"}},{"cell_type":"markdown","source":["RMSNorm"],"metadata":{"id":"7culAQVIXRGX"}},{"cell_type":"code","source":["class RMSNorm(torch.nn.Module):\n","  # Classe de normalização RMS (Root Mean Square).\n","\n","  # Atributos:\n","  #     input_shape (tupla): Dimensão da entrada do tensor.\n","  #     eps (float, opcional): Valor pequeno para evitar divisão por zero.\n","  #         Padrão é 1e-6.\n","  #     gi (nn.Parameter): Vetor de pesos aprendíveis, com o mesmo tamanho da\n","  #         entrada.\n","\n","  def __init__(self, input_shape, eps = 1e-6):\n","      super().__init__()\n","      self.eps = eps  # Valor para evitar divisão por zero\n","      self.input_shape = input_shape\n","      self.gi = nn.Parameter(torch.ones(input_shape))  # Vetor de pesos aprendíveis\n","      self.register_parameter(\"gi\", self.gi)\n","\n","  def forward(self, input_tensor):\n","    # Realiza a normalização RMS na entrada.\n","\n","    # Parâmetros:\n","    #     input_tensor (torch.Tensor): Tensor de entrada.\n","\n","    # Retorno:\n","    #     torch.Tensor: Tensor normalizado.\n","    norm_input = input_tensor.norm(2, dim=(1,2))  # Normalização longo dos eixos 1 e 2\n","    rms_input = norm_input * input_tensor[0].numel() ** (-1. / 2)  # Raiz quadrada média\n","\n","    # Normalização RMS com adição de epsilon para evitar divisão por zero\n","    input_normed = input_tensor / (rms_input.unsqueeze(-1).unsqueeze(-1) + self.eps)\n","    return input_normed * self.gi  # Aplica pesos aprendíveis\n"],"metadata":{"id":"AFfkupixbqPs","executionInfo":{"status":"ok","timestamp":1720530582239,"user_tz":240,"elapsed":368,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":4,"outputs":[]},{"cell_type":"markdown","source":["Rede Neural Feedforward com função de ativação SwuiGLU"],"metadata":{"id":"ztzrlx5aAouQ"}},{"cell_type":"code","source":["class FFN_SwiGLU(nn.Module):\n","\n","    def __init__(self, dim) -> None:\n","        super().__init__()\n","        # Camada linear 1: dim entradas -> dim saídas, sem viés (bias=False)\n","        self.w1 = nn.Linear(dim, dim, bias=False)\n","        # Camada linear 2: dim entradas -> dim saídas, sem viés (bias=False)\n","        self.w2 = nn.Linear(dim, dim, bias=False)\n","        # Camada linear 3: dim entradas -> dim saídas, sem viés (bias=False)\n","        self.w3 = nn.Linear(dim, dim, bias=False)\n","\n","    def forward(self, input_tensor):\n","        # Saída da camada linear 1\n","        x1 = F.linear(input_tensor, self.w1.weight)\n","        # Saída da camada linear 3\n","        x3 = F.linear(input_tensor, self.w3.weight)\n","        # Aplicação da função de ativação SiLU em x1 e multiplicação por x3\n","        hidden = F.silu(x1) * x3\n","        # Saída final: camada linear 2 aplicada em hidden\n","        return self.w2(hidden)\n"],"metadata":{"id":"pLzR_ZpHpdYh","executionInfo":{"status":"ok","timestamp":1720530584146,"user_tz":240,"elapsed":449,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":5,"outputs":[]},{"cell_type":"markdown","source":["Camada Multi-Attention com embeddings rotatórios"],"metadata":{"id":"mOWguICSBWpr"}},{"cell_type":"code","source":["def get_rotary_matrix(context_window, embedding_dim):\n","  # Esta função cria uma matriz de rotação para modelagem de posição.\n","\n","  # Args:\n","  #     context_window: Tamanho da janela de contexto (número de posições).\n","  #     embedding_dim: Dimensão do embedding.\n","\n","  # Returns:\n","  #     Uma matriz de rotação com dimensão (context_window, embedding_dim, embedding_dim).\n","\n","  R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)\n","  for position in range(context_window):\n","    # Loop para cada posição na janela de contexto.\n","    for i in range(embedding_dim // 2):\n","      # Loop para metade da dimensão do embedding (matriz rotacional é diagonal por blocos).\n","      theta = 10000. ** (-2. * (i - 1) / embedding_dim)\n","      m_theta = position * theta\n","      R[position, 2 * i, 2 * i] = np.cos(m_theta)\n","      R[position, 2 * i, 2 * i + 1] = -np.sin(m_theta)\n","      R[position, 2 * i + 1, 2 * i] = np.sin(m_theta)\n","      R[position, 2 * i + 1, 2 * i + 1] = np.cos(m_theta)\n","  return R\n"],"metadata":{"id":"8huMU-TA86t8","executionInfo":{"status":"ok","timestamp":1720530584531,"user_tz":240,"elapsed":3,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":6,"outputs":[]},{"cell_type":"code","source":["\n","class RopeAttentionHead(nn.Module):\n","  def __init__(self, config):\n","    super().__init__()\n","\n","    # Camada linear para mapeamento de consulta (query)\n","    self.w_q = nn.Linear(config['d_model'], config['d_model'], bias=False)\n","\n","    # Camada linear para mapeamento de chave (key)\n","    self.w_k = nn.Linear(config['d_model'], config['d_model'], bias=False)\n","\n","    # Camada linear para mapeamento de valor (value)\n","    self.w_v = nn.Linear(config['d_model'], config['d_model'], bias=False)\n","\n","    # Matriz de rotação para atenção posicional relativa\n","    self.R = get_rotary_matrix(config['context_window'], config['d_model'])\n","\n","  def forward(self, x):\n","    # x é um lote de sentenças\n","    b, m, d = x.shape  # b = tamanho do lote, m = tamanho da sequência, d = dimensão do embedding\n","\n","    q = self.w_q(x)  # Gera representações de consulta\n","    k = self.w_k(x)  # Gera representações de chave\n","    v = self.w_v(x)  # Gera representações de valor\n","\n","    # Aplica rotação posicional relativa usando a matriz R\n","    q_rotated = (torch.bmm(q.transpose(0, 1), self.R[:m])).transpose(0, 1)\n","    k_rotated = (torch.bmm(k.transpose(0, 1), self.R[:m])).transpose(0, 1)\n","\n","    # Calcula a atenção usando scaled dot-product attention com dropout\n","    a = F.scaled_dot_product_attention(q_rotated, k_rotated, v, dropout_p=.1, is_causal=True)\n","\n","    return a\n"],"metadata":{"id":"-5svowDC7Q_k","executionInfo":{"status":"ok","timestamp":1720530585334,"user_tz":240,"elapsed":2,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["class MultiheadRopeAttention(nn.Module):\n","  def __init__(self, config):\n","    super().__init__()\n","\n","    # Define o número de heads de atenção (cabeças de atenção)\n","    self.heads = nn.ModuleList([RopeAttentionHead(config) for _ in range(config['n_heads'])])\n","\n","    # Camada linear para projetar a saída concatenada das heads de volta para a dimensão original\n","    self.w0 = nn.Linear(config['n_heads'] * config['d_model'], config['d_model'])\n","\n","    # Dropout para regularização (evitar overfitting)\n","    self.dropout = nn.Dropout(.1)\n","\n","  def forward(self, x):\n","    # Processa a entrada 'x' por cada head de atenção individualmente\n","    heads = [head(x) for head in self.heads]\n","\n","    # Concatena as saídas de todas as heads de atenção (junta as dimensões)\n","    o = torch.cat(heads, dim=-1)\n","\n","    # Projeta a saída concatenada de volta para a dimensão original (d_model)\n","    o = self.w0(o)\n","\n","    # Aplica dropout para regularização\n","    o = self.dropout(o)\n","\n","    return o\n"],"metadata":{"id":"yYjUGRDFJWG-","executionInfo":{"status":"ok","timestamp":1720530585695,"user_tz":240,"elapsed":1,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":8,"outputs":[]},{"cell_type":"markdown","source":["### Llama block"],"metadata":{"id":"6JP7a_S2JSA8"}},{"cell_type":"code","source":["class LlamaBlock(nn.Module):\n","  def __init__(self, config):\n","    super().__init__()\n","    self.rms_norm = RMSNorm(config['d_model'])  # Normalização RMS\n","    self.multihead = MultiheadRopeAttention(config)  # Atenção Multi-head com Rope\n","    self.ffn_swiglu = FFN_SwiGLU(config['d_model'])  # Rede Feedforward com SwiGLU\n","\n","  def forward(self, x):\n","    x = self.rms_norm(x)  # Normalização RMS pré-ativação\n","    x = x + self.multihead(x)  # Soma o resultado da Atenção Multi-head com a entrada\n","\n","    x = self.rms_norm(x)  # Normalização RMS pré-ativação\n","    x = x + self.ffn_swiglu(x)  # Soma o resultado da Rede Feedforward com a entrada\n","    return x\n"],"metadata":{"id":"puERxeblJFrj","executionInfo":{"status":"ok","timestamp":1720530587052,"user_tz":240,"elapsed":0,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":9,"outputs":[]},{"cell_type":"markdown","source":["### Llama"],"metadata":{"id":"0GwxyFq78cxz"}},{"cell_type":"code","source":["class Llama(nn.Module):\n","  def __init__(self, config):\n","    super().__init__()\n","    # Armazena a configuração do modelo\n","    self.config = config\n","\n","    # Camada de embedding que transforma índices de palavras em vetores densos\n","    self.embedding_layer = nn.Embedding(config['vocab_size'], config['d_model'])\n","\n","    # Sequência de blocos Llama empilhados\n","    self.llama_blocks = nn.Sequential(OrderedDict([(f\"llama_{i}\", LlamaBlock(config)) for i in range(config['n_layers'])]))\n","\n","    # Normalização RMS (Root Mean Square)\n","    self.rms = RMSNorm(config['d_model'])\n","\n","    # Camada linear final para previsão do próximo token (cabeça do modelo de linguagem)\n","    self.lm_head = nn.Linear(config['d_model'], config['vocab_size'])\n","\n","    # Camada linear intermediária com FFN SwiGLU para tarefas auxiliares\n","    self.linear = nn.Sequential(\n","            nn.Linear(config['d_model'], config['d_model']),\n","            FFN_SwiGLU(config['d_model']),\n","            nn.Linear(config['d_model'], config['vocab_size']),\n","        )\n","\n","    # Imprime o número total de parâmetros do modelo\n","    print(\"model params:\", sum([m.numel() for m in self.parameters()]))\n","\n","  def forward(self, x):\n","    # Embutimento de palavras na camada de embedding\n","    x = self.embedding_layer(x)\n","\n","    # Passagem pela sequência de blocos Llama\n","    x = self.llama_blocks(x)\n","\n","    # Normalização RMS\n","    x = self.rms(x)\n","\n","    # Camada linear final (ou intermediária) para saída\n","    x = self.linear(x)\n","\n","    return x\n"],"metadata":{"id":"Ram7uk6n8B8r","executionInfo":{"status":"ok","timestamp":1720530587506,"user_tz":240,"elapsed":3,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":10,"outputs":[]},{"cell_type":"markdown","source":["## Treinamento"],"metadata":{"id":"8gIzLwCSBXVk"}},{"cell_type":"code","source":["config.update({\n","    'epochs': 50\n","})"],"metadata":{"id":"SRsg5NqVGLuw"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["config"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"K-Mv8B88GiPf","executionInfo":{"status":"ok","timestamp":1720475729998,"user_tz":240,"elapsed":3,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"2d47cbc1-cea7-4c91-ded9-bbda506095ad"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["{'d_model': 64,\n"," 'n_heads': 8,\n"," 'n_layers': 6,\n"," 'context_window': 22,\n"," 'epochs': 50,\n"," 'log_interval': 10,\n"," 'vocab_size': 7268,\n"," 'batch_size': 32}"]},"metadata":{},"execution_count":47}]},{"cell_type":"code","source":["model = Llama(config)\n","\n","loss_fn = torch.nn.CrossEntropyLoss()\n","optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ypO2rCbzG5qx","executionInfo":{"status":"ok","timestamp":1720475738713,"user_tz":240,"elapsed":3147,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"d77493d8-ab6b-4a09-83b0-804936918827"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["model params: 2287432\n"]}]},{"cell_type":"code","source":["def train_one_epoch(epoch_index, tb_writer):\n","    running_loss = 0.\n","    last_loss = 0.\n","\n","    # Here, we use enumerate(training_loader) instead of\n","    # iter(training_loader) so that we can track the batch\n","    # index and do some intra-epoch reporting\n","    for i, data in enumerate(train_dataloader):\n","      # Every data instance is an input + label pair\n","      inputs, labels = data\n","\n","      # Zero your gradients for every batch!\n","      optimizer.zero_grad()\n","\n","      # Make predictions for this batch\n","      outputs = model(inputs)\n","\n","      # Compute the loss and its gradients\n","      loss = F.cross_entropy(outputs.view(-1, config['vocab_size']), labels.view(-1))\n","      loss.backward()\n","\n","      # Adjust learning weights\n","      optimizer.step()\n","\n","      # Gather data and report\n","      running_loss += loss.item()\n","      if i % 100 == 99:\n","        last_loss = running_loss / 100 # loss per batch\n","        print('  batch {} loss: {}'.format(i + 1, last_loss))\n","        tb_x = epoch_index * len(train_dataloader) + i + 1\n","        tb_writer.add_scalar('Loss/train', last_loss, tb_x)\n","        running_loss = 0.\n","\n","    return last_loss"],"metadata":{"id":"VSR1SfhF_CjJ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def train_model(config, model):\n","  # Initializing in a separate cell so we can easily add more epochs to the same run\n","  timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')\n","  writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))\n","  epoch_number = 0\n","\n","  EPOCHS = 50\n","\n","  best_vloss = 1_000_000.\n","  losses = {'train': [], 'valid': []}\n","\n","  for epoch in range(EPOCHS):\n","      print('EPOCH {}:'.format(epoch_number + 1))\n","\n","      # Make sure gradient tracking is on, and do a pass over the data\n","      model.train(True)\n","      avg_loss = train_one_epoch(epoch_number, writer)\n","\n","\n","      running_vloss = 0.0\n","      # Set the model to evaluation mode, disabling dropout and using population\n","      # statistics for batch normalization.\n","      model.eval()\n","\n","      # Disable gradient computation and reduce memory consumption.\n","      with torch.no_grad():\n","          for i, vdata in enumerate(validation_loader):\n","              vinputs, vlabels = vdata\n","              voutputs = model(vinputs)\n","              vloss = F.cross_entropy(voutputs.view(-1, config['vocab_size']), vlabels.view(-1))\n","              running_vloss += vloss\n","\n","      avg_vloss = running_vloss / (i + 1)\n","      losses['train'].append(avg_loss)\n","      losses['valid'].append(avg_vloss.item())\n","      print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))\n","\n","      # Log the running loss averaged per batch\n","      # for both training and validation\n","      writer.add_scalars('Training vs. Validation Loss',\n","                      { 'Training' : avg_loss, 'Validation' : avg_vloss },\n","                      epoch_number + 1)\n","      writer.flush()\n","\n","      # Track best performance, and save the model's state\n","      if avg_vloss < best_vloss:\n","          best_vloss = avg_vloss\n","          model_path = 'model_{}_{}'.format(timestamp, epoch_number)\n","          torch.save(model.state_dict(), model_path)\n","\n","      epoch_number += 1\n","\n","  pd.DataFrame(losses).plot()"],"metadata":{"id":"FRJcTmVHCHH0"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Treinamento com `batch_size` = 16"],"metadata":{"id":"iY-mllSZvpJg"}},{"cell_type":"code","source":["train_model(config, model)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"W-dA_qYDGcgA","outputId":"27d423b8-e2ba-4c7c-b252-67d5aa3083c8","executionInfo":{"status":"ok","timestamp":1720475286590,"user_tz":240,"elapsed":176741,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["EPOCH 1:\n","  batch 100 loss: 7.254485898017883\n","  batch 200 loss: 6.72920759677887\n","  batch 300 loss: 6.6483221530914305\n","  batch 400 loss: 6.690575227737427\n","  batch 500 loss: 6.618522758483887\n","  batch 600 loss: 6.599264216423035\n","LOSS train 6.599264216423035 valid 6.901430606842041\n","EPOCH 2:\n","  batch 100 loss: 6.580296430587769\n","  batch 200 loss: 6.578885645866394\n","  batch 300 loss: 6.6078906965255735\n","  batch 400 loss: 6.584995603561401\n","  batch 500 loss: 6.601514177322388\n","  batch 600 loss: 6.609987564086914\n","LOSS train 6.609987564086914 valid 6.904172420501709\n","EPOCH 3:\n","  batch 100 loss: 6.591444034576416\n","  batch 200 loss: 6.564644765853882\n","  batch 300 loss: 6.575415391921997\n","  batch 400 loss: 6.590406394004821\n","  batch 500 loss: 6.593808465003967\n","  batch 600 loss: 6.60360047340393\n","LOSS train 6.60360047340393 valid 6.879724025726318\n","EPOCH 4:\n","  batch 100 loss: 6.580120692253113\n","  batch 200 loss: 6.5862398529052735\n","  batch 300 loss: 6.593850221633911\n","  batch 400 loss: 6.588077220916748\n","  batch 500 loss: 6.581966590881348\n","  batch 600 loss: 6.5882522487640385\n","LOSS train 6.5882522487640385 valid 6.909512996673584\n","EPOCH 5:\n","  batch 100 loss: 6.578727827072144\n","  batch 200 loss: 6.5831868982315065\n","  batch 300 loss: 6.590605983734131\n","  batch 400 loss: 6.590784053802491\n","  batch 500 loss: 6.580624976158142\n","  batch 600 loss: 6.578814177513123\n","LOSS train 6.578814177513123 valid 6.917264461517334\n","EPOCH 6:\n","  batch 100 loss: 6.560040545463562\n","  batch 200 loss: 6.572971782684326\n","  batch 300 loss: 6.599514727592468\n","  batch 400 loss: 6.596052861213684\n","  batch 500 loss: 6.594382667541504\n","  batch 600 loss: 6.582794198989868\n","LOSS train 6.582794198989868 valid 6.931935787200928\n","EPOCH 7:\n","  batch 100 loss: 6.559836192131042\n","  batch 200 loss: 6.575530066490173\n","  batch 300 loss: 6.55260015964508\n","  batch 400 loss: 6.597680506706237\n","  batch 500 loss: 6.60317440032959\n","  batch 600 loss: 6.584929151535034\n","LOSS train 6.584929151535034 valid 6.937662601470947\n","EPOCH 8:\n","  batch 100 loss: 6.566373510360718\n","  batch 200 loss: 6.556069231033325\n","  batch 300 loss: 6.595703058242798\n","  batch 400 loss: 6.570441856384277\n","  batch 500 loss: 6.615760755538941\n","  batch 600 loss: 6.558448724746704\n","LOSS train 6.558448724746704 valid 6.941678047180176\n","EPOCH 9:\n","  batch 100 loss: 6.564875254631042\n","  batch 200 loss: 6.579329714775086\n","  batch 300 loss: 6.570184783935547\n","  batch 400 loss: 6.580846128463745\n","  batch 500 loss: 6.580609383583069\n","  batch 600 loss: 6.596588416099548\n","LOSS train 6.596588416099548 valid 6.976319789886475\n","EPOCH 10:\n","  batch 100 loss: 6.571156268119812\n","  batch 200 loss: 6.552039556503296\n","  batch 300 loss: 6.564516258239746\n","  batch 400 loss: 6.60458215713501\n","  batch 500 loss: 6.583310332298279\n","  batch 600 loss: 6.583385305404663\n","LOSS train 6.583385305404663 valid 6.983857154846191\n","EPOCH 11:\n","  batch 100 loss: 6.567958526611328\n","  batch 200 loss: 6.6028373384475705\n","  batch 300 loss: 6.566983342170715\n","  batch 400 loss: 6.575128974914551\n","  batch 500 loss: 6.571655740737915\n","  batch 600 loss: 6.566354598999023\n","LOSS train 6.566354598999023 valid 6.989124298095703\n","EPOCH 12:\n","  batch 100 loss: 6.592281746864319\n","  batch 200 loss: 6.534661345481872\n","  batch 300 loss: 6.572481174468994\n","  batch 400 loss: 6.588158221244812\n","  batch 500 loss: 6.591622776985169\n","  batch 600 loss: 6.568725719451904\n","LOSS train 6.568725719451904 valid 7.0073137283325195\n","EPOCH 13:\n","  batch 100 loss: 6.562606024742126\n","  batch 200 loss: 6.562768802642823\n","  batch 300 loss: 6.567439794540405\n","  batch 400 loss: 6.565530376434326\n","  batch 500 loss: 6.603497505187988\n","  batch 600 loss: 6.589612183570861\n","LOSS train 6.589612183570861 valid 7.013947486877441\n","EPOCH 14:\n","  batch 100 loss: 6.567205629348755\n","  batch 200 loss: 6.564379653930664\n","  batch 300 loss: 6.571182346343994\n","  batch 400 loss: 6.577518548965454\n","  batch 500 loss: 6.588018941879272\n","  batch 600 loss: 6.581626229286194\n","LOSS train 6.581626229286194 valid 7.031599521636963\n","EPOCH 15:\n","  batch 100 loss: 6.5464914464950565\n","  batch 200 loss: 6.577048268318176\n","  batch 300 loss: 6.581699676513672\n","  batch 400 loss: 6.581840267181397\n","  batch 500 loss: 6.562737092971802\n","  batch 600 loss: 6.586595730781555\n","LOSS train 6.586595730781555 valid 7.034348964691162\n","EPOCH 16:\n","  batch 100 loss: 6.544588537216186\n","  batch 200 loss: 6.582893800735474\n","  batch 300 loss: 6.6008024454116825\n","  batch 400 loss: 6.58988461971283\n","  batch 500 loss: 6.541111888885498\n","  batch 600 loss: 6.581812653541565\n","LOSS train 6.581812653541565 valid 7.0664381980896\n","EPOCH 17:\n","  batch 100 loss: 6.551539568901062\n","  batch 200 loss: 6.57414707660675\n","  batch 300 loss: 6.58046178817749\n","  batch 400 loss: 6.553614950180053\n","  batch 500 loss: 6.583387479782105\n","  batch 600 loss: 6.595599536895752\n","LOSS train 6.595599536895752 valid 7.06130838394165\n","EPOCH 18:\n","  batch 100 loss: 6.566115574836731\n","  batch 200 loss: 6.560071096420288\n","  batch 300 loss: 6.551137237548828\n","  batch 400 loss: 6.576573548316955\n","  batch 500 loss: 6.597802829742432\n","  batch 600 loss: 6.568505039215088\n","LOSS train 6.568505039215088 valid 7.07714319229126\n","EPOCH 19:\n","  batch 100 loss: 6.564255890846252\n","  batch 200 loss: 6.558865914344787\n","  batch 300 loss: 6.55938762664795\n","  batch 400 loss: 6.575385274887085\n","  batch 500 loss: 6.571145167350769\n","  batch 600 loss: 6.589157776832581\n","LOSS train 6.589157776832581 valid 7.0945611000061035\n","EPOCH 20:\n","  batch 100 loss: 6.5550105762481685\n","  batch 200 loss: 6.574929261207581\n","  batch 300 loss: 6.556095662117005\n","  batch 400 loss: 6.566363430023193\n","  batch 500 loss: 6.5743719005584715\n","  batch 600 loss: 6.583142786026001\n","LOSS train 6.583142786026001 valid 7.118529319763184\n","EPOCH 21:\n","  batch 100 loss: 6.548122506141663\n","  batch 200 loss: 6.553948926925659\n","  batch 300 loss: 6.5781525087356565\n","  batch 400 loss: 6.543014073371888\n","  batch 500 loss: 6.592370190620422\n","  batch 600 loss: 6.586455979347229\n","LOSS train 6.586455979347229 valid 7.109251976013184\n","EPOCH 22:\n","  batch 100 loss: 6.523948349952698\n","  batch 200 loss: 6.548625469207764\n","  batch 300 loss: 6.571464214324951\n","  batch 400 loss: 6.586494879722595\n","  batch 500 loss: 6.578914699554443\n","  batch 600 loss: 6.573911905288696\n","LOSS train 6.573911905288696 valid 7.121114253997803\n","EPOCH 23:\n","  batch 100 loss: 6.558890852928162\n","  batch 200 loss: 6.538014535903931\n","  batch 300 loss: 6.569413876533508\n","  batch 400 loss: 6.541698789596557\n","  batch 500 loss: 6.579856009483337\n","  batch 600 loss: 6.554808902740478\n","LOSS train 6.554808902740478 valid 7.123199462890625\n","EPOCH 24:\n","  batch 100 loss: 6.538546228408814\n","  batch 200 loss: 6.539492511749268\n","  batch 300 loss: 6.543084988594055\n","  batch 400 loss: 6.549443392753601\n","  batch 500 loss: 6.589137706756592\n","  batch 600 loss: 6.568361105918885\n","LOSS train 6.568361105918885 valid 7.141907691955566\n","EPOCH 25:\n","  batch 100 loss: 6.5412467575073245\n","  batch 200 loss: 6.554701061248779\n","  batch 300 loss: 6.558792958259582\n","  batch 400 loss: 6.532972960472107\n","  batch 500 loss: 6.547908911705017\n","  batch 600 loss: 6.553964881896973\n","LOSS train 6.553964881896973 valid 7.15984582901001\n","EPOCH 26:\n","  batch 100 loss: 6.502383580207825\n","  batch 200 loss: 6.506418261528015\n","  batch 300 loss: 6.656693043708802\n","  batch 400 loss: 6.56102306842804\n","  batch 500 loss: 6.527300033569336\n","  batch 600 loss: 6.5371735429763795\n","LOSS train 6.5371735429763795 valid 7.115063667297363\n","EPOCH 27:\n","  batch 100 loss: 6.513447532653808\n","  batch 200 loss: 6.497464432716369\n","  batch 300 loss: 6.5209917020797725\n","  batch 400 loss: 6.491918354034424\n","  batch 500 loss: 6.529042053222656\n","  batch 600 loss: 6.508379697799683\n","LOSS train 6.508379697799683 valid 7.098448753356934\n","EPOCH 28:\n","  batch 100 loss: 6.491382918357849\n","  batch 200 loss: 6.477219285964966\n","  batch 300 loss: 6.4809818887710575\n","  batch 400 loss: 6.493492069244385\n","  batch 500 loss: 6.502482190132141\n","  batch 600 loss: 6.486911187171936\n","LOSS train 6.486911187171936 valid 7.050586700439453\n","EPOCH 29:\n","  batch 100 loss: 6.447491722106934\n","  batch 200 loss: 6.47750153541565\n","  batch 300 loss: 6.462226090431213\n","  batch 400 loss: 6.471752362251282\n","  batch 500 loss: 6.455769844055176\n","  batch 600 loss: 6.481536073684692\n","LOSS train 6.481536073684692 valid 7.067141532897949\n","EPOCH 30:\n","  batch 100 loss: 6.431420063972473\n","  batch 200 loss: 6.429241437911987\n","  batch 300 loss: 6.446416010856629\n","  batch 400 loss: 6.452865800857544\n","  batch 500 loss: 6.441871075630188\n","  batch 600 loss: 6.481301641464233\n","LOSS train 6.481301641464233 valid 7.033505439758301\n","EPOCH 31:\n","  batch 100 loss: 6.4371893072128294\n","  batch 200 loss: 6.4184278392791745\n","  batch 300 loss: 6.4399604320526125\n","  batch 400 loss: 6.460789761543274\n","  batch 500 loss: 6.460857768058776\n","  batch 600 loss: 6.442670884132386\n","LOSS train 6.442670884132386 valid 7.030813694000244\n","EPOCH 32:\n","  batch 100 loss: 6.41266547203064\n","  batch 200 loss: 6.443150177001953\n","  batch 300 loss: 6.411829543113709\n","  batch 400 loss: 6.420839729309082\n","  batch 500 loss: 6.448990087509156\n","  batch 600 loss: 6.430414538383484\n","LOSS train 6.430414538383484 valid 7.043978691101074\n","EPOCH 33:\n","  batch 100 loss: 6.389787497520447\n","  batch 200 loss: 6.411245346069336\n","  batch 300 loss: 6.421475305557251\n","  batch 400 loss: 6.507261333465576\n","  batch 500 loss: 6.476127953529358\n","  batch 600 loss: 6.453008193969726\n","LOSS train 6.453008193969726 valid 7.03551721572876\n","EPOCH 34:\n","  batch 100 loss: 6.430651307106018\n","  batch 200 loss: 6.421110577583313\n","  batch 300 loss: 6.415716428756713\n","  batch 400 loss: 6.4142104482650755\n","  batch 500 loss: 6.428186640739441\n","  batch 600 loss: 6.435166039466858\n","LOSS train 6.435166039466858 valid 7.043789863586426\n","EPOCH 35:\n","  batch 100 loss: 6.385214333534241\n","  batch 200 loss: 6.395712161064148\n","  batch 300 loss: 6.428929119110108\n","  batch 400 loss: 6.413891558647156\n","  batch 500 loss: 6.583587484359741\n","  batch 600 loss: 6.578288340568543\n","LOSS train 6.578288340568543 valid 7.108015537261963\n","EPOCH 36:\n","  batch 100 loss: 6.540544028282166\n","  batch 200 loss: 6.543326025009155\n","  batch 300 loss: 6.5408097887039185\n","  batch 400 loss: 6.5395600748062135\n","  batch 500 loss: 6.522721891403198\n","  batch 600 loss: 6.519186248779297\n","LOSS train 6.519186248779297 valid 7.113694667816162\n","EPOCH 37:\n","  batch 100 loss: 6.498777656555176\n","  batch 200 loss: 6.495661149024963\n","  batch 300 loss: 6.501683101654053\n","  batch 400 loss: 6.506357712745666\n","  batch 500 loss: 6.513505058288574\n","  batch 600 loss: 6.494283967018127\n","LOSS train 6.494283967018127 valid 7.116103172302246\n","EPOCH 38:\n","  batch 100 loss: 6.4706564617156985\n","  batch 200 loss: 6.5286725950241085\n","  batch 300 loss: 6.5076979064941405\n","  batch 400 loss: 6.514157242774964\n","  batch 500 loss: 6.487441992759704\n","  batch 600 loss: 6.501930351257324\n","LOSS train 6.501930351257324 valid 7.126080513000488\n","EPOCH 39:\n","  batch 100 loss: 6.4892742586135865\n","  batch 200 loss: 6.463800644874572\n","  batch 300 loss: 6.491882257461548\n","  batch 400 loss: 6.4794546365737915\n","  batch 500 loss: 6.475056920051575\n","  batch 600 loss: 6.475516648292541\n","LOSS train 6.475516648292541 valid 7.129582405090332\n","EPOCH 40:\n","  batch 100 loss: 6.45053807258606\n","  batch 200 loss: 6.438479175567627\n","  batch 300 loss: 6.507103371620178\n","  batch 400 loss: 6.512490930557251\n","  batch 500 loss: 6.48032247543335\n","  batch 600 loss: 6.499647622108459\n","LOSS train 6.499647622108459 valid 7.112171173095703\n","EPOCH 41:\n","  batch 100 loss: 6.463958077430725\n","  batch 200 loss: 6.467284770011902\n","  batch 300 loss: 6.487425794601441\n","  batch 400 loss: 6.501281037330627\n","  batch 500 loss: 6.477662000656128\n","  batch 600 loss: 6.4896377849578855\n","LOSS train 6.4896377849578855 valid 7.126997947692871\n","EPOCH 42:\n","  batch 100 loss: 6.4809159088134765\n","  batch 200 loss: 6.468009572029114\n","  batch 300 loss: 6.495647053718567\n","  batch 400 loss: 6.51474552154541\n","  batch 500 loss: 6.457751355171204\n","  batch 600 loss: 6.470253939628601\n","LOSS train 6.470253939628601 valid 7.133437156677246\n","EPOCH 43:\n","  batch 100 loss: 6.4647204875946045\n","  batch 200 loss: 6.452745881080627\n","  batch 300 loss: 6.4637512636184695\n","  batch 400 loss: 6.442908539772033\n","  batch 500 loss: 6.453943176269531\n","  batch 600 loss: 6.427407455444336\n","LOSS train 6.427407455444336 valid 7.121382713317871\n","EPOCH 44:\n","  batch 100 loss: 6.42605507850647\n","  batch 200 loss: 6.426686096191406\n","  batch 300 loss: 6.428371777534485\n","  batch 400 loss: 6.428095831871032\n","  batch 500 loss: 6.42728180885315\n","  batch 600 loss: 6.41529278755188\n","LOSS train 6.41529278755188 valid 7.101837158203125\n","EPOCH 45:\n","  batch 100 loss: 6.388901100158692\n","  batch 200 loss: 6.396873302459717\n","  batch 300 loss: 6.430823044776917\n","  batch 400 loss: 6.435814719200135\n","  batch 500 loss: 6.45037733078003\n","  batch 600 loss: 6.419003801345825\n","LOSS train 6.419003801345825 valid 7.124403476715088\n","EPOCH 46:\n","  batch 100 loss: 6.378414149284363\n","  batch 200 loss: 6.421618952751159\n","  batch 300 loss: 6.40506724357605\n","  batch 400 loss: 6.412917518615723\n","  batch 500 loss: 6.398574676513672\n","  batch 600 loss: 6.422784910202027\n","LOSS train 6.422784910202027 valid 7.114327430725098\n","EPOCH 47:\n","  batch 100 loss: 6.382177243232727\n","  batch 200 loss: 6.3837039232254025\n","  batch 300 loss: 6.375428819656372\n","  batch 400 loss: 6.3952450180053715\n","  batch 500 loss: 6.422552003860473\n","  batch 600 loss: 6.401733222007752\n","LOSS train 6.401733222007752 valid 7.130518913269043\n","EPOCH 48:\n","  batch 100 loss: 6.37449089050293\n","  batch 200 loss: 6.384265007972718\n","  batch 300 loss: 6.380558428764343\n","  batch 400 loss: 6.383541660308838\n","  batch 500 loss: 6.400592350959778\n","  batch 600 loss: 6.406537475585938\n","LOSS train 6.406537475585938 valid 7.119816780090332\n","EPOCH 49:\n","  batch 100 loss: 6.36315957069397\n","  batch 200 loss: 6.367721962928772\n","  batch 300 loss: 6.369583930969238\n","  batch 400 loss: 6.386765179634094\n","  batch 500 loss: 6.3947264766693115\n","  batch 600 loss: 6.395237894058227\n","LOSS train 6.395237894058227 valid 7.132377624511719\n","EPOCH 50:\n","  batch 100 loss: 6.371091256141662\n","  batch 200 loss: 6.37949001789093\n","  batch 300 loss: 6.383036327362061\n","  batch 400 loss: 6.37596417427063\n","  batch 500 loss: 6.368903679847717\n","  batch 600 loss: 6.388664498329162\n","LOSS train 6.388664498329162 valid 7.146467685699463\n"]},{"output_type":"display_data","data":{"text/plain":["<Figure size 640x480 with 1 Axes>"],"image/png":"\n"},"metadata":{}}]},{"cell_type":"code","source":["torch.save(model, 'my_model.pth')"],"metadata":{"id":"teVMf8MBHlVl"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Treinamento com `batch_size` = 32"],"metadata":{"id":"ss7kU7QrvwWx"}},{"cell_type":"code","source":["config.update({\n","    'batch_size': 32\n","})"],"metadata":{"id":"tbsgd2fEnz6B"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_model(config, model)"],"metadata":{"id":"xwAZdMKTV8Rj","colab":{"base_uri":"https://localhost:8080/","height":1000},"executionInfo":{"status":"ok","timestamp":1720482937546,"user_tz":300,"elapsed":599124,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"c6a2e561-dfad-456a-c251-098abbcee8b0"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["EPOCH 1:\n","  batch 100 loss: 7.172227883338929\n","  batch 200 loss: 6.641011209487915\n","  batch 300 loss: 6.622778711318969\n","LOSS train 6.622778711318969 valid 6.869982719421387\n","EPOCH 2:\n","  batch 100 loss: 6.589296593666076\n","  batch 200 loss: 6.5783744764328\n","  batch 300 loss: 6.523212966918945\n","LOSS train 6.523212966918945 valid 6.668927192687988\n","EPOCH 3:\n","  batch 100 loss: 6.367052774429322\n","  batch 200 loss: 6.245095438957215\n","  batch 300 loss: 6.161773190498352\n","LOSS train 6.161773190498352 valid 6.4662580490112305\n","EPOCH 4:\n","  batch 100 loss: 5.980791273117066\n","  batch 200 loss: 5.94928249835968\n","  batch 300 loss: 5.883374304771423\n","LOSS train 5.883374304771423 valid 6.2536187171936035\n","EPOCH 5:\n","  batch 100 loss: 5.765328869819641\n","  batch 200 loss: 5.708793077468872\n","  batch 300 loss: 5.647432579994201\n","LOSS train 5.647432579994201 valid 6.056188583374023\n","EPOCH 6:\n","  batch 100 loss: 5.522223701477051\n","  batch 200 loss: 5.51192259311676\n","  batch 300 loss: 5.485039143562317\n","LOSS train 5.485039143562317 valid 6.002313613891602\n","EPOCH 7:\n","  batch 100 loss: 5.366386723518372\n","  batch 200 loss: 5.343162803649903\n","  batch 300 loss: 5.349799561500549\n","LOSS train 5.349799561500549 valid 5.914008617401123\n","EPOCH 8:\n","  batch 100 loss: 5.2450868082046505\n","  batch 200 loss: 5.221986150741577\n","  batch 300 loss: 5.214074268341064\n","LOSS train 5.214074268341064 valid 5.940862655639648\n","EPOCH 9:\n","  batch 100 loss: 5.127347750663757\n","  batch 200 loss: 5.106661410331726\n","  batch 300 loss: 5.179489970207214\n","LOSS train 5.179489970207214 valid 5.893185615539551\n","EPOCH 10:\n","  batch 100 loss: 5.025272417068481\n","  batch 200 loss: 5.040288195610047\n","  batch 300 loss: 5.023049664497376\n","LOSS train 5.023049664497376 valid 5.818782329559326\n","EPOCH 11:\n","  batch 100 loss: 4.944939289093018\n","  batch 200 loss: 4.938350229263306\n","  batch 300 loss: 4.940429320335388\n","LOSS train 4.940429320335388 valid 5.79094934463501\n","EPOCH 12:\n","  batch 100 loss: 4.847834324836731\n","  batch 200 loss: 4.86101731300354\n","  batch 300 loss: 4.900515027046204\n","LOSS train 4.900515027046204 valid 5.810030460357666\n","EPOCH 13:\n","  batch 100 loss: 4.772279562950135\n","  batch 200 loss: 4.808963198661804\n","  batch 300 loss: 4.830773520469665\n","LOSS train 4.830773520469665 valid 5.7970099449157715\n","EPOCH 14:\n","  batch 100 loss: 4.729389381408692\n","  batch 200 loss: 4.736234517097473\n","  batch 300 loss: 4.760689487457276\n","LOSS train 4.760689487457276 valid 5.748399257659912\n","EPOCH 15:\n","  batch 100 loss: 4.739435610771179\n","  batch 200 loss: 4.761809258460999\n","  batch 300 loss: 4.7103673934936525\n","LOSS train 4.7103673934936525 valid 5.8211445808410645\n","EPOCH 16:\n","  batch 100 loss: 4.623151025772095\n","  batch 200 loss: 4.634681401252746\n","  batch 300 loss: 4.6569615793228145\n","LOSS train 4.6569615793228145 valid 5.820262908935547\n","EPOCH 17:\n","  batch 100 loss: 4.55587375164032\n","  batch 200 loss: 4.587977967262268\n","  batch 300 loss: 4.6188200044631955\n","LOSS train 4.6188200044631955 valid 5.745680809020996\n","EPOCH 18:\n","  batch 100 loss: 4.520093288421631\n","  batch 200 loss: 4.539075546264648\n","  batch 300 loss: 4.573925409317017\n","LOSS train 4.573925409317017 valid 5.784363746643066\n","EPOCH 19:\n","  batch 100 loss: 4.460341238975525\n","  batch 200 loss: 4.519164433479309\n","  batch 300 loss: 4.531325259208679\n","LOSS train 4.531325259208679 valid 5.878668785095215\n","EPOCH 20:\n","  batch 100 loss: 4.425350794792175\n","  batch 200 loss: 4.47442994594574\n","  batch 300 loss: 4.484270219802856\n","LOSS train 4.484270219802856 valid 5.974778652191162\n","EPOCH 21:\n","  batch 100 loss: 4.377736864089965\n","  batch 200 loss: 4.396480851173401\n","  batch 300 loss: 4.45960497379303\n","LOSS train 4.45960497379303 valid 5.947024822235107\n","EPOCH 22:\n","  batch 100 loss: 4.341839108467102\n","  batch 200 loss: 4.384986872673035\n","  batch 300 loss: 4.396161489486694\n","LOSS train 4.396161489486694 valid 5.971501350402832\n","EPOCH 23:\n","  batch 100 loss: 4.322825553417206\n","  batch 200 loss: 4.3269310212135315\n","  batch 300 loss: 4.361510910987854\n","LOSS train 4.361510910987854 valid 5.950575351715088\n","EPOCH 24:\n","  batch 100 loss: 4.263336644172669\n","  batch 200 loss: 4.314425795078278\n","  batch 300 loss: 4.329500770568847\n","LOSS train 4.329500770568847 valid 5.965373516082764\n","EPOCH 25:\n","  batch 100 loss: 4.73459703207016\n","  batch 200 loss: 4.406391320228576\n","  batch 300 loss: 4.387493586540222\n","LOSS train 4.387493586540222 valid 5.971930027008057\n","EPOCH 26:\n","  batch 100 loss: 4.218559174537659\n","  batch 200 loss: 4.2838466620445255\n","  batch 300 loss: 4.296542391777039\n","LOSS train 4.296542391777039 valid 5.994229316711426\n","EPOCH 27:\n","  batch 100 loss: 4.169474675655365\n","  batch 200 loss: 4.240190186500549\n","  batch 300 loss: 4.236780226230621\n","LOSS train 4.236780226230621 valid 6.0544915199279785\n","EPOCH 28:\n","  batch 100 loss: 4.117045829296112\n","  batch 200 loss: 4.20404869556427\n","  batch 300 loss: 4.20469839811325\n","LOSS train 4.20469839811325 valid 6.056950092315674\n","EPOCH 29:\n","  batch 100 loss: 4.076124663352966\n","  batch 200 loss: 4.141241619586944\n","  batch 300 loss: 4.1852592587471005\n","LOSS train 4.1852592587471005 valid 6.109691619873047\n","EPOCH 30:\n","  batch 100 loss: 4.060036919116974\n","  batch 200 loss: 4.113920016288757\n","  batch 300 loss: 4.162429881095886\n","LOSS train 4.162429881095886 valid 6.045992374420166\n","EPOCH 31:\n","  batch 100 loss: 4.022013132572174\n","  batch 200 loss: 4.089898765087128\n","  batch 300 loss: 4.1225871348381045\n","LOSS train 4.1225871348381045 valid 6.123666763305664\n","EPOCH 32:\n","  batch 100 loss: 4.005324738025665\n","  batch 200 loss: 4.0577108716964725\n","  batch 300 loss: 4.0878568100929265\n","LOSS train 4.0878568100929265 valid 6.173917770385742\n","EPOCH 33:\n","  batch 100 loss: 3.9602249479293823\n","  batch 200 loss: 4.042502901554108\n","  batch 300 loss: 4.085650110244751\n","LOSS train 4.085650110244751 valid 6.122374057769775\n","EPOCH 34:\n","  batch 100 loss: 3.946941063404083\n","  batch 200 loss: 4.015143511295318\n","  batch 300 loss: 4.061908349990845\n","LOSS train 4.061908349990845 valid 6.207901477813721\n","EPOCH 35:\n","  batch 100 loss: 3.9436364507675172\n","  batch 200 loss: 3.9668062853813173\n","  batch 300 loss: 4.03125834941864\n","LOSS train 4.03125834941864 valid 6.376011371612549\n","EPOCH 36:\n","  batch 100 loss: 3.902646629810333\n","  batch 200 loss: 3.9415264105796814\n","  batch 300 loss: 3.998851013183594\n","LOSS train 3.998851013183594 valid 6.308480739593506\n","EPOCH 37:\n","  batch 100 loss: 3.8685570883750917\n","  batch 200 loss: 3.9376883912086487\n","  batch 300 loss: 3.9492770409584046\n","LOSS train 3.9492770409584046 valid 6.350978374481201\n","EPOCH 38:\n","  batch 100 loss: 3.85991007566452\n","  batch 200 loss: 3.9139299488067625\n","  batch 300 loss: 3.9430702471733095\n","LOSS train 3.9430702471733095 valid 6.307162284851074\n","EPOCH 39:\n","  batch 100 loss: 3.830420000553131\n","  batch 200 loss: 3.8678387451171874\n","  batch 300 loss: 3.924768602848053\n","LOSS train 3.924768602848053 valid 6.285574913024902\n","EPOCH 40:\n","  batch 100 loss: 3.8177448987960814\n","  batch 200 loss: 3.862979459762573\n","  batch 300 loss: 3.8775752711296083\n","LOSS train 3.8775752711296083 valid 6.423800468444824\n","EPOCH 41:\n","  batch 100 loss: 4.959664311408996\n","  batch 200 loss: 4.55969336271286\n","  batch 300 loss: 4.203508949279785\n","LOSS train 4.203508949279785 valid 6.352086067199707\n","EPOCH 42:\n","  batch 100 loss: 3.9864105868339537\n","  batch 200 loss: 3.9753793144226073\n","  batch 300 loss: 3.9787585139274597\n","LOSS train 3.9787585139274597 valid 6.345963954925537\n","EPOCH 43:\n","  batch 100 loss: 3.824919855594635\n","  batch 200 loss: 3.8663721799850466\n","  batch 300 loss: 3.8794995021820067\n","LOSS train 3.8794995021820067 valid 6.470861911773682\n","EPOCH 44:\n","  batch 100 loss: 3.7395345830917357\n","  batch 200 loss: 3.8164124035835267\n","  batch 300 loss: 3.8584246921539305\n","LOSS train 3.8584246921539305 valid 6.545734405517578\n","EPOCH 45:\n","  batch 100 loss: 3.7270442414283753\n","  batch 200 loss: 3.7668605852127075\n","  batch 300 loss: 3.804299988746643\n","LOSS train 3.804299988746643 valid 6.558882236480713\n","EPOCH 46:\n","  batch 100 loss: 3.6908179712295532\n","  batch 200 loss: 3.802176911830902\n","  batch 300 loss: 3.8035173773765565\n","LOSS train 3.8035173773765565 valid 6.528656005859375\n","EPOCH 47:\n","  batch 100 loss: 3.6601963949203493\n","  batch 200 loss: 3.7149740624427796\n","  batch 300 loss: 3.7762435460090638\n","LOSS train 3.7762435460090638 valid 6.582528114318848\n","EPOCH 48:\n","  batch 100 loss: 3.6323843812942505\n","  batch 200 loss: 3.703022177219391\n","  batch 300 loss: 3.75785231590271\n","LOSS train 3.75785231590271 valid 6.653554439544678\n","EPOCH 49:\n","  batch 100 loss: 3.609143373966217\n","  batch 200 loss: 3.6926794934272764\n","  batch 300 loss: 4.974221394062043\n","LOSS train 4.974221394062043 valid 6.254998207092285\n","EPOCH 50:\n","  batch 100 loss: 4.266746120452881\n","  batch 200 loss: 4.006846477985382\n","  batch 300 loss: 3.8907069182395935\n","LOSS train 3.8907069182395935 valid 6.411135673522949\n"]},{"output_type":"display_data","data":{"text/plain":["<Figure size 640x480 with 1 Axes>"],"image/png":"\n"},"metadata":{}}]},{"cell_type":"code","source":["torch.save(model, 'my_model2.pth')"],"metadata":{"id":"7no1isIrWc-8"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Avaliação do modelo"],"metadata":{"id":"pI2yLWI5M-yY"}},{"cell_type":"markdown","source":["### Funções de avaliação"],"metadata":{"id":"FShdol5xqIig"}},{"cell_type":"markdown","source":["Perplexidade"],"metadata":{"id":"ygz7I_El0t_5"}},{"cell_type":"code","source":["def calculate_perplexity(lm_model, test_dataset):\n","\n","  lm_model.eval()\n","\n","  total_loss = []\n","  with torch.no_grad():\n","    for inputs in test_dataset:\n","      input_ids = inputs[0]\n","      labels = inputs[1]\n","\n","      # Forward pass through the language model\n","      logits = lm_model(input_ids)\n","\n","      # Calculate cross-entropy loss (without reduction)\n","      loss = F.cross_entropy(logits.view(-1, config['vocab_size']), labels.view(-1))\n","      total_loss.append(loss.item())\n","\n","      # Accumulate loss and number of tokens (excluding padding)\n","\n","  # Calculate average loss and perplexity\n","  avg_loss = np.mean(total_loss)\n","  perplexity = np.exp(avg_loss)\n","\n","  return perplexity.item()\n"],"metadata":{"id":"0ObPz2OAuJO5","executionInfo":{"status":"ok","timestamp":1720532710812,"user_tz":240,"elapsed":333,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":49,"outputs":[]},{"cell_type":"markdown","source":["Acurácia"],"metadata":{"id":"to4kq9mBqpSZ"}},{"cell_type":"code","source":["import torch\n","\n","def calculate_accuracy(model, test_loader):\n","  \"\"\"\n","  Calculates the accuracy of a language model on a test dataset.\n","\n","  Args:\n","      model (torch.nn.Module): The language model to evaluate.\n","      test_loader (torch.utils.data.DataLoader): The test data loader.\n","      criterion (torch.nn.Module): The loss function used during training.\n","\n","  Returns:\n","      float: The accuracy of the model on the test dataset.\n","  \"\"\"\n","  # Set model to evaluation mode\n","  model.eval()\n","\n","  # Track variables\n","  correct = 0\n","  total = 0\n","\n","  # No gradient calculation needed during evaluation\n","  with torch.no_grad():\n","    for inputs, labels in test_loader:\n","      outputs = model(inputs)\n","      predicted = torch.argmax(outputs, dim=-1)  # Get the index of max probability\n","\n","      # Update counters\n","      total += labels.numel()\n","      correct += (predicted == labels).sum().item()\n","\n","\n","  # Calculate accuracy\n","  accuracy = correct / total\n","  return accuracy\n"],"metadata":{"id":"0qL4ridHpslh","executionInfo":{"status":"ok","timestamp":1720533351518,"user_tz":240,"elapsed":348,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":72,"outputs":[]},{"cell_type":"markdown","source":["### Avaliação dos dois modelos"],"metadata":{"id":"rs37ec6I01RL"}},{"cell_type":"markdown","source":["Modelo 1 (`batch_size` = 16)\n","\n"],"metadata":{"id":"qiM1Ey3504Ht"}},{"cell_type":"code","source":["model = torch.load('my_model.pth')\n","model.eval()\n","\n","ppl = calculate_perplexity(model, test_dataloader)\n","acc = calculate_accuracy(model, test_dataloader)\n","\n","print(f\"Perplexidade: {ppl}\")\n","print(f\"Acurácia: {acc}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"PqF1BGVTwJ42","executionInfo":{"status":"ok","timestamp":1720534223283,"user_tz":240,"elapsed":11791,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"50c6609d-763f-4c19-d063-1987ba0b2ff1"},"execution_count":74,"outputs":[{"output_type":"stream","name":"stdout","text":["Perplexidade: 880.1006243403579\n","Acurácia: 0.08547741566609492\n"]}]},{"cell_type":"markdown","source":["Modelo 2 (`batch_size` = 32)\n","\n"],"metadata":{"id":"A9JRZiwY1Qwe"}},{"cell_type":"code","source":["model2 = torch.load('my_model2.pth')\n","model.eval()\n","\n","ppl = calculate_perplexity(model, test_dataloader)\n","acc = calculate_accuracy(model, test_dataloader)\n","\n","print(f\"Perplexidade: {ppl}\")\n","print(f\"Acurácia: {acc}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"-95redLc1Jny","executionInfo":{"status":"ok","timestamp":1720534268235,"user_tz":240,"elapsed":9146,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"5e8104a8-c872-45fa-db86-fc8c1cf225a8"},"execution_count":75,"outputs":[{"output_type":"stream","name":"stdout","text":["Perplexidade: 880.1528531595403\n","Acurácia: 0.08501286449399657\n"]}]},{"cell_type":"markdown","source":["## Gerando compartilhamento no Huggingface"],"metadata":{"id":"qk542nk83Gbt"}},{"cell_type":"code","source":["! pip install huggingface_hub"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"DQ4N9cD31VO6","executionInfo":{"status":"ok","timestamp":1720534818648,"user_tz":240,"elapsed":8558,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"a8f0829f-f598-4f69-fdbf-e250ad64664f"},"execution_count":76,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (0.23.4)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (3.15.4)\n","Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (2023.6.0)\n","Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (24.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (6.0.1)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (2.31.0)\n","Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (4.66.4)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (4.12.2)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (3.3.2)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (3.7)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (2.0.7)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (2024.6.2)\n"]}]},{"cell_type":"code","source":["from huggingface_hub import notebook_login\n","\n","notebook_login()"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":331,"referenced_widgets":["7d50cc4d70a8453ea1023c081dd107b6","d84e05020333431296fd65c7c3be9b2a","b2f4e6f16f0a4d3b90717d0179ed642a","3084ffb3b3664cfbb6b9931bea153f83","68fd6d6b9e6c4adeb5431e29e735713d","1d9521b2ef344ec4a8bf5a4d270d802b","7a6f7ecb69c24b6583c3e55a8b22f4c8","5998e43917dc48618e326cfa5d148e79","b01977fca14b443d9f10b7aea7575c7a","e636d0177e1e4d0a8b65896085538ef0","471c56e7723e4bb892ac8bfe97523304","6c5f07b8a0274ea48baf5825c496bfb1","219ec2d07f764f9fb137534c89682b7f","4abf542d1bd54db491bd53a4bd6ae39f","b8e6ba6cafd64e43a7c6dd40fc6cb377","a41054d6970c4fc2989e8942c1e8373f","621bc786ebaa46f3bb5e66312644a65c"]},"id":"nLaTKLtL3bww","executionInfo":{"status":"ok","timestamp":1720535647184,"user_tz":240,"elapsed":364,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"17fe7491-760d-4172-c8c8-cd12862e4334"},"execution_count":83,"outputs":[{"output_type":"display_data","data":{"text/plain":["VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"7d50cc4d70a8453ea1023c081dd107b6"}},"metadata":{}}]},{"cell_type":"code","source":["from huggingface_hub import HfApi\n","from huggingface_hub.utils import HfHubHTTPError\n","api = HfApi()\n","model_repo_name = \"gioandrade/llama1_TinyShakespeare\""],"metadata":{"id":"0Yoxa-ET3gBQ","executionInfo":{"status":"ok","timestamp":1720535616976,"user_tz":240,"elapsed":475,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}}},"execution_count":81,"outputs":[]},{"cell_type":"code","source":["model_path = '/content/drive/MyDrive/Mestrado/NLP_2024/TP_LLM/my_model2.pth'\n","torch.save(model2, model_path)\n","\n","api.upload_folder(\n","    folder_path='/content/drive/MyDrive/Mestrado/NLP_2024/TP_LLM/',\n","    repo_id=model_repo_name\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":602},"id":"pXTKJ0y24IYw","executionInfo":{"status":"error","timestamp":1720535626050,"user_tz":240,"elapsed":6922,"user":{"displayName":"GIOVANNA ANDRADE SANTOS","userId":"13165743699645260115"}},"outputId":"66cf70f2-ce6a-4bb2-887b-98dc6443a69b"},"execution_count":82,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n","The secret `HF_TOKEN` does not exist in your Colab secrets.\n","To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n","You will be able to reuse this secret in all of your notebooks.\n","Please note that authentication is recommended but still optional to access public models or datasets.\n","  warnings.warn(\n"]},{"output_type":"error","ename":"HfHubHTTPError","evalue":" (Request ID: Root=1-668d4a5a-564ed7ba76dfae4e5eb403b3;c12f819b-d5ae-4cd0-b490-5b7bb0a84619)\n\n403 Forbidden: Authorization error..\nCannot access content at: https://huggingface.co/gioandrade/llama1_TinyShakespeare.git/info/lfs/objects/batch.\nIf you are trying to create or update content,make sure you have a token with the `write` role.","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mHTTPError\u001b[0m                                 Traceback (most recent call last)","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_errors.py\u001b[0m in \u001b[0;36mhf_raise_for_status\u001b[0;34m(response, endpoint_name)\u001b[0m\n\u001b[1;32m    303\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 304\u001b[0;31m         \u001b[0mresponse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraise_for_status\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    305\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mHTTPError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/requests/models.py\u001b[0m in \u001b[0;36mraise_for_status\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1020\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mhttp_error_msg\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1021\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mHTTPError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhttp_error_msg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1022\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mHTTPError\u001b[0m: 403 Client Error: Forbidden for url: https://huggingface.co/gioandrade/llama1_TinyShakespeare.git/info/lfs/objects/batch","\nThe above exception was the direct cause of the following exception:\n","\u001b[0;31mHfHubHTTPError\u001b[0m                            Traceback (most recent call last)","\u001b[0;32m<ipython-input-82-aa3e9c1ccd77>\u001b[0m in \u001b[0;36m<cell line: 4>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m api.upload_folder(\n\u001b[0m\u001b[1;32m      5\u001b[0m     \u001b[0mfolder_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'/content/drive/MyDrive/Mestrado/NLP_2024/TP_LLM/'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0mrepo_id\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel_repo_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py\u001b[0m in \u001b[0;36m_inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m             \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msmoothly_deprecate_use_auth_token\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_token\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhas_token\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0m_inner_fn\u001b[0m  \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/hf_api.py\u001b[0m in \u001b[0;36m_inner\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1284\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1285\u001b[0m         \u001b[0;31m# Otherwise, call the function normally\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1286\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1287\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1288\u001b[0m     \u001b[0m_inner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_future_compatible\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m  \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/hf_api.py\u001b[0m in \u001b[0;36mupload_folder\u001b[0;34m(self, repo_id, folder_path, path_in_repo, commit_message, commit_description, token, repo_type, revision, create_pr, parent_commit, allow_patterns, ignore_patterns, delete_patterns, multi_commits, multi_commits_verbose, run_as_future)\u001b[0m\n\u001b[1;32m   4722\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mpr_url\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4723\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4724\u001b[0;31m         commit_info = self.create_commit(\n\u001b[0m\u001b[1;32m   4725\u001b[0m             \u001b[0mrepo_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrepo_type\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4726\u001b[0m             \u001b[0mrepo_id\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrepo_id\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py\u001b[0m in \u001b[0;36m_inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m             \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msmoothly_deprecate_use_auth_token\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_token\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhas_token\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0m_inner_fn\u001b[0m  \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/hf_api.py\u001b[0m in \u001b[0;36m_inner\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1284\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1285\u001b[0m         \u001b[0;31m# Otherwise, call the function normally\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1286\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1287\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1288\u001b[0m     \u001b[0m_inner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_future_compatible\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m  \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/hf_api.py\u001b[0m in \u001b[0;36mcreate_commit\u001b[0;34m(self, repo_id, operations, commit_message, commit_description, token, repo_type, revision, create_pr, num_threads, parent_commit, run_as_future)\u001b[0m\n\u001b[1;32m   3675\u001b[0m         \u001b[0m_warn_on_overwriting_operations\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moperations\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3676\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3677\u001b[0;31m         self.preupload_lfs_files(\n\u001b[0m\u001b[1;32m   3678\u001b[0m             \u001b[0mrepo_id\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrepo_id\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3679\u001b[0m             \u001b[0madditions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0madditions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/hf_api.py\u001b[0m in \u001b[0;36mpreupload_lfs_files\u001b[0;34m(self, repo_id, additions, token, repo_type, revision, create_pr, num_threads, free_memory, gitignore_content)\u001b[0m\n\u001b[1;32m   4182\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4183\u001b[0m         \u001b[0;31m# Upload new LFS files\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4184\u001b[0;31m         _upload_lfs_files(\n\u001b[0m\u001b[1;32m   4185\u001b[0m             \u001b[0madditions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_lfs_additions_to_upload\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4186\u001b[0m             \u001b[0mrepo_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrepo_type\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py\u001b[0m in \u001b[0;36m_inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m             \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msmoothly_deprecate_use_auth_token\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_token\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhas_token\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0m_inner_fn\u001b[0m  \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/_commit_api.py\u001b[0m in \u001b[0;36m_upload_lfs_files\u001b[0;34m(additions, repo_type, repo_id, headers, endpoint, num_threads, revision)\u001b[0m\n\u001b[1;32m    359\u001b[0m     \u001b[0mbatch_actions\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mDict\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    360\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mchunk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mchunk_iterable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0madditions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunk_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 361\u001b[0;31m         batch_actions_chunk, batch_errors_chunk = post_lfs_batch_info(\n\u001b[0m\u001b[1;32m    362\u001b[0m             \u001b[0mupload_infos\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupload_info\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mop\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mchunk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    363\u001b[0m             \u001b[0mrepo_id\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrepo_id\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py\u001b[0m in \u001b[0;36m_inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m             \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msmoothly_deprecate_use_auth_token\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_token\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhas_token\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0m_inner_fn\u001b[0m  \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/lfs.py\u001b[0m in \u001b[0;36mpost_lfs_batch_info\u001b[0;34m(upload_infos, token, repo_type, repo_id, revision, endpoint, headers)\u001b[0m\n\u001b[1;32m    165\u001b[0m     }\n\u001b[1;32m    166\u001b[0m     \u001b[0mresp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_session\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_url\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mheaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpayload\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m     \u001b[0mhf_raise_for_status\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    168\u001b[0m     \u001b[0mbatch_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjson\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    169\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_errors.py\u001b[0m in \u001b[0;36mhf_raise_for_status\u001b[0;34m(response, endpoint_name)\u001b[0m\n\u001b[1;32m    365\u001b[0m                 \u001b[0;34m+\u001b[0m \u001b[0;34m\"make sure you have a token with the `write` role.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    366\u001b[0m             )\n\u001b[0;32m--> 367\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mHfHubHTTPError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresponse\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    368\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    369\u001b[0m         \u001b[0;31m# Convert `HTTPError` into a `HfHubHTTPError` to display request information\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mHfHubHTTPError\u001b[0m:  (Request ID: Root=1-668d4a5a-564ed7ba76dfae4e5eb403b3;c12f819b-d5ae-4cd0-b490-5b7bb0a84619)\n\n403 Forbidden: Authorization error..\nCannot access content at: https://huggingface.co/gioandrade/llama1_TinyShakespeare.git/info/lfs/objects/batch.\nIf you are trying to create or update content,make sure you have a token with the `write` role."]}]},{"cell_type":"code","source":[],"metadata":{"id":"Q37l6dHG6ft-"},"execution_count":null,"outputs":[]}]}