{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "4b23e3a6-544b-4e8c-b37c-70d8d5f04e46", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://download.pytorch.org/whl/cu121\n", "Collecting torch\n", " Downloading https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp312-cp312-linux_x86_64.whl (780.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m780.4/780.4 MB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:02\u001b[0m\n", "\u001b[?25hCollecting torchvision\n", " Downloading https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp312-cp312-linux_x86_64.whl (7.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.3/7.3 MB\u001b[0m \u001b[31m14.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting torchaudio\n", " Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.5.1%2Bcu121-cp312-cp312-linux_x86_64.whl (3.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting filelock (from torch)\n", " Downloading https://download.pytorch.org/whl/filelock-3.13.1-py3-none-any.whl.metadata (2.8 kB)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.12/dist-packages (from torch) (4.14.0)\n", "Collecting networkx (from torch)\n", " Downloading https://download.pytorch.org/whl/networkx-3.3-py3-none-any.whl.metadata (5.1 kB)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch) (3.1.6)\n", "Collecting fsspec (from torch)\n", " Downloading https://download.pytorch.org/whl/fsspec-2024.6.1-py3-none-any.whl.metadata (11 kB)\n", "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m13.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m34.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m14.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu121/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:02\u001b[0m\n", "\u001b[?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu121/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu121/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m8.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu121/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu121/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu121/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-nccl-cu12==2.21.5 (from torch)\n", " Downloading https://download.pytorch.org/whl/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl (188.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m188.7/188.7 MB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu121/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting triton==3.1.0 (from torch)\n", " Downloading https://download.pytorch.org/whl/triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.6/209.6 MB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: setuptools in /usr/lib/python3/dist-packages (from torch) (68.1.2)\n", "Collecting sympy==1.13.1 (from torch)\n", " Downloading https://download.pytorch.org/whl/sympy-1.13.1-py3-none-any.whl (6.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.2/6.2 MB\u001b[0m \u001b[31m17.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n", " Downloading https://download.pytorch.org/whl/cu121/nvidia_nvjitlink_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (19.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m19.8/19.8 MB\u001b[0m \u001b[31m11.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting mpmath<1.4,>=1.1.0 (from sympy==1.13.1->torch)\n", " Downloading https://download.pytorch.org/whl/mpmath-1.3.0-py3-none-any.whl (536 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m536.2/536.2 kB\u001b[0m \u001b[31m13.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from torchvision) (2.3.1)\n", "Collecting pillow!=8.3.*,>=5.3.0 (from torchvision)\n", " Downloading https://download.pytorch.org/whl/pillow-11.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (9.1 kB)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch) (3.0.2)\n", "Downloading https://download.pytorch.org/whl/pillow-11.0.0-cp312-cp312-manylinux_2_28_x86_64.whl (4.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.4/4.4 MB\u001b[0m \u001b[31m13.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading https://download.pytorch.org/whl/filelock-3.13.1-py3-none-any.whl (11 kB)\n", "Downloading https://download.pytorch.org/whl/fsspec-2024.6.1-py3-none-any.whl (177 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m177.6/177.6 kB\u001b[0m \u001b[31m13.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading https://download.pytorch.org/whl/networkx-3.3-py3-none-any.whl (1.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hInstalling collected packages: mpmath, sympy, pillow, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, networkx, fsspec, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch, torchvision, torchaudio\n", "Successfully installed filelock-3.13.1 fsspec-2024.6.1 mpmath-1.3.0 networkx-3.3 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.21.5 nvidia-nvjitlink-cu12-12.1.105 nvidia-nvtx-cu12-12.1.105 pillow-11.0.0 sympy-1.13.1 torch-2.5.1+cu121 torchaudio-2.5.1+cu121 torchvision-0.20.1+cu121 triton-3.1.0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mCollecting diffusers\n", " Downloading diffusers-0.34.0-py3-none-any.whl.metadata (20 kB)\n", "Collecting transformers\n", " Downloading transformers-4.53.2-py3-none-any.whl.metadata (40 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m248.7 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting accelerate\n", " Downloading accelerate-1.9.0-py3-none-any.whl.metadata (19 kB)\n", "Collecting importlib_metadata (from diffusers)\n", " Downloading importlib_metadata-8.7.0-py3-none-any.whl.metadata (4.8 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from diffusers) (3.13.1)\n", "Collecting huggingface-hub>=0.27.0 (from diffusers)\n", " Downloading huggingface_hub-0.33.4-py3-none-any.whl.metadata (14 kB)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from diffusers) (2.3.1)\n", "Collecting regex!=2019.12.17 (from diffusers)\n", " Downloading regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.5/40.5 kB\u001b[0m \u001b[31m311.1 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from diffusers) (2.32.4)\n", "Collecting safetensors>=0.3.1 (from diffusers)\n", " Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)\n", "Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from diffusers) (11.0.0)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (25.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers) (6.0.2)\n", "Collecting tokenizers<0.22,>=0.21 (from transformers)\n", " Downloading tokenizers-0.21.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)\n", "Collecting tqdm>=4.27 (from transformers)\n", " Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.7/57.7 kB\u001b[0m \u001b[31m252.1 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from accelerate) (7.0.0)\n", "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from accelerate) (2.5.1+cu121)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.27.0->diffusers) (2024.6.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.27.0->diffusers) (4.14.0)\n", "Collecting hf-xet<2.0.0,>=1.1.2 (from huggingface-hub>=0.27.0->diffusers)\n", " Downloading hf_xet-1.1.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (879 bytes)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.3)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.1.6)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.1.105)\n", "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.1.0)\n", "Requirement already satisfied: setuptools in /usr/lib/python3/dist-packages (from torch>=2.0.0->accelerate) (68.1.2)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (1.13.1)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.12/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=2.0.0->accelerate) (12.1.105)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate) (1.3.0)\n", "Collecting zipp>=3.20 (from importlib_metadata->diffusers)\n", " Downloading zipp-3.23.0-py3-none-any.whl.metadata (3.6 kB)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->diffusers) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->diffusers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->diffusers) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->diffusers) (2025.6.15)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.0.0->accelerate) (3.0.2)\n", "Downloading diffusers-0.34.0-py3-none-any.whl (3.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m0m\n", "\u001b[?25hDownloading transformers-4.53.2-py3-none-any.whl (10.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.8/10.8 MB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m0:01\u001b[0m\n", "\u001b[?25hDownloading accelerate-1.9.0-py3-none-any.whl (367 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m367.1/367.1 kB\u001b[0m \u001b[31m27.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading huggingface_hub-0.33.4-py3-none-any.whl (515 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m515.3/515.3 kB\u001b[0m \u001b[31m28.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (796 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m796.9/796.9 kB\u001b[0m \u001b[31m30.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (471 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m471.6/471.6 kB\u001b[0m \u001b[31m30.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading tokenizers-0.21.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m36.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading tqdm-4.67.1-py3-none-any.whl (78 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.5/78.5 kB\u001b[0m \u001b[31m294.7 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hDownloading importlib_metadata-8.7.0-py3-none-any.whl (27 kB)\n", "Downloading hf_xet-1.1.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m35.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading zipp-3.23.0-py3-none-any.whl (10 kB)\n", "Installing collected packages: zipp, tqdm, safetensors, regex, hf-xet, importlib_metadata, huggingface-hub, tokenizers, diffusers, transformers, accelerate\n", "Successfully installed accelerate-1.9.0 diffusers-0.34.0 hf-xet-1.1.5 huggingface-hub-0.33.4 importlib_metadata-8.7.0 regex-2024.11.6 safetensors-0.5.3 tokenizers-0.21.2 tqdm-4.67.1 transformers-4.53.2 zipp-3.23.0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mCollecting datasets\n", " Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)\n", "Requirement already satisfied: pillow in /usr/local/lib/python3.12/dist-packages (11.0.0)\n", "Collecting matplotlib\n", " Downloading matplotlib-3.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (4.67.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets) (3.13.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.3.1)\n", "Collecting pyarrow>=15.0.0 (from datasets)\n", " Downloading pyarrow-21.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.3 kB)\n", "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", "Collecting pandas (from datasets)\n", " Downloading pandas-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (91 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m91.2/91.2 kB\u001b[0m \u001b[31m256.5 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.32.4)\n", "Collecting xxhash (from datasets)\n", " Downloading xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", "Collecting multiprocess<0.70.17 (from datasets)\n", " Downloading multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)\n", "Requirement already satisfied: fsspec<=2025.3.0,>=2023.1.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2024.6.1)\n", "Requirement already satisfied: huggingface-hub>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.33.4)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from datasets) (25.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets) (6.0.2)\n", "Collecting contourpy>=1.0.1 (from matplotlib)\n", " Downloading contourpy-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)\n", "Collecting cycler>=0.10 (from matplotlib)\n", " Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)\n", "Collecting fonttools>=4.22.0 (from matplotlib)\n", " Downloading fonttools-4.59.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl.metadata (107 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.9/107.9 kB\u001b[0m \u001b[31m681.9 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting kiwisolver>=1.3.1 (from matplotlib)\n", " Downloading kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.2 kB)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib) (3.1.1)\n", "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (2.9.0.post0)\n", "Collecting aiohttp!=4.0.0a0,!=4.0.0a1 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)\n", " Downloading aiohttp-3.12.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.6 kB)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (4.14.0)\n", "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (1.1.5)\n", "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2025.6.15)\n", "Collecting pytz>=2020.1 (from pandas->datasets)\n", " Downloading pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)\n", "Collecting tzdata>=2022.7 (from pandas->datasets)\n", " Downloading tzdata-2025.2-py2.py3-none-any.whl.metadata (1.4 kB)\n", "Collecting aiohappyeyeballs>=2.5.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets)\n", " Downloading aiohappyeyeballs-2.6.1-py3-none-any.whl.metadata (5.9 kB)\n", "Collecting aiosignal>=1.4.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets)\n", " Downloading aiosignal-1.4.0-py3-none-any.whl.metadata (3.7 kB)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (25.3.0)\n", "Collecting frozenlist>=1.1.1 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets)\n", " Downloading frozenlist-1.7.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)\n", "Collecting multidict<7.0,>=4.5 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets)\n", " Downloading multidict-6.6.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (5.3 kB)\n", "Collecting propcache>=0.2.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets)\n", " Downloading propcache-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", "Collecting yarl<2.0,>=1.17.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets)\n", " Downloading yarl-1.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (73 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m73.9/73.9 kB\u001b[0m \u001b[31m829.3 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hDownloading datasets-4.0.0-py3-none-any.whl (494 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m494.8/494.8 kB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0mm\n", "\u001b[?25hDownloading matplotlib-3.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.6/8.6 MB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading contourpy-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (323 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m323.7/323.7 kB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading cycler-0.12.1-py3-none-any.whl (8.3 kB)\n", "Downloading dill-0.3.8-py3-none-any.whl (116 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading fonttools-4.59.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl (4.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m17.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.5/1.5 MB\u001b[0m \u001b[31m34.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading multiprocess-0.70.16-py312-none-any.whl (146 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m146.7/146.7 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading pyarrow-21.0.0-cp312-cp312-manylinux_2_28_x86_64.whl (42.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.8/42.8 MB\u001b[0m \u001b[31m13.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading pandas-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.0/12.0 MB\u001b[0m \u001b[31m14.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.4/194.4 kB\u001b[0m \u001b[31m19.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading aiohttp-3.12.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m34.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading pytz-2025.2-py2.py3-none-any.whl (509 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m509.2/509.2 kB\u001b[0m \u001b[31m32.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading tzdata-2025.2-py2.py3-none-any.whl (347 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m347.8/347.8 kB\u001b[0m \u001b[31m27.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading aiohappyeyeballs-2.6.1-py3-none-any.whl (15 kB)\n", "Downloading aiosignal-1.4.0-py3-none-any.whl (7.5 kB)\n", "Downloading frozenlist-1.7.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (241 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m241.8/241.8 kB\u001b[0m \u001b[31m20.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading multidict-6.6.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (256 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m256.1/256.1 kB\u001b[0m \u001b[31m22.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading propcache-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (224 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m224.4/224.4 kB\u001b[0m \u001b[31m19.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading yarl-1.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (355 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m355.6/355.6 kB\u001b[0m \u001b[31m26.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: pytz, xxhash, tzdata, pyarrow, propcache, multidict, kiwisolver, frozenlist, fonttools, dill, cycler, contourpy, aiohappyeyeballs, yarl, pandas, multiprocess, matplotlib, aiosignal, aiohttp, datasets\n", "Successfully installed aiohappyeyeballs-2.6.1 aiohttp-3.12.14 aiosignal-1.4.0 contourpy-1.3.2 cycler-0.12.1 datasets-4.0.0 dill-0.3.8 fonttools-4.59.0 frozenlist-1.7.0 kiwisolver-1.4.8 matplotlib-3.10.3 multidict-6.6.3 multiprocess-0.70.16 pandas-2.3.1 propcache-0.3.2 pyarrow-21.0.0 pytz-2025.2 tzdata-2025.2 xxhash-3.5.0 yarl-1.20.1\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mCollecting wandb\n", " Downloading wandb-0.21.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)\n", "Requirement already satisfied: tensorboard in /usr/local/lib/python3.12/dist-packages (2.19.0)\n", "Collecting click!=8.0.0,>=7.1 (from wandb)\n", " Downloading click-8.2.1-py3-none-any.whl.metadata (2.5 kB)\n", "Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)\n", " Downloading GitPython-3.1.44-py3-none-any.whl.metadata (13 kB)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from wandb) (25.0)\n", "Requirement already satisfied: platformdirs in /usr/local/lib/python3.12/dist-packages (from wandb) (4.3.8)\n", "Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (6.31.1)\n", "Collecting pydantic<3 (from wandb)\n", " Downloading pydantic-2.11.7-py3-none-any.whl.metadata (67 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m68.0/68.0 kB\u001b[0m \u001b[31m192.8 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from wandb) (6.0.2)\n", "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.32.4)\n", "Collecting sentry-sdk>=2.0.0 (from wandb)\n", " Downloading sentry_sdk-2.33.0-py2.py3-none-any.whl.metadata (10 kB)\n", "Requirement already satisfied: typing-extensions<5,>=4.8 in /usr/local/lib/python3.12/dist-packages (from wandb) (4.14.0)\n", "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (2.3.0)\n", "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (1.73.1)\n", "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (3.8.2)\n", "Requirement already satisfied: numpy>=1.12.0 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (2.3.1)\n", "Requirement already satisfied: setuptools>=41.0.0 in /usr/lib/python3/dist-packages (from tensorboard) (68.1.2)\n", "Requirement already satisfied: six>1.9 in /usr/lib/python3/dist-packages (from tensorboard) (1.16.0)\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (0.7.2)\n", "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (3.1.3)\n", "Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb)\n", " Downloading gitdb-4.0.12-py3-none-any.whl.metadata (1.2 kB)\n", "Collecting annotated-types>=0.6.0 (from pydantic<3->wandb)\n", " Downloading annotated_types-0.7.0-py3-none-any.whl.metadata (15 kB)\n", "Collecting pydantic-core==2.33.2 (from pydantic<3->wandb)\n", " Downloading pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)\n", "Collecting typing-inspection>=0.4.0 (from pydantic<3->wandb)\n", " Downloading typing_inspection-0.4.1-py3-none-any.whl.metadata (2.6 kB)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.0.0->wandb) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.0.0->wandb) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.0.0->wandb) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.0.0->wandb) (2025.6.15)\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.12/dist-packages (from werkzeug>=1.0.1->tensorboard) (3.0.2)\n", "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb)\n", " Downloading smmap-5.0.2-py3-none-any.whl.metadata (4.3 kB)\n", "Downloading wandb-0.21.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m22.2/22.2 MB\u001b[0m \u001b[31m12.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading click-8.2.1-py3-none-any.whl (102 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.2/102.2 kB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading GitPython-3.1.44-py3-none-any.whl (207 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.6/207.6 kB\u001b[0m \u001b[31m17.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading pydantic-2.11.7-py3-none-any.whl (444 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m444.8/444.8 kB\u001b[0m \u001b[31m25.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m37.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hDownloading sentry_sdk-2.33.0-py2.py3-none-any.whl (356 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m356.4/356.4 kB\u001b[0m \u001b[31m32.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading annotated_types-0.7.0-py3-none-any.whl (13 kB)\n", "Downloading gitdb-4.0.12-py3-none-any.whl (62 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.8/62.8 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading typing_inspection-0.4.1-py3-none-any.whl (14 kB)\n", "Downloading smmap-5.0.2-py3-none-any.whl (24 kB)\n", "Installing collected packages: typing-inspection, smmap, sentry-sdk, pydantic-core, click, annotated-types, pydantic, gitdb, gitpython, wandb\n", "Successfully installed annotated-types-0.7.0 click-8.2.1 gitdb-4.0.12 gitpython-3.1.44 pydantic-2.11.7 pydantic-core-2.33.2 sentry-sdk-2.33.0 smmap-5.0.2 typing-inspection-0.4.1 wandb-0.21.0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mCollecting opencv-python\n", " Downloading opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (19 kB)\n", "Collecting scikit-image\n", " Downloading scikit_image-0.25.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)\n", "Collecting numpy<2.3.0,>=2 (from opencv-python)\n", " Downloading numpy-2.2.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.0/62.0 kB\u001b[0m \u001b[31m252.8 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting scipy>=1.11.4 (from scikit-image)\n", " Downloading scipy-1.16.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (61 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.9/61.9 kB\u001b[0m \u001b[31m679.3 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: networkx>=3.0 in /usr/local/lib/python3.12/dist-packages (from scikit-image) (3.3)\n", "Requirement already satisfied: pillow>=10.1 in /usr/local/lib/python3.12/dist-packages (from scikit-image) (11.0.0)\n", "Collecting imageio!=2.35.0,>=2.33 (from scikit-image)\n", " Downloading imageio-2.37.0-py3-none-any.whl.metadata (5.2 kB)\n", "Collecting tifffile>=2022.8.12 (from scikit-image)\n", " Downloading tifffile-2025.6.11-py3-none-any.whl.metadata (32 kB)\n", "Requirement already satisfied: packaging>=21 in /usr/local/lib/python3.12/dist-packages (from scikit-image) (25.0)\n", "Collecting lazy-loader>=0.4 (from scikit-image)\n", " Downloading lazy_loader-0.4-py3-none-any.whl.metadata (7.6 kB)\n", "Downloading opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (67.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.0/67.0 MB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading scikit_image-0.25.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.0/15.0 MB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading imageio-2.37.0-py3-none-any.whl (315 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m315.8/315.8 kB\u001b[0m \u001b[31m20.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading lazy_loader-0.4-py3-none-any.whl (12 kB)\n", "Downloading numpy-2.2.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.5/16.5 MB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading scipy-1.16.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.1/35.1 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n", "\u001b[?25hDownloading tifffile-2025.6.11-py3-none-any.whl (230 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m230.8/230.8 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: numpy, lazy-loader, tifffile, scipy, opencv-python, imageio, scikit-image\n", " Attempting uninstall: numpy\n", " Found existing installation: numpy 2.3.1\n", " Uninstalling numpy-2.3.1:\n", " Successfully uninstalled numpy-2.3.1\n", "Successfully installed imageio-2.37.0 lazy-loader-0.4 numpy-2.2.6 opencv-python-4.12.0.88 scikit-image-0.25.2 scipy-1.16.0 tifffile-2025.6.11\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mCollecting einops\n", " Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)\n", "Downloading einops-0.8.1-py3-none-any.whl (64 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.4/64.4 kB\u001b[0m \u001b[31m192.7 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hInstalling collected packages: einops\n", "Successfully installed einops-0.8.1\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "# Install required packages\n", "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n", "!pip install diffusers transformers accelerate\n", "!pip install datasets pillow matplotlib tqdm\n", "!pip install wandb tensorboard\n", "!pip install opencv-python scikit-image\n", "!pip install einops" ] }, { "cell_type": "code", "execution_count": 2, "id": "a66917c9-173b-4547-baa6-56f303707025", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Filesystem Size Used Avail Use% Mounted on\n", "overlay 131G 8.9G 123G 7% /\n", "tmpfs 64M 0 64M 0% /dev\n", "shm 15G 0 15G 0% /dev/shm\n", "/dev/loop6 1.7T 386G 1.3T 24% /etc/hosts\n", "/dev/nvme0n1p2 1.9T 562G 1.2T 32% /usr/bin/nvidia-smi\n", "tmpfs 13G 2.9M 13G 1% /run/nvidia-persistenced/socket\n", "tmpfs 63G 0 63G 0% /sys/fs/cgroup\n", "tmpfs 63G 0 63G 0% /proc/asound\n", "tmpfs 63G 0 63G 0% /proc/acpi\n", "tmpfs 63G 0 63G 0% /proc/scsi\n", "tmpfs 63G 0 63G 0% /sys/firmware\n", "tmpfs 63G 0 63G 0% /sys/devices/virtual/powercap\n" ] } ], "source": [ "!df -h" ] }, { "cell_type": "code", "execution_count": 1, "id": "88aec5bc-9914-44ee-8ab9-4b46378ee61a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n", "GPU: NVIDIA GeForce RTX 3060\n", "CUDA Version: 12.1\n", "Available VRAM: 11.66 GB\n", "Current VRAM usage: 0.00 GB\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, Dataset\n", "import torchvision.transforms as transforms\n", "from torchvision import datasets\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from tqdm import tqdm\n", "import os\n", "import math\n", "from PIL import Image\n", "import random\n", "\n", "# Check GPU setup\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(f\"Using device: {device}\")\n", "\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", " print(f\"CUDA Version: {torch.version.cuda}\")\n", " print(f\"Available VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB\")\n", " print(f\"Current VRAM usage: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB\")\n", "\n", "# Set random seeds for reproducibility\n", "torch.manual_seed(42)\n", "np.random.seed(42)\n", "random.seed(42)\n", "if torch.cuda.is_available():\n", " torch.cuda.manual_seed(42)" ] }, { "cell_type": "code", "execution_count": 2, "id": "bea9f729-e07e-4e17-996b-608fb7aff7af", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Package Version\n", "------------------------- --------------\n", "absl-py 2.3.0\n", "accelerate 1.9.0\n", "aiohappyeyeballs 2.6.1\n", "aiohttp 3.12.14\n", "aiosignal 1.4.0\n", "annotated-types 0.7.0\n", "anyio 4.9.0\n", "argon2-cffi 25.1.0\n", "argon2-cffi-bindings 21.2.0\n", "arrow 1.3.0\n", "asttokens 3.0.0\n", "async-lru 2.0.5\n", "attrs 25.3.0\n", "babel 2.17.0\n", "bash_kernel 0.10.0\n", "beautifulsoup4 4.13.4\n", "bleach 6.2.0\n", "blinker 1.7.0\n", "certifi 2025.6.15\n", "cffi 1.17.1\n", "charset-normalizer 3.4.2\n", "click 8.2.1\n", "comm 0.2.2\n", "conda-pack 0.8.1\n", "contourpy 1.3.2\n", "cryptography 41.0.7\n", "cycler 0.12.1\n", "datasets 4.0.0\n", "dbus-python 1.3.2\n", "debugpy 1.8.14\n", "decorator 5.2.1\n", "defusedxml 0.7.1\n", "diffusers 0.34.0\n", "dill 0.3.8\n", "distro 1.9.0\n", "einops 0.8.1\n", "executing 2.2.0\n", "fastjsonschema 2.21.1\n", "filelock 3.13.1\n", "filetype 1.2.0\n", "fonttools 4.59.0\n", "fqdn 1.5.1\n", "frozenlist 1.7.0\n", "fsspec 2024.6.1\n", "gitdb 4.0.12\n", "GitPython 3.1.44\n", "grpcio 1.73.1\n", "h11 0.16.0\n", "hf-xet 1.1.5\n", "httpcore 1.0.9\n", "httplib2 0.20.4\n", "httpx 0.28.1\n", "huggingface-hub 0.33.4\n", "idna 3.10\n", "imageio 2.37.0\n", "importlib_metadata 8.7.0\n", "iniconfig 2.1.0\n", "iotop 0.6\n", "ipykernel 6.29.5\n", "ipython 9.3.0\n", "ipython_pygments_lexers 1.1.1\n", "ipywidgets 8.1.7\n", "isoduration 20.11.0\n", "jedi 0.19.2\n", "Jinja2 3.1.6\n", "json5 0.12.0\n", "jsonpointer 3.0.0\n", "jsonschema 4.24.0\n", "jsonschema-specifications 2025.4.1\n", "jupyter 1.1.1\n", "jupyter-archive 3.4.0\n", "jupyter_client 8.6.3\n", "jupyter-console 6.6.3\n", "jupyter_core 5.8.1\n", "jupyter-events 0.12.0\n", "jupyter-http-over-ws 0.0.8\n", "jupyter-lsp 2.2.5\n", "jupyter_server 2.16.0\n", "jupyter_server_terminals 0.5.3\n", "jupyterlab 4.4.4\n", "jupyterlab_pygments 0.3.0\n", "jupyterlab_server 2.27.3\n", "jupyterlab_widgets 3.0.15\n", "kiwisolver 1.4.8\n", "launchpadlib 1.11.0\n", "lazr.restfulclient 0.14.6\n", "lazr.uri 1.0.6\n", "lazy_loader 0.4\n", "Markdown 3.8.2\n", "MarkupSafe 3.0.2\n", "matplotlib 3.10.3\n", "matplotlib-inline 0.1.7\n", "mistune 3.1.3\n", "mpmath 1.3.0\n", "multidict 6.6.3\n", "multiprocess 0.70.16\n", "nbclient 0.10.2\n", "nbconvert 7.16.6\n", "nbformat 5.10.4\n", "nbzip 0.1.0\n", "nest-asyncio 1.6.0\n", "networkx 3.3\n", "notebook 7.4.3\n", "notebook_shim 0.2.4\n", "numpy 2.2.6\n", "nvidia-cublas-cu12 12.1.3.1\n", "nvidia-cuda-cupti-cu12 12.1.105\n", "nvidia-cuda-nvrtc-cu12 12.1.105\n", "nvidia-cuda-runtime-cu12 12.1.105\n", "nvidia-cudnn-cu12 9.1.0.70\n", "nvidia-cufft-cu12 11.0.2.54\n", "nvidia-curand-cu12 10.3.2.106\n", "nvidia-cusolver-cu12 11.4.5.107\n", "nvidia-cusparse-cu12 12.1.0.106\n", "nvidia-nccl-cu12 2.21.5\n", "nvidia-nvjitlink-cu12 12.1.105\n", "nvidia-nvtx-cu12 12.1.105\n", "oauthlib 3.2.2\n", "opencv-python 4.12.0.88\n", "overrides 7.7.0\n", "packaging 25.0\n", "pandas 2.3.1\n", "pandocfilters 1.5.1\n", "parso 0.8.4\n", "pexpect 4.9.0\n", "pillow 11.0.0\n", "pip 24.0\n", "platformdirs 4.3.8\n", "pluggy 1.6.0\n", "prometheus_client 0.22.1\n", "prompt_toolkit 3.0.51\n", "propcache 0.3.2\n", "protobuf 6.31.1\n", "psutil 7.0.0\n", "ptyprocess 0.7.0\n", "pure_eval 0.2.3\n", "pyarrow 21.0.0\n", "pycparser 2.22\n", "pydantic 2.11.7\n", "pydantic_core 2.33.2\n", "Pygments 2.19.2\n", "PyGObject 3.48.2\n", "PyJWT 2.7.0\n", "pyparsing 3.1.1\n", "pytest 8.4.1\n", "python-apt 2.7.7+ubuntu4\n", "python-dateutil 2.9.0.post0\n", "python-json-logger 3.3.0\n", "pytz 2025.2\n", "PyYAML 6.0.2\n", "pyzmq 27.0.0\n", "referencing 0.36.2\n", "regex 2024.11.6\n", "requests 2.32.4\n", "rfc3339-validator 0.1.4\n", "rfc3986-validator 0.1.1\n", "rpds-py 0.25.1\n", "safetensors 0.5.3\n", "scikit-image 0.25.2\n", "scipy 1.16.0\n", "Send2Trash 1.8.3\n", "sentry-sdk 2.33.0\n", "setuptools 68.1.2\n", "six 1.16.0\n", "smmap 5.0.2\n", "sniffio 1.3.1\n", "soupsieve 2.7\n", "stack-data 0.6.3\n", "supervisor 4.2.5\n", "sympy 1.13.1\n", "tensorboard 2.19.0\n", "tensorboard-data-server 0.7.2\n", "terminado 0.18.1\n", "tifffile 2025.6.11\n", "tinycss2 1.4.0\n", "tokenizers 0.21.2\n", "torch 2.5.1+cu121\n", "torchaudio 2.5.1+cu121\n", "torchvision 0.20.1+cu121\n", "tornado 6.5.1\n", "tqdm 4.67.1\n", "traitlets 5.14.3\n", "transformers 4.53.2\n", "triton 3.1.0\n", "types-python-dateutil 2.9.0.20250516\n", "typing_extensions 4.14.0\n", "typing-inspection 0.4.1\n", "tzdata 2025.2\n", "uri-template 1.3.0\n", "urllib3 2.5.0\n", "uv 0.7.16\n", "wadllib 1.3.6\n", "wandb 0.21.0\n", "wcwidth 0.2.13\n", "webcolors 24.11.1\n", "webencodings 0.5.1\n", "websocket-client 1.8.0\n", "Werkzeug 3.1.3\n", "wheel 0.42.0\n", "widgetsnbextension 4.0.14\n", "xxhash 3.5.0\n", "yarl 1.20.1\n", "zipp 3.23.0\n" ] } ], "source": [ "!pip list" ] }, { "cell_type": "code", "execution_count": 3, "id": "aeb74a18-747e-40ef-8229-0de619fa3859", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading CIFAR-10 dataset...\n", "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 170M/170M [00:14<00:00, 11.9MB/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting ./data/cifar-10-python.tar.gz to ./data\n", "Dataset downloaded successfully!\n", "Number of training images: 50000\n", "Image shape: torch.Size([3, 32, 32])\n", "Classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n", "Batch size: 128\n", "Number of batches: 391\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Step 3: Download and prepare CIFAR-10 dataset\n", "from torchvision import datasets, transforms\n", "\n", "# Define data transforms\n", "transform = transforms.Compose([\n", " transforms.Resize(32), # CIFAR-10 is 32x32\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to [-1, 1]\n", "])\n", "\n", "# Download CIFAR-10 dataset\n", "print(\"Downloading CIFAR-10 dataset...\")\n", "train_dataset = datasets.CIFAR10(\n", " root='./data', \n", " train=True, \n", " download=True, \n", " transform=transform\n", ")\n", "\n", "print(f\"Dataset downloaded successfully!\")\n", "print(f\"Number of training images: {len(train_dataset)}\")\n", "print(f\"Image shape: {train_dataset[0][0].shape}\")\n", "print(f\"Classes: {train_dataset.classes}\")\n", "\n", "# Create dataloader\n", "batch_size = 128 # Good for your 12GB VRAM\n", "train_loader = DataLoader(\n", " train_dataset, \n", " batch_size=batch_size, \n", " shuffle=True, \n", " num_workers=4,\n", " pin_memory=True\n", ")\n", "\n", "print(f\"Batch size: {batch_size}\")\n", "print(f\"Number of batches: {len(train_loader)}\")\n", "\n", "# Visualize some samples\n", "def show_samples(dataset, num_samples=8):\n", " fig, axes = plt.subplots(2, 4, figsize=(12, 6))\n", " for i in range(num_samples):\n", " img, label = dataset[i]\n", " # Convert from [-1, 1] to [0, 1] for display\n", " img = (img + 1) / 2\n", " img = img.permute(1, 2, 0)\n", " \n", " row, col = i // 4, i % 4\n", " axes[row, col].imshow(img)\n", " axes[row, col].set_title(f'Class: {dataset.classes[label]}')\n", " axes[row, col].axis('off')\n", " \n", " plt.tight_layout()\n", " plt.show()\n", "\n", "show_samples(train_dataset)" ] }, { "cell_type": "code", "execution_count": 4, "id": "2a4d4d2c-8119-41a3-be04-22d98bbff38a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Diffusion model components defined successfully!\n", "Next: U-Net architecture\n" ] } ], "source": [ "# Step 4: Define the U-Net architecture for diffusion model\n", "\n", "class TimeEmbedding(nn.Module):\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.dim = dim\n", " \n", " def forward(self, time):\n", " device = time.device\n", " half_dim = self.dim // 2\n", " embeddings = math.log(10000) / (half_dim - 1)\n", " embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)\n", " embeddings = time[:, None] * embeddings[None, :]\n", " embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)\n", " return embeddings\n", "\n", "class ResidualBlock(nn.Module):\n", " def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):\n", " super().__init__()\n", " self.time_mlp = nn.Linear(time_emb_dim, out_channels)\n", " \n", " self.block1 = nn.Sequential(\n", " nn.GroupNorm(8, in_channels),\n", " nn.SiLU(),\n", " nn.Conv2d(in_channels, out_channels, 3, padding=1),\n", " )\n", " \n", " self.block2 = nn.Sequential(\n", " nn.GroupNorm(8, out_channels),\n", " nn.SiLU(),\n", " nn.Dropout(dropout),\n", " nn.Conv2d(out_channels, out_channels, 3, padding=1),\n", " )\n", " \n", " if in_channels != out_channels:\n", " self.shortcut = nn.Conv2d(in_channels, out_channels, 1)\n", " else:\n", " self.shortcut = nn.Identity()\n", " \n", " def forward(self, x, time_emb):\n", " h = self.block1(x)\n", " time_emb = self.time_mlp(time_emb)\n", " h = h + time_emb[:, :, None, None]\n", " h = self.block2(h)\n", " return h + self.shortcut(x)\n", "\n", "class AttentionBlock(nn.Module):\n", " def __init__(self, channels):\n", " super().__init__()\n", " self.channels = channels\n", " self.group_norm = nn.GroupNorm(8, channels)\n", " self.q = nn.Conv2d(channels, channels, 1)\n", " self.k = nn.Conv2d(channels, channels, 1)\n", " self.v = nn.Conv2d(channels, channels, 1)\n", " self.proj_out = nn.Conv2d(channels, channels, 1)\n", " \n", " def forward(self, x):\n", " B, C, H, W = x.shape\n", " h = self.group_norm(x)\n", " q = self.q(h)\n", " k = self.k(h)\n", " v = self.v(h)\n", " \n", " q = q.reshape(B, C, H*W).permute(0, 2, 1)\n", " k = k.reshape(B, C, H*W)\n", " v = v.reshape(B, C, H*W).permute(0, 2, 1)\n", " \n", " attn = torch.bmm(q, k) * (int(C) ** (-0.5))\n", " attn = F.softmax(attn, dim=2)\n", " \n", " h = torch.bmm(attn, v)\n", " h = h.permute(0, 2, 1).reshape(B, C, H, W)\n", " h = self.proj_out(h)\n", " \n", " return x + h\n", "\n", "print(\"Diffusion model components defined successfully!\")\n", "print(\"Next: U-Net architecture\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "1ea769ea-8e47-47fb-add3-f4139afcc5b9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model created with 16,808,835 parameters\n", "Test input shape: torch.Size([4, 3, 32, 32])\n", "Test output shape: torch.Size([4, 3, 32, 32])\n", "Model forward pass successful!\n", "VRAM usage after model creation: 0.22 GB\n" ] } ], "source": [ "# Step 5: Simplified Working U-Net Model\n", "\n", "class SimpleUNet(nn.Module):\n", " def __init__(self, in_channels=3, out_channels=3, time_emb_dim=128):\n", " super().__init__()\n", " \n", " # Time embedding\n", " self.time_embedding = TimeEmbedding(time_emb_dim)\n", " self.time_mlp = nn.Sequential(\n", " nn.Linear(time_emb_dim, time_emb_dim * 4),\n", " nn.SiLU(),\n", " nn.Linear(time_emb_dim * 4, time_emb_dim * 4),\n", " )\n", " \n", " # Encoder\n", " self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)\n", " self.res1 = ResidualBlock(64, 64, time_emb_dim * 4)\n", " self.down1 = nn.Conv2d(64, 64, 3, stride=2, padding=1) # 32->16\n", " \n", " self.res2 = ResidualBlock(64, 128, time_emb_dim * 4)\n", " self.down2 = nn.Conv2d(128, 128, 3, stride=2, padding=1) # 16->8\n", " \n", " self.res3 = ResidualBlock(128, 256, time_emb_dim * 4)\n", " self.down3 = nn.Conv2d(256, 256, 3, stride=2, padding=1) # 8->4\n", " \n", " # Middle\n", " self.mid1 = ResidualBlock(256, 512, time_emb_dim * 4)\n", " self.mid_attn = AttentionBlock(512)\n", " self.mid2 = ResidualBlock(512, 512, time_emb_dim * 4)\n", " \n", " # Decoder\n", " self.up3 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1) # 4->8\n", " self.res_up3 = ResidualBlock(256 + 256, 256, time_emb_dim * 4)\n", " \n", " self.up2 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1) # 8->16\n", " self.res_up2 = ResidualBlock(128 + 128, 128, time_emb_dim * 4)\n", " \n", " self.up1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1) # 16->32\n", " self.res_up1 = ResidualBlock(64 + 64, 64, time_emb_dim * 4)\n", " \n", " # Output\n", " self.output = nn.Sequential(\n", " nn.GroupNorm(8, 64),\n", " nn.SiLU(),\n", " nn.Conv2d(64, out_channels, 3, padding=1),\n", " )\n", " \n", " def forward(self, x, time):\n", " # Time embedding\n", " time_emb = self.time_embedding(time)\n", " time_emb = self.time_mlp(time_emb)\n", " \n", " # Encoder\n", " x1 = self.conv1(x)\n", " x1 = self.res1(x1, time_emb)\n", " \n", " x2 = self.down1(x1)\n", " x2 = self.res2(x2, time_emb)\n", " \n", " x3 = self.down2(x2)\n", " x3 = self.res3(x3, time_emb)\n", " \n", " x4 = self.down3(x3)\n", " \n", " # Middle\n", " x4 = self.mid1(x4, time_emb)\n", " x4 = self.mid_attn(x4)\n", " x4 = self.mid2(x4, time_emb)\n", " \n", " # Decoder\n", " x = self.up3(x4)\n", " x = torch.cat([x, x3], dim=1)\n", " x = self.res_up3(x, time_emb)\n", " \n", " x = self.up2(x)\n", " x = torch.cat([x, x2], dim=1)\n", " x = self.res_up2(x, time_emb)\n", " \n", " x = self.up1(x)\n", " x = torch.cat([x, x1], dim=1)\n", " x = self.res_up1(x, time_emb)\n", " \n", " return self.output(x)\n", "\n", "# Initialize model\n", "model = SimpleUNet().to(device)\n", "print(f\"Model created with {sum(p.numel() for p in model.parameters()):,} parameters\")\n", "\n", "# Test model forward pass\n", "with torch.no_grad():\n", " test_x = torch.randn(4, 3, 32, 32).to(device)\n", " test_t = torch.randint(0, 1000, (4,)).to(device)\n", " test_output = model(test_x, test_t)\n", " print(f\"Test input shape: {test_x.shape}\")\n", " print(f\"Test output shape: {test_output.shape}\")\n", " print(\"Model forward pass successful!\")\n", " \n", "# Check memory usage\n", "print(f\"VRAM usage after model creation: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "52a6b46c-f17e-426c-80f5-1042f5529e44", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Diffusion scheduler created successfully!\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Forward diffusion process test completed!\n" ] } ], "source": [ "# Step 6: Define the Diffusion Process (Fixed)\n", "\n", "class DDPMScheduler:\n", " def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02, device='cuda'):\n", " self.num_timesteps = num_timesteps\n", " self.device = device\n", " \n", " # Linear beta schedule\n", " self.betas = torch.linspace(beta_start, beta_end, num_timesteps).to(device)\n", " self.alphas = 1.0 - self.betas\n", " self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)\n", " self.alpha_cumprod_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha_cumprod[:-1]])\n", " \n", " # Calculations for diffusion q(x_t | x_{t-1}) and others\n", " self.sqrt_alpha_cumprod = torch.sqrt(self.alpha_cumprod)\n", " self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alpha_cumprod)\n", " \n", " # Calculations for posterior q(x_{t-1} | x_t, x_0)\n", " self.posterior_variance = self.betas * (1.0 - self.alpha_cumprod_prev) / (1.0 - self.alpha_cumprod)\n", " self.posterior_log_variance_clipped = torch.log(\n", " torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]])\n", " )\n", " self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alpha_cumprod_prev) / (1.0 - self.alpha_cumprod)\n", " self.posterior_mean_coef2 = (1.0 - self.alpha_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alpha_cumprod)\n", " \n", " def add_noise(self, x_start, timesteps, noise=None):\n", " \"\"\"Forward diffusion process - add noise to images\"\"\"\n", " if noise is None:\n", " noise = torch.randn_like(x_start)\n", " \n", " # Move timesteps to same device\n", " timesteps = timesteps.to(self.device)\n", " \n", " sqrt_alpha_cumprod_t = self.sqrt_alpha_cumprod[timesteps].reshape(-1, 1, 1, 1)\n", " sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alpha_cumprod[timesteps].reshape(-1, 1, 1, 1)\n", " \n", " return sqrt_alpha_cumprod_t * x_start + sqrt_one_minus_alpha_cumprod_t * noise\n", " \n", " def sample_prev_timestep(self, model_output, timestep, sample):\n", " \"\"\"Reverse diffusion process - remove noise from images\"\"\"\n", " timestep = timestep.to(self.device)\n", " \n", " # Compute coefficients for predicted original sample (x_0) and current sample (x_t)\n", " alpha_prod_t = self.alpha_cumprod[timestep]\n", " alpha_prod_t_prev = self.alpha_cumprod_prev[timestep] if timestep > 0 else torch.tensor(1.0).to(self.device)\n", " beta_prod_t = 1 - alpha_prod_t\n", " \n", " # Compute predicted original sample from predicted noise\n", " pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n", " \n", " # Compute coefficients for pred_original_sample and current sample x_t\n", " pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t\n", " current_sample_coeff = self.alphas[timestep] ** (0.5) * (1 - alpha_prod_t_prev) / beta_prod_t\n", " \n", " # Compute predicted previous sample\n", " pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample\n", " \n", " return pred_prev_sample\n", "\n", "# Initialize scheduler\n", "scheduler = DDPMScheduler(num_timesteps=1000, device=device)\n", "print(\"Diffusion scheduler created successfully!\")\n", "\n", "# Test the diffusion process\n", "def test_diffusion_process():\n", " # Take a batch from our dataset\n", " data_iter = iter(train_loader)\n", " images, _ = next(data_iter)\n", " images = images[:4].to(device) # Take first 4 images\n", " \n", " # Test forward process (adding noise)\n", " timesteps = torch.randint(0, scheduler.num_timesteps, (4,))\n", " noise = torch.randn_like(images)\n", " noisy_images = scheduler.add_noise(images, timesteps, noise)\n", " \n", " # Visualize original vs noisy images\n", " fig, axes = plt.subplots(2, 4, figsize=(12, 6))\n", " \n", " for i in range(4):\n", " # Original image\n", " orig_img = (images[i].cpu() + 1) / 2 # Convert from [-1,1] to [0,1]\n", " orig_img = orig_img.permute(1, 2, 0)\n", " axes[0, i].imshow(orig_img)\n", " axes[0, i].set_title('Original')\n", " axes[0, i].axis('off')\n", " \n", " # Noisy image\n", " noisy_img = (noisy_images[i].cpu() + 1) / 2\n", " noisy_img = torch.clamp(noisy_img, 0, 1)\n", " noisy_img = noisy_img.permute(1, 2, 0)\n", " axes[1, i].imshow(noisy_img)\n", " axes[1, i].set_title(f'Noisy (t={timesteps[i]})')\n", " axes[1, i].axis('off')\n", " \n", " plt.tight_layout()\n", " plt.show()\n", " \n", " print(\"Forward diffusion process test completed!\")\n", " return images, noisy_images, timesteps, noise\n", "\n", "# Run the test\n", "test_images, test_noisy, test_timesteps, test_noise = test_diffusion_process()" ] }, { "cell_type": "code", "execution_count": 10, "id": "54bba55a-5419-44b2-92d8-2a156e0a5eaa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training setup complete!\n", "Learning rate: 0.0001\n", "Number of epochs: 20\n", "Optimizer: AdamW\n", "Batch size: 128\n", "Number of batches per epoch: 391\n", "\n", "Testing loss computation...\n", "Test loss: 1.1121\n", "\n", "Memory usage:\n", "Current VRAM usage: 0.22 GB\n", "Max VRAM usage: 0.50 GB\n", "\n", "Ready to start training!\n", "Run the training in the next step...\n" ] } ], "source": [ "# Step 7: Define Loss Function and Training Setup\n", "\n", "def compute_loss(model, batch, scheduler, device):\n", " \"\"\"Compute the diffusion loss\"\"\"\n", " images, _ = batch\n", " images = images.to(device)\n", " batch_size = images.shape[0]\n", " \n", " # Sample random timesteps\n", " timesteps = torch.randint(0, scheduler.num_timesteps, (batch_size,), device=device)\n", " \n", " # Sample noise\n", " noise = torch.randn_like(images)\n", " \n", " # Add noise to images (forward diffusion)\n", " noisy_images = scheduler.add_noise(images, timesteps, noise)\n", " \n", " # Predict noise using model\n", " predicted_noise = model(noisy_images, timesteps)\n", " \n", " # Compute MSE loss between predicted and actual noise\n", " loss = F.mse_loss(predicted_noise, noise)\n", " \n", " return loss\n", "\n", "# Training parameters\n", "learning_rate = 1e-4\n", "num_epochs = 20 # Start with fewer epochs for testing\n", "save_every = 5 # Save model every 5 epochs\n", "\n", "# Initialize optimizer\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-6)\n", "\n", "# Learning rate scheduler\n", "lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)\n", "\n", "print(\"Training setup complete!\")\n", "print(f\"Learning rate: {learning_rate}\")\n", "print(f\"Number of epochs: {num_epochs}\")\n", "print(f\"Optimizer: AdamW\")\n", "print(f\"Batch size: {batch_size}\")\n", "print(f\"Number of batches per epoch: {len(train_loader)}\")\n", "\n", "# Test the loss computation\n", "print(\"\\nTesting loss computation...\")\n", "with torch.no_grad():\n", " data_iter = iter(train_loader)\n", " test_batch = next(data_iter)\n", " test_loss = compute_loss(model, test_batch, scheduler, device)\n", " print(f\"Test loss: {test_loss.item():.4f}\")\n", "\n", "print(\"\\nMemory usage:\")\n", "print(f\"Current VRAM usage: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB\")\n", "print(f\"Max VRAM usage: {torch.cuda.max_memory_allocated(0) / 1024**3:.2f} GB\")\n", "\n", "# Training function\n", "def train_epoch(model, train_loader, optimizer, scheduler, device, epoch):\n", " model.train()\n", " total_loss = 0\n", " num_batches = len(train_loader)\n", " \n", " progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')\n", " \n", " for batch_idx, batch in enumerate(progress_bar):\n", " optimizer.zero_grad()\n", " \n", " # Compute loss\n", " loss = compute_loss(model, batch, scheduler, device)\n", " \n", " # Backward pass\n", " loss.backward()\n", " \n", " # Gradient clipping to prevent exploding gradients\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", " \n", " # Update weights\n", " optimizer.step()\n", " \n", " total_loss += loss.item()\n", " \n", " # Update progress bar\n", " avg_loss = total_loss / (batch_idx + 1)\n", " progress_bar.set_postfix({\n", " 'loss': f'{loss.item():.4f}',\n", " 'avg_loss': f'{avg_loss:.4f}',\n", " 'lr': f'{optimizer.param_groups[0][\"lr\"]:.6f}'\n", " })\n", " \n", " # Clear cache every 50 batches to prevent memory buildup\n", " if batch_idx % 50 == 0:\n", " torch.cuda.empty_cache()\n", " \n", " return total_loss / num_batches\n", "\n", "print(\"\\nReady to start training!\")\n", "print(\"Run the training in the next step...\")" ] }, { "cell_type": "code", "execution_count": null, "id": "9c2e7980-3007-474a-8cce-11477d4224af", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting training...\n", "==================================================\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1/20: 100%|██████████| 391/391 [00:43<00:00, 8.95it/s, loss=0.0746, avg_loss=0.1349, lr=0.000100]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 1/20 completed in 43.68s\n", "Average loss: 0.1349\n", "Learning rate: 0.000099\n", "VRAM usage: 0.43 GB\n", "--------------------------------------------------\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 2/20: 13%|█▎ | 49/391 [00:05<00:37, 9.10it/s, loss=0.0544, avg_loss=0.0644, lr=0.000099]" ] } ], "source": [ "# Step 8: Start Training\n", "\n", "import time\n", "import os\n", "\n", "# Create directory to save models\n", "os.makedirs('checkpoints', exist_ok=True)\n", "\n", "# Training history\n", "train_losses = []\n", "start_time = time.time()\n", "\n", "print(\"Starting training...\")\n", "print(\"=\" * 50)\n", "\n", "try:\n", " for epoch in range(num_epochs):\n", " epoch_start_time = time.time()\n", " \n", " # Train one epoch\n", " avg_loss = train_epoch(model, train_loader, optimizer, scheduler, device, epoch)\n", " \n", " # Update learning rate\n", " lr_scheduler.step()\n", " \n", " # Record loss\n", " train_losses.append(avg_loss)\n", " \n", " # Calculate epoch time\n", " epoch_time = time.time() - epoch_start_time\n", " \n", " # Print epoch summary\n", " print(f\"\\nEpoch {epoch+1}/{num_epochs} completed in {epoch_time:.2f}s\")\n", " print(f\"Average loss: {avg_loss:.4f}\")\n", " print(f\"Learning rate: {optimizer.param_groups[0]['lr']:.6f}\")\n", " print(f\"VRAM usage: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB\")\n", " \n", " # Save model checkpoint\n", " if (epoch + 1) % save_every == 0:\n", " checkpoint = {\n", " 'epoch': epoch + 1,\n", " 'model_state_dict': model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'scheduler_state_dict': lr_scheduler.state_dict(),\n", " 'loss': avg_loss,\n", " 'train_losses': train_losses\n", " }\n", " torch.save(checkpoint, f'checkpoints/diffusion_model_epoch_{epoch+1}.pth')\n", " print(f\"Model saved: checkpoints/diffusion_model_epoch_{epoch+1}.pth\")\n", " \n", " print(\"-\" * 50)\n", " \n", " # Clear cache\n", " torch.cuda.empty_cache()\n", "\n", "except KeyboardInterrupt:\n", " print(\"\\nTraining interrupted by user\")\n", " \n", "# Training completed\n", "total_time = time.time() - start_time\n", "print(f\"\\nTraining completed in {total_time/60:.2f} minutes\")\n", "\n", "# Plot training loss\n", "plt.figure(figsize=(10, 6))\n", "plt.plot(train_losses, 'b-', linewidth=2, label='Training Loss')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.title('Training Loss Over Time')\n", "plt.legend()\n", "plt.grid(True, alpha=0.3)\n", "plt.show()\n", "\n", "# Save final model\n", "final_checkpoint = {\n", " 'epoch': num_epochs,\n", " 'model_state_dict': model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'scheduler_state_dict': lr_scheduler.state_dict(),\n", " 'loss': train_losses[-1],\n", " 'train_losses': train_losses\n", "}\n", "torch.save(final_checkpoint, 'checkpoints/diffusion_model_final.pth')\n", "print(\"Final model saved: checkpoints/diffusion_model_final.pth\")\n", "\n", "print(f\"\\nFinal training loss: {train_losses[-1]:.4f}\")\n", "print(f\"Best training loss: {min(train_losses):.4f} at epoch {train_losses.index(min(train_losses))+1}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "e2891971-bc45-4811-ad03-a6fe443a35a7", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python3 (System)", "language": "python", "name": "system-python" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }