Spaces:
Runtime error
Runtime error
| """ | |
| PyTorch Lightning DataModule for LLaVA dataset | |
| """ | |
| import lightning as L | |
| import torch | |
| from torch.utils.data import DataLoader, random_split | |
| from typing import Optional, Dict, Any | |
| import logging | |
| from .dataset import LLaVADataset, MultimodalCollator | |
| logger = logging.getLogger(__name__) | |
| class LLaVADataModule(L.LightningDataModule): | |
| """Lightning DataModule for LLaVA dataset""" | |
| def __init__( | |
| self, | |
| tokenizer, | |
| vision_processor, | |
| config: Dict[str, Any] | |
| ): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.vision_processor = vision_processor | |
| self.config = config | |
| # Data configuration | |
| data_config = config["data"] | |
| self.batch_size = config["training"]["batch_size"] | |
| self.num_workers = data_config.get("num_workers", 4) | |
| self.pin_memory = data_config.get("pin_memory", True) | |
| self.persistent_workers = data_config.get("persistent_workers", True) | |
| # Dataset splits | |
| self.train_split = data_config.get("train_split", "train") | |
| self.val_split = data_config.get("val_split", "train") # LLaVA doesn't have separate val | |
| self.val_size = data_config.get("val_size", 0.02) | |
| # Initialize datasets to None | |
| self.train_dataset = None | |
| self.val_dataset = None | |
| # Create collator | |
| self.collator = MultimodalCollator( | |
| tokenizer=self.tokenizer, | |
| vision_processor=self.vision_processor, | |
| config=self.config | |
| ) | |
| logger.info("LLaVADataModule initialized") | |
| def prepare_data(self) -> None: | |
| """Download and prepare data (called only on rank 0)""" | |
| # This will download the dataset if not already cached | |
| try: | |
| LLaVADataset( | |
| config=self.config, | |
| split=self.train_split | |
| ) | |
| logger.info("Dataset preparation completed") | |
| except Exception as e: | |
| logger.error(f"Failed to prepare dataset: {e}") | |
| raise | |
| def setup(self, stage: Optional[str] = None) -> None: | |
| """Setup datasets for training/validation/testing""" | |
| if stage == "fit" or stage is None: | |
| # Load full training dataset | |
| full_dataset = LLaVADataset( | |
| config=self.config, | |
| split=self.train_split | |
| ) | |
| # Split into train and validation | |
| total_size = len(full_dataset) | |
| val_size = int(total_size * self.val_size) | |
| train_size = total_size - val_size | |
| self.train_dataset, self.val_dataset = random_split( | |
| full_dataset, | |
| [train_size, val_size], | |
| generator=torch.Generator().manual_seed(42) # For reproducibility | |
| ) | |
| logger.info(f"Dataset split: {train_size} train, {val_size} validation") | |
| if stage == "test": | |
| # For testing, we'll use a small subset of the training data | |
| self.test_dataset = LLaVADataset( | |
| config=self.config, | |
| split=self.train_split | |
| ) | |
| if stage == "predict": | |
| # For prediction, setup can be done dynamically | |
| pass | |
| def train_dataloader(self) -> DataLoader: | |
| """Create training dataloader""" | |
| if self.train_dataset is None: | |
| raise RuntimeError("Train dataset not initialized. Call setup() first.") | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=self.num_workers, | |
| pin_memory=self.pin_memory, | |
| persistent_workers=self.persistent_workers and self.num_workers > 0, | |
| collate_fn=self.collator, | |
| drop_last=True # Drop incomplete batches for consistent training | |
| ) | |
| def val_dataloader(self) -> DataLoader: | |
| """Create validation dataloader""" | |
| if self.val_dataset is None: | |
| raise RuntimeError("Validation dataset not initialized. Call setup() first.") | |
| return DataLoader( | |
| self.val_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| pin_memory=self.pin_memory, | |
| persistent_workers=self.persistent_workers and self.num_workers > 0, | |
| collate_fn=self.collator, | |
| drop_last=False | |
| ) | |
| def test_dataloader(self) -> DataLoader: | |
| """Create test dataloader""" | |
| if not hasattr(self, 'test_dataset') or self.test_dataset is None: | |
| raise RuntimeError("Test dataset not initialized. Call setup() first.") | |
| return DataLoader( | |
| self.test_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| pin_memory=self.pin_memory, | |
| collate_fn=self.collator, | |
| drop_last=False | |
| ) | |
| def predict_dataloader(self) -> DataLoader: | |
| """Create prediction dataloader""" | |
| # This can be implemented based on specific prediction needs | |
| return self.val_dataloader() | |
| def teardown(self, stage: Optional[str] = None) -> None: | |
| """Clean up after training/testing""" | |
| # Log dataset statistics if available | |
| if hasattr(self, 'train_dataset') and self.train_dataset is not None: | |
| if hasattr(self.train_dataset.dataset, 'get_stats'): | |
| stats = self.train_dataset.dataset.get_stats() | |
| logger.info(f"Training dataset stats: {stats}") | |
| if hasattr(self, 'val_dataset') and self.val_dataset is not None: | |
| if hasattr(self.val_dataset.dataset, 'get_stats'): | |
| stats = self.val_dataset.dataset.get_stats() | |
| logger.info(f"Validation dataset stats: {stats}") | |
| def get_dataset_info(self) -> Dict[str, Any]: | |
| """Get information about the loaded datasets""" | |
| info = {} | |
| if self.train_dataset is not None: | |
| info["train_size"] = len(self.train_dataset) | |
| if self.val_dataset is not None: | |
| info["val_size"] = len(self.val_dataset) | |
| info["batch_size"] = self.batch_size | |
| info["num_workers"] = self.num_workers | |
| return info | |