Spaces:
Runtime error
Runtime error
File size: 6,492 Bytes
34b253d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
"""
PyTorch Lightning DataModule for LLaVA dataset
"""
import lightning as L
import torch
from torch.utils.data import DataLoader, random_split
from typing import Optional, Dict, Any
import logging
from .dataset import LLaVADataset, MultimodalCollator
logger = logging.getLogger(__name__)
class LLaVADataModule(L.LightningDataModule):
"""Lightning DataModule for LLaVA dataset"""
def __init__(
self,
tokenizer,
vision_processor,
config: Dict[str, Any]
):
super().__init__()
self.tokenizer = tokenizer
self.vision_processor = vision_processor
self.config = config
# Data configuration
data_config = config["data"]
self.batch_size = config["training"]["batch_size"]
self.num_workers = data_config.get("num_workers", 4)
self.pin_memory = data_config.get("pin_memory", True)
self.persistent_workers = data_config.get("persistent_workers", True)
# Dataset splits
self.train_split = data_config.get("train_split", "train")
self.val_split = data_config.get("val_split", "train") # LLaVA doesn't have separate val
self.val_size = data_config.get("val_size", 0.02)
# Initialize datasets to None
self.train_dataset = None
self.val_dataset = None
# Create collator
self.collator = MultimodalCollator(
tokenizer=self.tokenizer,
vision_processor=self.vision_processor,
config=self.config
)
logger.info("LLaVADataModule initialized")
def prepare_data(self) -> None:
"""Download and prepare data (called only on rank 0)"""
# This will download the dataset if not already cached
try:
LLaVADataset(
config=self.config,
split=self.train_split
)
logger.info("Dataset preparation completed")
except Exception as e:
logger.error(f"Failed to prepare dataset: {e}")
raise
def setup(self, stage: Optional[str] = None) -> None:
"""Setup datasets for training/validation/testing"""
if stage == "fit" or stage is None:
# Load full training dataset
full_dataset = LLaVADataset(
config=self.config,
split=self.train_split
)
# Split into train and validation
total_size = len(full_dataset)
val_size = int(total_size * self.val_size)
train_size = total_size - val_size
self.train_dataset, self.val_dataset = random_split(
full_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42) # For reproducibility
)
logger.info(f"Dataset split: {train_size} train, {val_size} validation")
if stage == "test":
# For testing, we'll use a small subset of the training data
self.test_dataset = LLaVADataset(
config=self.config,
split=self.train_split
)
if stage == "predict":
# For prediction, setup can be done dynamically
pass
def train_dataloader(self) -> DataLoader:
"""Create training dataloader"""
if self.train_dataset is None:
raise RuntimeError("Train dataset not initialized. Call setup() first.")
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers and self.num_workers > 0,
collate_fn=self.collator,
drop_last=True # Drop incomplete batches for consistent training
)
def val_dataloader(self) -> DataLoader:
"""Create validation dataloader"""
if self.val_dataset is None:
raise RuntimeError("Validation dataset not initialized. Call setup() first.")
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers and self.num_workers > 0,
collate_fn=self.collator,
drop_last=False
)
def test_dataloader(self) -> DataLoader:
"""Create test dataloader"""
if not hasattr(self, 'test_dataset') or self.test_dataset is None:
raise RuntimeError("Test dataset not initialized. Call setup() first.")
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=self.collator,
drop_last=False
)
def predict_dataloader(self) -> DataLoader:
"""Create prediction dataloader"""
# This can be implemented based on specific prediction needs
return self.val_dataloader()
def teardown(self, stage: Optional[str] = None) -> None:
"""Clean up after training/testing"""
# Log dataset statistics if available
if hasattr(self, 'train_dataset') and self.train_dataset is not None:
if hasattr(self.train_dataset.dataset, 'get_stats'):
stats = self.train_dataset.dataset.get_stats()
logger.info(f"Training dataset stats: {stats}")
if hasattr(self, 'val_dataset') and self.val_dataset is not None:
if hasattr(self.val_dataset.dataset, 'get_stats'):
stats = self.val_dataset.dataset.get_stats()
logger.info(f"Validation dataset stats: {stats}")
def get_dataset_info(self) -> Dict[str, Any]:
"""Get information about the loaded datasets"""
info = {}
if self.train_dataset is not None:
info["train_size"] = len(self.train_dataset)
if self.val_dataset is not None:
info["val_size"] = len(self.val_dataset)
info["batch_size"] = self.batch_size
info["num_workers"] = self.num_workers
return info
|