yinuozhang's picture
add data_processing
680e7ec
raw
history blame
9.56 kB
import sys
import torch
from datasets import Dataset, DatasetDict, load_from_disk
from torch.utils.data import DataLoader
import os
from multiprocessing import Pool
from tqdm import tqdm
import lightning.pytorch as pl
sys.path.append('/home/yz927/projects/peptune/scripts/')
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
global_tokenizer = None
def init_pool(tokenizer):
global global_tokenizer
global_tokenizer = tokenizer
class SequenceDataset:
def __init__(self, sequences, tokenizer, max_sequence_length, num_cores=8):
self.sequences = sequences
self.tokenizer = tokenizer
self.max_sequence_length = max_sequence_length
self.num_cores = 8
self.tokenized_sequences = []
self.original_sequences = []
def tokenize_sequences(self):
print(f"Starting parallel tokenization using {self.num_cores} cores")
with Pool(processes=self.num_cores, initializer=init_pool, initargs=(self.tokenizer,)) as pool:
results = list(tqdm(
pool.imap(standalone_tokenize_function, self.sequences),
total=len(self.sequences)
))
for result, seq in zip(results, self.sequences):
if result is not None and len(result['input_ids'][0]) <= self.max_sequence_length:
self.tokenized_sequences.append(result)
self.original_sequences.append(seq)
def process_sequences(self, batch_size):
self.tokenize_sequences()
lengths = [(len(seq['input_ids'][0]), i) for i, seq in enumerate(self.tokenized_sequences)]
lengths.sort()
batches = []
sequence_batches = []
current_batch = []
current_sequence_batch = []
current_length = 0
for length, idx in tqdm(lengths):
if current_length + length > self.max_sequence_length or len(current_batch) == batch_size:
if current_batch:
batches.append([self.tokenized_sequences[i] for i in current_batch])
sequence_batches.append([self.original_sequences[i] for i in current_batch])
current_batch = [idx]
current_sequence_batch = [self.original_sequences[idx]]
current_length = length
else:
current_batch.append(idx)
current_sequence_batch.append(self.original_sequences[idx])
current_length += length
if current_batch:
batches.append([self.tokenized_sequences[i] for i in current_batch])
sequence_batches.append([self.original_sequences[i] for i in current_batch])
token_batch_fn = TokenizeBatch(self.tokenizer)
processed_batches = [token_batch_fn(batch) for batch in tqdm(batches)]
dataset = Dataset.from_dict({
'attention_mask': [batch['attention_mask'] for batch in processed_batches],
'input_ids': [batch['input_ids'] for batch in processed_batches],
'labels': sequence_batches
})
return dataset
class DynamicBatchingDataset(Dataset):
"""
Process dynamically batched datasets of Huggingface Datasets object. Need special handling since in the previous
steps, each batch (row in the Datasets object) is already processed for per batch loading
"""
def __init__(self, dataset_dict):
print('Initializing dataset...')
self.dataset_dict = {
'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
'labels': dataset_dict['labels'] # Store original sequences as it is
}
def __len__(self):
return len(self.dataset_dict['attention_mask'])
def __getitem__(self, idx):
if isinstance(idx, int):
return {
'attention_mask': self.dataset_dict['attention_mask'][idx],
'input_ids': self.dataset_dict['input_ids'][idx],
'labels': self.dataset_dict['labels'][idx]
}
elif isinstance(idx, list):
return {
'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
'labels': [self.dataset_dict['labels'][i] for i in idx]
}
else:
raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
@staticmethod
def collate_fn(batch, verbose=False):
item = batch[0]
return {
'input_ids': item['input_ids'],
'attention_mask': item['attention_mask'],
'labels': item['labels']
}
def standalone_tokenize_function(sequence):
global global_tokenizer
try:
tokens = global_tokenizer(sequence)
# The tokenizer already returns lists of integers, so we just need to wrap them in another list
# to match the expected format [batch_size, sequence_length]
return {
'input_ids': [tokens['input_ids']],
'attention_mask': [tokens['attention_mask']]
}
except Exception as e:
print(f"Error tokenizing sequence '{sequence}': {e}")
return None
class TokenizeBatch:
def __init__(self, tokenizer):
self.pad_token_id = tokenizer.pad_token_id
def __call__(self, batches):
data_tokens = [torch.tensor(batch['input_ids'][0]) for batch in batches]
data_tokens_padded = torch.nn.utils.rnn.pad_sequence(data_tokens, batch_first=True, padding_value=self.pad_token_id)
attention_masks = (data_tokens_padded != self.pad_token_id).long()
return {
'input_ids': data_tokens_padded,
'attention_mask': attention_masks,
}
class PretrainSequenceDataModule(pl.LightningDataModule):
def __init__(self,
tokenizer,
input_dataset_path,
output_dataset_path,
num_workers,
batch_size,
max_sequence_length=512,):
super().__init__()
self.tokenizer = tokenizer
self.input_path = input_dataset_path
self.output_path = output_dataset_path
self.num_workers = num_workers
self.batch_size = batch_size
self.max_sequence_length = max_sequence_length
def prepare_data(self):
if not os.path.exists(self.output_path):
print("Loading text files")
with open(f"{self.input_path}/train.txt", 'r') as f:
train_sequences = [line.strip() for line in f if line.strip()]
with open(f"{self.input_path}/val.txt", 'r') as f:
val_sequences = [line.strip() for line in f if line.strip()]
print("Processing training data")
train_dataset = SequenceDataset(train_sequences,
self.tokenizer,
self.max_sequence_length)
print("Processing validation data")
val_dataset = SequenceDataset(val_sequences,
self.tokenizer,
self.max_sequence_length)
processed_train = train_dataset.process_sequences(self.batch_size)
processed_val = val_dataset.process_sequences(self.batch_size)
print("Combining datasets")
combined_dataset = DatasetDict({
'train': processed_train,
'val': processed_val,
})
print(f"Saving dataset to {self.output_path}")
combined_dataset.save_to_disk(self.output_path)
def setup(self, stage: str):
print("Loading processed dataset")
dataset = load_from_disk(self.output_path)
self.train_dataset = DynamicBatchingDataset(dataset['train'])
self.val_dataset = DynamicBatchingDataset(dataset['val'])
def train_dataloader(self):
print("Creating training dataloader")
return DataLoader(self.train_dataset,
batch_size=1,
shuffle=False,
num_workers=self.num_workers,
collate_fn=DynamicBatchingDataset.collate_fn,
pin_memory=True)
def val_dataloader(self):
print("Creating validation dataloader")
return DataLoader(self.val_dataset,
batch_size=1,
shuffle=False,
num_workers=self.num_workers,
collate_fn=DynamicBatchingDataset.collate_fn,
pin_memory=True)
if __name__ == '__main__':
tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
'/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
dm = PretrainSequenceDataModule(
tokenizer=tokenizer,
input_dataset_path='/home/yz927/projects/peptune/tokens/11M_smiles',
output_dataset_path='/home/yz927/projects/peptune/tokenized/11M_smiles_old_tokenizer_no_limit',
num_workers=8,
batch_size=2000,
max_sequence_length=16*1000,
)
dm.prepare_data()
dm.setup('fit')
dm.train_dataloader()
dm.val_dataloader()