# Use CUDA base image for GPU support FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime # Set working directory WORKDIR /app # Install system dependencies RUN apt-get update && apt-get install -y \ git \ && rm -rf /var/lib/apt/lists/* # Copy requirements into container COPY requirements.txt . # Install Python dependencies RUN pip install --no-cache-dir -r requirements.txt RUN pip install --no-cache-dir wandb tiktoken datasets # Create output directory with proper permissions RUN mkdir -p /app/out && chmod 777 /app/out # Copy the project files COPY . . # Prepare the OpenWebText dataset RUN cd /app/data/openwebtext && python prepare.py # Command to run training CMD ["python", "train.py", "--wandb_log=True"]