|
|
import sys |
|
|
import torch |
|
|
from datasets import Dataset, DatasetDict, load_from_disk |
|
|
from torch.utils.data import DataLoader |
|
|
import os |
|
|
from multiprocessing import Pool |
|
|
from tqdm import tqdm |
|
|
import lightning.pytorch as pl |
|
|
sys.path.append('/home/yz927/projects/peptune/scripts/') |
|
|
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
global_tokenizer = None |
|
|
|
|
|
|
|
|
def init_pool(tokenizer): |
|
|
global global_tokenizer |
|
|
global_tokenizer = tokenizer |
|
|
|
|
|
class SequenceDataset: |
|
|
def __init__(self, sequences, tokenizer, max_sequence_length, num_cores=8): |
|
|
self.sequences = sequences |
|
|
self.tokenizer = tokenizer |
|
|
self.max_sequence_length = max_sequence_length |
|
|
self.num_cores = 8 |
|
|
self.tokenized_sequences = [] |
|
|
self.original_sequences = [] |
|
|
|
|
|
def tokenize_sequences(self): |
|
|
print(f"Starting parallel tokenization using {self.num_cores} cores") |
|
|
with Pool(processes=self.num_cores, initializer=init_pool, initargs=(self.tokenizer,)) as pool: |
|
|
results = list(tqdm( |
|
|
pool.imap(standalone_tokenize_function, self.sequences), |
|
|
total=len(self.sequences) |
|
|
)) |
|
|
|
|
|
for result, seq in zip(results, self.sequences): |
|
|
if result is not None and len(result['input_ids'][0]) <= self.max_sequence_length: |
|
|
self.tokenized_sequences.append(result) |
|
|
self.original_sequences.append(seq) |
|
|
|
|
|
|
|
|
def process_sequences(self, batch_size): |
|
|
self.tokenize_sequences() |
|
|
|
|
|
lengths = [(len(seq['input_ids'][0]), i) for i, seq in enumerate(self.tokenized_sequences)] |
|
|
lengths.sort() |
|
|
|
|
|
batches = [] |
|
|
sequence_batches = [] |
|
|
current_batch = [] |
|
|
current_sequence_batch = [] |
|
|
current_length = 0 |
|
|
|
|
|
for length, idx in tqdm(lengths): |
|
|
if current_length + length > self.max_sequence_length or len(current_batch) == batch_size: |
|
|
if current_batch: |
|
|
batches.append([self.tokenized_sequences[i] for i in current_batch]) |
|
|
sequence_batches.append([self.original_sequences[i] for i in current_batch]) |
|
|
current_batch = [idx] |
|
|
current_sequence_batch = [self.original_sequences[idx]] |
|
|
current_length = length |
|
|
else: |
|
|
current_batch.append(idx) |
|
|
current_sequence_batch.append(self.original_sequences[idx]) |
|
|
current_length += length |
|
|
|
|
|
if current_batch: |
|
|
batches.append([self.tokenized_sequences[i] for i in current_batch]) |
|
|
sequence_batches.append([self.original_sequences[i] for i in current_batch]) |
|
|
|
|
|
token_batch_fn = TokenizeBatch(self.tokenizer) |
|
|
processed_batches = [token_batch_fn(batch) for batch in tqdm(batches)] |
|
|
|
|
|
dataset = Dataset.from_dict({ |
|
|
'attention_mask': [batch['attention_mask'] for batch in processed_batches], |
|
|
'input_ids': [batch['input_ids'] for batch in processed_batches], |
|
|
'labels': sequence_batches |
|
|
}) |
|
|
|
|
|
return dataset |
|
|
|
|
|
class DynamicBatchingDataset(Dataset): |
|
|
""" |
|
|
Process dynamically batched datasets of Huggingface Datasets object. Need special handling since in the previous |
|
|
steps, each batch (row in the Datasets object) is already processed for per batch loading |
|
|
""" |
|
|
|
|
|
def __init__(self, dataset_dict): |
|
|
print('Initializing dataset...') |
|
|
self.dataset_dict = { |
|
|
'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']], |
|
|
'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']], |
|
|
'labels': dataset_dict['labels'] |
|
|
} |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataset_dict['attention_mask']) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
if isinstance(idx, int): |
|
|
return { |
|
|
'attention_mask': self.dataset_dict['attention_mask'][idx], |
|
|
'input_ids': self.dataset_dict['input_ids'][idx], |
|
|
'labels': self.dataset_dict['labels'][idx] |
|
|
} |
|
|
elif isinstance(idx, list): |
|
|
return { |
|
|
'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx], |
|
|
'input_ids': [self.dataset_dict['input_ids'][i] for i in idx], |
|
|
'labels': [self.dataset_dict['labels'][i] for i in idx] |
|
|
} |
|
|
else: |
|
|
raise ValueError(f"Expected idx to be int or list, but got {type(idx)}") |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch, verbose=False): |
|
|
item = batch[0] |
|
|
return { |
|
|
'input_ids': item['input_ids'], |
|
|
'attention_mask': item['attention_mask'], |
|
|
'labels': item['labels'] |
|
|
} |
|
|
|
|
|
def standalone_tokenize_function(sequence): |
|
|
global global_tokenizer |
|
|
try: |
|
|
tokens = global_tokenizer(sequence) |
|
|
|
|
|
|
|
|
return { |
|
|
'input_ids': [tokens['input_ids']], |
|
|
'attention_mask': [tokens['attention_mask']] |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Error tokenizing sequence '{sequence}': {e}") |
|
|
return None |
|
|
|
|
|
class TokenizeBatch: |
|
|
def __init__(self, tokenizer): |
|
|
self.pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
def __call__(self, batches): |
|
|
data_tokens = [torch.tensor(batch['input_ids'][0]) for batch in batches] |
|
|
data_tokens_padded = torch.nn.utils.rnn.pad_sequence(data_tokens, batch_first=True, padding_value=self.pad_token_id) |
|
|
attention_masks = (data_tokens_padded != self.pad_token_id).long() |
|
|
|
|
|
return { |
|
|
'input_ids': data_tokens_padded, |
|
|
'attention_mask': attention_masks, |
|
|
} |
|
|
|
|
|
class PretrainSequenceDataModule(pl.LightningDataModule): |
|
|
def __init__(self, |
|
|
tokenizer, |
|
|
input_dataset_path, |
|
|
output_dataset_path, |
|
|
num_workers, |
|
|
batch_size, |
|
|
max_sequence_length=512,): |
|
|
super().__init__() |
|
|
self.tokenizer = tokenizer |
|
|
self.input_path = input_dataset_path |
|
|
self.output_path = output_dataset_path |
|
|
self.num_workers = num_workers |
|
|
self.batch_size = batch_size |
|
|
self.max_sequence_length = max_sequence_length |
|
|
|
|
|
def prepare_data(self): |
|
|
if not os.path.exists(self.output_path): |
|
|
print("Loading text files") |
|
|
with open(f"{self.input_path}/train.txt", 'r') as f: |
|
|
train_sequences = [line.strip() for line in f if line.strip()] |
|
|
with open(f"{self.input_path}/val.txt", 'r') as f: |
|
|
val_sequences = [line.strip() for line in f if line.strip()] |
|
|
|
|
|
print("Processing training data") |
|
|
train_dataset = SequenceDataset(train_sequences, |
|
|
self.tokenizer, |
|
|
self.max_sequence_length) |
|
|
print("Processing validation data") |
|
|
val_dataset = SequenceDataset(val_sequences, |
|
|
self.tokenizer, |
|
|
self.max_sequence_length) |
|
|
|
|
|
processed_train = train_dataset.process_sequences(self.batch_size) |
|
|
processed_val = val_dataset.process_sequences(self.batch_size) |
|
|
|
|
|
print("Combining datasets") |
|
|
combined_dataset = DatasetDict({ |
|
|
'train': processed_train, |
|
|
'val': processed_val, |
|
|
}) |
|
|
|
|
|
print(f"Saving dataset to {self.output_path}") |
|
|
combined_dataset.save_to_disk(self.output_path) |
|
|
|
|
|
def setup(self, stage: str): |
|
|
print("Loading processed dataset") |
|
|
dataset = load_from_disk(self.output_path) |
|
|
self.train_dataset = DynamicBatchingDataset(dataset['train']) |
|
|
self.val_dataset = DynamicBatchingDataset(dataset['val']) |
|
|
|
|
|
def train_dataloader(self): |
|
|
print("Creating training dataloader") |
|
|
return DataLoader(self.train_dataset, |
|
|
batch_size=1, |
|
|
shuffle=False, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=DynamicBatchingDataset.collate_fn, |
|
|
pin_memory=True) |
|
|
|
|
|
def val_dataloader(self): |
|
|
print("Creating validation dataloader") |
|
|
return DataLoader(self.val_dataset, |
|
|
batch_size=1, |
|
|
shuffle=False, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=DynamicBatchingDataset.collate_fn, |
|
|
pin_memory=True) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt', |
|
|
'/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt') |
|
|
dm = PretrainSequenceDataModule( |
|
|
tokenizer=tokenizer, |
|
|
input_dataset_path='/home/yz927/projects/peptune/tokens/11M_smiles', |
|
|
output_dataset_path='/home/yz927/projects/peptune/tokenized/11M_smiles_old_tokenizer_no_limit', |
|
|
num_workers=8, |
|
|
batch_size=2000, |
|
|
max_sequence_length=16*1000, |
|
|
) |
|
|
dm.prepare_data() |
|
|
dm.setup('fit') |
|
|
dm.train_dataloader() |
|
|
dm.val_dataloader() |
|
|
|