sagar007's picture
Upload folder using huggingface_hub
34b253d verified
"""
PyTorch Lightning DataModule for LLaVA dataset
"""
import lightning as L
import torch
from torch.utils.data import DataLoader, random_split
from typing import Optional, Dict, Any
import logging
from .dataset import LLaVADataset, MultimodalCollator
logger = logging.getLogger(__name__)
class LLaVADataModule(L.LightningDataModule):
"""Lightning DataModule for LLaVA dataset"""
def __init__(
self,
tokenizer,
vision_processor,
config: Dict[str, Any]
):
super().__init__()
self.tokenizer = tokenizer
self.vision_processor = vision_processor
self.config = config
# Data configuration
data_config = config["data"]
self.batch_size = config["training"]["batch_size"]
self.num_workers = data_config.get("num_workers", 4)
self.pin_memory = data_config.get("pin_memory", True)
self.persistent_workers = data_config.get("persistent_workers", True)
# Dataset splits
self.train_split = data_config.get("train_split", "train")
self.val_split = data_config.get("val_split", "train") # LLaVA doesn't have separate val
self.val_size = data_config.get("val_size", 0.02)
# Initialize datasets to None
self.train_dataset = None
self.val_dataset = None
# Create collator
self.collator = MultimodalCollator(
tokenizer=self.tokenizer,
vision_processor=self.vision_processor,
config=self.config
)
logger.info("LLaVADataModule initialized")
def prepare_data(self) -> None:
"""Download and prepare data (called only on rank 0)"""
# This will download the dataset if not already cached
try:
LLaVADataset(
config=self.config,
split=self.train_split
)
logger.info("Dataset preparation completed")
except Exception as e:
logger.error(f"Failed to prepare dataset: {e}")
raise
def setup(self, stage: Optional[str] = None) -> None:
"""Setup datasets for training/validation/testing"""
if stage == "fit" or stage is None:
# Load full training dataset
full_dataset = LLaVADataset(
config=self.config,
split=self.train_split
)
# Split into train and validation
total_size = len(full_dataset)
val_size = int(total_size * self.val_size)
train_size = total_size - val_size
self.train_dataset, self.val_dataset = random_split(
full_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42) # For reproducibility
)
logger.info(f"Dataset split: {train_size} train, {val_size} validation")
if stage == "test":
# For testing, we'll use a small subset of the training data
self.test_dataset = LLaVADataset(
config=self.config,
split=self.train_split
)
if stage == "predict":
# For prediction, setup can be done dynamically
pass
def train_dataloader(self) -> DataLoader:
"""Create training dataloader"""
if self.train_dataset is None:
raise RuntimeError("Train dataset not initialized. Call setup() first.")
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers and self.num_workers > 0,
collate_fn=self.collator,
drop_last=True # Drop incomplete batches for consistent training
)
def val_dataloader(self) -> DataLoader:
"""Create validation dataloader"""
if self.val_dataset is None:
raise RuntimeError("Validation dataset not initialized. Call setup() first.")
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers and self.num_workers > 0,
collate_fn=self.collator,
drop_last=False
)
def test_dataloader(self) -> DataLoader:
"""Create test dataloader"""
if not hasattr(self, 'test_dataset') or self.test_dataset is None:
raise RuntimeError("Test dataset not initialized. Call setup() first.")
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=self.collator,
drop_last=False
)
def predict_dataloader(self) -> DataLoader:
"""Create prediction dataloader"""
# This can be implemented based on specific prediction needs
return self.val_dataloader()
def teardown(self, stage: Optional[str] = None) -> None:
"""Clean up after training/testing"""
# Log dataset statistics if available
if hasattr(self, 'train_dataset') and self.train_dataset is not None:
if hasattr(self.train_dataset.dataset, 'get_stats'):
stats = self.train_dataset.dataset.get_stats()
logger.info(f"Training dataset stats: {stats}")
if hasattr(self, 'val_dataset') and self.val_dataset is not None:
if hasattr(self.val_dataset.dataset, 'get_stats'):
stats = self.val_dataset.dataset.get_stats()
logger.info(f"Validation dataset stats: {stats}")
def get_dataset_info(self) -> Dict[str, Any]:
"""Get information about the loaded datasets"""
info = {}
if self.train_dataset is not None:
info["train_size"] = len(self.train_dataset)
if self.val_dataset is not None:
info["val_size"] = len(self.val_dataset)
info["batch_size"] = self.batch_size
info["num_workers"] = self.num_workers
return info