Spaces:
Runtime error
Runtime error
import os | |
import pandas as pd | |
import pytorch_lightning as pl | |
import transformers | |
import wandb | |
from config import CONFIG | |
from data import ( | |
get_annotation_ground_truth_str_from_image_index, | |
load_train_image_ids, | |
build_dataloader, | |
Split, | |
Batch, | |
) | |
from metrics import benetech_score_string_prediction | |
from model import generate_token_strings, LightningModule | |
from utils import set_tokenizers_parallelism, set_torch_device_order_pci_bus | |
class MetricsCallback(pl.callbacks.Callback): | |
def on_validation_batch_start( | |
self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0 | |
): | |
predicted_strings = generate_token_strings(pl_module.model, images=batch.images) | |
for expected_data_index, predicted_string in zip( | |
batch.data_indices, predicted_strings, strict=True | |
): | |
benetech_score = benetech_score_string_prediction( | |
expected_data_index=expected_data_index, | |
predicted_string=predicted_string, | |
) | |
wandb.log(dict(benetech_score=benetech_score)) | |
ground_truth_strings = [ | |
get_annotation_ground_truth_str_from_image_index(i) | |
for i in batch.data_indices | |
] | |
string_ids = [load_train_image_ids()[i] for i in batch.data_indices] | |
strings_dataframe = pd.DataFrame( | |
dict( | |
string_ids=string_ids, | |
ground_truth=ground_truth_strings, | |
predicted=predicted_strings, | |
) | |
) | |
wandb.log(dict(strings=wandb.Table(dataframe=strings_dataframe))) | |
class TransformersPreTrainedModelsCheckpointIO(pl.plugins.CheckpointIO): | |
def __init__( | |
self, pretrained_models: list[transformers.modeling_utils.PreTrainedModel] | |
): | |
super().__init__() | |
self.pretrained_models = pretrained_models | |
def save_checkpoint(self, checkpoint, path, storage_options=None): | |
for pretrained_model in self.pretrained_models: | |
pretrained_model.save_pretrained(path) | |
def load_checkpoint(self, path, storage_options=None): | |
self.pretrained_models = [ | |
pm.from_pretrained(path) for pm in self.pretrained_models | |
] | |
def remove_checkpoint(self, path): | |
os.remove(path) | |
def train(): | |
set_tokenizers_parallelism(False) | |
set_torch_device_order_pci_bus() | |
pl_module = LightningModule(CONFIG) | |
model_checkpoint = pl.callbacks.ModelCheckpoint( | |
dirpath=CONFIG.training_directory, | |
monitor="val_loss", | |
save_top_k=CONFIG.save_top_k_checkpoints, | |
) | |
metrics_callback = MetricsCallback() | |
logger = pl.loggers.WandbLogger( | |
project=CONFIG.wandb_project_name, save_dir=CONFIG.training_directory | |
) | |
plugin = TransformersPreTrainedModelsCheckpointIO( | |
[pl_module.model.processor, pl_module.model.encoder_decoder] | |
) | |
trainer = pl.Trainer( | |
accelerator=CONFIG.accelerator, | |
devices=CONFIG.devices, | |
plugins=[plugin], | |
callbacks=[model_checkpoint, metrics_callback], | |
logger=logger, | |
limit_train_batches=CONFIG.limit_train_batches, | |
limit_val_batches=CONFIG.limit_val_batches, | |
) | |
trainer.fit( | |
model=pl_module, | |
train_dataloaders=build_dataloader( | |
Split.train, pl_module.model.batch_collate_function | |
), | |
val_dataloaders=build_dataloader( | |
Split.val, pl_module.model.batch_collate_function | |
), | |
) | |
if __name__ == "__main__": | |
train() | |