alexnasa commited on
Commit
8e16429
·
verified ·
1 Parent(s): 1ee8ca8

Upload 52 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. LICENSE +21 -0
  3. config/a100l.yaml +13 -0
  4. config/multi_gpu.yaml +17 -0
  5. etc/architecture_figure.png +3 -0
  6. etc/genvid_57_11_04453.gif +3 -0
  7. etc/genvid_64_48_08386.gif +3 -0
  8. etc/genvid_87_21_08924.gif +3 -0
  9. one_sample.ipynb +19 -0
  10. run_gen_videos.py +44 -0
  11. scripts/controlnet_train_action_multigpu.sh +58 -0
  12. scripts/controlnet_train_action_singlegpu.sh +58 -0
  13. scripts/mmau_train_video_diffusion_multigpu.sh +52 -0
  14. scripts/mmau_train_video_diffusion_singlegpu.sh +53 -0
  15. src/__init__.py +0 -0
  16. src/__pycache__/__init__.cpython-310.pyc +0 -0
  17. src/datasets/base_dataset.py +189 -0
  18. src/datasets/bbox_utils.py +68 -0
  19. src/datasets/bdd100k_dataset.py +185 -0
  20. src/datasets/dada2000_dataset.py +339 -0
  21. src/datasets/dataset_factory.py +36 -0
  22. src/datasets/dataset_utils.py +50 -0
  23. src/datasets/merged_dataset.py +54 -0
  24. src/datasets/mmau_dataset.py +549 -0
  25. src/datasets/nuscenes_dataset.py +298 -0
  26. src/datasets/russia_crash_dataset.py +173 -0
  27. src/eval/README.md +120 -0
  28. src/eval/__pycache__/generate_samples.cpython-310.pyc +0 -0
  29. src/eval/generate_samples.py +394 -0
  30. src/eval/video_dataset.py +79 -0
  31. src/eval/video_quality_metrics_fvd_gt_rand.py +458 -0
  32. src/eval/video_quality_metrics_fvd_pair.py +349 -0
  33. src/eval/video_quality_metrics_jedi_gt_rand.py +91 -0
  34. src/eval/video_quality_metrics_jedi_pair.py +92 -0
  35. src/models/__init__.py +2 -0
  36. src/models/controlnet.py +391 -0
  37. src/models/unet_spatio_temporal_condition.py +169 -0
  38. src/pipelines/__init__.py +4 -0
  39. src/pipelines/pipeline_video_control.py +408 -0
  40. src/pipelines/pipeline_video_control_factor_guidance.py +615 -0
  41. src/pipelines/pipeline_video_control_nullmodel.py +406 -0
  42. src/pipelines/pipeline_video_diffusion.py +305 -0
  43. src/preprocess/README.md +105 -0
  44. src/preprocess/filter_dataset_tool.py +315 -0
  45. src/preprocess/preprocess_cap_dataset.py +224 -0
  46. src/preprocess/preprocess_dada_dataset.py +222 -0
  47. src/preprocess/preprocess_russia_dataset.py +168 -0
  48. src/preprocess/yolo_sam.py +584 -0
  49. src/utils/__init__.py +2 -0
  50. src/utils/parser.py +472 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ etc/architecture_figure.png filter=lfs diff=lfs merge=lfs -text
37
+ etc/genvid_57_11_04453.gif filter=lfs diff=lfs merge=lfs -text
38
+ etc/genvid_64_48_08386.gif filter=lfs diff=lfs merge=lfs -text
39
+ etc/genvid_87_21_08924.gif filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
config/a100l.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ deepspeed_config: {}
3
+ distributed_type: NO
4
+ fsdp_config: {}
5
+ machine_rank: 0
6
+ main_process_ip: null
7
+ main_process_port: null
8
+ main_training_function: main
9
+ mixed_precision: fp16
10
+ num_machines: 1
11
+ num_processes: 1
12
+ use_cpu: false
13
+ gpu_ids: all
config/multi_gpu.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: true
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: fp16
10
+ num_machines: 1
11
+ num_processes: 4
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
etc/architecture_figure.png ADDED

Git LFS Details

  • SHA256: bc7fb8e3c488eaca947d612cb381338219d993330518db454ab18ae889e8513b
  • Pointer size: 131 Bytes
  • Size of remote file: 332 kB
etc/genvid_57_11_04453.gif ADDED

Git LFS Details

  • SHA256: 0af33f0855619ecea683552fdedb2f66df9ae6d953e9421e1e4dd03ffe8cde4a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.88 MB
etc/genvid_64_48_08386.gif ADDED

Git LFS Details

  • SHA256: 2df6a977385d2dd05e6c624ae0d13f05ddfb832c302e44b4de97d5d347611a94
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
etc/genvid_87_21_08924.gif ADDED

Git LFS Details

  • SHA256: bfb4e57880ce7d208a401eb1d78537173b60b20882d59a8313fe7e96ab1c4861
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
one_sample.ipynb ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "dd192fe4",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": []
10
+ }
11
+ ],
12
+ "metadata": {
13
+ "language_info": {
14
+ "name": "python"
15
+ }
16
+ },
17
+ "nbformat": 4,
18
+ "nbformat_minor": 5
19
+ }
run_gen_videos.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse #test
2
+
3
+ from src.eval.generate_samples import generate_samples
4
+
5
+ if __name__ == "__main__":
6
+ parser = argparse.ArgumentParser(description="Generate test samples from MMAU dataset")
7
+ parser.add_argument('--model_path', type=str, required=True, help='Model checkpoint used for generation')
8
+ parser.add_argument('--data_root', type=str, required=True, help='Dataset root path')
9
+ parser.add_argument('--output_path', type=str, default="./output_videos", help='Video output path')
10
+ parser.add_argument('--disable_null_model', action="store_true", default=False, help='For uncond noise preds, whether to use a null model')
11
+ parser.add_argument('--use_factor_guidance', action="store_true", default=False, help='')
12
+ parser.add_argument('--num_demo_samples', type=int, default=10, help='Number of samples to collect for generation')
13
+ parser.add_argument('--max_output_vids', type=int, default=200, help='Exit program once this many videos have been generated')
14
+ parser.add_argument('--num_gens_per_sample', type=int, default=1, help='Number videos to generate for each test case')
15
+ parser.add_argument('--eval_output', action="store_true", default=False, help='')
16
+ parser.add_argument('--seed', type=int, default=None, help='')
17
+ parser.add_argument('--dataset', type=str, default="mmau")
18
+ parser.add_argument(
19
+ "--bbox_mask_idx_batch",
20
+ nargs="+",
21
+ type=int,
22
+ default=[None],
23
+ choices=list(range(25+1)),
24
+ help="Where to start the masking, multiple values represent multiple different test cases for each sample",
25
+ )
26
+ parser.add_argument(
27
+ "--force_action_type_batch",
28
+ nargs="+",
29
+ type=int,
30
+ default=[None],
31
+ choices=[0, 1, 2, 3, 4],
32
+ help="Which action type to force, multiple values represent multiple different test cases for each sample",
33
+ )
34
+ parser.add_argument(
35
+ "--guidance_scales",
36
+ nargs="+",
37
+ type=int,
38
+ default=[(1, 9)],
39
+ help="Guidance progression to use, multiple values represent multiple different test cases for each sample",
40
+ )
41
+
42
+ args = parser.parse_args()
43
+
44
+ generate_samples(args)
scripts/controlnet_train_action_multigpu.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nvidia-smi | grep 'python' | awk '{ print $5 }' | xargs -n1 kill -9
2
+
3
+ # User-specific paths and settings
4
+ DATASET_PATH="<path/to/datasets>" # e.g., "/home/datasets_root"
5
+ NAME="<experiment_name>" # e.g., "box2video_experiment1"
6
+ OUT_DIR="<path/to/output>/${NAME}" # e.g., "/home/results/${NAME}"
7
+ PROJECT_NAME='<wandb_project_name>' # e.g., 'car_crash'
8
+ WANDB_ENTITY='<wandb_username>' # Your Weights & Biases username
9
+ PRETRAINED_MODEL_PATH="<path/to/pretrained/model>" # e.g., "/home/checkpoints_root/checkpoint"
10
+
11
+ # export HF_HOME=/path/to/root # Where the SVD pretrained models are/will be downloaded
12
+
13
+ # Create output directory
14
+ mkdir -p $OUT_DIR
15
+
16
+ # Save training script for reference
17
+ SCRIPT_PATH=$0
18
+ SAVE_SCRIPT_PATH="${OUT_DIR}/train_scripts.sh"
19
+ cp $SCRIPT_PATH $SAVE_SCRIPT_PATH
20
+ echo "Saved script to ${SAVE_SCRIPT_PATH}"
21
+
22
+ # Training command
23
+ CUDA_LAUNCH_BLOCKING=1 accelerate launch --config_file config/multi_gpu.yaml train_video_controlnet.py \
24
+ --run_name $NAME \
25
+ --data_root $DATASET_PATH \
26
+ --project_name $PROJECT_NAME \
27
+ --pretrained_model_name_or_path $PRETRAINED_MODEL_PATH \
28
+ --output_dir $OUT_DIR \
29
+ --variant fp16 \
30
+ --dataset_name mmau \
31
+ --train_batch_size 1 \
32
+ --learning_rate 4e-5 \
33
+ --checkpoints_total_limit 3 \
34
+ --checkpointing_steps 300 \
35
+ --checkpointing_time 10620 \
36
+ --gradient_accumulation_steps 5 \
37
+ --validation_steps 300 \
38
+ --enable_gradient_checkpointing \
39
+ --lr_scheduler constant \
40
+ --report_to wandb \
41
+ --seed 1234 \
42
+ --mixed_precision fp16 \
43
+ --clip_length 25 \
44
+ --fps 6 \
45
+ --min_guidance_scale 1.0 \
46
+ --max_guidance_scale 3.0 \
47
+ --noise_aug_strength 0.01 \
48
+ --num_demo_samples 15 \
49
+ --num_train_epochs 10 \
50
+ --dataloader_num_workers 0 \
51
+ --resume_from_checkpoint latest \
52
+ --wandb_entity $WANDB_ENTITY \
53
+ --train_H 320 \
54
+ --train_W 512 \
55
+ --use_action_conditioning \
56
+ --contiguous_bbox_masking_prob 0.75 \
57
+ --contiguous_bbox_masking_start_ratio 0.0 \
58
+ --val_on_first_step
scripts/controlnet_train_action_singlegpu.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nvidia-smi | grep 'python' | awk '{ print $5 }' | xargs -n1 kill -9
2
+
3
+ # User-specific paths and settings
4
+ DATASET_PATH="<path/to/datasets>" # e.g., "/home/datasets_root"
5
+ NAME="<experiment_name>" # e.g., "box2video_experiment1"
6
+ OUT_DIR="<path/to/output>/${NAME}" # e.g., "/home/results/${NAME}"
7
+ PROJECT_NAME='<wandb_project_name>' # e.g., 'car_crash'
8
+ WANDB_ENTITY='<wandb_username>' # Your Weights & Biases username
9
+ PRETRAINED_MODEL_PATH="<path/to/pretrained/model>" # e.g., "/path/to/pretrained/checkpoint"
10
+
11
+ # export HF_HOME=/path/to/root # Where the SVD pretrained models are/will be downloaded
12
+
13
+ # Create output directory
14
+ mkdir -p $OUT_DIR
15
+
16
+ # Save training script for reference
17
+ SCRIPT_PATH=$0
18
+ SAVE_SCRIPT_PATH="${OUT_DIR}/train_scripts.sh"
19
+ cp $SCRIPT_PATH $SAVE_SCRIPT_PATH
20
+ echo "Saved script to ${SAVE_SCRIPT_PATH}"
21
+
22
+ # Training command
23
+ CUDA_LAUNCH_BLOCKING=1 accelerate launch --config_file config/a100l.yaml train_video_controlnet.py \
24
+ --run_name $NAME \
25
+ --data_root $DATASET_PATH \
26
+ --project_name $PROJECT_NAME \
27
+ --pretrained_model_name_or_path $PRETRAINED_MODEL_PATH \
28
+ --output_dir $OUT_DIR \
29
+ --variant fp16 \
30
+ --dataset_name mmau \
31
+ --train_batch_size 1 \
32
+ --learning_rate 4e-5 \
33
+ --checkpoints_total_limit 3 \
34
+ --checkpointing_steps 300 \
35
+ --checkpointing_time 10620 \
36
+ --gradient_accumulation_steps 5 \
37
+ --validation_steps 300 \
38
+ --enable_gradient_checkpointing \
39
+ --lr_scheduler constant \
40
+ --report_to wandb \
41
+ --seed 1234 \
42
+ --mixed_precision fp16 \
43
+ --clip_length 25 \
44
+ --fps 6 \
45
+ --min_guidance_scale 1.0 \
46
+ --max_guidance_scale 3.0 \
47
+ --noise_aug_strength 0.01 \
48
+ --num_demo_samples 15 \
49
+ --num_train_epochs 10 \
50
+ --dataloader_num_workers 0 \
51
+ --resume_from_checkpoint latest \
52
+ --wandb_entity $WANDB_ENTITY \
53
+ --train_H 320 \
54
+ --train_W 512 \
55
+ --use_action_conditioning \
56
+ --contiguous_bbox_masking_prob 0.75 \
57
+ --contiguous_bbox_masking_start_ratio 0.0 \
58
+ --val_on_first_step
scripts/mmau_train_video_diffusion_multigpu.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nvidia-smi | grep 'python' | awk '{ print $5 }' | xargs -n1 kill -9
2
+
3
+ # User-specific paths and settings
4
+ DATASET_PATH="<path/to/datasets>" # e.g., "/home/datasets_root"
5
+ NAME="<experiment_name>" # e.g., "box2video_experiment1"
6
+ OUT_DIR="<path/to/output>/${NAME}" # e.g., "/home/results/${NAME}"
7
+ PROJECT_NAME='<wandb_project_name>' # e.g., 'car_crash'
8
+ WANDB_ENTITY='<wandb_username>' # Your Weights & Biases username
9
+ PRETRAINED_MODEL_PATH="stabilityai/stable-video-diffusion-img2vid-xt" # HuggingFace model ID
10
+
11
+ # export HF_HOME=/path/to/root # Where the SVD pretrained models are/will be downloaded
12
+
13
+ # Create output directory
14
+ mkdir -p $OUT_DIR
15
+
16
+ # Save training script for reference
17
+ SCRIPT_PATH=$0
18
+ SAVE_SCRIPT_PATH="${OUT_DIR}/train_scripts.sh"
19
+ cp $SCRIPT_PATH $SAVE_SCRIPT_PATH
20
+ echo "Saved script to ${SAVE_SCRIPT_PATH}"
21
+
22
+ # Training command
23
+ CUDA_LAUNCH_BLOCKING=1 accelerate launch --config_file config/multi_gpu.yaml train_video_diffusion.py \
24
+ --run_name $NAME \
25
+ --data_root $DATASET_PATH \
26
+ --project_name $PROJECT_NAME \
27
+ --pretrained_model_name_or_path $PRETRAINED_MODEL_PATH \
28
+ --output_dir $OUT_DIR \
29
+ --variant fp16 \
30
+ --dataset_name mmau \
31
+ --train_batch_size 1 \
32
+ --learning_rate 1e-5 \
33
+ --checkpoints_total_limit 3 \
34
+ --checkpointing_steps 300 \
35
+ --gradient_accumulation_steps 5 \
36
+ --validation_steps 300 \
37
+ --enable_gradient_checkpointing \
38
+ --lr_scheduler constant \
39
+ --report_to wandb \
40
+ --seed 1234 \
41
+ --mixed_precision fp16 \
42
+ --clip_length 25 \
43
+ --min_guidance_scale 1.0 \
44
+ --max_guidance_scale 3.0 \
45
+ --noise_aug_strength 0.01 \
46
+ --num_demo_samples 15 \
47
+ --backprop_temporal_blocks_start_iter -1 \
48
+ --num_train_epochs 30 \
49
+ --train_H 320 \
50
+ --train_W 512 \
51
+ --resume_from_checkpoint latest \
52
+ --wandb_entity $WANDB_ENTITY
scripts/mmau_train_video_diffusion_singlegpu.sh ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nvidia-smi | grep 'python' | awk '{ print $5 }' | xargs -n1 kill -9
2
+
3
+ # User-specific paths and settings
4
+ DATASET_PATH="<path/to/datasets>" # e.g., "/home/datasets_root"
5
+ NAME="<experiment_name>" # e.g., "box2video_experiment1"
6
+ OUT_DIR="<path/to/output>/${NAME}" # e.g., "/home/results/${NAME}"
7
+ PROJECT_NAME='<wandb_project_name>' # e.g., 'car_crash'
8
+ WANDB_ENTITY='<wandb_username>' # Your Weights & Biases usernames
9
+ PRETRAINED_MODEL_PATH="stabilityai/stable-video-diffusion-img2vid-xt" # HuggingFace model ID
10
+
11
+ # export HF_HOME=/path/to/root # Where the SVD pretrained models are/will be downloaded
12
+
13
+ # Create output directory
14
+ mkdir -p $OUT_DIR
15
+
16
+ # Save training script for reference
17
+ SCRIPT_PATH=$0
18
+ SAVE_SCRIPT_PATH="${OUT_DIR}/train_scripts.sh"
19
+ cp $SCRIPT_PATH $SAVE_SCRIPT_PATH
20
+ echo "Saved script to ${SAVE_SCRIPT_PATH}"
21
+
22
+ # Training command
23
+ CUDA_LAUNCH_BLOCKING=1 accelerate launch --config_file config/a100l.yaml train_video_diffusion.py \
24
+ --run_name $NAME \
25
+ --data_root $DATASET_PATH \
26
+ --project_name $PROJECT_NAME \
27
+ --pretrained_model_name_or_path $PRETRAINED_MODEL_PATH \
28
+ --output_dir $OUT_DIR \
29
+ --variant fp16 \
30
+ --dataset_name mmau \
31
+ --train_batch_size 1 \
32
+ --learning_rate 1e-5 \
33
+ --checkpoints_total_limit 3 \
34
+ --checkpointing_steps 300 \
35
+ --gradient_accumulation_steps 5 \
36
+ --validation_steps 300 \
37
+ --enable_gradient_checkpointing \
38
+ --lr_scheduler constant \
39
+ --report_to wandb \
40
+ --seed 1234 \
41
+ --mixed_precision fp16 \
42
+ --clip_length 25 \
43
+ --min_guidance_scale 1.0 \
44
+ --max_guidance_scale 3.0 \
45
+ --noise_aug_strength 0.01 \
46
+ --bbox_dropout_prob 0.1 \
47
+ --num_demo_samples 15 \
48
+ --backprop_temporal_blocks_start_iter -1 \
49
+ --num_train_epochs 30 \
50
+ --train_H 320 \
51
+ --train_W 512 \
52
+ --resume_from_checkpoint latest \
53
+ --wandb_entity $WANDB_ENTITY
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (153 Bytes). View file
 
src/datasets/base_dataset.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import random
6
+
7
+ from src.datasets.bbox_utils import plot_2d_bbox
8
+
9
+
10
+ class BaseDataset:
11
+
12
+ def __init__(self,
13
+ root='./datasets',
14
+ train=True,
15
+ clip_length=25,
16
+ # orig_width=None, orig_height=None,
17
+ resize_width=512, resize_height=320,
18
+ non_overlapping_clips=False,
19
+ bbox_masking_prob=0.0,
20
+ sample_clip_from_end=True,
21
+ ego_only=False,
22
+ ignore_labels=False):
23
+
24
+ self.root = root
25
+ self.train = train
26
+ self.clip_length = clip_length
27
+ # self.orig_width = orig_width
28
+ # self.orig_height = orig_height
29
+ self.resize_width = resize_width
30
+ self.resize_height = resize_height
31
+
32
+ self.non_overlapping_clips = non_overlapping_clips
33
+ self.bbox_masking_prob = bbox_masking_prob
34
+ self.sample_clip_from_end = sample_clip_from_end
35
+ self.ego_only = ego_only
36
+ self.ignore_labels = ignore_labels
37
+
38
+ self.data_split = 'train' if self.train else 'val'
39
+
40
+ # Image transforms
41
+ self.transform = transforms.Compose([
42
+ transforms.Resize((self.resize_height, self.resize_width)),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # map from [0,1] to [-1,1]
45
+ ])
46
+ self.revert_transform = transforms.Compose([
47
+ transforms.Normalize(mean=(-1, -1, -1), std=(2, 2, 2)),
48
+ ])
49
+
50
+ self.image_files = [] # Contains the paths of all the images in the dataset
51
+ self.clip_list = [] # Contains a list of image indices for each clip
52
+ self.frame_labels = [] # For each image file, contains a list of dicts (labels of each object in the frame)
53
+
54
+ self.disable_cache = True
55
+ if self.disable_cache:
56
+ print("Bbox image caching disabled")
57
+
58
+ def __len__(self):
59
+ return len(self.clip_list)
60
+
61
+ def __getitem__(self, index):
62
+ return self._getclipitem(index)
63
+
64
+ def _getclipitem(self, index):
65
+ frames_indices = self.clip_list[index]
66
+
67
+ images, labels, bboxes, image_paths = [], [], [], []
68
+ masked_track_ids = self._get_masked_track_ids(frames_indices)
69
+ for frame_idx in frames_indices:
70
+ ret_dict = self._getimageitem(frame_idx, masked_track_ids=masked_track_ids)
71
+ images.append(ret_dict["image"])
72
+ labels.append(ret_dict["labels"])
73
+ bboxes.append(ret_dict["bbox_image"])
74
+ image_paths.append(ret_dict["image_path"])
75
+ images = torch.stack(images)
76
+ prompt = "" # NOTE: Currently not supporting prompts
77
+
78
+ action_type = 0 # Assume "normal" driving when unspecified
79
+ if hasattr(self, "action_type_list"):
80
+ action_type = self.action_type_list[index]
81
+
82
+ vid_name = self.image_files[frames_indices[0]].split("/")[-1].split(".")[0][:-5]
83
+
84
+ if not self.ignore_labels:
85
+ bboxes = torch.stack(bboxes)
86
+
87
+ # NOTE: Keys are plural because this makes more sense when batches get collated
88
+ ret_dict = {"clips": images,
89
+ "prompts": prompt,
90
+ "indices": index,
91
+ "bbox_images": bboxes,
92
+ "action_type": action_type,
93
+ "vid_name": vid_name,
94
+ "image_paths": image_paths
95
+ }
96
+ else:
97
+ ret_dict = {"clips": images,
98
+ "prompts": prompt,
99
+ "indices": index}
100
+
101
+ return ret_dict
102
+
103
+
104
+ def _getimageitem(self, frame_index, masked_track_ids=None):
105
+ # Get the image
106
+ image_file = self.image_files[frame_index]
107
+ image = Image.open(image_file)
108
+ image = self.transform(image)
109
+
110
+ if not self.ignore_labels:
111
+
112
+ # Get the labels
113
+ labels = self.frame_labels[frame_index]
114
+
115
+ # Get the bbox image (from cache or draw new one)
116
+ image_filename = image_file.split('/')[-1].split('.')[0]
117
+ cache_filename = f"{image_filename}_bboxes"
118
+ cache_file = os.path.join(self.bbox_image_dir, f"{cache_filename}.jpg")
119
+ redraw_for_masked_agents = masked_track_ids is not None and len(masked_track_ids) > 0
120
+ if not os.path.exists(cache_file) or redraw_for_masked_agents or self.disable_cache:
121
+ bbox_im = self._draw_bbox(labels, cache_img_name=cache_filename, masked_track_ids=masked_track_ids, disable_cache=redraw_for_masked_agents or self.disable_cache)
122
+ else:
123
+ bbox_im = Image.open(cache_file)
124
+ bbox_im = self.transform(bbox_im)
125
+ else:
126
+ labels = None
127
+ bbox_im = None
128
+
129
+ ret_dict = {"image": image,
130
+ "image_path": image_file,
131
+ "labels": labels,
132
+ "frame_index": frame_index,
133
+ "bbox_image": bbox_im}
134
+
135
+ return ret_dict
136
+
137
+ def _draw_bbox(self, frame_labels, cache_img_name=None, masked_track_ids=None, disable_cache=False):
138
+ canvas = torch.zeros((3, self.orig_height, self.orig_width))
139
+ bbox_im = plot_2d_bbox(canvas, frame_labels, show_track_color=True, masked_track_ids=masked_track_ids)
140
+ transform = transforms.Compose([transforms.ToPILImage()])
141
+ bbox_pil = transform(bbox_im)
142
+
143
+ if cache_img_name is not None and not disable_cache:
144
+ if not os.path.exists(self.bbox_image_dir):
145
+ os.makedirs(self.bbox_image_dir, exist_ok=True)
146
+ image_path = os.path.join(self.bbox_image_dir, f"{cache_img_name}.jpg")
147
+ bbox_pil.save(image_path)
148
+ print("Cached bbox file:", image_path)
149
+
150
+ bbox_im = self.transform(bbox_pil)
151
+ return bbox_im
152
+
153
+ def _get_masked_track_ids(self, frames_indices):
154
+ masked_track_ids = []
155
+ if self.bbox_masking_prob > 0:
156
+ # Find all the trackIDs in the clip, randomly select some to mask and exclude from the bbox rendering
157
+ all_track_ids = set()
158
+ for frame_idx in frames_indices:
159
+ frame_labels = self.frame_labels[frame_idx] #self._parse_label(self.image_files[frame])
160
+ for label in frame_labels:
161
+ track_id = label['track_id']
162
+ if track_id not in all_track_ids and random.random() <= self.bbox_masking_prob:
163
+ # Mask out this agent
164
+ masked_track_ids.append(track_id)
165
+ all_track_ids.add(label['track_id'])
166
+
167
+ return masked_track_ids
168
+
169
+ def get_frame_file_by_index(self, index, timestep=0):
170
+ frames = self.clip_list[index]
171
+ if timestep is None:
172
+ ret = []
173
+ for frame in frames:
174
+ ret.append(self.image_files[frame])
175
+ return ret
176
+ return self.image_files[frames[timestep]]
177
+
178
+ def get_bbox_image_file_by_index(self, index=None, image_file=None):
179
+ if image_file is None:
180
+ image_file = self.get_frame_file_by_index(index)
181
+
182
+ clip_name = image_file.split("/")[-2]
183
+ return image_file.replace(self.image_dir, self.bbox_image_dir).replace('/'+clip_name+'/', '/').replace(".jpg", "_bboxes.jpg")
184
+
185
+
186
+
187
+
188
+
189
+
src/datasets/bbox_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from collections import defaultdict
4
+
5
+ class CVCOLORS:
6
+ RED = (0,0,255)
7
+ GREEN = (0,255,0)
8
+ BLUE = (255,0,0)
9
+ PURPLE = (247,44,200)
10
+ ORANGE = (44,162,247)
11
+ MINT = (239,255,66)
12
+ YELLOW = (2,255,250)
13
+ BROWN = (42,42,165)
14
+ LIME=(51,255,153)
15
+ GRAY=(128, 128, 128)
16
+ LIGHTPINK = (222,209,255)
17
+ LIGHTGREEN = (204,255,204)
18
+ LIGHTBLUE = (255,235,207)
19
+ LIGHTPURPLE = (255,153,204)
20
+ LIGHTRED = (204,204,255)
21
+ WHITE = (255,255,255)
22
+ BLACK = (0,0,0)
23
+
24
+ TRACKID_LOOKUP = defaultdict(lambda: (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255)))
25
+ TYPE_LOOKUP = [BROWN, BLUE, PURPLE, RED, ORANGE, YELLOW, LIGHTPINK, LIGHTPURPLE, GRAY, LIGHTRED, GREEN]
26
+ REVERT_CHANNEL_F = lambda x: (x[2], x[1], x[0])
27
+
28
+
29
+ # TODO: This could be moved to base dataset class (?)
30
+ def plot_2d_bbox(img, labels, show_track_color=False, channel_first=True, rgb2bgr=False, box_color=None, masked_track_ids=None, crash_border=False):
31
+
32
+ if channel_first:
33
+ img = img.permute((1, 2, 0)).detach().cpu().numpy().copy()*255
34
+ else:
35
+ img = img.detach().cpu().numpy().copy()*255
36
+
37
+ if rgb2bgr:
38
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
39
+
40
+ masked_track_ids = masked_track_ids or []
41
+
42
+ for i, label_info in enumerate(labels):
43
+ track_id = label_info['track_id']
44
+ if track_id in masked_track_ids:
45
+ continue
46
+
47
+ box_2d = label_info['bbox']
48
+
49
+ if not show_track_color:
50
+ type_color_i = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[label_info['class_id']])) / 255 if box_color is None else box_color
51
+ track_color_i = CVCOLORS.REVERT_CHANNEL_F((1, 1, 1))
52
+
53
+ cv2.rectangle(img, (int(box_2d[0]), int(box_2d[1])), (int(box_2d[2]), int(box_2d[3])), type_color_i, cv2.FILLED)
54
+ cv2.rectangle(img, (int(box_2d[0]), int(box_2d[1])), (int(box_2d[2]), int(box_2d[3])), track_color_i, 2)
55
+ else:
56
+ type_color_i = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[label_info['class_id']])) / 255 if box_color is None else box_color
57
+ track_color_i = CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TRACKID_LOOKUP[label_info['track_id']])
58
+
59
+ dim = min(box_2d[2] - box_2d[0], box_2d[3] - box_2d[1])
60
+ b_thick = min(max(dim * 0.1, 2), 8)
61
+ cv2.rectangle(img, (int(box_2d[0]), int(box_2d[1])), (int(box_2d[2]), int(box_2d[3])), type_color_i, cv2.FILLED)
62
+ cv2.rectangle(img, (int(box_2d[0] + b_thick), int(box_2d[1] + b_thick)), (int(box_2d[2] - b_thick), int(box_2d[3] - b_thick)), track_color_i, cv2.FILLED)
63
+
64
+ if crash_border:
65
+ thickness = 20
66
+ cv2.rectangle(img, (0, 0), (img.shape[1], img.shape[0]), color=(0, 1, 0), thickness=thickness, lineType=cv2.LINE_8)
67
+
68
+ return img
src/datasets/bdd100k_dataset.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_dataset import BaseDataset
2
+
3
+ from PIL import Image
4
+ import os
5
+ import json
6
+
7
+
8
+ class BDD100KDataset(BaseDataset):
9
+ CLASS_NAME_TO_ID = {
10
+ 'pedestrian': 1,
11
+ 'rider': 2,
12
+ 'car': 3,
13
+ 'truck': 4,
14
+ 'bus': 5,
15
+ 'train': 6,
16
+ 'motorcycle': 7,
17
+ 'bicycle': 8,
18
+ 'traffic light': 9,
19
+ 'traffic sign': 10,
20
+ }
21
+
22
+ TO_COCO_LABELS = {
23
+ 1: 0,
24
+ 2: 0,
25
+ 3: 2,
26
+ 4: 7,
27
+ 5: 5,
28
+ 6: 6,
29
+ }
30
+
31
+ TO_IMAGE_DIR = 'images/track'
32
+ TO_BBOX_DIR = 'bboxes/track'
33
+ TO_LABEL_DIR = 'labels'
34
+ TO_BBOX_LABELS = 'labels/box_track_20'
35
+ TO_SEG_LABELS = 'labels/seg_track_20/colormaps'
36
+ TO_POSE_LABELS = 'labels/pose_21'
37
+
38
+ def __init__(self,
39
+ root='./datasets',
40
+ train=True,
41
+ clip_length=25,
42
+ #orig_height=720, orig_width=1280, # TODO: Define this (and use it)
43
+ resize_height=320, resize_width=512,
44
+ non_overlapping_clips=False,
45
+ bbox_masking_prob=0.0,
46
+ sample_clip_from_end=True,
47
+ ego_only=False,
48
+ ignore_labels=False,
49
+ use_preplotted_bbox=True,
50
+ specific_samples=None,
51
+ specific_categories=None,
52
+ force_clip_type=None):
53
+
54
+ super(BDD100KDataset, self).__init__(root=root,
55
+ train=train,
56
+ clip_length=clip_length,
57
+ resize_height=resize_height,
58
+ resize_width=resize_width,
59
+ non_overlapping_clips=non_overlapping_clips,
60
+ bbox_masking_prob=bbox_masking_prob,
61
+ sample_clip_from_end=sample_clip_from_end,
62
+ ego_only=ego_only,
63
+ ignore_labels=ignore_labels)
64
+
65
+ self.MAX_BOXES_PER_DATA = 30
66
+ self._location = 'train' if self.train else 'val'
67
+ self.version = 'bdd100k'
68
+ self.use_preplotted_bbox = use_preplotted_bbox
69
+
70
+ self.image_dir = os.path.join(self.root, self.version, BDD100KDataset.TO_IMAGE_DIR, self._location)
71
+ self.bbox_label_dir = os.path.join(self.root, self.version, BDD100KDataset.TO_BBOX_LABELS, self._location)
72
+ self.bbox_image_dir = os.path.join(self.root, self.version, BDD100KDataset.TO_BBOX_DIR, self._location)
73
+
74
+ if specific_categories is not None:
75
+ print("BDD100k does not support `specific_categories`")
76
+ if force_clip_type is not None:
77
+ print("BDD100k does not support `force_clip_type`")
78
+ self.specific_samples = specific_samples
79
+ if self.specific_samples is not None:
80
+ print("Only loading specific samples:", self.specific_samples)
81
+
82
+ listed_image_dir = os.listdir(self.image_dir)
83
+ try:
84
+ listed_image_dir.remove('pred')
85
+ except:
86
+ pass
87
+ self.clip_folders = sorted(listed_image_dir)
88
+ self.clip_folder_lengths = {k:len(os.listdir(os.path.join(self.image_dir, k))) for k in self.clip_folders}
89
+
90
+ for l in self.clip_folder_lengths.values():
91
+ assert l >= self.clip_length, f'clip length {self.clip_length} is too long for clip folder length {l}'
92
+
93
+ self._collect_clips()
94
+
95
+ def _collect_clips(self):
96
+ print("Collecting dataset clips...")
97
+
98
+ for clip_folder in self.clip_folders:
99
+ clip_path = os.path.join(self.image_dir, clip_folder)
100
+ clip_frames = sorted(os.listdir(clip_path))
101
+
102
+ if self.specific_samples is not None and clip_folder not in self.specific_samples:
103
+ continue
104
+
105
+ # Add all images to image_files
106
+ image_indices = []
107
+ for frame in clip_frames:
108
+ self.image_files.append(os.path.join(clip_path, frame))
109
+ image_indices.append(len(self.image_files)-1)
110
+
111
+ # Create clips of length clip_length
112
+ if self.clip_length is not None:
113
+ # Collect clips as overlapping clips (i.e. A video with 30 frames will yield 5 25-frame clips)
114
+ for start_image_idx in range(0, len(clip_frames) - self.clip_length + 1):
115
+ end_image_idx = start_image_idx + self.clip_length
116
+ clip_indices = image_indices[start_image_idx:end_image_idx]
117
+ self.clip_list.append(clip_indices)
118
+
119
+ def _parse_label(self, label_file, frame_id):
120
+ target = []
121
+ with open(label_file, 'r') as f:
122
+ label = json.load(f)
123
+ frame_i = int(frame_id[-11:-4])-1
124
+ assert frame_id == label[frame_i]['name']
125
+ for obj in label[frame_i-1]['labels']:
126
+ if obj['category'] not in BDD100KDataset.CLASS_NAME_TO_ID:
127
+ continue
128
+ target.append({
129
+ 'frame_name': frame_id,
130
+ 'track_id': int(obj['id']),
131
+ 'bbox': [obj['box2d']['x1'], obj['box2d']['y1'], obj['box2d']['x2'], obj['box2d']['y2']],
132
+ 'class_id': BDD100KDataset.CLASS_NAME_TO_ID[obj['category']],
133
+ 'class_name': obj['category'],
134
+ })
135
+ if len(target) >= self.MAX_BOXES_PER_DATA:
136
+ break
137
+ return target
138
+
139
+ def _getimageitem(self, frame_index, masked_track_ids=None):
140
+ # Get the image
141
+ image_file = self.image_files[frame_index]
142
+ image = Image.open(image_file)
143
+ image = self.transform(image)
144
+
145
+ if not self.ignore_labels:
146
+ # Get the labels
147
+ clip_id = image_file[:image_file.rfind('/')]
148
+ clip_id = clip_id[clip_id.rfind('/')+1:]
149
+ label_file = os.path.join(self.bbox_label_dir, f'{clip_id}.json')
150
+ frame_id = image_file[image_file.rfind('/')+1:]
151
+ labels = self._parse_label(label_file, frame_id)
152
+
153
+ # Get the bbox image
154
+ if self.use_preplotted_bbox:
155
+ bbox_file = self.get_bbox_image_file_by_index(image_file=image_file)
156
+ bbox_im = Image.open(bbox_file)
157
+ bbox_im = self.transform(bbox_im)
158
+ else:
159
+ bbox_im = self._draw_bbox(labels, masked_track_ids=masked_track_ids)
160
+ else:
161
+ labels = None
162
+ bbox_im = None
163
+
164
+ ret_dict = {"image": image,
165
+ "image_path": image_file,
166
+ "labels": labels,
167
+ "frame_index": frame_index,
168
+ "bbox_image": bbox_im}
169
+
170
+ return ret_dict
171
+
172
+ def get_bbox_image_file_by_index(self, index=None, image_file=None):
173
+ if image_file is None:
174
+ image_file = self.get_image_file_by_index(index)
175
+
176
+ return image_file.replace(BDD100KDataset.TO_IMAGE_DIR, BDD100KDataset.TO_BBOX_DIR)
177
+
178
+ def get_image_file_by_index(self, index):
179
+ return self.image_files[index]
180
+
181
+ def __len__(self):
182
+ return len(self.clip_list) if self.clip_length is not None else len(self.image_files)
183
+
184
+ if __name__ == "__init__":
185
+ dataset = BDD100KDataset()
src/datasets/dada2000_dataset.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from tqdm import tqdm
4
+ import csv
5
+
6
+ from src.datasets.base_dataset import BaseDataset
7
+
8
+
9
+ class DADA2000Dataset(BaseDataset):
10
+
11
+ CLASS_NAME_TO_ID = {
12
+ 'person': 1,
13
+ 'car': 3,
14
+ 'truck': 4,
15
+ 'bus': 5,
16
+ 'train': 6,
17
+ 'motorcycle': 7,
18
+ 'bicycle': 8,
19
+ }
20
+
21
+ def __init__(self,
22
+ root='./datasets',
23
+ train=True,
24
+ clip_length=25,
25
+ orig_height=660, orig_width=1056,
26
+ resize_height=320, resize_width=512,
27
+ non_overlapping_clips=False,
28
+ bbox_masking_prob=0.0,
29
+ sample_clip_from_end=True,
30
+ ego_only=False,
31
+ specific_samples=None):
32
+
33
+ super(DADA2000Dataset, self).__init__(root=root,
34
+ train=train,
35
+ clip_length=clip_length,
36
+ resize_height=resize_height,
37
+ resize_width=resize_width,
38
+ non_overlapping_clips=non_overlapping_clips,
39
+ bbox_masking_prob=bbox_masking_prob,
40
+ sample_clip_from_end=sample_clip_from_end,
41
+ ego_only=ego_only)
42
+
43
+ self.dataset_name = "preprocess_dada2000"
44
+
45
+ self.orig_width = orig_width
46
+ self.orig_height = orig_height
47
+ self.image_dir = os.path.join(self.root, self.dataset_name, "images", self.data_split)
48
+ self.label_dir = os.path.join(self.root, self.dataset_name, "labels", self.data_split)
49
+ self.bbox_image_dir = os.path.join(self.root, self.dataset_name, "bbox_images", self.data_split)
50
+ self.metadata_csv_path = os.path.join(self.root, self.dataset_name, "metadata.csv") # TODO: This information could be transfered into each individual label file
51
+
52
+ self.strict_collision_filter = True
53
+ if self.strict_collision_filter:
54
+ print("Strict collision filter set for DADA2000")
55
+
56
+ self.specific_samples = specific_samples
57
+ if self.specific_samples is not None:
58
+ print("Only loading specific samples:", self.specific_samples)
59
+
60
+ self._collect_clips()
61
+
62
+ def _collect_clips(self):
63
+
64
+ accident_frame_metadata = {}
65
+ with open(self.metadata_csv_path) as csv_file:
66
+ csv_reader = csv.reader(csv_file)
67
+ for i, row in enumerate(csv_reader):
68
+ if i == 0:
69
+ continue
70
+
71
+ video_num = row[0]
72
+ video_type = row[5]
73
+ abnormal_start_frame_idx = int(row[7])
74
+ accident_frame_idx = int(row[8])
75
+ abnormal_end_frame_idx = int(row[9])
76
+
77
+ video_name = f"{video_type}_{video_num.rjust(3, '0')}"
78
+
79
+ if accident_frame_idx == "-1":
80
+ # print("Skipping video:", video_name)
81
+ continue
82
+
83
+ # Need to convert the original frame idx to closest downsampled frame index
84
+ downsample_factor = 30/7 # Because we downsampled from 30fps to 7fps
85
+ accident_frame_metadata[video_name] = (int(abnormal_start_frame_idx / downsample_factor), int(accident_frame_idx / downsample_factor + 0.5), int(abnormal_end_frame_idx / downsample_factor + 0.5))
86
+
87
+ self.clip_type_list = [] # crash or normal or abnormal (abnormal is a scene that has abnormal driving but doesn't contain the actual crash moment)
88
+ image_indices_by_clip = {}
89
+ for label_file in sorted(os.listdir(self.label_dir)):
90
+ if not label_file.endswith('.json'):
91
+ continue
92
+
93
+ full_filename = os.path.join(self.label_dir, label_file)
94
+ with open(full_filename) as json_file:
95
+ all_data = json.load(json_file)
96
+ metadata = all_data['metadata']
97
+
98
+ if self.ego_only:
99
+ print("Ego collisions only activated!")
100
+ if metadata['ego_involved'] == False:
101
+ continue
102
+
103
+ if self.strict_collision_filter and metadata["accident_type"] in [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 25, 26, 28, 29, 31, 32, 34, 35, 36]:
104
+ # Reject video where the collision is with "static" agents
105
+ continue
106
+
107
+ # Some rejected clips
108
+ if all_data["video_source"] in ["10_001.mp4"]:
109
+ continue
110
+
111
+ if self.specific_samples is not None and all_data["video_source"].split(".")[0] not in self.specific_samples:
112
+ continue
113
+
114
+
115
+ clip_filename = label_file.split('.')[0]
116
+ clip_file = os.path.join(self.image_dir, clip_filename)
117
+ clip_data = all_data["data"]
118
+ clip_frames = sorted(os.listdir(clip_file))
119
+
120
+ num_frames = len(clip_frames)
121
+ if num_frames < self.clip_length:
122
+ # print(f"{clip_filename} does not have enough frames: has {num_frames} expected at least {self.clip_length}")
123
+ continue
124
+
125
+ accident_metadata = accident_frame_metadata.get(clip_filename)
126
+ if accident_metadata is None:
127
+ print(clip_filename, "no accident metadata found")
128
+ continue
129
+
130
+ # Frames within the abnormal range are considered accidents and outside are considered normal driving
131
+ # ab_start_idx, acc_idx, ab_end_idx = accident_metadata
132
+ # if ab_end_idx - ab_start_idx >= self.clip_length:
133
+ # # We can just feed abnormal clip frames
134
+ # clip_data = clip_data[ab_start_idx:ab_end_idx]
135
+ # clip_frames = clip_frames[ab_start_idx:ab_end_idx]
136
+ # num_frames = len(clip_frames)
137
+ # else:
138
+ # # print(clip_filename, "no enough abnormal frames:", ab_end_idx - ab_start_idx)
139
+ # continue
140
+
141
+ # #NOTE: Let's trim really long videos: videos over 75 frames will get the first frames trimmed
142
+ # if num_frames > 75:
143
+ # # print("Long video:", clip_filename, len(num_frames), "frames")
144
+ # frames_to_trim = num_frames - 75
145
+ # clip_data = clip_data[frames_to_trim:]
146
+ # clip_frames = clip_frames[frames_to_trim:]
147
+
148
+ clip_label_data = self._parse_clip_labels(clip_data)
149
+ self.frame_labels.extend(clip_label_data) # In this case labels are already sorted so they will match up to the image indices
150
+
151
+ image_indices_by_clip[clip_filename] = []
152
+ for image_file in clip_frames:
153
+ self.image_files.append(os.path.join(clip_file, image_file))
154
+ image_indices_by_clip[clip_filename].append(len(self.image_files)-1)
155
+
156
+ assert len(self.frame_labels) == len(self.image_files) # We assume a one-to-one association between images and labels
157
+
158
+ ab_start_idx, acc_idx, ab_end_idx = accident_metadata
159
+ def get_clip_type(image_idx, end_image_idx):
160
+ clip_type = "normal"
161
+ if image_idx <= acc_idx and end_image_idx > acc_idx:
162
+ # Contains accident frame
163
+ clip_type = "crash"
164
+ elif (image_idx >= ab_start_idx and image_idx <= ab_end_idx) or (end_image_idx > ab_start_idx and end_image_idx < ab_end_idx):
165
+ # Does not contain accident frame, but contains abnormal driving (moment before and after accident)
166
+ clip_type = "abnormal"
167
+
168
+ return clip_type
169
+
170
+ # Cut the videos in clips of the correct length according to the strategies chosen
171
+ if not self.non_overlapping_clips:
172
+ for image_idx in range(len(image_indices_by_clip[clip_filename]) - self.clip_length + 1):
173
+ end_image_idx = image_idx+self.clip_length
174
+
175
+ clip_type = get_clip_type(image_idx, end_image_idx)
176
+ if clip_type == "abnormal":
177
+ # Let's just reject the abnormal clips
178
+ continue
179
+
180
+ self.clip_list.append(image_indices_by_clip[clip_filename][image_idx:end_image_idx])
181
+ self.clip_type_list.append(clip_type)
182
+
183
+ else:
184
+ if self.sample_clip_from_end:
185
+ # In case self.clip_length << actual video sample length, we can create multiple non-overlapping clips for each sample
186
+ # Prioritize selecting clips from the end, to make sur the accident is included (which tends to be at the end of the videos)
187
+ total_frames = len(image_indices_by_clip[clip_filename])
188
+ for clip_i in range(total_frames // self.clip_length):
189
+ start_image_idx = total_frames - (self.clip_length * (clip_i + 1))
190
+ end_image_idx = total_frames - (self.clip_length * clip_i)
191
+
192
+ clip_type = get_clip_type(start_image_idx, end_image_idx)
193
+ if clip_type == "abnormal":
194
+ # Let's just reject the abnormal clips
195
+ continue
196
+
197
+ self.clip_list.append(image_indices_by_clip[clip_filename][start_image_idx:end_image_idx])
198
+ self.clip_type_list.append(clip_type)
199
+ else:
200
+ total_frames = len(image_indices_by_clip[clip_filename])
201
+ for clip_i in range(total_frames // self.clip_length):
202
+ start_image_idx = clip_i * self.clip_length
203
+ end_image_idx = start_image_idx + self.clip_length
204
+
205
+ clip_type = get_clip_type(start_image_idx, end_image_idx)
206
+ if clip_type == "abnormal":
207
+ # Let's just reject the abnormal clips
208
+ continue
209
+
210
+ self.clip_list.append(image_indices_by_clip[clip_filename][start_image_idx:end_image_idx])
211
+ self.clip_type_list.append(clip_type)
212
+
213
+ print("Number of clips DADA2000:", len(self.clip_list), f"({self.data_split})")
214
+ crash_clip_count = 0
215
+ normal_clip_count = 0
216
+ for clip_type in self.clip_type_list:
217
+ if clip_type == "crash":
218
+ crash_clip_count += 1
219
+ elif clip_type == "normal":
220
+ normal_clip_count += 1
221
+ print(crash_clip_count, "crash clips", normal_clip_count, "normal clips")
222
+
223
+ def _parse_clip_labels(self, clip_data):
224
+ frame_labels = []
225
+ for frame_data in clip_data:
226
+ obj_data = frame_data['labels']
227
+
228
+ object_labels = []
229
+ for label in obj_data:
230
+ # Only keep the classes of interest
231
+ class_id = DADA2000Dataset.CLASS_NAME_TO_ID.get(label['name'])
232
+ if class_id is None:
233
+ continue
234
+
235
+ # Convert bbox coordinates to pixel space wrt to image size
236
+ bbox = label['box']
237
+ bbox_coords_pixel = [int(bbox[0] * self.orig_width), # x1
238
+ int(bbox[1] * self.orig_height), # y1
239
+ int(bbox[2] * self.orig_width), # x2
240
+ int(bbox[3] * self.orig_height)] # y2
241
+
242
+ object_labels.append({
243
+ 'frame_name': frame_data["image_source"],
244
+ 'track_id': int(label['track_id']),
245
+ 'bbox': bbox_coords_pixel,
246
+ 'class_id': class_id,
247
+ 'class_name': label['name'], # Class name of the object
248
+ })
249
+
250
+ frame_labels.append(object_labels)
251
+
252
+ return frame_labels
253
+
254
+
255
+ def pre_cache_dataset(dataset_root):
256
+ # Trigger label and bbox image cache generation
257
+ from time import time
258
+ dataset_train = DADA2000Dataset(root=dataset_root, train=True, clip_length=25, non_overlapping_clips=False)
259
+ t = time()
260
+ for i in tqdm(range(len(dataset_train))):
261
+ d = dataset_train[i]
262
+ if i >= 100:
263
+ print("Time:", time() - t)
264
+ print("break")
265
+
266
+ dataset_val = DADA2000Dataset(root=dataset_root, train=False, clip_length=25, non_overlapping_clips=True)
267
+ for i in tqdm(range(len(dataset_val))):
268
+ d = dataset_val[i]
269
+
270
+ print("Done.")
271
+
272
+ if __name__ == "__main__":
273
+ dataset_root = "/path/to/Datasets"
274
+ pre_cache_dataset(dataset_root)
275
+
276
+
277
+ """
278
+ ACCIDENT TYPES
279
+ {
280
+ "ego_car_involved": {
281
+ "self_initiated": {
282
+ "out_of_control": [61]
283
+ },
284
+ "dynamic_participants": {
285
+ "person_centric": {
286
+ "pedestrian": [1, 2],
287
+ "cyclist": [3, 4]
288
+ },
289
+ "vehicle_centric": {
290
+ "motorbike": [5, 6],
291
+ "truck": [7, 8, 9],
292
+ "car": [10, 11, 12]
293
+ }
294
+ },
295
+ "static_participants": {
296
+ "road_crentric": {
297
+ "large_roadblocks": [13],
298
+ "curb": [14],
299
+ "small_roadblocks_potholes": [15]
300
+ },
301
+ "other_semantics_centric": {
302
+ "trees": [16],
303
+ "telegraph_poles": [17],
304
+ "other_road_facilities": [18]
305
+ }
306
+ }
307
+ },
308
+ "ego_car_uninvolved": {
309
+ "dynamic_participants": {
310
+ "vehicle_centric": {
311
+ "motorbike_motorbike": [37, 38],
312
+ "truck_truck": [39, 40, 41],
313
+ "car_car": [42, 43, 44],
314
+ "motorbike_truck": [45, 46, 47],
315
+ "truck_car": [48, 49],
316
+ "car_motorbike": [50, 51]
317
+ },
318
+ "person_centric": [52, 53, 54, 55, 56, 57, 58, 59, 60]
319
+ },
320
+ "static_participants" : [19, 20, 21, 22, 25, 26, 28, 29, 31, 32, 34, 35, 36]
321
+ },
322
+ "summary": {
323
+ "ego_car_involved": {
324
+ "person_centric": [1, 2, 3, 4],
325
+ "vehicle_centric": [5, 6, 7, 8, 9, 10, 11, 12],
326
+ "static_participants": [13, 14, 15, 16, 17, 18],
327
+ "out_of_control": [61]
328
+ },
329
+ "ego_car_uninvolved": {
330
+ "static_participants": [19, 20, 21, 22, 25, 26, 28, 29, 31, 32, 34, 35, 36],
331
+ "vehicle_centric": [37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 50, 51],
332
+ "person_centric": [52, 53, 54, 55, 56, 57, 58, 59, 60]
333
+ }
334
+ }
335
+ }
336
+ """
337
+
338
+
339
+
src/datasets/dataset_factory.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.datasets.merged_dataset import MergedDataset
2
+ from src.datasets.russia_crash_dataset import RussiaCrashDataset
3
+ from src.datasets.dada2000_dataset import DADA2000Dataset
4
+ from src.datasets.mmau_dataset import MMAUDataset
5
+ from src.datasets.bdd100k_dataset import BDD100KDataset
6
+ # from src.datasets.nuscenes_dataset import NuScenesDataset
7
+
8
+
9
+ def create_dataset(dataset_name, **kwargs):
10
+
11
+ if str.lower(dataset_name) == "russia_crash":
12
+ dataset = RussiaCrashDataset(**kwargs)
13
+ elif str.lower(dataset_name) == "nuscenes":
14
+ dataset = NuScenesDataset(**kwargs)
15
+ elif str.lower(dataset_name) == "dada2000":
16
+ dataset = DADA2000Dataset(**kwargs)
17
+ elif str.lower(dataset_name) == "mmau":
18
+ dataset = MMAUDataset(**kwargs)
19
+ elif str.lower(dataset_name) == "bdd100k":
20
+ dataset = BDD100KDataset(**kwargs)
21
+ else:
22
+ raise NotImplementedError(f"Dataset '{dataset_name}' not implemented")
23
+
24
+ return dataset
25
+
26
+ def dataset_factory(dataset_names, **kwargs):
27
+ if isinstance(dataset_names, str) or (isinstance(dataset_names, list) and len(dataset_names) == 1):
28
+ dataset_name = dataset_names[0] if isinstance(dataset_names, list) else dataset_names
29
+ # Init the single dataset
30
+ return create_dataset(dataset_name, **kwargs)
31
+ elif isinstance(dataset_names, list):
32
+ all_datasets = []
33
+ for dataset_name in dataset_names:
34
+ all_datasets.append(create_dataset(dataset_name, **kwargs))
35
+ return MergedDataset(all_datasets)
36
+
src/datasets/dataset_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ from .dataset_factory import dataset_factory
5
+
6
+
7
+ def worker_init_fn(worker_id):
8
+ os.sched_setaffinity(0, range(os.cpu_count()))
9
+
10
+ def get_dataloader(dset_root,
11
+ dset_names,
12
+ if_train,
13
+ batch_size,
14
+ num_workers,
15
+ clip_length=25,
16
+ shuffle=True,
17
+ image_height=None,
18
+ image_width=None,
19
+ non_overlapping_clips=False,
20
+ ego_only=False,
21
+ bbox_masking_prob=0.0,
22
+ specific_samples=None,
23
+ specific_categories=None,
24
+ force_clip_type=None):
25
+
26
+ dataset = dataset_factory(dset_names,
27
+ root=dset_root,
28
+ train=if_train,
29
+ clip_length=clip_length,
30
+ resize_height=image_height,
31
+ resize_width=image_width,
32
+ non_overlapping_clips=non_overlapping_clips,
33
+ bbox_masking_prob=bbox_masking_prob,
34
+ ego_only=ego_only,
35
+ specific_samples=specific_samples,
36
+ specific_categories=specific_categories,
37
+ force_clip_type=force_clip_type)
38
+
39
+ dataloader = torch.utils.data.DataLoader(
40
+ dataset,
41
+ batch_size=batch_size,
42
+ num_workers=num_workers,
43
+ shuffle=shuffle,
44
+ pin_memory=True,
45
+ drop_last=True,
46
+ worker_init_fn=worker_init_fn
47
+ )
48
+
49
+ return dataset, dataloader
50
+
src/datasets/merged_dataset.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+
3
+ class MergedDataset:
4
+ """
5
+ Dataset wrapper to access many datasets as one
6
+ """
7
+
8
+ def __init__(self, dataset_list):
9
+
10
+ self.dataset_list = dataset_list
11
+
12
+ # TODO: Make sure this matches all datasets
13
+ self.resize_width = self.dataset_list[0].resize_width
14
+ self.resize_height = self.dataset_list[0].resize_height
15
+ self.revert_transform = self.dataset_list[0].revert_transform
16
+
17
+ print("TOTAL number of clips in merged dataset:", self.__len__(), f"({self.dataset_list[0].data_split})")
18
+
19
+
20
+ def __len__(self):
21
+ return sum([len(dset) for dset in self.dataset_list])
22
+
23
+
24
+ def __getitem__(self, global_index):
25
+
26
+ target_dset, rel_index = self.get_dataset_by_sample_index(global_index)
27
+ ret_dict = target_dset.__getitem__(rel_index)
28
+
29
+ # Overwrite returned index with the global index
30
+ ret_dict["indices"] = global_index
31
+
32
+ return ret_dict
33
+
34
+
35
+ def get_dataset_by_sample_index(self, index):
36
+ total_idx = 0
37
+ target_dset = None
38
+ for dset in self.dataset_list:
39
+ total_idx += len(dset)
40
+ if index < total_idx:
41
+ target_dset = dset
42
+ break
43
+
44
+ return target_dset, (index - (total_idx - len(target_dset)))
45
+
46
+
47
+ def get_frame_file_by_index(self, index, timestep=None):
48
+ target_dset, rel_index = self.get_dataset_by_sample_index(index)
49
+ return target_dset.get_frame_file_by_index(rel_index, timestep=timestep)
50
+
51
+
52
+ def get_bbox_image_file_by_index(self, index, image_file=None):
53
+ target_dset, rel_index = self.get_dataset_by_sample_index(index)
54
+ return target_dset.get_bbox_image_file_by_index(index=rel_index)
src/datasets/mmau_dataset.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from tqdm import tqdm
4
+ import csv
5
+ import json
6
+ import cv2
7
+
8
+ from src.datasets.base_dataset import BaseDataset
9
+
10
+ def load_json(filename):
11
+ if os.path.exists(filename):
12
+ with open(filename, "r") as f:
13
+ return json.load(f)
14
+ print(filename, "not found")
15
+ return []
16
+
17
+ def create_video_from_images(images_list, output_video, out_fps, start_frame=None, end_frame=None):
18
+
19
+ img0_path = images_list[0]
20
+ img0 = cv2.imread(img0_path)
21
+ height, width, _ = img0.shape
22
+
23
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
24
+ out = cv2.VideoWriter(output_video, fourcc, out_fps, (width, height))
25
+
26
+ for idx, frame_name in enumerate(images_list):
27
+
28
+ if start_frame is not None and idx < start_frame:
29
+ continue
30
+ if end_frame is not None and idx >= end_frame:
31
+ continue
32
+
33
+ img = cv2.imread(frame_name)
34
+ out.write(img)
35
+
36
+ out.release()
37
+ print("Saved video:", output_video)
38
+
39
+
40
+ class MMAUDataset(BaseDataset):
41
+
42
+ CLASS_NAME_TO_ID = {
43
+ 'person': 1,
44
+ 'car': 3,
45
+ 'truck': 4,
46
+ 'bus': 5,
47
+ 'train': 6,
48
+ 'motorcycle': 7,
49
+ 'bicycle': 8,
50
+ }
51
+
52
+ def __init__(self,
53
+ root='./datasets',
54
+ train=True,
55
+ clip_length=25,
56
+ orig_height=640, orig_width=1024,
57
+ resize_height=320, resize_width=512,
58
+ non_overlapping_clips=False,
59
+ bbox_masking_prob=0.0,
60
+ sample_clip_from_end=True,
61
+ ego_only=False,
62
+ specific_samples=None,
63
+ specific_categories=None,
64
+ dada_only=False,
65
+ cleanup_dataset=False,
66
+ force_clip_type=None
67
+ ):
68
+
69
+ self.ignore_labels = False
70
+ if self.ignore_labels:
71
+ print("IGNORING LABELS in MMAU dataset")
72
+
73
+ super(MMAUDataset, self).__init__(root=root,
74
+ train=train,
75
+ clip_length=clip_length,
76
+ resize_height=resize_height,
77
+ resize_width=resize_width,
78
+ non_overlapping_clips=non_overlapping_clips,
79
+ bbox_masking_prob=bbox_masking_prob,
80
+ sample_clip_from_end=sample_clip_from_end,
81
+ ego_only=ego_only,
82
+ ignore_labels=self.ignore_labels) # NOTE: Ignoring labels currently
83
+
84
+ self.dada_only = dada_only
85
+ self.cleanup_dataset = cleanup_dataset
86
+ self.dataset_name = "mmau_images_12fps" if not dada_only else "dada2000_images_12fps"
87
+
88
+ self.orig_width = orig_width
89
+ self.orig_height = orig_height
90
+ self.split = "train" if train else "val"
91
+
92
+ self.image_dir = os.path.join(self.root, self.dataset_name, "images")
93
+ self.label_dir = os.path.join(self.root, self.dataset_name, "labels")
94
+ self.bbox_image_dir = os.path.join(self.root, self.dataset_name, "bbox_images")
95
+
96
+ self.downsample_6fps = True
97
+ if self.downsample_6fps:
98
+ print("Downsampling MMAU clips to 6 fps")
99
+
100
+ self.ego_only = ego_only
101
+ if self.ego_only:
102
+ print("Ego collisions only filter set for MMAU dataset")
103
+
104
+ self.strict_collision_filter = False
105
+ if self.strict_collision_filter:
106
+ print("Strict collision filter set for MMAU dataset")
107
+
108
+ self.specific_samples = specific_samples
109
+ if self.specific_samples is not None:
110
+ print("Only loading specific samples:", self.specific_samples)
111
+
112
+ self.specific_categories = specific_categories
113
+ if self.specific_categories is not None:
114
+ print("Only loading specific categories:", self.specific_categories)
115
+
116
+ self.force_clip_type = force_clip_type
117
+ if self.force_clip_type is not None:
118
+ print("Only loading samples with type:", force_clip_type)
119
+
120
+ self._collect_clips()
121
+
122
+ def _collect_metadata_csv(self, metadata_csv_path):
123
+ accident_frame_metadata = {}
124
+ with open(metadata_csv_path) as csv_file:
125
+ csv_reader = csv.reader(csv_file)
126
+ for i, row in enumerate(csv_reader):
127
+ if i == 0:
128
+ continue
129
+
130
+ video_num = str(int(row[0]))
131
+ video_type = str(int(row[5]))
132
+ abnormal_start_frame_idx = int(row[7])
133
+ accident_frame_idx = int(row[9])
134
+ abnormal_end_frame_idx = int(row[8])
135
+
136
+ video_name = f"{video_type}_{video_num.rjust(5, '0')}"
137
+
138
+ if accident_frame_idx == "-1":
139
+ # print("Skipping video:", video_name)
140
+ continue
141
+
142
+ downsample_factor = 30/12 if video_num.startswith("90") else 1 # Downsample for DADA (30fps) and CAP (12 fps) are not the same
143
+ if self.downsample_6fps:
144
+ downsample_factor *= 2
145
+ accident_frame_metadata[video_name] = (int(abnormal_start_frame_idx / downsample_factor),
146
+ int(accident_frame_idx / downsample_factor + 0.5),
147
+ int(abnormal_end_frame_idx / downsample_factor + 0.5))
148
+
149
+ return accident_frame_metadata
150
+
151
+ def _collect_clips(self):
152
+ print("Collecting dataset clips...")
153
+
154
+ mmau_dataset = os.path.join(self.root, self.dataset_name)
155
+
156
+ # Load data split
157
+ datasplit_data = load_json(os.path.join(mmau_dataset, "mmau_datasplit.json"))
158
+
159
+ # Compile reject videos
160
+ auto_filtered_vids = load_json(os.path.join(mmau_dataset, "auto_low_quality.json"))
161
+ rejected_vids = load_json(os.path.join(mmau_dataset, "rejected.json"))
162
+ all_rejected_vids = auto_filtered_vids + rejected_vids
163
+
164
+ # Collect the accident moment information
165
+ accident_frame_metadata = self._collect_metadata_csv(os.path.join(mmau_dataset, "mmau_metadata.csv"))
166
+
167
+ self.clip_type_list = [] # crash or normal or abnormal (abnormal is a scene that has abnormal driving but doesn't contain the actual crash moment)
168
+ self.action_type_list = [] # 0-4, 0 is normal, 1-4 are different types of crashes
169
+ image_indices_by_clip = {}
170
+
171
+ null_labels = []
172
+ # Iterate datasplit file
173
+ count_vid = 0
174
+ for category, split in datasplit_data.items():
175
+ for split_name, vid_names in split.items():
176
+ if split_name != self.split:
177
+ continue
178
+
179
+ for vid_name in vid_names:
180
+ if vid_name in all_rejected_vids:
181
+ continue
182
+
183
+ if self.dada_only and not vid_name.split("_")[-1].startswith("90"): # NOTE: REMOVE THIS
184
+ continue
185
+
186
+ # Read image files
187
+ image_dir = os.path.join(mmau_dataset, "images")
188
+ clip_file = os.path.join(image_dir, category, vid_name)
189
+ clip_frames = sorted(os.listdir(clip_file))
190
+
191
+ if self.cleanup_dataset:
192
+ # NOTE: For renaming frames (can remove this later)
193
+ fix_label = False
194
+ for frame_name in clip_frames:
195
+ if vid_name not in frame_name:
196
+ fix_label = True
197
+ new_frame_name = f"{vid_name}_{frame_name}"
198
+ root_path = os.path.join(mmau_dataset, "images", category, vid_name)
199
+ os.rename(os.path.join(root_path, frame_name), os.path.join(root_path, new_frame_name))
200
+
201
+ image_dir = os.path.join(mmau_dataset, "images")
202
+ clip_file = os.path.join(image_dir, category, vid_name)
203
+ clip_frames = sorted(os.listdir(clip_file))
204
+
205
+ # Also rename in label file
206
+ label_file_path = os.path.join(self.label_dir, f"{vid_name}.json")
207
+ if os.path.exists(label_file_path) and fix_label:
208
+ with open(label_file_path, "r") as f:
209
+ data = json.load(f)
210
+ data_field = data["data"]
211
+ if data_field is None:
212
+ print(f"{vid_name}.json CLIP DATA IS NULL 2")
213
+ null_labels.append(vid_name)
214
+ else:
215
+ for i, frame_data in enumerate(data_field):
216
+ current_frame_name = frame_data["image_source"]
217
+ if vid_name not in current_frame_name:
218
+ new_frame_name = f"{vid_name}_{current_frame_name}"
219
+ data["data"][i]["image_source"] = new_frame_name
220
+
221
+ with open(label_file_path, "w") as f:
222
+ json.dump(data, f, indent=1)
223
+
224
+ num_frames = len(clip_frames) if not self.downsample_6fps else len(clip_frames) // 2
225
+ if num_frames < self.clip_length:
226
+ print(f"{vid_name} does not have enough frames: has {num_frames}, expected at least {self.clip_length}")
227
+ continue
228
+
229
+ accident_metadata = accident_frame_metadata.get(vid_name)
230
+ if accident_metadata is None:
231
+ print(vid_name, "no accident metadata found")
232
+ continue
233
+
234
+ step = 2 if self.downsample_6fps else 1
235
+ clip_frame_names = []
236
+ for image_idx in range(0, len(clip_frames), step):
237
+ image_file = clip_frames[image_idx]
238
+ clip_frame_names.append(image_file)
239
+
240
+ count_vid += 1
241
+ # Read label file
242
+ if not self.ignore_labels:
243
+ label_file_path = os.path.join(self.label_dir, f"{vid_name}.json")
244
+ if not os.path.exists(label_file_path):
245
+ if num_frames <= 300:
246
+ # Because a lot of the long videos were rejected because they were too long to process
247
+ # print(f"{label_file_path} does not exist")
248
+ pass
249
+ continue
250
+
251
+ with open(label_file_path) as json_file:
252
+ all_data = json.load(json_file)
253
+ metadata = all_data['metadata']
254
+
255
+ if self.ego_only:
256
+ if metadata['ego_involved'] == False:
257
+ continue
258
+
259
+ if self.strict_collision_filter and metadata["accident_type"] in [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 25, 26, 28, 29, 31, 32, 34, 35, 36]:
260
+ # Reject video where the collision is with "static" agents
261
+ continue
262
+
263
+ # Some post-hoc rejected clips
264
+ if all_data["video_source"] in ["10_90001.mp4"]:
265
+ continue
266
+
267
+ if self.specific_samples is not None and all_data["video_source"].split(".")[0] not in self.specific_samples:
268
+ continue
269
+
270
+ if self.specific_categories is not None and metadata["accident_type"] not in self.specific_categories:
271
+ continue
272
+
273
+ clip_data = all_data["data"]
274
+ if clip_data is None:
275
+ print(f"{vid_name}.json CLIP DATA IS NULL")
276
+ null_labels.append(vid_name)
277
+ continue
278
+
279
+ clip_label_data = self._parse_clip_labels(clip_data, clip_frame_names)
280
+ self.frame_labels.extend(clip_label_data) # In this case, labels are already sorted so they will match up to the image indices
281
+
282
+ image_indices_by_clip[vid_name] = []
283
+ for image_file in clip_frame_names:
284
+ self.image_files.append(os.path.join(clip_file, image_file))
285
+ image_indices_by_clip[vid_name].append(len(self.image_files)-1)
286
+
287
+ if not self.ignore_labels:
288
+ assert len(self.frame_labels) == len(self.image_files), f"{len(self.frame_labels)} frame labels != {len(self.image_files)} image files" # We assume a one-to-one association between images and labels
289
+
290
+ ab_start_idx, acc_idx, ab_end_idx = accident_metadata
291
+ def get_clip_type(image_idx, end_image_idx):
292
+ clip_type = "normal"
293
+ if image_idx <= acc_idx and end_image_idx >= acc_idx:
294
+ # Contains accident frame
295
+ clip_type = "crash"
296
+ elif (image_idx >= ab_start_idx and image_idx <= ab_end_idx) \
297
+ or (end_image_idx >= ab_start_idx and end_image_idx <= ab_end_idx) \
298
+ or image_idx > acc_idx: # Let's also consider "normal" driving clip that happen after the accident to be "abnormal" as they might show the aftermath (e.g. car damage)
299
+ # Does not contain accident frame, but contains abnormal driving (moment before and after accident)
300
+ clip_type = "abnormal"
301
+
302
+ return clip_type
303
+
304
+ # Cut the videos in clips of the correct length
305
+ # NOTE: Only implementing strategy of selecting two clips per video: 1 with normal driving and 1 with crash
306
+ # Select normal driving from beginning preferably and crash clip try to center it on the accident instant
307
+
308
+ # Find crash clip
309
+ crash_found = False
310
+ if self.force_clip_type is None or self.force_clip_type == "crash":
311
+ start_image_idx, end_image_idx = None, None
312
+ total_frames = len(image_indices_by_clip[vid_name])
313
+ if acc_idx is not None and self.clip_length is not None:
314
+ # Keep frame_count frames around accident frame
315
+ start_image_idx = acc_idx - int(self.clip_length/2 + 0.5)
316
+ end_image_idx = acc_idx + int(self.clip_length/2)
317
+
318
+ if total_frames < self.clip_length:
319
+ print(f"Not enough frames in '{vid_name}': {total_frames}, skipping")
320
+ else:
321
+ if start_image_idx < 0:
322
+ end_image_idx += -(start_image_idx)
323
+ start_image_idx = 0
324
+
325
+ if end_image_idx > total_frames:
326
+ start_image_idx -= (end_image_idx - total_frames)
327
+ end_image_idx = total_frames
328
+
329
+ self.clip_list.append(image_indices_by_clip[vid_name][start_image_idx:end_image_idx])
330
+ self.clip_type_list.append("crash")
331
+ action_type = self._get_action_type(metadata["accident_type"])
332
+ self.action_type_list.append(action_type)
333
+ crash_found = True
334
+
335
+ # Debug: #############
336
+ # frame_path_list = [self.image_files[i] for i in image_indices_by_clip[vid_name][start_image_idx:end_image_idx]]
337
+ # create_video_from_images(frame_path_list, f"outputs/sample_clip_{vid_name}_crash.mp4", out_fps=6 if self.downsample_6fps else 12)
338
+
339
+ # # Debug plot bboxes:
340
+ # out_bbox_path = os.path.join("outputs", f"{vid_name}_bboxes_crash")
341
+ # os.makedirs(out_bbox_path, exist_ok=True)
342
+ # for frame_path, label_data in zip(frame_path_list, clip_label_data[start_image_idx:end_image_idx]):
343
+ # plt.figure(figsize=(9, 6))
344
+ # plt.axis("off")
345
+ # img = Image.open(frame_path)
346
+ # plt.imshow(img)
347
+ # for obj in label_data:
348
+ # color = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[obj["class_id"]])) / 255.0
349
+ # show_box(obj["bbox"], plt.gca(), label=str(obj["track_id"]), color=color)
350
+
351
+ # frame_id_name = frame_path.split("_")[-1].split(".")[0]
352
+ # plt.savefig(os.path.join(out_bbox_path, f"bboxes_frame_{frame_id_name}.jpg"))
353
+ #######################3
354
+
355
+ if not crash_found:
356
+ print("Crash not found for", vid_name)
357
+
358
+ assert end_image_idx > start_image_idx
359
+
360
+ if self.force_clip_type is None or self.force_clip_type == "normal":
361
+ normal_found = False
362
+ for start_image_idx in range(len(image_indices_by_clip[vid_name]) - self.clip_length + 1):
363
+ end_image_idx = start_image_idx+self.clip_length
364
+
365
+ clip_type = get_clip_type(start_image_idx, end_image_idx)
366
+ if clip_type == "abnormal" or clip_type == "crash":
367
+ # Let's just reject the abnormal clips
368
+ continue
369
+
370
+ self.clip_list.append(image_indices_by_clip[vid_name][start_image_idx:end_image_idx])
371
+ self.clip_type_list.append(clip_type)
372
+ self.action_type_list.append(0)
373
+ normal_found = True
374
+
375
+ # Debug: ########
376
+ # frame_path_list = [self.image_files[i] for i in image_indices_by_clip[vid_name][start_image_idx:end_image_idx]]
377
+ # create_video_from_images(frame_path_list, f"outputs/sample_clip_{vid_name}_normal.mp4", out_fps=6 if self.downsample_6fps else 12)
378
+
379
+ # out_bbox_path = os.path.join("outputs", f"{vid_name}_bboxes_normal")
380
+ # os.makedirs(out_bbox_path, exist_ok=True)
381
+ # for frame_path, label_data in zip(frame_path_list, clip_label_data[start_image_idx:end_image_idx]):
382
+ # plt.figure(figsize=(9, 6))
383
+ # plt.axis("off")
384
+ # img = Image.open(frame_path)
385
+ # plt.imshow(img)
386
+ # for obj in label_data:
387
+ # color = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[obj["class_id"]])) / 255.0
388
+ # show_box(obj["bbox"], plt.gca(), label=str(obj["track_id"]), color=color)
389
+
390
+ # frame_id_name = frame_path.split("_")[-1].split(".")[0]
391
+ # plt.savefig(os.path.join(out_bbox_path, f"bboxes_frame_{frame_id_name}.jpg"))
392
+ #################
393
+
394
+ break
395
+
396
+ # if not normal_found:
397
+ # print("Normal not found for", vid_name)
398
+
399
+ assert len(self.clip_list) == len(self.clip_type_list) == len(self.action_type_list)
400
+
401
+ print("Number of clips MMAU:", len(self.clip_list), f"({self.data_split})", f"(from {count_vid} original videos)")
402
+ crash_clip_count = 0
403
+ normal_clip_count = 0
404
+ for clip_type in self.clip_type_list:
405
+ if clip_type == "crash":
406
+ crash_clip_count += 1
407
+ elif clip_type == "normal":
408
+ normal_clip_count += 1
409
+ print(crash_clip_count, "crash clips", normal_clip_count, "normal clips")
410
+
411
+ if self.cleanup_dataset and len(null_labels) > 0:
412
+ print("Null labels:", null_labels)
413
+ for label_name in null_labels:
414
+ label_file_path = os.path.join(self.label_dir, f"{label_name}.json")
415
+ if os.path.exists(label_file_path):
416
+ os.remove(label_file_path)
417
+ print("Removed label file:", label_file_path)
418
+
419
+ def _parse_clip_labels(self, clip_data, clip_frame_names):
420
+ frame_labels = []
421
+ for frame_data in clip_data:
422
+ obj_data = frame_data['labels']
423
+ image_source = frame_data["image_source"]
424
+
425
+ if self.downsample_6fps and image_source not in clip_frame_names:
426
+ # Only preserve even numbered frames
427
+ continue
428
+
429
+ object_labels = []
430
+ for label in obj_data:
431
+ # Only keep the classes of interest
432
+ class_id = MMAUDataset.CLASS_NAME_TO_ID.get(label['name'])
433
+ if class_id is None:
434
+ continue
435
+
436
+ # Convert bbox coordinates to pixel space wrt to image size
437
+ bbox = label['box']
438
+ bbox_coords_pixel = [int(bbox[0] * self.orig_width), # x1
439
+ int(bbox[1] * self.orig_height), # y1
440
+ int(bbox[2] * self.orig_width), # x2
441
+ int(bbox[3] * self.orig_height)] # y2
442
+
443
+ object_labels.append({
444
+ 'frame_name': image_source,
445
+ 'track_id': int(label['track_id']),
446
+ 'bbox': bbox_coords_pixel,
447
+ 'class_id': class_id,
448
+ 'class_name': label['name'], # Class name of the object
449
+ })
450
+
451
+ frame_labels.append(object_labels)
452
+
453
+ return frame_labels
454
+
455
+ def _get_action_type(self, accident_type):
456
+ # [0: normal, 1: ego, 2: ego/veh, 3: veh, 4: veh/veh]
457
+ accident_type = int(accident_type)
458
+ if accident_type in [61, 62, 13, 14, 15, 16, 17, 18]:
459
+ return 1
460
+ elif accident_type in range(1, 12 + 1):
461
+ return 2
462
+ elif accident_type in [37, 39, 41, 42, 44] + list(range(19, 36 + 1)) + list(range(52, 60 + 1)):
463
+ return 3
464
+ elif accident_type in [38, 40, 43, 45, 46, 47, 48, 49, 50, 51]:
465
+ return 4
466
+ else:
467
+ raise ValueError(f"Unknown accident type: {accident_type}")
468
+
469
+ def pre_cache_dataset(dataset_root):
470
+ # dset = MMAUDataset(dataset_root, train=False, cleanup_dataset=True, specific_categories=["42"])
471
+
472
+ dset = MMAUDataset(dataset_root, train=False, cleanup_dataset=True)
473
+ # s = dset.__getitem__(0)
474
+
475
+ # dset = MMAUDataset(dataset_root, train=False, cleanup_dataset=True)
476
+ # s = dset.__getitem__(0)
477
+ # Trigger label and bbox image cache generation
478
+ # from time import time
479
+ # dataset_train = DADA2000Dataset(root=dataset_root, train=True, clip_length=25, non_overlapping_clips=False)
480
+ # t = time()
481
+ # for i in tqdm(range(len(dataset_train))):
482
+ # d = dataset_train[i]
483
+ # if i >= 100:
484
+ # print("Time:", time() - t)
485
+ # print("break")
486
+
487
+ # dataset_val = DADA2000Dataset(root=dataset_root, train=False, clip_length=25, non_overlapping_clips=True)
488
+ # for i in tqdm(range(len(dataset_val))):
489
+ # d = dataset_val[i]
490
+
491
+ # print("Done.")
492
+
493
+ if __name__ == "__main__":
494
+ dataset_root = "/path/to/Datasets"
495
+ pre_cache_dataset(dataset_root)
496
+
497
+ MMAUDataset(dataset_root, train=True)
498
+
499
+
500
+ """
501
+ ACCIDENT TYPES
502
+ {
503
+ "ego_car_involved": {
504
+ "self_initiated": {
505
+ "out_of_control": [61]
506
+ },
507
+ "dynamic_participants": {
508
+ "person_centric": {
509
+ "pedestrian": [1, 2],
510
+ "cyclist": [3, 4]
511
+ },
512
+ "vehicle_centric": {
513
+ "motorbike": [5, 6],
514
+ "truck": [7, 8, 9],
515
+ "car": [10, 11, 12]
516
+ }
517
+ },
518
+ "static_participants": {
519
+ "road_crentric": {
520
+ "large_roadblocks": [13],
521
+ "curb": [14],
522
+ "small_roadblocks_potholes": [15]
523
+ },
524
+ "other_semantics_centric": {
525
+ "trees": [16],
526
+ "telegraph_poles": [17],
527
+ "other_road_facilities": [18]
528
+ }
529
+ }
530
+ },
531
+ "ego_car_uninvolved": {
532
+ "dynamic_participants": {
533
+ "vehicle_centric": {
534
+ "motorbike_motorbike": [37, 38],
535
+ "truck_truck": [39, 40, 41],
536
+ "car_car": [42, 43, 44],
537
+ "motorbike_truck": [45, 46, 47],
538
+ "truck_car": [48, 49],
539
+ "car_motorbike": [50, 51]
540
+ },
541
+ "person_centric": [52, 53, 54, 55, 56, 57, 58, 59, 60]
542
+ },
543
+ "static_participants" : [19, 20, 21, 22, 25, 26, 28, 29, 31, 32, 34, 35, 36]
544
+ },
545
+ }
546
+ """
547
+
548
+
549
+
src/datasets/nuscenes_dataset.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nuscenes.nuscenes import NuScenes
2
+ from nuscenes.utils.geometry_utils import view_points
3
+
4
+ import numpy as np
5
+ from pyquaternion import Quaternion
6
+ import os
7
+ from typing import Tuple
8
+ from nuscenes.utils.splits import create_splits_scenes
9
+ from shapely.geometry import MultiPoint, box
10
+ from typing import List, Tuple, Union
11
+ import json
12
+ from tqdm import tqdm
13
+
14
+ from src.datasets.base_dataset import BaseDataset
15
+
16
+
17
+ # "Singleton" that holds the data so we only have to load once for training & validation
18
+ nusc_data = None
19
+
20
+ class NuScenesDataset(BaseDataset):
21
+
22
+ CLASS_NAME_TO_ID = {
23
+ "animal": 1,
24
+ "human.pedestrian.adult": 1,
25
+ "human.pedestrian.child": 1,
26
+ "human.pedestrian.construction_worker": 1,
27
+ "human.pedestrian.personal_mobility": 1,
28
+ "human.pedestrian.police_officer": 1,
29
+ "human.pedestrian.stroller": 1,
30
+ "human.pedestrian.wheelchair": 1,
31
+
32
+ # "movable_object.barrier": 10,
33
+ # "movable_object.debris": 10,
34
+ # "movable_object.pushable_pullable": 10,
35
+ # "movable_object.trafficcone": 10,
36
+ # "static_object.bicycle_rack": 10,
37
+
38
+ "vehicle.bicycle": 8,
39
+
40
+ "vehicle.bus.bendy": 5,
41
+ "vehicle.bus.rigid": 5,
42
+
43
+ "vehicle.car": 3,
44
+ "vehicle.emergency.police": 3,
45
+
46
+ "vehicle.construction": 4,
47
+ "vehicle.emergency.ambulance": 4,
48
+ "vehicle.trailer": 4,
49
+ "vehicle.truck": 4,
50
+
51
+ "vehicle.motorcycle": 7,
52
+
53
+ "None": 10,
54
+ }
55
+
56
+ def __init__(self,
57
+ root='./datasets',
58
+ train=True,
59
+ clip_length=25,
60
+ orig_height=900, orig_width=1600,
61
+ resize_height=320, resize_width=512,
62
+ non_overlapping_clips=False,
63
+ bbox_masking_prob=0.0,
64
+ test_split=False,
65
+ ego_only=False):
66
+
67
+ super(NuScenesDataset, self).__init__(root=root,
68
+ train=train,
69
+ clip_length=clip_length,
70
+ resize_height=resize_height,
71
+ resize_width=resize_width,
72
+ non_overlapping_clips=non_overlapping_clips,
73
+ bbox_masking_prob=bbox_masking_prob)
74
+
75
+ self.dataset_name = 'nuscenes'
76
+ self.train = train
77
+ self.orig_width = orig_width
78
+ self.orig_height = orig_height
79
+ self.non_overlapping_clips = non_overlapping_clips
80
+ self.bbox_image_dir = os.path.join(self.root, self.dataset_name, "bbox_images", self.data_split)
81
+ self.label_dir = os.path.join(self.root, self.dataset_name, "labels", self.data_split)
82
+ os.makedirs(self.label_dir, exist_ok=True)
83
+
84
+ self.inst_token_to_track_id = {}
85
+
86
+ split_scenes = self._load_nusc(test_split)
87
+ self._collect_clips(split_scenes)
88
+
89
+
90
+ def _load_nusc(self, test_split):
91
+ global nusc_data
92
+ if nusc_data is None:
93
+ data_split = 'v1.0-trainval' if not test_split else 'v1.0-test'
94
+ # data_split = 'v1.0-mini' # Or: 'v1.0-mini' for testing
95
+ nusc_data = NuScenes(version=data_split,
96
+ dataroot=os.path.join(self.root, self.dataset_name),
97
+ verbose=True)
98
+ self.nusc = nusc_data
99
+
100
+ dataset_split = 'train' if self.train else 'val'
101
+ if test_split:
102
+ dataset_split = 'test'
103
+
104
+ split_scene_names = create_splits_scenes()[dataset_split] # [train: 700, val: 150, test: 150]
105
+ split_scenes = [scene for scene in nusc_data.scene if scene['name'] in split_scene_names]
106
+
107
+ return split_scenes
108
+
109
+
110
+ def _collect_clips(self, split_scenes):
111
+ image_indices_by_scene = {}
112
+
113
+ def collect_frame(scene_idx, sample_data):
114
+ # Get image
115
+ image_path = os.path.join(self.root, self.dataset_name, sample_data['filename'])
116
+ self.image_files.append(image_path)
117
+ if image_indices_by_scene.get(scene_idx) is None:
118
+ image_indices_by_scene[scene_idx] = []
119
+ image_indices_by_scene[scene_idx].append(len(self.image_files) - 1)
120
+
121
+ # Parse label
122
+ labels = self._parse_label(sample_data["token"])
123
+ self.frame_labels.append(labels)
124
+
125
+ # Interpolating annotations to increase the frame rate (nuscenes annotation fps=2Hz, video data fps=12Hz)
126
+ self.fps = 7
127
+ target_period = 1/self.fps # For fps downsampling
128
+ max_frames_per_scene = 75
129
+ print("Collecting nuscenes clips...")
130
+ for scene_i, scene in enumerate(split_scenes):
131
+
132
+ curr_data_token = self.nusc.get('sample', scene['first_sample_token'])['data']["CAM_FRONT"]
133
+ curr_sample_data = self.nusc.get('sample_data', curr_data_token)
134
+ collect_frame(scene_i, curr_sample_data)
135
+
136
+ cumul_delta = 0
137
+ total_delta = 0
138
+ t = 0
139
+ while curr_data_token:
140
+ curr_sample_data = self.nusc.get('sample_data', curr_data_token)
141
+
142
+ next_sample_data_token = curr_sample_data['next']
143
+ if not next_sample_data_token:
144
+ break
145
+ next_sample_data = self.nusc.get('sample_data', next_sample_data_token)
146
+
147
+ # FPS downsampling: only select certain frames based on elapsed times
148
+ delta = (next_sample_data['timestamp'] - curr_sample_data['timestamp']) / 1e6
149
+ cumul_delta += delta
150
+ total_delta += delta
151
+ if cumul_delta >= target_period:
152
+ collect_frame(scene_i, next_sample_data)
153
+ t += 1
154
+ cumul_delta = cumul_delta - target_period
155
+
156
+ curr_data_token = next_sample_data_token
157
+
158
+ if len(image_indices_by_scene[scene_i]) > max_frames_per_scene:
159
+ break
160
+
161
+ # print(f"Fps: {len(image_indices_by_scene[scene_i]) / total_delta:.4f}")
162
+
163
+ if not self.non_overlapping_clips:
164
+ for image_idx in range(len(image_indices_by_scene[scene_i]) - self.clip_length + 1):
165
+ self.clip_list.append(image_indices_by_scene[scene_i][image_idx:image_idx+self.clip_length])
166
+ else:
167
+ # In case self.clip_length << actual video sample length (~20s), we can create multiple non-overlapping clips for each sample
168
+ total_frames = len(image_indices_by_scene[scene_i])
169
+ for clip_i in range(total_frames // self.clip_length):
170
+ start_image_idx = clip_i * self.clip_length
171
+ self.clip_list.append(image_indices_by_scene[scene_i][start_image_idx:start_image_idx+self.clip_length])
172
+
173
+ print("Number of nuScenes clips:", len(self.clip_list), f"({'train' if self.train else 'val'})")
174
+
175
+
176
+ def _parse_label(self, token):
177
+
178
+ cam_front_data = self.nusc.get('sample_data', token)
179
+
180
+ # Check cache, if it doesn't exist, then create label file
181
+ filename = cam_front_data["filename"].split('/')[-1].split('.')[0]
182
+ label_file_path = os.path.join(self.label_dir, f"{filename}.json")
183
+ if os.path.exists(label_file_path):
184
+ with open(label_file_path, 'r') as json_file:
185
+ object_labels = json.load(json_file)
186
+
187
+ return object_labels
188
+ else:
189
+ front_camera_sensor = self.nusc.get('calibrated_sensor', cam_front_data['calibrated_sensor_token'])
190
+ camera_intrinsic = np.array(front_camera_sensor['camera_intrinsic'])
191
+ ego_pose = self.nusc.get('ego_pose', cam_front_data['ego_pose_token'])
192
+
193
+ object_labels = []
194
+ bbox_center_by_track_id = {}
195
+ for bbox_3d in self.nusc.get_boxes(token):
196
+
197
+ class_name = bbox_3d.name
198
+ if class_name not in NuScenesDataset.CLASS_NAME_TO_ID:
199
+ continue
200
+ class_id = NuScenesDataset.CLASS_NAME_TO_ID[class_name]
201
+
202
+ instance_token = self.nusc.get('sample_annotation', bbox_3d.token)['instance_token']
203
+ if instance_token not in self.inst_token_to_track_id:
204
+ self.inst_token_to_track_id[instance_token] = len(self.inst_token_to_track_id)
205
+
206
+ # Project 3D bboxes to 2D
207
+ # (Code adapted from: https://github.com/nutonomy/nuscenes-devkit/blob/master/python-sdk/nuscenes/scripts/export_2d_annotations_as_json.py)
208
+
209
+ # Move them to the ego-pose frame.
210
+ bbox_3d.translate(-np.array(ego_pose['translation']))
211
+ bbox_3d.rotate(Quaternion(ego_pose['rotation']).inverse)
212
+
213
+ # Move them to the calibrated sensor frame.
214
+ bbox_3d.translate(-np.array(front_camera_sensor['translation']))
215
+ bbox_3d.rotate(Quaternion(front_camera_sensor['rotation']).inverse)
216
+
217
+ # Filter out the corners that are not in front of the calibrated sensor.
218
+ corners_3d = bbox_3d.corners()
219
+ in_front = np.argwhere(corners_3d[2, :] > 0).flatten()
220
+ corners_3d = corners_3d[:, in_front]
221
+
222
+ # Project 3d box to 2d.
223
+ corner_coords = view_points(corners_3d, camera_intrinsic, True).T[:, :2].tolist()
224
+
225
+ # Keep only corners that fall within the image.
226
+ final_coords = self._post_process_coords(corner_coords)
227
+
228
+ # Skip if the convex hull of the re-projected corners does not intersect the image canvas.
229
+ if final_coords is None:
230
+ continue
231
+
232
+ min_x, min_y, max_x, max_y = final_coords
233
+ track_id = self.inst_token_to_track_id[instance_token]
234
+
235
+ bbox_center_by_track_id[track_id] = bbox_3d.center
236
+
237
+ obj_label = {
238
+ 'frame_name': cam_front_data["filename"],
239
+ 'track_id': track_id,
240
+ 'bbox': [min_x, min_y, max_x, max_y],
241
+ 'class_id': class_id,
242
+ 'class_name': class_name,
243
+ }
244
+
245
+ object_labels.append(obj_label)
246
+
247
+ # Render the furthest bboxes first (closer ones should be on top)
248
+ object_labels.sort(key=lambda label: np.linalg.norm(bbox_center_by_track_id[label["track_id"]]), reverse=True)
249
+
250
+ # Cache file
251
+ with open(label_file_path, 'w') as json_file:
252
+ json.dump(object_labels, json_file)
253
+ print("Cached labels:", label_file_path)
254
+
255
+ return object_labels
256
+
257
+
258
+ def _post_process_coords(self, corner_coords: List) -> Union[Tuple[float, float, float, float], None]:
259
+ """
260
+ Get the intersection of the convex hull of the reprojected bbox corners and the image canvas, return None if no intersection.
261
+ :param corner_coords: Corner coordinates of reprojected bounding box.
262
+ :param imsize: Size of the image canvas.
263
+ :return: Intersection of the convex hull of the 2D box corners and the image canvas.
264
+ """
265
+ imsize = (self.orig_width, self.orig_height)
266
+
267
+ polygon_from_2d_box = MultiPoint(corner_coords).convex_hull
268
+ img_canvas = box(0, 0, imsize[0], imsize[1])
269
+
270
+ if polygon_from_2d_box.intersects(img_canvas):
271
+ img_intersection = polygon_from_2d_box.intersection(img_canvas)
272
+ intersection_coords = np.array([coord for coord in img_intersection.exterior.coords])
273
+
274
+ min_x = min(intersection_coords[:, 0])
275
+ min_y = min(intersection_coords[:, 1])
276
+ max_x = max(intersection_coords[:, 0])
277
+ max_y = max(intersection_coords[:, 1])
278
+
279
+ return min_x, min_y, max_x, max_y
280
+ else:
281
+ return None
282
+
283
+
284
+ def pre_cache_dataset(dataset_root):
285
+ # Trigger label and bbox image cache generation
286
+ dataset_val = NuScenesDataset(root=dataset_root, train=False, clip_length=25, non_overlapping_clips=True)
287
+ for i in tqdm(range(len(dataset_val))):
288
+ d = dataset_val[i]
289
+
290
+ dataset_train = NuScenesDataset(root=dataset_root, train=True, clip_length=25, non_overlapping_clips=True)
291
+ for i in tqdm(range(len(dataset_train))):
292
+ d = dataset_train[i]
293
+
294
+ print("Done.")
295
+
296
+ if __name__ == "__main__":
297
+ dataset_root = "/path/to/Datasets"
298
+ pre_cache_dataset(dataset_root)
src/datasets/russia_crash_dataset.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from src.datasets.base_dataset import BaseDataset
5
+
6
+
7
+ class RussiaCrashDataset(BaseDataset):
8
+
9
+ CLASS_NAME_TO_ID = {
10
+ 'person': 1,
11
+ 'car': 3,
12
+ 'truck': 4,
13
+ 'bus': 5,
14
+ 'train': 6,
15
+ 'motorcycle': 7,
16
+ 'bicycle': 8,
17
+ }
18
+
19
+ def __init__(self,
20
+ root='./datasets',
21
+ train=True,
22
+ clip_length=25,
23
+ orig_height=555, orig_width=986,
24
+ resize_height=320, resize_width=512,
25
+ non_overlapping_clips=False,
26
+ bbox_masking_prob=0.0,
27
+ sample_clip_from_end=True,
28
+ ego_only=False,
29
+ specific_samples=None):
30
+
31
+ super(RussiaCrashDataset, self).__init__(root=root,
32
+ train=train,
33
+ clip_length=clip_length,
34
+ resize_height=resize_height,
35
+ resize_width=resize_width,
36
+ non_overlapping_clips=non_overlapping_clips,
37
+ bbox_masking_prob=bbox_masking_prob,
38
+ sample_clip_from_end=sample_clip_from_end,
39
+ ego_only=ego_only)
40
+
41
+ self.dataset_name = "preprocess_russia_crash"
42
+
43
+ self.orig_width = orig_width
44
+ self.orig_height = orig_height
45
+ self.image_dir = os.path.join(self.root, self.dataset_name, "images", self.data_split)
46
+ self.label_dir = os.path.join(self.root, self.dataset_name, "labels", self.data_split)
47
+ self.bbox_image_dir = os.path.join(self.root, self.dataset_name, "bbox_images", self.data_split)
48
+
49
+ self.specific_samples = specific_samples
50
+
51
+ self._collect_clips()
52
+
53
+ def _collect_clips(self):
54
+ image_indices_by_clip = {}
55
+ for label_file in sorted(os.listdir(self.label_dir)):
56
+ if not label_file.endswith('.json'):
57
+ continue
58
+
59
+ full_filename = os.path.join(self.label_dir, label_file)
60
+ with open(full_filename) as json_file:
61
+ all_data = json.load(json_file)
62
+ metadata = all_data['metadata']
63
+
64
+ # Only include dashcam samples
65
+ if metadata['camera'] != "Dashcam":
66
+ continue
67
+ # Exclude animal and "other" accidents
68
+ if metadata['accident_type'] == "Risk of collision/collision with an animal":
69
+ continue
70
+ if metadata['accident_type'] == 'Other types of traffic accidents':
71
+ continue
72
+ # NOTE uncomment to only include actual car collision (no close misses and dangerous events)
73
+ # if metadata['collision_type'] == "No Collision":
74
+ # continue
75
+ if self.ego_only:
76
+ print("Ego collisions only activated!")
77
+ if metadata['collision_type'] == "No Collision" or metadata["ego_car_involved"] != "Yes":
78
+ continue
79
+
80
+ clip_filename = label_file.split('.')[0]
81
+ clip_file = os.path.join(self.image_dir, clip_filename)
82
+
83
+ if self.specific_samples is not None and clip_filename not in self.specific_samples:
84
+ continue
85
+
86
+ if len(os.listdir(clip_file)) < self.clip_length:
87
+ # print(f"{clip_filename} does not have enough frames: has {len(os.listdir(clip_file))} expected at least {self.clip_length}")
88
+ continue
89
+
90
+ clip_label_data = self._parse_clip_labels(all_data["data"])
91
+ self.frame_labels.extend(clip_label_data) # In this case labels are already sorted so they will match up to the image indices
92
+
93
+ image_indices_by_clip[clip_filename] = []
94
+ for image_file in sorted(os.listdir(clip_file)):
95
+ self.image_files.append(os.path.join(clip_file, image_file))
96
+ image_indices_by_clip[clip_filename].append(len(self.image_files)-1)
97
+
98
+ assert len(self.frame_labels) == len(self.image_files) # We assume a one-to-one association between images and labels
99
+
100
+ # Cut the videos in clips of the correct length according to the strategies chosen
101
+ if not self.non_overlapping_clips:
102
+ for image_idx in range(len(image_indices_by_clip[clip_filename]) - self.clip_length + 1):
103
+ self.clip_list.append(image_indices_by_clip[clip_filename][image_idx:image_idx+self.clip_length])
104
+ else:
105
+ if self.sample_clip_from_end:
106
+ # In case self.clip_length << actual video sample length, we can create multiple non-overlapping clips for each sample
107
+ # Prioritize selecting clips from the end, to make sur the accident is included (which tends to be at the end of the videos)
108
+ total_frames = len(image_indices_by_clip[clip_filename])
109
+ for clip_i in range(total_frames // self.clip_length):
110
+ start_image_idx = total_frames - (self.clip_length * (clip_i + 1))
111
+ end_image_idx = total_frames - (self.clip_length * clip_i)
112
+ self.clip_list.append(image_indices_by_clip[clip_filename][start_image_idx:end_image_idx])
113
+ else:
114
+ total_frames = len(image_indices_by_clip[clip_filename])
115
+ for clip_i in range(total_frames // self.clip_length):
116
+ start_image_idx = clip_i * self.clip_length
117
+ end_image_idx = start_image_idx + self.clip_length
118
+ self.clip_list.append(image_indices_by_clip[clip_filename][start_image_idx:end_image_idx])
119
+
120
+ print("Number of clips Russia_crash:", len(self.clip_list), f"({self.data_split})")
121
+
122
+ def _parse_clip_labels(self, clip_data):
123
+ frame_labels = []
124
+ for frame_data in clip_data:
125
+ obj_data = frame_data['labels']
126
+
127
+ object_labels = []
128
+ for label in obj_data:
129
+ # Only keep the classes of interest
130
+ class_id = RussiaCrashDataset.CLASS_NAME_TO_ID.get(label['name'])
131
+ if class_id is None:
132
+ continue
133
+
134
+ # Convert bbox coordinates to pixel space wrt to image size
135
+ bbox = label['box']
136
+ bbox_coords_pixel = [int(bbox[0] * self.orig_width), # x1
137
+ int(bbox[1] * self.orig_height), # y1
138
+ int(bbox[2] * self.orig_width), # x2
139
+ int(bbox[3] * self.orig_height)] # y2
140
+
141
+ object_labels.append({
142
+ 'frame_name': frame_data["image_source"],
143
+ 'track_id': int(label['track_id']),
144
+ 'bbox': bbox_coords_pixel,
145
+ 'class_id': class_id,
146
+ 'class_name': label['name'], # Class name of the object
147
+ })
148
+
149
+ frame_labels.append(object_labels)
150
+
151
+ return frame_labels
152
+
153
+
154
+ def pre_cache_dataset(dataset_root):
155
+ # Trigger label and bbox image cache generation
156
+ dataset_val = RussiaCrashDataset(root=dataset_root, train=False, clip_length=25, non_overlapping_clips=True)
157
+ for i in tqdm(range(len(dataset_val))):
158
+ d = dataset_val[i]
159
+
160
+ dataset_train = RussiaCrashDataset(root=dataset_root, train=True, clip_length=25, non_overlapping_clips=True)
161
+ for i in tqdm(range(len(dataset_train))):
162
+ d = dataset_train[i]
163
+
164
+ print("Done.")
165
+
166
+ if __name__ == "__main__":
167
+ from tqdm import tqdm
168
+
169
+ dataset_root = "/path/to/Datasets"
170
+ pre_cache_dataset(dataset_root)
171
+
172
+
173
+
src/eval/README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Video Quality Evaluation Tools
2
+
3
+ This directory contains scripts for evaluating video quality metrics between generated and ground truth videos. There are four main evaluation scripts:
4
+
5
+ 1. `video_quality_metrics_fvd_pair.py`: Evaluates FVD (Fréchet Video Distance) between paired generated and ground truth videos
6
+ 2. `video_quality_metrics_fvd_gt_rand.py`: Evaluates FVD using pre-computed ground truth statistics
7
+ 3. `video_quality_metrics_jedi_pair.py`: Evaluates JEDi metric between paired generated and ground truth videos
8
+ 4. `video_quality_metrics_jedi_gt_rand.py`: Evaluates JEDi metric using random ground truth samples
9
+
10
+ ## Video Generation
11
+
12
+ Before running the evaluation scripts, you'll need to generate video samples using the `run_gen_videos.py` script:
13
+
14
+ ```bash
15
+ python run_gen_videos.py \
16
+ --model_path /path/to/model/checkpoint \
17
+ --output_path /path/to/output/videos \
18
+ --data_root /path/to/dataset_root \
19
+ --num_demo_samples 10 \
20
+ --max_output_vids 200 \
21
+ --num_gens_per_sample 1 \
22
+ --eval_output
23
+ ```
24
+
25
+ ### Key Generation Arguments
26
+
27
+ ```bash
28
+ --model_path PATH # Path to model checkpoint (required)
29
+ --data_root PATH # Dataset root path
30
+ --output_path PATH # Where to save generated videos
31
+ --num_demo_samples N # Number of samples to collect for generation
32
+ --max_output_vids N # Maximum number of videos to generate
33
+ --num_gens_per_sample N # Videos to generate per test case
34
+
35
+ # Optional arguments for controlling generation
36
+ --bbox_mask_idx_batch N1 N2 ... # Where to start masking (0-25)
37
+ --force_action_type_batch N1 N2 ... # Force specific action types (0-4)
38
+ --guidance_scales N1 N2 ... # Guidance scales to use
39
+ --seed N # Random seed for reproducibility
40
+ --disable_null_model # Disable null model for unconditional noise
41
+ --use_factor_guidance # Use factor guidance during generation
42
+ --eval_output # Enable evaluation output
43
+ ```
44
+
45
+ ### Action Types
46
+ - 0: Normal driving
47
+ - 1-4: Different types of crash scenarios
48
+
49
+ ## Common Arguments for Evaluation
50
+
51
+ All evaluation scripts share some common command line arguments:
52
+
53
+ ```bash
54
+ --vid_root PATH # Root directory containing generated videos (required)
55
+ --samples N # Number of samples to evaluate (default: 200)
56
+ --num_frames N # Number of frames per video (default: 25)
57
+ --downsample_int N # Downsample interval for frames (default: 1)
58
+ --action_type N # Action type to filter videos (0: normal, 1-4: crash types)
59
+ --shuffle # Shuffle videos before evaluation
60
+ ```
61
+
62
+ ## FVD Evaluation
63
+
64
+ ### Paired Evaluation
65
+ ```bash
66
+ python video_quality_metrics_fvd_pair.py \
67
+ --vid_root /path/to/videos \
68
+ --samples 200 \
69
+ --num_frames 25 \
70
+ --downsample
71
+ ```
72
+
73
+ ### Ground Truth Statistics Evaluation
74
+ ```bash
75
+ # First, collect ground truth statistics
76
+ python video_quality_metrics_fvd_gt_rand.py \
77
+ --vid_root /path/to/videos \
78
+ --collect_stats \
79
+ --samples 500 \
80
+ --action_type 1
81
+
82
+ # Then evaluate using the collected statistics
83
+ python video_quality_metrics_fvd_gt_rand.py \
84
+ --vid_root /path/to/videos \
85
+ --gt_stats /path/to/stats.npz \
86
+ --samples 200 \
87
+ --shuffle
88
+ ```
89
+
90
+ ## JEDi Evaluation
91
+
92
+ ### Paired Evaluation
93
+ ```bash
94
+ python video_quality_metrics_jedi_pair.py \
95
+ --vid_root /path/to/videos \
96
+ --samples 200 \
97
+ --num_frames 25 \
98
+ --test_feature_path /path/to/features
99
+ ```
100
+
101
+ ### Ground Truth Random Evaluation
102
+ ```bash
103
+ python video_quality_metrics_jedi_gt_rand.py \
104
+ --vid_root /path/to/videos \
105
+ --samples 200 \
106
+ --gt_samples 500 \
107
+ --test_feature_path /path/to/features \
108
+ --action_type 1 \
109
+ --shuffle
110
+ ```
111
+
112
+ ## Additional Notes
113
+
114
+ - The `--action_type` argument can be used to filter videos by category:
115
+ - 0: Normal driving videos
116
+ - 1-4: Different types of crash videos
117
+ - For FVD evaluation with ground truth statistics, you can collect statistics once and reuse them for multiple evaluations
118
+ - The JEDi metric requires a test feature path for model loading
119
+ - All scripts support shuffling of videos before evaluation for more robust results
120
+ - The default resolution for videos is 320x512 pixels
src/eval/__pycache__/generate_samples.cpython-310.pyc ADDED
Binary file (9.13 kB). View file
 
src/eval/generate_samples.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageDraw
3
+ import cv2
4
+ from tqdm import tqdm
5
+ import json
6
+ import argparse
7
+
8
+ import warnings
9
+ import numpy as np
10
+ import torch
11
+ torch.cuda.empty_cache()
12
+ import torch.utils.checkpoint
13
+ from accelerate.utils import set_seed
14
+
15
+ with warnings.catch_warnings():
16
+ warnings.simplefilter("ignore")
17
+ from src.pipelines import StableVideoControlPipeline
18
+ from src.pipelines import StableVideoControlNullModelPipeline
19
+ from src.pipelines import StableVideoControlFactorGuidancePipeline
20
+
21
+ from src.models import UNetSpatioTemporalConditionModel, ControlNetModel
22
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
23
+ from diffusers.models import AutoencoderKLTemporalDecoder
24
+
25
+ from src.datasets.dataset_utils import get_dataloader
26
+ from src.utils import get_samples
27
+
28
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
+ print(f"Device: {device}")
30
+
31
+ generator = None #torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
32
+ CLIP_LENGTH = 25
33
+
34
+
35
+ def create_video_from_np(sample, video_path, fps=6):
36
+ video_filename = f"{video_path}.mp4"
37
+ frame_size = (512, 320)
38
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
39
+ video_writer_out = cv2.VideoWriter(video_filename, fourcc, fps, frame_size)
40
+
41
+ for img in sample:
42
+ img = np.transpose(img, (1, 2, 0))
43
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
44
+ video_writer_out.write(img)
45
+
46
+ video_writer_out.release()
47
+ print(f"Video saved: {video_filename}")
48
+
49
+
50
+ def export_to_video(video_frames, output_video_path=None, fps=6):
51
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
52
+ h, w, c = video_frames[0].shape
53
+ video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
54
+ for i in range(len(video_frames)):
55
+ img = cv2.cvtColor(video_frames[i].astype(np.uint8), cv2.COLOR_RGB2BGR)
56
+ video_writer.write(img)
57
+ return output_video_path
58
+
59
+
60
+ def label_frames_with_action_id(bbox_frames, action_id, masked_idx=None):
61
+ action_name = {0: "Normal", 1: "Ego", 2: "Ego/Veh", 3: "Veh", 4: "Veh/Veh"}
62
+ action_text = f"Action: {action_name[action_id]} ({action_id})"
63
+ for i in range(bbox_frames.shape[0]):
64
+ # Convert numpy array to PIL Image
65
+ frame = Image.fromarray(bbox_frames[i].transpose(1, 2, 0))
66
+ draw = ImageDraw.Draw(frame)
67
+
68
+ # Add text in top right corner
69
+ text_position = (frame.width - 10, 10) # 10 pixels from top, 10 pixels from right
70
+ if masked_idx is not None and masked_idx <= i:
71
+ text_color = (0, 0, 0)
72
+ action_text = f"Action: {action_name[action_id]} ({action_id}) [masked]"
73
+ else:
74
+ text_color = (255, 255, 255)
75
+ action_text = action_text
76
+
77
+ draw.text(text_position, action_text, fill=text_color, anchor="ra")
78
+
79
+ # Convert back to numpy array
80
+ bbox_frames[i] = np.array(frame).transpose(2, 0, 1)
81
+
82
+ return bbox_frames
83
+
84
+ def load_ctrlv_pipelines(model_dir, use_null_model=False, use_factor_guidance=False):
85
+ unet_variant = "fp16" if "stabilityai" in model_dir else None
86
+
87
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
88
+ model_dir,
89
+ subfolder="unet",
90
+ variant=unet_variant,
91
+ low_cpu_mem_usage=True,
92
+ num_frames=CLIP_LENGTH
93
+ )
94
+ ctrlnet = ControlNetModel.from_pretrained(
95
+ model_dir,
96
+ subfolder="control_net",
97
+ variant=unet_variant,
98
+ num_frames=25
99
+ )
100
+
101
+ if not use_null_model and not use_factor_guidance:
102
+ pipeline = StableVideoControlPipeline.from_pretrained(
103
+ "stabilityai/stable-video-diffusion-img2vid-xt",
104
+ controlnet=ctrlnet,
105
+ unet=unet,
106
+ variant=unet_variant
107
+ )
108
+
109
+ else:
110
+
111
+ # For null model prediction of uncond noise
112
+ null_model_path = "stabilityai/stable-video-diffusion-img2vid-xt"
113
+ null_model_unet = UNetSpatioTemporalConditionModel.from_pretrained(
114
+ null_model_path,
115
+ subfolder="unet",
116
+ variant=None,
117
+ low_cpu_mem_usage=True,
118
+ num_frames=CLIP_LENGTH
119
+ )
120
+
121
+ if use_null_model and not use_factor_guidance:
122
+ pipeline = StableVideoControlNullModelPipeline.from_pretrained(
123
+ "stabilityai/stable-video-diffusion-img2vid-xt",
124
+ controlnet=ctrlnet,
125
+ unet=unet,
126
+ null_model=null_model_unet,
127
+ variant=unet_variant
128
+ )
129
+ elif use_factor_guidance:
130
+ pipeline = StableVideoControlFactorGuidancePipeline.from_pretrained(
131
+ "stabilityai/stable-video-diffusion-img2vid-xt",
132
+ controlnet=ctrlnet,
133
+ unet=unet,
134
+ null_model=null_model_unet,
135
+ variant=unet_variant
136
+ )
137
+
138
+ pipeline = pipeline.to(device)
139
+ pipeline.set_progress_bar_config(disable=True)
140
+
141
+ unet.eval()
142
+ ctrlnet.eval()
143
+
144
+ return pipeline
145
+
146
+
147
+ def generate_video_ctrlv(sample, pipeline, video_path="video_out/genvid", json_path="video_out/gt_frames", bbox_mask_frames=None, action_type=None, use_factor_guidance=False, guidance=[1.0, 3.0], video_path2=None):
148
+ frame_size = (512, 320)
149
+ FPS = 6
150
+ CLIP_LENGTH = sample['bbox_images'].shape[0]
151
+
152
+ init_image = sample['image_init']
153
+ bbox_images = sample['bbox_images'].unsqueeze(0)
154
+ action_type = sample['action_type'].unsqueeze(0) if action_type is None else action_type
155
+
156
+ sample['bbox_images'].to(device)
157
+
158
+ # Save GT frame paths to json file
159
+ gt_frame_paths = [file_path[0] for file_path in sample['image_paths']]
160
+ with open(json_path, "w") as file:
161
+ json.dump(gt_frame_paths, file, indent=1)
162
+ print("Saved GT frames json file:", json_path)
163
+
164
+ if not use_factor_guidance:
165
+ frames = pipeline(init_image,
166
+ cond_images=bbox_images,
167
+ bbox_mask_frames=bbox_mask_frames,
168
+ action_type=action_type,
169
+ height=frame_size[1], width=frame_size[0],
170
+ decode_chunk_size=8, motion_bucket_id=127, fps=FPS,
171
+ num_inference_steps=30,
172
+ num_frames=CLIP_LENGTH,
173
+ control_condition_scale=1.0,
174
+ min_guidance_scale=guidance[0],
175
+ max_guidance_scale=guidance[1],
176
+ noise_aug_strength=0.01,
177
+ generator=generator, output_type='pt').frames[0]
178
+ else:
179
+ frames = pipeline(init_image,
180
+ cond_images=bbox_images,
181
+ bbox_mask_frames=bbox_mask_frames,
182
+ action_type=action_type,
183
+ height=frame_size[1], width=frame_size[0],
184
+ decode_chunk_size=8, motion_bucket_id=127, fps=FPS,
185
+ num_inference_steps=30,
186
+ num_frames=CLIP_LENGTH,
187
+ control_condition_scale=1.0,
188
+ min_guidance_scale_img=1.0,
189
+ max_guidance_scale_img=3.0,
190
+ min_guidance_scale_action=6.0,
191
+ max_guidance_scale_action=12.0,
192
+ min_guidance_scale_bbox=1.0,
193
+ max_guidance_scale_bbox=3.0,
194
+ noise_aug_strength=0.01,
195
+ generator=generator, output_type='pt').frames[0]
196
+
197
+ frames = frames.detach().cpu().numpy()*255
198
+ frames = frames.astype(np.uint8)
199
+
200
+ tmp = np.moveaxis(np.transpose(frames, (0, 2, 3, 1)), 0, 0)
201
+ output_video_path = f"{video_path}.mp4"
202
+ export_to_video(tmp, output_video_path, fps=FPS)
203
+ print(f"Video saved:", output_video_path)
204
+
205
+ if video_path2 is not None:
206
+ output_video_path2 = f"{video_path2}.mp4"
207
+ export_to_video(tmp, output_video_path2, fps=FPS)
208
+
209
+
210
+ def generate_samples(args):
211
+ model_path = args.model_path
212
+ print("Model path:", model_path)
213
+
214
+ if args.seed is not None:
215
+ set_seed(args.seed)
216
+ print("Set seed:", args.seed)
217
+
218
+ # LOAD PIPELINE
219
+ use_null_model = not args.disable_null_model
220
+ use_factor_guidance = args.use_factor_guidance
221
+ pipeline = load_ctrlv_pipelines(model_path, use_null_model=use_null_model, use_factor_guidance=use_factor_guidance)
222
+
223
+ # LOAD DATASET
224
+ data_root = args.data_root
225
+ dataset_name = args.dataset
226
+ train_set = False
227
+ val_dataset, val_loader = get_dataloader(
228
+ data_root, dataset_name, if_train=train_set, clip_length=CLIP_LENGTH,
229
+ batch_size=1, num_workers=0, shuffle=True,
230
+ image_height=320, image_width=512,
231
+ non_overlapping_clips=True, #specific_samples=specific_samples
232
+ )
233
+ if train_set:
234
+ print("WARNING: Currently using training split")
235
+
236
+ # COLLECT SAMPLES
237
+ num_demo_samples = args.num_demo_samples
238
+ demo_samples = get_samples(val_loader, num_demo_samples, show_progress=True)
239
+
240
+ sample_range = range(0, num_demo_samples)
241
+ num_samples = len(sample_range)
242
+
243
+ # video_dir_path = os.path.join(os.getcwd(), "video_out", "video_out_box2video_may1_eval_test")
244
+ video_dir_path = args.output_path
245
+ os.makedirs(video_dir_path, exist_ok=True)
246
+ video_counter = 0
247
+
248
+ # GENERATION PARAMETERS
249
+
250
+ # Set the bbox masking
251
+ bbox_mask_idx_batch = args.bbox_mask_idx_batch
252
+ condition_on_last_bbox = False
253
+
254
+ # Set the action type
255
+ force_action_type = None #1 # 0: Normal, 1: Ego, 2: Ego/Veh, 3: Veh, 4: Veh/Veh
256
+ force_action_type_batch = args.force_action_type_batch
257
+
258
+ num_gens_per_sample = args.num_gens_per_sample
259
+ guidance_scales = args.guidance_scales
260
+ eval_output = args.eval_output
261
+
262
+ # GENERATE VIDEOS
263
+
264
+ # Check for samples that were already done and do not compute them again
265
+ skip_samples = {}
266
+ out_video_path = f"{video_dir_path}/gt_ref"
267
+ if os.path.exists(out_video_path):
268
+ all_videos = os.listdir(out_video_path)
269
+ video_counter = len(all_videos)
270
+ for sample_name in all_videos:
271
+ vid_name = "_".join(sample_name.split("_")[1:])
272
+ skip_samples[vid_name] = True
273
+
274
+ print("SKIP SAMPLES:", skip_samples)
275
+
276
+ for guidance in guidance_scales or [-1]:
277
+
278
+ if guidance != -1:
279
+ print("Guidance:", force_action_type)
280
+ else:
281
+ guidance = [1, 3]
282
+
283
+ for _ in range(num_gens_per_sample):
284
+ for force_action_type in force_action_type_batch or [-1]:
285
+
286
+ if force_action_type != -1:
287
+ print("Force action type:", force_action_type)
288
+ else:
289
+ force_action_type = None
290
+
291
+ for bbox_mask_idx in bbox_mask_idx_batch or [-1]:
292
+
293
+ if bbox_mask_idx != -1:
294
+ print("Bbox masking:", bbox_mask_idx)
295
+ else:
296
+ bbox_mask_idx = None
297
+
298
+ for i, sample in tqdm(enumerate(demo_samples), desc="Generating samples", total=num_samples):
299
+ if i >= list(sample_range)[-1] + 1:
300
+ break
301
+ if i not in sample_range:
302
+ continue
303
+
304
+ if video_counter > args.max_output_vids:
305
+ print(f"MAX OUTPUT VIDS REACHED: {video_counter} >= {args.max_output_vids}")
306
+ exit()
307
+
308
+ vid_name = sample["vid_name"]
309
+
310
+ mask_hint = "" if bbox_mask_idx is None else f"_bframes:{str(bbox_mask_idx)}"
311
+ action_hint = "" if force_action_type is None else f"_action:{str(force_action_type)}"
312
+ guidance_hint = "" if guidance_scales is None else f"_guide{guidance[0]}:{guidance[1]}"
313
+ scene_name = f"{video_counter}_{vid_name}{mask_hint}{action_hint}{guidance_hint}"
314
+
315
+ scene_name_no_counter = "_".join(scene_name.split("_")[1:])
316
+ if scene_name_no_counter in skip_samples:
317
+ print(f"Skipping sample that was already computed: {vid_name}")
318
+ continue
319
+
320
+ print("Generating video for:", scene_name)
321
+
322
+ if eval_output:
323
+ os.makedirs(f"{video_dir_path}/gen_videos", exist_ok=True)
324
+ os.makedirs(f"{video_dir_path}/gt_frames", exist_ok=True)
325
+ os.makedirs(f"{video_dir_path}/gt_ref", exist_ok=True)
326
+
327
+ gt_vid_path = f"{video_dir_path}/gt_ref/{scene_name}/(1)gt_video_{scene_name}"
328
+ bbox_out_path_root = f"{video_dir_path}/gt_ref/{scene_name}"
329
+ out_video_path = f"{video_dir_path}/gen_videos/genvid_{video_counter}_{vid_name}"
330
+ out_json_path = os.path.join(video_dir_path, "gt_frames", f"gt_frames_{video_counter}_{vid_name}.json")
331
+
332
+ out_video_path2 = f"{bbox_out_path_root}/(3)genvid_adv_{scene_name}"
333
+
334
+ os.makedirs(bbox_out_path_root, exist_ok=True)
335
+ else:
336
+ os.makedirs(f"{video_dir_path}/{scene_name}", exist_ok=True)
337
+
338
+ gt_vid_path = f"{video_dir_path}/{scene_name}/(1)gt_video_{scene_name}"
339
+ bbox_out_path_root = f"{video_dir_path}/{scene_name}"
340
+ out_video_path = f"{video_dir_path}/{scene_name}/(3)genvid_adv_{scene_name}"
341
+ out_json_path = os.path.join(video_dir_path, scene_name, f"gt_frames_{sample['vid_name']}.json")
342
+ out_video_path2 = None
343
+
344
+ create_video_from_np(sample['gt_clip_np'], video_path=gt_vid_path)
345
+
346
+ # Add action type text to ground truth bounding box frames # TODO: Make sure the action type aligns if we change it for generation
347
+ action_type = sample['action_type'].unsqueeze(0)
348
+ og_action_type = action_type.item()
349
+ if force_action_type is not None:
350
+ action_type = torch.ones_like(action_type) * force_action_type
351
+
352
+ action_id = action_type.item()
353
+ bbox_frames = sample['bbox_images_np'].copy()
354
+ if bbox_mask_idx is not None:
355
+ # print(f"Masking bboxes after index {bbox_mask_idx}")
356
+
357
+ # Let's save a copy of the original bboxes for reference
358
+ bbox_frames_ref = sample['bbox_images_np'].copy()
359
+ label_frames_with_action_id(bbox_frames_ref, og_action_type)
360
+ create_video_from_np(bbox_frames_ref, video_path=f"{bbox_out_path_root}/(2)video_2dbboxes_{scene_name}_nomask")
361
+
362
+ # For display, let's mask with white
363
+ mask_cond = bbox_mask_idx <= np.arange(CLIP_LENGTH).reshape(CLIP_LENGTH, 1, 1, 1)
364
+ if condition_on_last_bbox:
365
+ mask_cond[-1, 0, 0, 0] = False
366
+ bbox_frames = np.where(mask_cond, np.ones_like(bbox_frames)*255, bbox_frames)
367
+ label_frames_with_action_id(bbox_frames, action_id, masked_idx=bbox_mask_idx)
368
+ else:
369
+ label_frames_with_action_id(bbox_frames, action_id)
370
+
371
+ create_video_from_np(bbox_frames, video_path=f"{bbox_out_path_root}/(2)video_2dbboxes_{scene_name}")
372
+
373
+ bbox_mask_frames = [False] * CLIP_LENGTH
374
+ if bbox_mask_idx is not None:
375
+ bbox_mask_frames[bbox_mask_idx:] = [True] * (len(bbox_mask_frames) - bbox_mask_idx)
376
+ if condition_on_last_bbox:
377
+ bbox_mask_frames[-1] = False
378
+
379
+ generate_video_ctrlv(
380
+ sample,
381
+ pipeline,
382
+ video_path=out_video_path,
383
+ json_path=out_json_path,
384
+ bbox_mask_frames=bbox_mask_frames,
385
+ action_type=action_type,
386
+ use_factor_guidance=use_factor_guidance,
387
+ guidance=guidance,
388
+ video_path2=out_video_path2
389
+ )
390
+
391
+ video_counter += 1
392
+
393
+ print("DONE")
394
+
src/eval/video_dataset.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ import cv2
5
+ import numpy as np
6
+
7
+ class VideoDataset(Dataset):
8
+ def __init__(self, video_root, num_frames=25, downsample_int=1, transform=None):
9
+ """
10
+ Args:
11
+ video_root (str): Directory with all the video files
12
+ num_frames (int): Number of frames to extract from each video
13
+ downsample_int (int): Interval between frames to extract
14
+ transform (callable, optional): Optional transform to be applied on frames
15
+ """
16
+ self.video_root = video_root
17
+ self.num_frames = num_frames
18
+ self.downsample_int = downsample_int
19
+ self.transform = transform
20
+
21
+ # Get list of video files
22
+ self.video_files = []
23
+ gen_videos = os.path.join(video_root, "gen_videos") if os.path.exists(os.path.join(video_root, "gen_videos")) else video_root
24
+ for fname in os.listdir(gen_videos):
25
+ if fname.endswith('.mp4'):
26
+ self.video_files.append(os.path.join(gen_videos, fname))
27
+
28
+ self.video_files.sort()
29
+
30
+ def __len__(self):
31
+ return len(self.video_files)
32
+
33
+ def get_frames_mp4(self, video_path):
34
+ """Extract frames from video file"""
35
+ cap = cv2.VideoCapture(video_path)
36
+ if not cap.isOpened():
37
+ raise ValueError(f"Could not open video file: {video_path}")
38
+
39
+ frames = []
40
+ frame_count = 0
41
+
42
+ while True:
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ break
46
+
47
+ frame = cv2.resize(frame, (512, 320))
48
+
49
+ if frame_count % self.downsample_int == 0:
50
+ frames.append(frame)
51
+
52
+ frame_count += 1
53
+
54
+ if len(frames) >= self.num_frames:
55
+ break
56
+
57
+ cap.release()
58
+
59
+ if len(frames) < self.num_frames:
60
+ # Pad with last frame if we don't have enough frames
61
+ last_frame = frames[-1] if frames else np.zeros((320, 512, 3), dtype=np.uint8)
62
+ while len(frames) < self.num_frames:
63
+ frames.append(last_frame)
64
+
65
+ return np.array(frames[:self.num_frames])
66
+
67
+ def __getitem__(self, idx):
68
+ video_path = self.video_files[idx]
69
+ frames = self.get_frames_mp4(video_path)
70
+
71
+ # Convert to torch tensor and normalize
72
+ frames = torch.from_numpy(frames).float()
73
+ frames = frames.permute(0, 3, 1, 2) # Change from (T, H, W, C) to (T, C, H, W)
74
+ frames = frames / (255/2.0) - 1.0 # Normalize to [-1, 1]
75
+
76
+ if self.transform:
77
+ frames = self.transform(frames)
78
+
79
+ return frames, []
src/eval/video_quality_metrics_fvd_gt_rand.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import scipy.linalg
4
+ from typing import Tuple
5
+ import torch.nn.functional as F
6
+ import math
7
+ import cv2
8
+ import json
9
+ import random
10
+ import os
11
+ import argparse
12
+ from tqdm import tqdm
13
+
14
+ import numpy as np
15
+ import io
16
+ import re
17
+ import requests
18
+ import html
19
+ import hashlib
20
+ import urllib
21
+ import urllib.request
22
+ import uuid
23
+
24
+ from distutils.util import strtobool
25
+ from typing import Any, List, Tuple, Union, Dict
26
+
27
+ from src.datasets.dataset_utils import get_dataloader
28
+ from src.utils import get_samples
29
+
30
+
31
+ def get_frames_from_path_list(path_list):
32
+ frames = []
33
+ for path in path_list:
34
+ img = cv2.imread(path)
35
+ img = cv2.resize(img, [512, 320])
36
+ frames.append(img)
37
+ return np.array(frames)
38
+
39
+ def get_frames_mp4(video_path: str, frame_interval: int = 1) -> None:
40
+
41
+ # Open the video file
42
+ cap = cv2.VideoCapture(video_path)
43
+ if not cap.isOpened():
44
+ raise ValueError(f"Could not open video file: {video_path}")
45
+
46
+ frame_count = 0
47
+ saved_count = 0
48
+
49
+ frames = []
50
+ while True:
51
+ ret, frame = cap.read()
52
+ if not ret:
53
+ break
54
+
55
+ frame = cv2.resize(frame, (512, 320))
56
+
57
+ # Save frame if it's the right interval
58
+ if frame_count % frame_interval == 0:
59
+ frames.append(frame)
60
+ saved_count += 1
61
+
62
+ frame_count += 1
63
+
64
+ cap.release()
65
+ return np.array(frames)
66
+
67
+
68
+ def load_json(filename):
69
+ if os.path.exists(filename):
70
+ with open(filename, "r") as f:
71
+ return json.load(f)
72
+ print(filename, "not found")
73
+ return []
74
+
75
+
76
+ def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any:
77
+ """Download the given URL and return a binary-mode file object to access the data."""
78
+ assert num_attempts >= 1
79
+
80
+ # Doesn't look like an URL scheme so interpret it as a local filename.
81
+ if not re.match('^[a-z]+://', url):
82
+ return url if return_filename else open(url, "rb")
83
+
84
+ # Handle file URLs. This code handles unusual file:// patterns that
85
+ # arise on Windows:
86
+ #
87
+ # file:///c:/foo.txt
88
+ #
89
+ # which would translate to a local '/c:/foo.txt' filename that's
90
+ # invalid. Drop the forward slash for such pathnames.
91
+ #
92
+ # If you touch this code path, you should test it on both Linux and
93
+ # Windows.
94
+ #
95
+ # Some internet resources suggest using urllib.request.url2pathname() but
96
+ # but that converts forward slashes to backslashes and this causes
97
+ # its own set of problems.
98
+ if url.startswith('file://'):
99
+ filename = urllib.parse.urlparse(url).path
100
+ if re.match(r'^/[a-zA-Z]:', filename):
101
+ filename = filename[1:]
102
+ return filename if return_filename else open(filename, "rb")
103
+
104
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
105
+
106
+ # Download.
107
+ url_name = None
108
+ url_data = None
109
+ with requests.Session() as session:
110
+ if verbose:
111
+ print("Downloading %s ..." % url, end="", flush=True)
112
+ for attempts_left in reversed(range(num_attempts)):
113
+ try:
114
+ with session.get(url) as res:
115
+ res.raise_for_status()
116
+ if len(res.content) == 0:
117
+ raise IOError("No data received")
118
+
119
+ if len(res.content) < 8192:
120
+ content_str = res.content.decode("utf-8")
121
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
122
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
123
+ if len(links) == 1:
124
+ url = requests.compat.urljoin(url, links[0])
125
+ raise IOError("Google Drive virus checker nag")
126
+ if "Google Drive - Quota exceeded" in content_str:
127
+ raise IOError("Google Drive download quota exceeded -- please try again later")
128
+
129
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
130
+ url_name = match[1] if match else url
131
+ url_data = res.content
132
+ if verbose:
133
+ print(" done")
134
+ break
135
+ except KeyboardInterrupt:
136
+ raise
137
+ except:
138
+ if not attempts_left:
139
+ if verbose:
140
+ print(" failed")
141
+ raise
142
+ if verbose:
143
+ print(".", end="", flush=True)
144
+
145
+ # Return data as file object.
146
+ assert not return_filename
147
+ return io.BytesIO(url_data)
148
+
149
+ """
150
+ Modified from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
151
+ """
152
+ class FVD:
153
+ def __init__(self, device,
154
+ detector_url='https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1',
155
+ rescale=False, resize=False, return_features=True):
156
+
157
+ self.device = device
158
+ self.detector_kwargs = dict(rescale=False, resize=False, return_features=True)
159
+
160
+ with open_url(detector_url, verbose=False) as f:
161
+ self.detector = torch.jit.load(f).eval().to(device)
162
+
163
+ # Initialize ground truth statistics
164
+ self.mu_real = None
165
+ self.sigma_real = None
166
+
167
+ def to_device(self, device):
168
+ self.device = device
169
+ self.detector = self.detector.to(self.device)
170
+
171
+ def _compute_stats(self, feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
172
+ mu = feats.mean(axis=0) # [d]
173
+ sigma = np.cov(feats, rowvar=False) # [d, d]
174
+ return mu, sigma
175
+
176
+ def save_gt_stats(self, save_path: str):
177
+ """Save ground truth statistics to a file."""
178
+ if self.mu_real is None or self.sigma_real is None:
179
+ raise ValueError("Ground truth statistics not computed yet")
180
+
181
+ stats = {
182
+ 'mu_real': self.mu_real,
183
+ 'sigma_real': self.sigma_real
184
+ }
185
+ np.savez(save_path, **stats)
186
+
187
+ def load_gt_stats(self, load_path: str):
188
+ """Load ground truth statistics from a file."""
189
+ stats = np.load(load_path)
190
+ self.mu_real = stats['mu_real']
191
+ self.sigma_real = stats['sigma_real']
192
+
193
+ def preprocess_videos(self, videos, resolution=224, sequence_length=None):
194
+
195
+ b, t, c, h, w = videos.shape
196
+
197
+ # temporal crop
198
+ if sequence_length is not None:
199
+ assert sequence_length <= t
200
+ videos = videos[:, :sequence_length, ::]
201
+
202
+ # b*t x c x h x w
203
+ videos = videos.reshape(-1, c, h, w)
204
+ if c == 1:
205
+ videos = torch.cat([videos, videos, videos], 1)
206
+ c = 3
207
+
208
+ # scale shorter side to resolution
209
+ scale = resolution / min(h, w)
210
+ # import pdb; pdb.set_trace()
211
+ if h < w:
212
+ target_size = (resolution, math.ceil(w * scale))
213
+ else:
214
+ target_size = (math.ceil(h * scale), resolution)
215
+
216
+ videos = F.interpolate(videos, size=target_size).clamp(min=-1, max=1)
217
+
218
+ # center crop
219
+ _, c, h, w = videos.shape
220
+
221
+ h_start = (h - resolution) // 2
222
+ w_start = (w - resolution) // 2
223
+ videos = videos[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
224
+
225
+ # b, c, t, w, h
226
+ videos = videos.reshape(b, t, c, resolution, resolution).permute(0, 2, 1, 3, 4)
227
+
228
+ return videos.contiguous()
229
+
230
+ @torch.no_grad()
231
+ def evaluate(self, video_fake, video_real=None, res=224, use_saved_stats=False, save_stats_path=None):
232
+ """Evaluate FVD score.
233
+
234
+ Args:
235
+ video_fake: Generated videos
236
+ video_real: Ground truth videos (optional if use_saved_stats=True)
237
+ res: Resolution for preprocessing
238
+ use_saved_stats: Whether to use saved ground truth statistics
239
+ """
240
+ video_fake = self.preprocess_videos(video_fake, resolution=res)
241
+ feats_fake = self.detector(video_fake, **self.detector_kwargs).cpu().numpy()
242
+
243
+ if use_saved_stats:
244
+ if self.mu_real is None or self.sigma_real is None:
245
+ raise ValueError("Ground truth statistics not loaded. Call load_gt_stats() first.")
246
+ mu_real = self.mu_real
247
+ sigma_real = self.sigma_real
248
+ else:
249
+ if video_real is None:
250
+ raise ValueError("video_real must be provided when use_saved_stats=False")
251
+ video_real = self.preprocess_videos(video_real, resolution=res)
252
+ feats_real = self.detector(video_real, **self.detector_kwargs).cpu().numpy()
253
+ mu_real, sigma_real = self._compute_stats(feats_real)
254
+ # Save the computed statistics
255
+ self.mu_real = mu_real
256
+ self.sigma_real = sigma_real
257
+ if save_stats_path is not None:
258
+ self.save_gt_stats(save_stats_path)
259
+
260
+ mu_gen, sigma_gen = self._compute_stats(feats_fake)
261
+
262
+ m = np.square(mu_gen - mu_real).sum()
263
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False)
264
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
265
+ return fid
266
+
267
+
268
+ def collect_fvd_stats(data_root, samples=200, downsample_int=1, num_frames=25, save_path=None, action_type=None):
269
+ """Collect and save ground truth statistics for FVD evaluation."""
270
+
271
+ if save_path is None:
272
+ save_path = os.path.join(data_root, "gt_fvd_stats.npz")
273
+
274
+ # Set up category filtering if specified
275
+ specific_categories = None
276
+ force_clip_type = None
277
+ if action_type is not None:
278
+ if action_type == 0:
279
+ force_clip_type = "normal"
280
+ print("Collecting normal samples only")
281
+ else:
282
+ classes_by_action_type = {
283
+ 1: [61, 62, 13, 14, 15, 16, 17, 18],
284
+ 2: list(range(1, 12 + 1)),
285
+ 3: [37, 39, 41, 42, 44] + list(range(19, 36 + 1)) + list(range(52, 60 + 1)),
286
+ 4: [38, 40, 43, 45, 46, 47, 48, 49, 50, 51]
287
+ }
288
+ specific_categories = classes_by_action_type[action_type]
289
+ force_clip_type = "crash"
290
+ print("Collecting crash samples from categories:", specific_categories)
291
+
292
+ # Create dataset and dataloader
293
+ dataset_name = "mmau"
294
+ train_set = True
295
+ val_dataset, val_loader = get_dataloader(data_root, dataset_name,
296
+ if_train=train_set, clip_length=num_frames,
297
+ batch_size=1, num_workers=0, shuffle=True,
298
+ image_height=320, image_width=512,
299
+ non_overlapping_clips=True,
300
+ specific_categories=specific_categories,
301
+ force_clip_type=force_clip_type)
302
+
303
+ # Collect video paths
304
+ gt_videos = []
305
+ for sample in tqdm(val_loader, desc="Collecting samples", total=samples):
306
+ vid_path = os.path.dirname(sample["image_paths"][0][0])
307
+ gt_videos.append(vid_path)
308
+ if len(gt_videos) >= samples:
309
+ break
310
+
311
+ random.shuffle(gt_videos)
312
+
313
+ num_found_samples = len(gt_videos)
314
+ print(f"Found {num_found_samples} ground truth video directories")
315
+
316
+ # Initialize array for all videos
317
+ all_videos = torch.zeros((num_found_samples, num_frames, 3, 320, 512), device="cuda")
318
+
319
+ # Load and process videos
320
+ valid = 0
321
+ for idx, video_path in tqdm(enumerate(gt_videos), desc="Processing videos", total=num_found_samples):
322
+ if valid == num_found_samples:
323
+ break
324
+
325
+ # Get list of jpg files in directory
326
+ frame_files = sorted([f for f in os.listdir(video_path) if f.endswith('.jpg')])
327
+
328
+ if len(frame_files) < num_frames:
329
+ print(f"Skipping {video_path.split('/')[-1]}, insufficient frames: {len(frame_files)}")
330
+ continue
331
+
332
+ # Load frames
333
+ frames = []
334
+ for frame_file in frame_files[0:num_frames:downsample_int]:
335
+ frame_path = os.path.join(video_path, frame_file)
336
+ img = cv2.imread(frame_path)
337
+ img = cv2.resize(img, (512, 320))
338
+ frames.append(img)
339
+
340
+ frames = torch.tensor(np.array(frames), device="cuda")
341
+
342
+ # Process frames
343
+ frames = frames.unsqueeze(0).permute(0, 1, 4, 2, 3)
344
+ all_videos[valid] = frames[:, :num_frames, ::]
345
+ valid += 1
346
+
347
+ if valid == 0:
348
+ raise ValueError("No valid videos found")
349
+
350
+ # Convert to torch tensor and normalize
351
+ all_videos = all_videos.float()
352
+ all_videos.div_(255/2.0).sub_(1.0)
353
+
354
+ # Initialize FVD and compute statistics
355
+ with torch.no_grad():
356
+ fvd = FVD(device='cuda')
357
+ video_real = fvd.preprocess_videos(all_videos)
358
+ feats_real = fvd.detector(video_real, **fvd.detector_kwargs).cpu().numpy()
359
+ mu_real, sigma_real = fvd._compute_stats(feats_real)
360
+
361
+ # Save statistics
362
+ stats = {
363
+ 'mu_real': mu_real,
364
+ 'sigma_real': sigma_real,
365
+ 'num_videos': valid,
366
+ 'num_frames': num_frames,
367
+ 'resolution': 320
368
+ }
369
+ np.savez(save_path, **stats)
370
+ print(f"Saved ground truth statistics to {save_path}")
371
+
372
+ # Clean up
373
+ del fvd, all_videos, video_real, feats_real
374
+ torch.cuda.empty_cache()
375
+
376
+ return save_path
377
+
378
+ def evaluate_vids(vid_root, samples=200, downsample_int=1, num_frames=25, gt_stats=None, shuffle=False):
379
+ """Evaluate FVD score for generated videos using pre-computed ground truth statistics."""
380
+
381
+ # Initialize FVD and load ground truth statistics
382
+ fvd = FVD(device='cuda')
383
+ if gt_stats is not None:
384
+ fvd.load_gt_stats(gt_stats)
385
+
386
+ # Collect generated video paths
387
+ f_gen_vid = []
388
+ gen_videos = os.path.join(vid_root, "gen_videos") if os.path.exists(os.path.join(vid_root, "gen_videos")) else vid_root
389
+ for fname in os.listdir(gen_videos):
390
+ f_gen_vid.append(fname)
391
+
392
+ print(f"Number of generated videos: {len(f_gen_vid)}")
393
+
394
+ if not shuffle:
395
+ f_gen_vid.sort()
396
+ else:
397
+ random.shuffle(f_gen_vid)
398
+
399
+ f_gen_vid = f_gen_vid[:samples]
400
+
401
+ # Initialize array for all videos
402
+ all_gen = np.zeros((samples, num_frames, 3, 320, 512))
403
+
404
+ # Load and process videos
405
+ valid = 0
406
+ for idx, fgen in tqdm(enumerate(f_gen_vid)):
407
+ if valid == samples:
408
+ break
409
+
410
+ gen_vid_path = os.path.join(gen_videos, fgen)
411
+ gen_vid = get_frames_mp4(gen_vid_path, frame_interval=downsample_int)
412
+
413
+ if gen_vid.shape[0] < num_frames:
414
+ print("Skipping, wrong size:", gen_vid.shape[0])
415
+ continue
416
+
417
+ gen_vid = np.expand_dims(gen_vid, 0).transpose(0, 1, 4, 2, 3)
418
+ all_gen[valid] = gen_vid[:, :num_frames, ::]
419
+ valid += 1
420
+
421
+ # Convert to torch tensor and normalize
422
+ all_gen = torch.from_numpy(all_gen).cuda().float()
423
+ all_gen /= 255/2.0
424
+ all_gen -= 1.0
425
+
426
+ # Compute FVD score
427
+ fvd_score = fvd.evaluate(all_gen, video_real=None, use_saved_stats=True)
428
+ del fvd
429
+
430
+ print(f'FVD Score: {fvd_score}')
431
+
432
+ if __name__ == '__main__':
433
+ parser = argparse.ArgumentParser(description='Evaluate FVD score using pre-computed ground truth statistics')
434
+ parser.add_argument('--vid_root', type=str, required=True,
435
+ help='Root directory containing generated videos')
436
+ parser.add_argument('--samples', type=int, default=200,
437
+ help='Number of samples to evaluate (default: 200)')
438
+ parser.add_argument('--num_frames', type=int, default=25,
439
+ help='Number of frames per video (default: 25)')
440
+ parser.add_argument('--downsample_int', type=int, default=1,
441
+ help='Downsample interval for frames (default: 1)')
442
+ parser.add_argument('--gt_stats', type=str, default=None,
443
+ help='Path to ground truth statistics file (optional)')
444
+ parser.add_argument('--shuffle', action='store_true',
445
+ help='Shuffle videos before evaluation')
446
+ parser.add_argument('--collect_stats', action='store_true',
447
+ help='Collect and save ground truth statistics')
448
+ parser.add_argument('--data_root', type=str, required=True,
449
+ help='Root directory for datasets')
450
+ parser.add_argument('--action_type', type=int, default=None,
451
+ help='Action type to filter videos (0: normal, 1-4: crash types)')
452
+ args = parser.parse_args()
453
+
454
+ if args.collect_stats:
455
+ stats_path = collect_fvd_stats(args.data_root, args.samples, args.downsample_int, args.num_frames, args.gt_stats, args.action_type)
456
+ args.gt_stats = stats_path
457
+
458
+ evaluate_vids(args.vid_root, args.samples, args.downsample_int, args.num_frames, args.gt_stats, args.shuffle)
src/eval/video_quality_metrics_fvd_pair.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import scipy.linalg
4
+ from typing import Tuple
5
+ import torch.nn.functional as F
6
+ import math
7
+ from torchvision import transforms
8
+ import cv2
9
+ import json
10
+ import argparse
11
+ from tqdm import tqdm
12
+ import lpips
13
+ from skimage.metrics import structural_similarity as ssim
14
+ from skimage.metrics import peak_signal_noise_ratio as psnr
15
+
16
+ """
17
+ Copy-pasted from Copy-pasted from https://github.com/NVlabs/stylegan2-ada-pytorch
18
+ """
19
+ import ctypes
20
+ import fnmatch
21
+ import importlib
22
+ import inspect
23
+ import numpy as np
24
+ import os
25
+ import shutil
26
+ import sys
27
+ import types
28
+ import io
29
+ import pickle
30
+ import re
31
+ import requests
32
+ import html
33
+ import hashlib
34
+ import glob
35
+ import tempfile
36
+ import urllib
37
+ import urllib.request
38
+ import uuid
39
+ from tqdm import tqdm
40
+
41
+ from distutils.util import strtobool
42
+ from typing import Any, List, Tuple, Union, Dict
43
+
44
+
45
+ def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any:
46
+ """Download the given URL and return a binary-mode file object to access the data."""
47
+ assert num_attempts >= 1
48
+
49
+ # Doesn't look like an URL scheme so interpret it as a local filename.
50
+ if not re.match('^[a-z]+://', url):
51
+ return url if return_filename else open(url, "rb")
52
+
53
+ # Handle file URLs. This code handles unusual file:// patterns that
54
+ # arise on Windows:
55
+ #
56
+ # file:///c:/foo.txt
57
+ #
58
+ # which would translate to a local '/c:/foo.txt' filename that's
59
+ # invalid. Drop the forward slash for such pathnames.
60
+ #
61
+ # If you touch this code path, you should test it on both Linux and
62
+ # Windows.
63
+ #
64
+ # Some internet resources suggest using urllib.request.url2pathname() but
65
+ # but that converts forward slashes to backslashes and this causes
66
+ # its own set of problems.
67
+ if url.startswith('file://'):
68
+ filename = urllib.parse.urlparse(url).path
69
+ if re.match(r'^/[a-zA-Z]:', filename):
70
+ filename = filename[1:]
71
+ return filename if return_filename else open(filename, "rb")
72
+
73
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
74
+
75
+ # Download.
76
+ url_name = None
77
+ url_data = None
78
+ with requests.Session() as session:
79
+ if verbose:
80
+ print("Downloading %s ..." % url, end="", flush=True)
81
+ for attempts_left in reversed(range(num_attempts)):
82
+ try:
83
+ with session.get(url) as res:
84
+ res.raise_for_status()
85
+ if len(res.content) == 0:
86
+ raise IOError("No data received")
87
+
88
+ if len(res.content) < 8192:
89
+ content_str = res.content.decode("utf-8")
90
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
91
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
92
+ if len(links) == 1:
93
+ url = requests.compat.urljoin(url, links[0])
94
+ raise IOError("Google Drive virus checker nag")
95
+ if "Google Drive - Quota exceeded" in content_str:
96
+ raise IOError("Google Drive download quota exceeded -- please try again later")
97
+
98
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
99
+ url_name = match[1] if match else url
100
+ url_data = res.content
101
+ if verbose:
102
+ print(" done")
103
+ break
104
+ except KeyboardInterrupt:
105
+ raise
106
+ except:
107
+ if not attempts_left:
108
+ if verbose:
109
+ print(" failed")
110
+ raise
111
+ if verbose:
112
+ print(".", end="", flush=True)
113
+
114
+ # Return data as file object.
115
+ assert not return_filename
116
+ return io.BytesIO(url_data)
117
+
118
+ """
119
+ Modified from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
120
+ """
121
+ class FVD:
122
+ def __init__(self, device,
123
+ detector_url='https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1',
124
+ rescale=False, resize=False, return_features=True):
125
+
126
+ self.device = device
127
+ self.detector_kwargs = dict(rescale=False, resize=False, return_features=True)
128
+
129
+ with open_url(detector_url, verbose=False) as f:
130
+ self.detector = torch.jit.load(f).eval().to(device)
131
+
132
+ def to_device(self, device):
133
+ self.device = device
134
+ self.detector = self.detector.to(self.device)
135
+
136
+ def _compute_stats(self, feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
137
+ mu = feats.mean(axis=0) # [d]
138
+ sigma = np.cov(feats, rowvar=False) # [d, d]
139
+ return mu, sigma
140
+
141
+ def preprocess_videos(self, videos, resolution=224, sequence_length=None):
142
+
143
+ b, t, c, h, w = videos.shape
144
+
145
+ # temporal crop
146
+ if sequence_length is not None:
147
+ assert sequence_length <= t
148
+ videos = videos[:, :sequence_length, ::]
149
+
150
+ # b*t x c x h x w
151
+ videos = videos.reshape(-1, c, h, w)
152
+ if c == 1:
153
+ videos = torch.cat([videos, videos, videos], 1)
154
+ c = 3
155
+
156
+ # scale shorter side to resolution
157
+ scale = resolution / min(h, w)
158
+ # import pdb; pdb.set_trace()
159
+ if h < w:
160
+ target_size = (resolution, math.ceil(w * scale))
161
+ else:
162
+ target_size = (math.ceil(h * scale), resolution)
163
+
164
+ videos = F.interpolate(videos, size=target_size).clamp(min=-1, max=1)
165
+
166
+ # center crop
167
+ _, c, h, w = videos.shape
168
+
169
+ h_start = (h - resolution) // 2
170
+ w_start = (w - resolution) // 2
171
+ videos = videos[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
172
+
173
+ # b, c, t, w, h
174
+ videos = videos.reshape(b, t, c, resolution, resolution).permute(0, 2, 1, 3, 4)
175
+
176
+ return videos.contiguous()
177
+
178
+ @torch.no_grad()
179
+ def evaluate(self, video_fake, video_real, res=224):
180
+
181
+ video_fake = self.preprocess_videos(video_fake,resolution=res)
182
+ video_real = self.preprocess_videos(video_real,resolution=res)
183
+ feats_fake = self.detector(video_fake, **self.detector_kwargs).cpu().numpy()
184
+ feats_real = self.detector(video_real, **self.detector_kwargs).cpu().numpy()
185
+
186
+ mu_gen, sigma_gen = self._compute_stats(feats_fake)
187
+ mu_real, sigma_real = self._compute_stats(feats_real)
188
+
189
+ m = np.square(mu_gen - mu_real).sum()
190
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
191
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
192
+ return fid
193
+
194
+ def evaluate_vids(vid_root, samples=200, downsample=False, num_frames=25):
195
+ """Evaluate video quality metrics between generated and ground truth videos."""
196
+
197
+ # Collect video paths
198
+ vid_name_to_gt_frames = {}
199
+ gt_videos_refs = os.path.join(vid_root, "gt_frames")
200
+ for fname in os.listdir(gt_videos_refs):
201
+ vid_name = fname.strip("gt_frames_").split(".")[0]
202
+ vid_name_to_gt_frames[vid_name] = fname
203
+
204
+ f_gen_vid = []
205
+ gen_videos = os.path.join(vid_root, "gen_videos")
206
+ for fname in os.listdir(gen_videos):
207
+ f_gen_vid.append(fname)
208
+ vid_name = fname.strip("genvid_").split(".")[0]
209
+ assert vid_name_to_gt_frames.get(vid_name) is not None, f"{fname} has no matching gt frames"
210
+
211
+ print(f"Number of generated videos: {len(f_gen_vid)}")
212
+
213
+ # Initialize arrays for all videos
214
+ all_gt = np.zeros((samples, num_frames, 3, 320, 512))
215
+ all_gen = np.zeros((samples, num_frames, 3, 320, 512))
216
+
217
+ # Load and process videos
218
+ valid = 0
219
+ for idx, fgen in tqdm(enumerate(f_gen_vid), desc="Collecting video frames"):
220
+ if valid == samples:
221
+ break
222
+
223
+ vid_name = fgen.strip("genvid_").split(".")[0]
224
+ fgt = vid_name_to_gt_frames[vid_name]
225
+
226
+ gen_vid_path = os.path.join(gen_videos, fgen)
227
+ gen_vid = get_frames_mp4(gen_vid_path)
228
+
229
+ with open(os.path.join(gt_videos_refs, fgt)) as gt_json:
230
+ gt_vid = get_frames_from_path_list(json.load(gt_json))
231
+
232
+ if gt_vid.shape[0] < num_frames or gen_vid.shape[0] < num_frames:
233
+ print("Skipping, wrong size:", gt_vid.shape[0], gen_vid.shape[0])
234
+ continue
235
+
236
+ gt_vid = np.expand_dims(gt_vid, 0).transpose(0, 1, 4, 2, 3)
237
+ gen_vid = np.expand_dims(gen_vid, 0).transpose(0, 1, 4, 2, 3)
238
+
239
+ all_gt[valid] = gt_vid[:, :num_frames, ::]
240
+ all_gen[valid] = gen_vid[:, :num_frames, ::]
241
+ valid += 1
242
+
243
+ # Convert to torch tensors and normalize
244
+ all_gt = torch.from_numpy(all_gt).cuda().float()
245
+ all_gt /= 255/2.0
246
+ all_gt -= 1.0
247
+ all_gen = torch.from_numpy(all_gen).cuda().float()
248
+ all_gen /= 255/2.0
249
+ all_gen -= 1.0
250
+
251
+ # Compute FVD score
252
+ fvd = FVD(device='cuda')
253
+ fvd_score = fvd.evaluate(all_gt, all_gen)
254
+ del fvd
255
+
256
+ # Compute LPIPS score
257
+ loss_fn_alex = lpips.LPIPS(net='alex').cuda()
258
+ lpips_score = 0
259
+ for idx in range(all_gen.shape[0]):
260
+ lpips_score += loss_fn_alex(all_gt[idx], all_gen[idx])/all_gen.shape[0]
261
+ lpips_score = lpips_score.mean().item()
262
+ del loss_fn_alex
263
+
264
+ # Compute SSIM and PSNR scores
265
+ all_gen = all_gen.detach().cpu().numpy()
266
+ all_gt = all_gt.detach().cpu().numpy()
267
+
268
+ ssim_score_vid = np.zeros(samples)
269
+ ssim_score_image = np.zeros((samples, num_frames))
270
+ psnr_score_vid = np.zeros(samples)
271
+ psnr_score_image = np.zeros((samples, num_frames))
272
+ psnr_score_all = psnr(all_gt, all_gen)
273
+
274
+ for vid_idx in tqdm(range(all_gen.shape[0]), desc="Computing SSIM and PSNR"):
275
+ for f_idx in range(all_gen.shape[1]):
276
+ img_gt = all_gt[vid_idx, f_idx]
277
+ img_gen = all_gen[vid_idx, f_idx]
278
+ data_range = max(img_gt.max(), img_gen.max()) - min(img_gt.min(), img_gen.min())
279
+ ssim_score_image[vid_idx, f_idx] = ssim(img_gt, img_gen, channel_axis=0, data_range=data_range, gaussian_weights=True, sigma=1.5)
280
+ psnr_score_image[vid_idx, f_idx] = psnr(img_gt, img_gen, data_range=data_range)
281
+
282
+ vid_gt = all_gt[vid_idx]
283
+ vid_gen = all_gen[vid_idx]
284
+ data_range = max(vid_gt.max(), vid_gen.max()) - min(vid_gt.min(), vid_gen.min())
285
+ ssim_score_vid[vid_idx] = ssim(vid_gt, vid_gen, channel_axis=1, data_range=data_range, gaussian_weights=True, sigma=1.5)
286
+ psnr_score_vid[vid_idx] = psnr(vid_gt, vid_gen, data_range=data_range)
287
+
288
+ ssim_score_image_error = np.sqrt(((ssim_score_image - ssim_score_image.mean())**2).sum()/200)
289
+ psnr_score_image_error = np.sqrt(((psnr_score_image - psnr_score_image.mean())**2).sum()/200)
290
+
291
+ # Print results
292
+ print(f'FVD Score: {fvd_score}')
293
+ print(f'LPIPS Score: {lpips_score}')
294
+ print(f'SSIM Score (per image): {ssim_score_image.mean()}')
295
+ print(f'SSIM Score Error: {ssim_score_image_error}')
296
+ print(f'PSNR Score (per image): {psnr_score_image.mean()}')
297
+ print(f'PSNR Score Error: {psnr_score_image_error}')
298
+
299
+ # Print copy-friendly format
300
+ print("\nCopy friendly format:")
301
+ print(f"{fvd_score}, {lpips_score}, {ssim_score_image.mean()}, {psnr_score_image.mean()}")
302
+
303
+ if __name__ == '__main__':
304
+ parser = argparse.ArgumentParser(description='Evaluate video quality metrics between generated and ground truth videos')
305
+ parser.add_argument('--vid_root', type=str, required=True,
306
+ help='Root directory containing generated and ground truth videos')
307
+ parser.add_argument('--samples', type=int, default=200,
308
+ help='Number of samples to evaluate (default: 200)')
309
+ parser.add_argument('--num_frames', type=int, default=25,
310
+ help='Number of frames per video (default: 25)')
311
+ parser.add_argument('--downsample', action='store_true',
312
+ help='Downsample videos during evaluation')
313
+ args = parser.parse_args()
314
+
315
+ evaluate_vids(args.vid_root, args.samples, args.downsample, args.num_frames)
316
+
317
+ def get_frames_from_path_list(path_list):
318
+ frames = []
319
+ for path in path_list:
320
+ img = cv2.imread(path)
321
+ img = cv2.resize(img, [512, 320])
322
+ frames.append(img)
323
+ return np.array(frames)
324
+
325
+ def get_frames_mp4(video_path: str, frame_interval: int = 1) -> None:
326
+
327
+ # Open the video file
328
+ cap = cv2.VideoCapture(video_path)
329
+ if not cap.isOpened():
330
+ raise ValueError(f"Could not open video file: {video_path}")
331
+
332
+ frame_count = 0
333
+ saved_count = 0
334
+
335
+ frames = []
336
+ while True:
337
+ ret, frame = cap.read()
338
+ if not ret:
339
+ break
340
+
341
+ # Save frame if it's the right interval
342
+ if frame_count % frame_interval == 0:
343
+ frames.append(frame)
344
+ saved_count += 1
345
+
346
+ frame_count += 1
347
+
348
+ cap.release()
349
+ return np.array(frames)
src/eval/video_quality_metrics_jedi_gt_rand.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import cv2
4
+ import json
5
+ import random
6
+ import os
7
+ import argparse
8
+ from tqdm import tqdm
9
+ from torch.utils.data import DataLoader
10
+
11
+ from videojedi import JEDiMetric
12
+ from .video_dataset import VideoDataset
13
+ from src.datasets.dataset_utils import get_dataloader
14
+
15
+ def custom_collate(batch):
16
+ """Custom collate function for DataLoader to handle video clips."""
17
+ videos, targets = [], []
18
+ for sample in batch:
19
+ clips = sample["clips"]
20
+ videos.append(clips)
21
+ return torch.utils.data.dataloader.default_collate(videos), targets
22
+
23
+ def evaluate_vids(vid_root, samples=200, downsample_int=1, num_frames=25, gt_samples=500, test_feature_path=None, action_type=None, shuffle=False):
24
+ """Evaluate JEDi metric between generated and ground truth videos."""
25
+
26
+ # Initialize JEDi metric
27
+ jedi = JEDiMetric(feature_path=vid_root,
28
+ test_feature_path=test_feature_path,
29
+ model_dir="/path/to/Models")
30
+
31
+ # Create dataset and dataloader for generated videos
32
+ gen_dataset = VideoDataset(vid_root, num_frames=num_frames, downsample_int=downsample_int)
33
+ gen_loader = DataLoader(gen_dataset, batch_size=1, shuffle=shuffle, num_workers=4)
34
+
35
+ # Set up category filtering if specified
36
+ specific_categories = None
37
+ force_clip_type = None
38
+ if action_type is not None:
39
+ if action_type == 0:
40
+ force_clip_type = "normal"
41
+ print("Collecting normal samples only")
42
+ else:
43
+ classes_by_action_type = {
44
+ 1: [61, 62, 13, 14, 15, 16, 17, 18],
45
+ 2: list(range(1, 12 + 1)),
46
+ 3: [37, 39, 41, 42, 44] + list(range(19, 36 + 1)) + list(range(52, 60 + 1)),
47
+ 4: [38, 40, 43, 45, 46, 47, 48, 49, 50, 51]
48
+ }
49
+ specific_categories = classes_by_action_type[action_type]
50
+ force_clip_type = "crash"
51
+ print("Collecting crash samples from categories:", specific_categories)
52
+
53
+ # Create dataset and dataloader for ground truth videos
54
+ dataset_name = "mmau"
55
+ train_set = True
56
+ val_dataset, _ = get_dataloader("path/to/Datasets", dataset_name,
57
+ if_train=train_set, clip_length=num_frames,
58
+ batch_size=1, num_workers=0, shuffle=True,
59
+ image_height=320, image_width=512,
60
+ non_overlapping_clips=True,
61
+ specific_categories=specific_categories,
62
+ force_clip_type=force_clip_type)
63
+ val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, collate_fn=custom_collate)
64
+
65
+ # Compute JEDi metric
66
+ jedi.load_features(train_loader=gen_loader, test_loader=val_loader,
67
+ num_samples=samples, num_test_samples=gt_samples)
68
+ jedi_metric = jedi.compute_metric()
69
+ print(f"JEDi Metric: {jedi_metric}")
70
+
71
+ if __name__ == '__main__':
72
+ parser = argparse.ArgumentParser(description='Evaluate JEDi metric between generated and ground truth videos')
73
+ parser.add_argument('--vid_root', type=str, required=True,
74
+ help='Root directory containing generated videos')
75
+ parser.add_argument('--samples', type=int, default=200,
76
+ help='Number of samples to evaluate (default: 200)')
77
+ parser.add_argument('--gt_samples', type=int, default=500,
78
+ help='Number of ground truth samples to use (default: 500)')
79
+ parser.add_argument('--num_frames', type=int, default=25,
80
+ help='Number of frames per video (default: 25)')
81
+ parser.add_argument('--downsample_int', type=int, default=1,
82
+ help='Downsample interval for frames (default: 1)')
83
+ parser.add_argument('--test_feature_path', type=str, default=None,
84
+ help='Path to test features (optional)')
85
+ parser.add_argument('--action_type', type=int, default=None,
86
+ help='Action type to filter videos (0: normal, 1-4: crash types)')
87
+ parser.add_argument('--shuffle', action='store_true',
88
+ help='Shuffle videos before evaluation')
89
+ args = parser.parse_args()
90
+
91
+ evaluate_vids(args.vid_root, args.samples, args.downsample_int, args.num_frames, args.gt_samples, args.test_feature_path, args.action_type, args.shuffle)
src/eval/video_quality_metrics_jedi_pair.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import argparse
4
+ from torch.utils.data import DataLoader
5
+
6
+ from videojedi import JEDiMetric
7
+ from .video_dataset import VideoDataset
8
+ from src.datasets.dataset_utils import get_dataloader
9
+
10
+ def custom_collate(batch):
11
+ videos, targets = [], []
12
+ for sample in batch:
13
+ clips = sample["clips"]
14
+ videos.append(clips)
15
+ return torch.utils.data.dataloader.default_collate(videos), targets
16
+
17
+ def evaluate_vids(vid_root, samples=200, downsample_int=1, num_frames=25, gt_samples=500, test_feature_path=None, action_type=None):
18
+ """Evaluate JEDi metric between generated and ground truth videos."""
19
+
20
+ # Initialize JEDi metric
21
+ jedi = JEDiMetric(feature_path=vid_root,
22
+ test_feature_path=test_feature_path,
23
+ model_dir="/path/to/Models")
24
+
25
+ # Create dataset and dataloader for generated videos
26
+ gen_dataset = VideoDataset(vid_root, num_frames=num_frames, downsample_int=downsample_int)
27
+ gen_loader = DataLoader(gen_dataset, batch_size=1, shuffle=False, num_workers=4)
28
+
29
+ # Set up category filtering if specified
30
+ specific_categories = None
31
+ force_clip_type = None
32
+ if action_type is not None:
33
+ if action_type == 0:
34
+ force_clip_type = "normal"
35
+ print("Collecting normal samples only")
36
+ else:
37
+ classes_by_action_type = {
38
+ 1: [61, 62, 13, 14, 15, 16, 17, 18],
39
+ 2: list(range(1, 12 + 1)),
40
+ 3: [37, 39, 41, 42, 44] + list(range(19, 36 + 1)) + list(range(52, 60 + 1)),
41
+ 4: [38, 40, 43, 45, 46, 47, 48, 49, 50, 51]
42
+ }
43
+ specific_categories = classes_by_action_type[action_type]
44
+ force_clip_type = "crash"
45
+ print("Collecting crash samples from categories:", specific_categories)
46
+
47
+ # Get specific samples to evaluate
48
+ specific_samples = []
49
+ gen_videos = os.path.join(vid_root, "gen_videos") if os.path.exists(f"{vid_root}/gen_videos") else vid_root
50
+ for fname in os.listdir(gen_videos):
51
+ vid_name = fname.strip("genvid_").split(".")[0]
52
+ gt_vid_name = "_".join(vid_name.split("_")[1:])
53
+ specific_samples.append(gt_vid_name)
54
+
55
+ # Create dataset and dataloader for ground truth videos
56
+ dataset_name = "mmau"
57
+ train_set = False
58
+ val_dataset, _ = get_dataloader("/path/to/Datasets", dataset_name,
59
+ if_train=train_set, clip_length=num_frames,
60
+ batch_size=1, num_workers=0, shuffle=True,
61
+ image_height=320, image_width=512,
62
+ non_overlapping_clips=True,
63
+ specific_categories=specific_categories,
64
+ force_clip_type=force_clip_type,
65
+ specific_samples=specific_samples)
66
+ val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, collate_fn=custom_collate)
67
+
68
+ # Compute JEDi metric
69
+ jedi.load_features(train_loader=gen_loader, test_loader=val_loader,
70
+ num_samples=samples, num_test_samples=gt_samples)
71
+ jedi_metric = jedi.compute_metric()
72
+ print(f"JEDi Metric: {jedi_metric}")
73
+
74
+ if __name__ == '__main__':
75
+ parser = argparse.ArgumentParser(description='Evaluate JEDi metric between generated and ground truth videos')
76
+ parser.add_argument('--vid_root', type=str, required=True,
77
+ help='Root directory containing generated videos')
78
+ parser.add_argument('--samples', type=int, default=200,
79
+ help='Number of samples to evaluate (default: 200)')
80
+ parser.add_argument('--gt_samples', type=int, default=500,
81
+ help='Number of ground truth samples to use (default: 500)')
82
+ parser.add_argument('--num_frames', type=int, default=25,
83
+ help='Number of frames per video (default: 25)')
84
+ parser.add_argument('--downsample_int', type=int, default=1,
85
+ help='Downsample interval for frames (default: 1)')
86
+ parser.add_argument('--test_feature_path', type=str, default=None,
87
+ help='Path to test features (optional)')
88
+ parser.add_argument('--action_type', type=int, default=None,
89
+ help='Action type to filter videos (0: normal, 1-4: crash types)')
90
+ args = parser.parse_args()
91
+
92
+ evaluate_vids(args.vid_root, args.samples, args.downsample_int, args.num_frames, args.gt_samples, args.test_feature_path, args.action_type)
src/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from src.models.controlnet import ControlNetModel
2
+ from src.models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
src/models/controlnet.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block
10
+ from diffusers.loaders import FromOriginalControlNetMixin
11
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
12
+ from diffusers.utils import logging
13
+ from diffusers.models import ControlNetModel as ControlNetModel_original
14
+ from diffusers.models.controlnet import ControlNetOutput, zero_module
15
+
16
+ from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
17
+
18
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
+
20
+ class ControlNetModel(ControlNetModel_original): # (ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
21
+
22
+ r"""
23
+ A controlnet for conditional Spatio-Temporal UNet model.
24
+
25
+ Parameters:
26
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
27
+ Height and width of input/output sample.
28
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
29
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
30
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
31
+ The tuple of downsample blocks to use.
32
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
33
+ The tuple of output channels for each block.
34
+ addition_time_embed_dim: (`int`, defaults to 256):
35
+ Dimension to to encode the additional time ids.
36
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
37
+ The dimension of the projection of encoded `added_time_ids`.
38
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
39
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
40
+ The dimension of the cross attention features.
41
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
42
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
43
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
44
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
45
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
46
+ The number of attention heads.
47
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
48
+ action_dim: (`int`, defaults to 256):
49
+ Dimension of the action features.
50
+ """
51
+
52
+ _supports_gradient_checkpointing = True
53
+
54
+ @register_to_config
55
+ def __init__(
56
+ self,
57
+ sample_size: Optional[int] = None,
58
+ in_channels: int = 8,
59
+ down_block_types: Tuple[str] = (
60
+ "CrossAttnDownBlockSpatioTemporal",
61
+ "CrossAttnDownBlockSpatioTemporal",
62
+ "CrossAttnDownBlockSpatioTemporal",
63
+ "DownBlockSpatioTemporal",
64
+ ),
65
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
66
+ addition_time_embed_dim: int = 256,
67
+ projection_class_embeddings_input_dim: int = 768,
68
+ layers_per_block: Union[int, Tuple[int]] = 2,
69
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
70
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
71
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
72
+ num_frames: int = 25,
73
+ action_dim: int = 5, # Dimension of the action features
74
+ bbox_embedding_shape: Tuple[int] = (4, 128, 128),
75
+ ):
76
+ # calling the super class constructors without calling ControlNetModel_original's
77
+ ModelMixin.__init__(self)
78
+ ConfigMixin.__init__(self)
79
+ FromOriginalControlNetMixin.__init__(self)
80
+
81
+ self.sample_size = sample_size
82
+
83
+ # Check inputs
84
+ if len(block_out_channels) != len(down_block_types):
85
+ raise ValueError(
86
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
87
+ )
88
+
89
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
90
+ raise ValueError(
91
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
92
+ )
93
+
94
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
95
+ raise ValueError(
96
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
97
+ )
98
+
99
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
100
+ raise ValueError(
101
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
102
+ )
103
+
104
+ # input
105
+ self.conv_in = nn.Conv2d(
106
+ in_channels,
107
+ block_out_channels[0],
108
+ kernel_size=3,
109
+ padding=1,
110
+ )
111
+
112
+ # time
113
+ time_embed_dim = block_out_channels[0] * 4
114
+
115
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
116
+ timestep_input_dim = block_out_channels[0]
117
+
118
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
119
+
120
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
121
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
122
+
123
+ # Action projection layer
124
+ hidden_action_dim = 256
125
+ self.action_embedding = nn.Embedding(action_dim, hidden_action_dim)
126
+ self.action_proj = nn.Linear(hidden_action_dim, cross_attention_dim)
127
+
128
+ # Learnable null embedding for bbox masking
129
+ self.bbox_null_embedding = nn.Parameter(torch.randn(bbox_embedding_shape))
130
+
131
+ self.down_blocks = nn.ModuleList([])
132
+ self.controlnet_down_blocks = nn.ModuleList([])
133
+
134
+ if isinstance(num_attention_heads, int):
135
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
136
+
137
+ if isinstance(cross_attention_dim, int):
138
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
139
+
140
+ if isinstance(layers_per_block, int):
141
+ layers_per_block = [layers_per_block] * len(down_block_types)
142
+
143
+ if isinstance(transformer_layers_per_block, int):
144
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
145
+
146
+ blocks_time_embed_dim = time_embed_dim
147
+
148
+ self.control_conv_in = nn.Conv2d(
149
+ in_channels//2,
150
+ block_out_channels[0],
151
+ kernel_size=3,
152
+ padding=1,
153
+ )
154
+ # # Initialize the re-zero parameter
155
+ # self.rz_weight = nn.Parameter(torch.Tensor([0]))
156
+
157
+ # down
158
+ output_channel = block_out_channels[0]
159
+
160
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
161
+ controlnet_block = zero_module(controlnet_block)
162
+ self.controlnet_down_blocks.append(controlnet_block)
163
+
164
+ for i, down_block_type in enumerate(down_block_types):
165
+ input_channel = output_channel
166
+ output_channel = block_out_channels[i]
167
+ is_final_block = i == len(block_out_channels) - 1
168
+
169
+ down_block = get_down_block(
170
+ down_block_type,
171
+ num_layers=layers_per_block[i],
172
+ transformer_layers_per_block=transformer_layers_per_block[i],
173
+ in_channels=input_channel,
174
+ out_channels=output_channel,
175
+ temb_channels=blocks_time_embed_dim,
176
+ add_downsample=not is_final_block,
177
+ resnet_eps=1e-5,
178
+ cross_attention_dim=cross_attention_dim[i],
179
+ num_attention_heads=num_attention_heads[i],
180
+ resnet_act_fn="silu",
181
+ )
182
+ self.down_blocks.append(down_block)
183
+
184
+ for _ in range(layers_per_block[i]):
185
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
186
+ controlnet_block = zero_module(controlnet_block)
187
+ self.controlnet_down_blocks.append(controlnet_block)
188
+
189
+ if not is_final_block:
190
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
191
+ controlnet_block = zero_module(controlnet_block)
192
+ self.controlnet_down_blocks.append(controlnet_block)
193
+
194
+ # mid
195
+ controlnet_block = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], kernel_size=1)
196
+ controlnet_block = zero_module(controlnet_block)
197
+ self.controlnet_mid_block = controlnet_block
198
+ self.mid_block = UNetMidBlockSpatioTemporal(
199
+ block_out_channels[-1],
200
+ temb_channels=blocks_time_embed_dim,
201
+ transformer_layers_per_block=transformer_layers_per_block[-1],
202
+ cross_attention_dim=cross_attention_dim[-1],
203
+ num_attention_heads=num_attention_heads[-1],
204
+ )
205
+
206
+ # count how many layers upsample the images
207
+ self.num_upsamplers = 0
208
+
209
+ @classmethod
210
+ def from_unet(cls,
211
+ unet: UNetSpatioTemporalConditionModel,
212
+ load_weights_from_unet: bool = True,
213
+ action_dim: int = 5,
214
+ bbox_embedding_shape: Tuple[int] = (4, 128, 128)):
215
+
216
+ ctrlnet = cls(
217
+ in_channels=unet.config.in_channels,
218
+ down_block_types=unet.config.down_block_types,
219
+ block_out_channels=unet.config.block_out_channels,
220
+ addition_time_embed_dim=unet.config.addition_time_embed_dim,
221
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
222
+ layers_per_block=unet.config.layers_per_block,
223
+ cross_attention_dim=unet.config.cross_attention_dim,
224
+ transformer_layers_per_block=unet.config.transformer_layers_per_block,
225
+ num_attention_heads=unet.config.num_attention_heads,
226
+ num_frames=unet.config.num_frames,
227
+ action_dim=action_dim,
228
+ bbox_embedding_shape=bbox_embedding_shape,
229
+ )
230
+ unet_keys = set(unet.state_dict().keys())
231
+ ctrl_keys = set(ctrlnet.state_dict().keys())
232
+ intersection_keys = ctrl_keys.intersection(unet_keys)
233
+ for key in ctrl_keys:
234
+ if key in intersection_keys:
235
+ if load_weights_from_unet:
236
+ ctrlnet.state_dict()[key].copy_(unet.state_dict()[key])
237
+ # else:
238
+ # logger.warning(f"Key {key} not found in UNet model, initializing it randomly.")
239
+
240
+ return ctrlnet
241
+
242
+ def forward(
243
+ self,
244
+ sample: torch.FloatTensor,
245
+ timestep: Union[torch.Tensor, float, int],
246
+ encoder_hidden_states: torch.Tensor,
247
+ added_time_ids: torch.Tensor,
248
+ control_cond: torch.FloatTensor = None,
249
+ action_type: torch.LongTensor = None,
250
+ conditioning_scale: float = 1.0,
251
+ return_dict: bool = True,
252
+ ) -> Union[ControlNetOutput, Tuple]:
253
+ r"""
254
+ This approach effectively integrates the forward method of the UNetSpatioTemporalConditionModel with the forward
255
+ method of the ControlNetModel
256
+
257
+ Args:
258
+ sample (`torch.FloatTensor`):
259
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
260
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
261
+ encoder_hidden_states (`torch.FloatTensor`):
262
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
263
+ added_time_ids: (`torch.FloatTensor`):
264
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
265
+ embeddings and added to the time embeddings.
266
+ controlnet_cond (`torch.FloatTensor`):
267
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
268
+ action_type (`torch.LongTensor`):
269
+ The action type with shape `(batch_size)`.
270
+ conditioning_scale (`float`, defaults to `1.0`):
271
+ The scale factor for ControlNet outputs.
272
+ return_dict (`bool`, *optional*, defaults to `True`):
273
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
274
+ of a plain tuple.
275
+ Returns:
276
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
277
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
278
+ returned where the first element is the sample tensor.
279
+ """
280
+ # 1. time
281
+ timesteps = timestep
282
+ if len(timesteps.shape) == 0:
283
+ timesteps = timesteps[None].to(sample.device)
284
+
285
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
286
+ batch_size, num_frames = sample.shape[:2]
287
+ timesteps = timesteps.expand(batch_size)
288
+
289
+ t_emb = self.time_proj(timesteps)
290
+
291
+ # `Timesteps` does not contain any weights and will always return f32 tensors
292
+ # but time_embedding might actually be running in fp16. so we need to cast here.
293
+ # there might be better ways to encapsulate this.
294
+ t_emb = t_emb.to(dtype=sample.dtype)
295
+
296
+ emb = self.time_embedding(t_emb)
297
+
298
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
299
+ time_embeds = time_embeds.reshape((batch_size, -1))
300
+ time_embeds = time_embeds.to(emb.dtype)
301
+ aug_emb = self.add_embedding(time_embeds)
302
+ emb = emb + aug_emb
303
+
304
+ # Flatten the batch and frames dimensions
305
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
306
+ sample = sample.flatten(0, 1)
307
+ # control_cond: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
308
+ control_cond = control_cond.flatten(0, 1)
309
+ # Repeat the embeddings num_video_frames times
310
+ # emb: [batch, channels] -> [batch * frames, channels]
311
+ emb = emb.repeat_interleave(num_frames, dim=0)
312
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
313
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
314
+
315
+ # Process action features if provided
316
+ if action_type is not None:
317
+
318
+ # Embed action features
319
+ action_type = action_type.to(encoder_hidden_states.device, dtype=torch.long)
320
+
321
+ # Flatten action features to match the batch*frames dimension
322
+ # action_features: [batch, action_dim] -> [batch * frames, action_dim]
323
+ if action_type.dim() == 1:
324
+ action_type = action_type.unsqueeze(0)
325
+ action_type = action_type.repeat_interleave(num_frames, dim=0)
326
+
327
+ # Project action features to match the embedding dimension
328
+ action_features = self.action_embedding(action_type)
329
+ action_emb = self.action_proj(action_features)
330
+
331
+ # Add action embeddings to the encoder_hidden_states
332
+ # Make sure not to add action embeddings to masked hidden states
333
+ is_masked_cond = (encoder_hidden_states == 0).all(dim=2).unsqueeze(-1)
334
+ encoder_hidden_states = torch.where(is_masked_cond, encoder_hidden_states, encoder_hidden_states + action_emb)
335
+
336
+ # 2. pre-process
337
+ sample = self.conv_in(sample)
338
+ control_cond = self.control_conv_in(control_cond)
339
+ sample = sample + control_cond# * self.rz_weight
340
+
341
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
342
+
343
+ down_block_res_samples = (sample,)
344
+ for downsample_block in self.down_blocks:
345
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
346
+ sample, res_samples = downsample_block(
347
+ hidden_states=sample,
348
+ temb=emb,
349
+ encoder_hidden_states=encoder_hidden_states,
350
+ image_only_indicator=image_only_indicator,
351
+ )
352
+ else:
353
+ sample, res_samples = downsample_block(
354
+ hidden_states=sample,
355
+ temb=emb,
356
+ image_only_indicator=image_only_indicator,
357
+ )
358
+
359
+ down_block_res_samples += res_samples
360
+
361
+ # 4. mid
362
+ sample = self.mid_block(
363
+ hidden_states=sample,
364
+ temb=emb,
365
+ encoder_hidden_states=encoder_hidden_states,
366
+ image_only_indicator=image_only_indicator,
367
+ )
368
+
369
+ # 5. Control net blocks
370
+
371
+ controlnet_down_block_res_samples = ()
372
+
373
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
374
+ down_block_res_sample = controlnet_block(down_block_res_sample)
375
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
376
+
377
+ down_block_res_samples = controlnet_down_block_res_samples
378
+
379
+ mid_block_res_sample = self.controlnet_mid_block(sample)
380
+
381
+ # 6. scaling
382
+
383
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
384
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
385
+
386
+ if not return_dict:
387
+ return (down_block_res_samples, mid_block_res_sample)
388
+
389
+ return ControlNetOutput(
390
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
391
+ )
src/models/unet_spatio_temporal_condition.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from diffusers.loaders import PeftAdapterMixin
3
+ from diffusers import ModelMixin
4
+
5
+ from diffusers import UNetSpatioTemporalConditionModel as UNetSpatioTemporalConditionModel_orig
6
+ import torch
7
+ from einops import rearrange
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
11
+
12
+ # NOTE: Only added ModelMixin to make it compatible with some older version of diffusers when using from_pretrained
13
+ class UNetSpatioTemporalConditionModel(UNetSpatioTemporalConditionModel_orig, PeftAdapterMixin, ModelMixin):
14
+
15
+ def enable_grad(self, temporal_transformer_block=True, all=False):
16
+ parameters_list = []
17
+ for name, param in self.named_parameters():
18
+ if bool('temporal_transformer_block' in name and temporal_transformer_block) or all:
19
+ parameters_list.append(param)
20
+ param.requires_grad = True
21
+ else:
22
+ param.requires_grad = False
23
+ return parameters_list
24
+
25
+ def get_parameters_with_grad(self):
26
+ return [param for param in self.parameters() if param.requires_grad]
27
+
28
+ def forward(
29
+ self,
30
+ sample: torch.FloatTensor,
31
+ timestep: Union[torch.Tensor, float, int],
32
+ encoder_hidden_states: torch.Tensor,
33
+ added_time_ids: torch.Tensor,
34
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
35
+ mid_block_additional_residuals: Optional[torch.Tensor] = None,
36
+ return_dict: bool = True,
37
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
38
+ r"""
39
+ The [`UNetSpatioTemporalConditionModel`] forward method.
40
+
41
+ Args:
42
+ sample (`torch.FloatTensor`):
43
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
44
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
45
+ encoder_hidden_states (`torch.FloatTensor`):
46
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
47
+ added_time_ids: (`torch.FloatTensor`):
48
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
49
+ embeddings and added to the time embeddings.
50
+ return_dict (`bool`, *optional*, defaults to `True`):
51
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
52
+ of a plain tuple.
53
+ Returns:
54
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
55
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
56
+ returned, otherwise a `tuple` is returned where the first element is the sample tensor.
57
+ """
58
+ is_controlnet = mid_block_additional_residuals is not None and down_block_additional_residuals is not None
59
+
60
+ # 1. time
61
+ timesteps = timestep
62
+ if len(timesteps.shape) == 0:
63
+ timesteps = timesteps[None].to(sample.device)
64
+
65
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
66
+ batch_size, num_frames = sample.shape[:2]
67
+ timesteps = timesteps.expand(batch_size)
68
+
69
+ t_emb = self.time_proj(timesteps)
70
+
71
+ # `Timesteps` does not contain any weights and will always return f32 tensors
72
+ # but time_embedding might actually be running in fp16. so we need to cast here.
73
+ # there might be better ways to encapsulate this.
74
+ t_emb = t_emb.to(dtype=sample.dtype)
75
+
76
+ emb = self.time_embedding(t_emb)
77
+
78
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
79
+ time_embeds = time_embeds.reshape((batch_size, -1))
80
+ time_embeds = time_embeds.to(emb.dtype)
81
+ aug_emb = self.add_embedding(time_embeds)
82
+ emb = emb + aug_emb
83
+
84
+ # Flatten the batch and frames dimensions
85
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
86
+ sample = sample.flatten(0, 1)
87
+ # Repeat the embeddings num_video_frames times
88
+ # emb: [batch, channels] -> [batch * frames, channels]
89
+ emb = emb.repeat_interleave(num_frames, dim=0)
90
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
91
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
92
+
93
+ # 2. pre-process
94
+ sample = self.conv_in(sample)
95
+
96
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
97
+
98
+ down_block_res_samples = (sample,)
99
+ for downsample_block in self.down_blocks:
100
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
101
+ sample, res_samples = downsample_block(
102
+ hidden_states=sample,
103
+ temb=emb,
104
+ encoder_hidden_states=encoder_hidden_states,
105
+ image_only_indicator=image_only_indicator,
106
+ )
107
+ else:
108
+ sample, res_samples = downsample_block(
109
+ hidden_states=sample,
110
+ temb=emb,
111
+ image_only_indicator=image_only_indicator,
112
+ )
113
+
114
+ down_block_res_samples += res_samples
115
+
116
+ if is_controlnet:
117
+ new_down_block_res_samples = ()
118
+ for down_block_res_sample, down_block_additional_residual in zip(
119
+ down_block_res_samples, down_block_additional_residuals
120
+ ):
121
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
122
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
123
+
124
+ down_block_res_samples = new_down_block_res_samples
125
+
126
+ # 4. mid
127
+ sample = self.mid_block(
128
+ hidden_states=sample,
129
+ temb=emb,
130
+ encoder_hidden_states=encoder_hidden_states,
131
+ image_only_indicator=image_only_indicator,
132
+ )
133
+ if is_controlnet:
134
+ sample = sample + mid_block_additional_residuals
135
+
136
+ # 5. up
137
+ for i, upsample_block in enumerate(self.up_blocks):
138
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
139
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
140
+
141
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
142
+ sample = upsample_block(
143
+ hidden_states=sample,
144
+ temb=emb,
145
+ res_hidden_states_tuple=res_samples,
146
+ encoder_hidden_states=encoder_hidden_states,
147
+ image_only_indicator=image_only_indicator,
148
+ )
149
+ else:
150
+ sample = upsample_block(
151
+ hidden_states=sample,
152
+ temb=emb,
153
+ res_hidden_states_tuple=res_samples,
154
+ image_only_indicator=image_only_indicator,
155
+ )
156
+
157
+ # 6. post-process
158
+ sample = self.conv_norm_out(sample)
159
+ sample = self.conv_act(sample)
160
+ sample = self.conv_out(sample)
161
+
162
+ # 7. Reshape back to original shape
163
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
164
+
165
+ if not return_dict:
166
+ return (sample,)
167
+
168
+ return UNetSpatioTemporalConditionOutput(sample=sample)
169
+
src/pipelines/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .pipeline_video_control import StableVideoControlPipeline
2
+ from .pipeline_video_diffusion import VideoDiffusionPipeline
3
+ from .pipeline_video_control_nullmodel import StableVideoControlNullModelPipeline
4
+ from .pipeline_video_control_factor_guidance import StableVideoControlFactorGuidancePipeline
src/pipelines/pipeline_video_control.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Callable, Dict, List, Optional, Union
3
+ import PIL.Image
4
+ from einops import rearrange
5
+ import time
6
+
7
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
8
+ tensor2vid,
9
+ StableVideoDiffusionPipelineOutput,
10
+ _append_dims,
11
+ EXAMPLE_DOC_STRING
12
+ )
13
+ from diffusers import StableVideoDiffusionPipeline as StableVideoDiffusionPipeline_original
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.utils import logging, replace_example_docstring
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+
18
+ from src.models import UNetSpatioTemporalConditionModel, ControlNetModel
19
+ from diffusers.models import AutoencoderKLTemporalDecoder
20
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
21
+ from diffusers import EulerDiscreteScheduler
22
+ from diffusers.image_processor import VaeImageProcessor
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+ class StableVideoControlPipeline(StableVideoDiffusionPipeline_original):
27
+
28
+ def __init__(
29
+ self,
30
+ vae: AutoencoderKLTemporalDecoder,
31
+ image_encoder: CLIPVisionModelWithProjection,
32
+ unet: UNetSpatioTemporalConditionModel,
33
+ controlnet: ControlNetModel,
34
+ scheduler: EulerDiscreteScheduler,
35
+ feature_extractor: CLIPImageProcessor,
36
+ null_model: UNetSpatioTemporalConditionModel = None
37
+ ):
38
+ # calling the super class constructors without calling StableVideoDiffusionPipeline_original's
39
+ DiffusionPipeline.__init__(self)
40
+
41
+ self.register_modules(
42
+ vae=vae,
43
+ image_encoder=image_encoder,
44
+ controlnet=controlnet,
45
+ unet=unet,
46
+ scheduler=scheduler,
47
+ feature_extractor=feature_extractor,
48
+ null_model=null_model,
49
+ )
50
+
51
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
52
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
53
+
54
+ def check_inputs(self, image, cond_images, height, width):
55
+ if (
56
+ not isinstance(image, torch.Tensor)
57
+ and not isinstance(image, PIL.Image.Image)
58
+ and not isinstance(image, list)
59
+ ):
60
+ raise ValueError(
61
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
62
+ f" {type(image)}"
63
+ )
64
+ if not isinstance(cond_images, torch.Tensor):
65
+ raise ValueError(
66
+ "`cond_images` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
67
+ f" {type(cond_images)}"
68
+ )
69
+
70
+ if height % 8 != 0 or width % 8 != 0:
71
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
72
+
73
+
74
+ def _encode_vae_condition(
75
+ self,
76
+ cond_image: torch.tensor,
77
+ device: Union[str, torch.device],
78
+ num_videos_per_prompt: int,
79
+ do_classifier_free_guidance: bool,
80
+ bbox_mask_frames: List[bool] = None
81
+ ):
82
+ video_length = cond_image.shape[1]
83
+ cond_image = cond_image.to(device=device)
84
+ cond_image = cond_image.to(dtype=self.vae.dtype)
85
+
86
+ if cond_image.shape[2] == 3:
87
+ cond_image = rearrange(cond_image, "b f c h w -> (b f) c h w")
88
+ cond_em = self.vae.encode(cond_image).latent_dist.mode()
89
+ cond_em = rearrange(cond_em, "(b f) c h w -> b f c h w", f=video_length)
90
+ else:
91
+ assert cond_image.shape[2] == 4, "The input tensor should have 3 or 4 channels. 3 for frames and 4 for latents."
92
+ cond_em = cond_image
93
+
94
+ # duplicate cond_em for each generation per prompt, using mps friendly method
95
+ cond_em = cond_em.repeat(num_videos_per_prompt, 1, 1, 1, 1)
96
+
97
+ # Bbox conditioning masking during inference (requiring the model to predict behaviour instead)
98
+ if bbox_mask_frames is not None:
99
+ mask_cond = torch.tensor(bbox_mask_frames, device=cond_em.device).view(num_videos_per_prompt, video_length, 1, 1, 1)
100
+ null_embedding = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
101
+ cond_em = torch.where(mask_cond, null_embedding, cond_em)
102
+
103
+ if do_classifier_free_guidance:
104
+ # negative_cond_em = torch.zeros_like(cond_em)
105
+ negative_cond_em = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
106
+
107
+ # For classifier free guidance, we need to do two forward passes.
108
+ # Here we concatenate the unconditional and text embeddings into a single batch
109
+ # to avoid doing two forward passes
110
+ cond_em = torch.cat([negative_cond_em, cond_em])
111
+
112
+ return cond_em
113
+
114
+ @property
115
+ def do_classifier_free_guidance(self):
116
+ # Don't do the normal CFG when using null model. The null model will take care of computing the unconditional noise
117
+ if self.null_model is not None:
118
+ return False
119
+ if isinstance(self.guidance_scale, (int, float)):
120
+ return self.guidance_scale > 1
121
+ return self.guidance_scale.max() > 1
122
+
123
+ @torch.no_grad()
124
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
125
+ def __call__(
126
+ self,
127
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
128
+ cond_images: torch.FloatTensor = None,
129
+ bbox_mask_frames: List[bool] = None,
130
+ action_type: torch.FloatTensor = None,
131
+ height: int = 576,
132
+ width: int = 1024,
133
+ num_frames: Optional[int] = None,
134
+ num_inference_steps: int = 25,
135
+ min_guidance_scale: float = 1.0,
136
+ max_guidance_scale: float = 3.0,
137
+ control_condition_scale: float=1.0,
138
+ fps: int = 7,
139
+ motion_bucket_id: int = 127,
140
+ noise_aug_strength: float = 0.02,
141
+ decode_chunk_size: Optional[int] = None,
142
+ num_videos_per_prompt: Optional[int] = 1,
143
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
144
+ latents: Optional[torch.FloatTensor] = None,
145
+ output_type: Optional[str] = "pil",
146
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
147
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
148
+ return_dict: bool = True,
149
+ ):
150
+ r"""
151
+ The call function to the pipeline for generation.
152
+
153
+ Args:
154
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
155
+ Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
156
+ 1]`.
157
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
158
+ The height in pixels of the generated image.
159
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
160
+ The width in pixels of the generated image.
161
+ num_frames (`int`, *optional*):
162
+ The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
163
+ `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
164
+ num_inference_steps (`int`, *optional*, defaults to 25):
165
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
166
+ expense of slower inference. This parameter is modulated by `strength`.
167
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
168
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
169
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
170
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
171
+ fps (`int`, *optional*, defaults to 7):
172
+ Frames per second. The rate at which the generated images shall be exported to a video after
173
+ generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
174
+ motion_bucket_id (`int`, *optional*, defaults to 127):
175
+ Used for conditioning the amount of motion for the generation. The higher the number the more motion
176
+ will be in the video.
177
+ noise_aug_strength (`float`, *optional*, defaults to 0.02):
178
+ The amount of noise added to the init image, the higher it is the less the video will look like the
179
+ init image. Increase it for more motion.
180
+ action_type (`torch.FloatTensor`, *optional*, defaults to None):
181
+ The action type to condition the generation. These features are used by the ControlNet
182
+ to influence the generation process. The features should be of shape `[batch_size, 1]`.
183
+ decode_chunk_size (`int`, *optional*):
184
+ The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
185
+ expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
186
+ For lower memory usage, reduce `decode_chunk_size`.
187
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
188
+ The number of videos to generate per prompt.
189
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
190
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
191
+ generation deterministic.
192
+ latents (`torch.FloatTensor`, *optional*):
193
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
194
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
195
+ tensor is generated by sampling using the supplied random `generator`.
196
+ output_type (`str`, *optional*, defaults to `"pil"`):
197
+ The output format of the generated image. Choose between `pil`, `np` or `pt`.
198
+ callback_on_step_end (`Callable`, *optional*):
199
+ A function that is called at the end of each denoising step during inference. The function is called
200
+ with the following arguments:
201
+ `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
202
+ `callback_kwargs` will include a list of all tensors as specified by
203
+ `callback_on_step_end_tensor_inputs`.
204
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
205
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
206
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
207
+ `._callback_tensor_inputs` attribute of your pipeline class.
208
+ return_dict (`bool`, *optional*, defaults to `True`):
209
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
210
+ plain tuple.
211
+
212
+ Examples:
213
+
214
+ Returns:
215
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
216
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
217
+ returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`)
218
+ is returned.
219
+ """
220
+ # 0. Default height and width to unet
221
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
222
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
223
+
224
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
225
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
226
+
227
+ # 1. Check inputs. Raise error if not correct
228
+ self.check_inputs(image, cond_images, height, width)
229
+
230
+ # 2. Define call parameters
231
+ if isinstance(image, PIL.Image.Image):
232
+ batch_size = 1
233
+ elif isinstance(image, list):
234
+ batch_size = len(image)
235
+ else:
236
+ batch_size = image.shape[0]
237
+ device = self._execution_device
238
+ vae_device = self.vae.device
239
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
240
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
241
+ # corresponds to doing no classifier free guidance.
242
+ self._guidance_scale = max_guidance_scale
243
+
244
+ # 3. Encode input image
245
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
246
+
247
+ # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
248
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
249
+ fps = fps - 1
250
+
251
+ # 4. Encode input image using VAE
252
+ image = self.image_processor.preprocess(image, height=height, width=width).to(device)
253
+ noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
254
+ image = image + noise_aug_strength * noise
255
+
256
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
257
+ if needs_upcasting:
258
+ self.vae.to(dtype=torch.float32)
259
+
260
+ image_latents = self._encode_vae_image(
261
+ image,
262
+ device=device,
263
+ num_videos_per_prompt=num_videos_per_prompt,
264
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
265
+ )
266
+ image_latents = image_latents.to(image_embeddings.dtype)
267
+
268
+ # Repeat the image latents for each frame so we can concatenate them with the noise
269
+ # image_latents [batch, channels, height, width] -> [batch, num_frames, channels, height, width]
270
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
271
+ # 5. Get Added Time IDs
272
+ added_time_ids = self._get_add_time_ids(
273
+ fps,
274
+ motion_bucket_id,
275
+ noise_aug_strength,
276
+ image_embeddings.dtype,
277
+ batch_size,
278
+ num_videos_per_prompt,
279
+ self.do_classifier_free_guidance,
280
+ )
281
+ added_time_ids = added_time_ids.to(device)
282
+
283
+ # 6. Prepare timesteps
284
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
285
+ timesteps = self.scheduler.timesteps
286
+
287
+ # 7a. Prepare latent variables
288
+ num_channels_latents = self.unet.config.out_channels*2
289
+ latents = self.prepare_latents(
290
+ batch_size * num_videos_per_prompt,
291
+ num_frames,
292
+ num_channels_latents,
293
+ height,
294
+ width,
295
+ image_embeddings.dtype,
296
+ device,
297
+ generator,
298
+ latents,
299
+ )
300
+
301
+ # 7b. Prepare control latent embeds
302
+ if not cond_images is None:
303
+ cond_em = self._encode_vae_condition(cond_images,
304
+ device,
305
+ num_videos_per_prompt,
306
+ self.do_classifier_free_guidance,
307
+ bbox_mask_frames=bbox_mask_frames)
308
+ cond_em = cond_em.to(image_embeddings.dtype)
309
+ else:
310
+ cond_em = None
311
+
312
+ # 7c. Prepare action features
313
+ if not action_type is None:
314
+ if self.do_classifier_free_guidance:
315
+ action_type = torch.cat([torch.zeros_like(action_type).unsqueeze(0), action_type.unsqueeze(0)])
316
+ else:
317
+ action_type = None
318
+
319
+ # 8. Prepare guidance scale
320
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
321
+ guidance_scale = guidance_scale.to(device, latents.dtype)
322
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
323
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
324
+
325
+ self._guidance_scale = guidance_scale
326
+
327
+ # 9. Denoising loop
328
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
329
+ self._num_timesteps = len(timesteps)
330
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
331
+ for i, t in enumerate(timesteps):
332
+ # expand the latents if we are doing classifier free guidance
333
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
334
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
335
+
336
+ # Concatenate image_latents over channels dimension
337
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
338
+ down_block_additional_residuals, mid_block_additional_residuals = self.controlnet(
339
+ latent_model_input,
340
+ timestep=t,
341
+ encoder_hidden_states=image_embeddings,
342
+ added_time_ids=added_time_ids,
343
+ control_cond=cond_em,
344
+ action_type=action_type,
345
+ conditioning_scale=control_condition_scale,
346
+ return_dict=False,
347
+ )
348
+
349
+ # predict the noise residual
350
+ noise_pred = self.unet(
351
+ sample=latent_model_input,
352
+ timestep=t,
353
+ encoder_hidden_states=image_embeddings,
354
+ added_time_ids=added_time_ids,
355
+ down_block_additional_residuals=down_block_additional_residuals,
356
+ mid_block_additional_residuals=mid_block_additional_residuals,
357
+ return_dict=False,
358
+ )[0]
359
+
360
+ # Predict unconditional noise
361
+ if self.null_model is not None:
362
+ t = time.time()
363
+ noise_pred_uncond = self.null_model(
364
+ latent_model_input,
365
+ t,
366
+ encoder_hidden_states=image_embeddings,
367
+ added_time_ids=added_time_ids,
368
+ return_dict=False,
369
+ )[0]
370
+ print(f"Computed uncond noise in: {time.time()-t:.4f}s")
371
+
372
+ # perform guidance
373
+ if self.do_classifier_free_guidance:
374
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
375
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
376
+ elif self.null_model is not None:
377
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred - noise_pred_uncond)
378
+
379
+ # compute the previous noisy sample x_t -> x_t-1
380
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
381
+
382
+ if callback_on_step_end is not None:
383
+ callback_kwargs = {}
384
+ for k in callback_on_step_end_tensor_inputs:
385
+ callback_kwargs[k] = locals()[k]
386
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
387
+
388
+ latents = callback_outputs.pop("latents", latents)
389
+
390
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
391
+ progress_bar.update()
392
+
393
+ if not output_type == "latent":
394
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
395
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
396
+ else:
397
+ frames = latents
398
+
399
+ # cast back to fp16 if needed
400
+ if needs_upcasting:
401
+ self.vae.to(dtype=torch.float16)
402
+
403
+ self.maybe_free_model_hooks()
404
+
405
+ if not return_dict:
406
+ return frames
407
+
408
+ return StableVideoDiffusionPipelineOutput(frames=frames)
src/pipelines/pipeline_video_control_factor_guidance.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Callable, Dict, List, Optional, Union
3
+ import PIL.Image
4
+ from einops import rearrange
5
+ import numpy as np
6
+
7
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
8
+ tensor2vid,
9
+ StableVideoDiffusionPipelineOutput,
10
+ _append_dims,
11
+ EXAMPLE_DOC_STRING
12
+ )
13
+ from diffusers import StableVideoDiffusionPipeline as StableVideoDiffusionPipeline_original
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.utils import logging, replace_example_docstring
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+
18
+ from src.models import UNetSpatioTemporalConditionModel, ControlNetModel
19
+ from diffusers.models import AutoencoderKLTemporalDecoder
20
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
21
+ from diffusers import EulerDiscreteScheduler
22
+ from diffusers.image_processor import VaeImageProcessor
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ PipelineImageInput = Union[
28
+ PIL.Image.Image,
29
+ np.ndarray,
30
+ torch.FloatTensor,
31
+ List[PIL.Image.Image],
32
+ List[np.ndarray],
33
+ List[torch.FloatTensor],
34
+ ]
35
+
36
+ class StableVideoControlFactorGuidancePipeline(StableVideoDiffusionPipeline_original):
37
+
38
+ def __init__(
39
+ self,
40
+ vae: AutoencoderKLTemporalDecoder,
41
+ image_encoder: CLIPVisionModelWithProjection,
42
+ unet: UNetSpatioTemporalConditionModel,
43
+ controlnet: ControlNetModel,
44
+ scheduler: EulerDiscreteScheduler,
45
+ feature_extractor: CLIPImageProcessor,
46
+ null_model: UNetSpatioTemporalConditionModel,
47
+ ):
48
+ # calling the super class constructors without calling StableVideoDiffusionPipeline_original's
49
+ DiffusionPipeline.__init__(self)
50
+
51
+ self.register_modules(
52
+ vae=vae,
53
+ image_encoder=image_encoder,
54
+ controlnet=controlnet,
55
+ unet=unet,
56
+ scheduler=scheduler,
57
+ feature_extractor=feature_extractor,
58
+ null_model=null_model,
59
+ )
60
+
61
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
62
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
63
+
64
+ def check_inputs(self, image, cond_images, height, width):
65
+ if (
66
+ not isinstance(image, torch.Tensor)
67
+ and not isinstance(image, PIL.Image.Image)
68
+ and not isinstance(image, list)
69
+ ):
70
+ raise ValueError(
71
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
72
+ f" {type(image)}"
73
+ )
74
+ if not isinstance(cond_images, torch.Tensor):
75
+ raise ValueError(
76
+ "`cond_images` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
77
+ f" {type(cond_images)}"
78
+ )
79
+
80
+ if height % 8 != 0 or width % 8 != 0:
81
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
82
+
83
+ def _get_add_time_ids(
84
+ self,
85
+ fps: int,
86
+ motion_bucket_id: int,
87
+ noise_aug_strength: float,
88
+ dtype: torch.dtype,
89
+ batch_size: int,
90
+ num_videos_per_prompt: int,
91
+ ):
92
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
93
+
94
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
95
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
96
+
97
+ if expected_add_embed_dim != passed_add_embed_dim:
98
+ raise ValueError(
99
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
100
+ )
101
+
102
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
103
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
104
+
105
+ return add_time_ids
106
+
107
+ def _encode_image(
108
+ self,
109
+ image: PipelineImageInput,
110
+ device: Union[str, torch.device],
111
+ num_videos_per_prompt: int,
112
+ ) -> torch.FloatTensor:
113
+ dtype = next(self.image_encoder.parameters()).dtype
114
+
115
+ if not isinstance(image, torch.Tensor):
116
+ image = self.image_processor.pil_to_numpy(image)
117
+ image = self.image_processor.numpy_to_pt(image)
118
+
119
+ # We normalize the image before resizing to match with the original implementation.
120
+ # Then we unnormalize it after resizing.
121
+ image = image * 2.0 - 1.0
122
+ image = _resize_with_antialiasing(image, (224, 224))
123
+ image = (image + 1.0) / 2.0
124
+
125
+ # Normalize the image with for CLIP input
126
+ image = self.feature_extractor(
127
+ images=image,
128
+ do_normalize=True,
129
+ do_center_crop=False,
130
+ do_resize=False,
131
+ do_rescale=False,
132
+ return_tensors="pt",
133
+ ).pixel_values
134
+
135
+ image = image.to(device=device, dtype=dtype)
136
+ image_embeddings = self.image_encoder(image).image_embeds
137
+ image_embeddings = image_embeddings.unsqueeze(1)
138
+
139
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
140
+ bs_embed, seq_len, _ = image_embeddings.shape
141
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
142
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
143
+
144
+ null_image_embeddings = torch.zeros_like(image_embeddings)
145
+
146
+ return image_embeddings, null_image_embeddings
147
+
148
+ def _encode_vae_image(
149
+ self,
150
+ image: torch.Tensor,
151
+ device: Union[str, torch.device],
152
+ num_videos_per_prompt: int,
153
+ ):
154
+ image = image.to(device=device)
155
+ image_latents = self.vae.encode(image).latent_dist.mode()
156
+
157
+ # duplicate image_latents for each generation per prompt, using mps friendly method
158
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
159
+ null_image_latents = torch.zeros_like(image_latents)
160
+
161
+ return image_latents, null_image_latents
162
+
163
+ def _encode_vae_condition(
164
+ self,
165
+ cond_image: torch.tensor,
166
+ device: Union[str, torch.device],
167
+ num_videos_per_prompt: int,
168
+ bbox_mask_frames: List[bool] = None
169
+ ):
170
+ video_length = cond_image.shape[1]
171
+ cond_image = cond_image.to(device=device)
172
+ cond_image = cond_image.to(dtype=self.vae.dtype)
173
+
174
+ if cond_image.shape[2] == 3:
175
+ cond_image = rearrange(cond_image, "b f c h w -> (b f) c h w")
176
+ cond_em = self.vae.encode(cond_image).latent_dist.mode()
177
+ cond_em = rearrange(cond_em, "(b f) c h w -> b f c h w", f=video_length)
178
+ else:
179
+ assert cond_image.shape[2] == 4, "The input tensor should have 3 or 4 channels. 3 for frames and 4 for latents."
180
+ cond_em = cond_image
181
+
182
+ # duplicate cond_em for each generation per prompt, using mps friendly method
183
+ cond_em = cond_em.repeat(num_videos_per_prompt, 1, 1, 1, 1)
184
+
185
+ # Bbox conditioning masking during inference (requiring the model to predict behaviour instead)
186
+ if bbox_mask_frames is not None:
187
+ mask_cond = torch.tensor(bbox_mask_frames, device=cond_em.device).view(num_videos_per_prompt, video_length, 1, 1, 1)
188
+ null_embedding = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
189
+ cond_em = torch.where(mask_cond, null_embedding, cond_em)
190
+
191
+ null_cond_em = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
192
+
193
+ return cond_em, null_cond_em
194
+
195
+
196
+ @torch.no_grad()
197
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
198
+ def __call__(
199
+ self,
200
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
201
+ cond_images: torch.FloatTensor = None,
202
+ bbox_mask_frames: List[bool] = None,
203
+ action_type: torch.FloatTensor = None,
204
+ height: int = 576,
205
+ width: int = 1024,
206
+ num_frames: Optional[int] = None,
207
+ num_inference_steps: int = 25,
208
+ min_guidance_scale_img: float = 1.0,
209
+ max_guidance_scale_img: float = 3.0,
210
+ min_guidance_scale_action: float = 1.0,
211
+ max_guidance_scale_action: float = 3.0,
212
+ min_guidance_scale_bbox: float = 1.0,
213
+ max_guidance_scale_bbox: float = 3.0,
214
+ control_condition_scale: float=1.0,
215
+ fps: int = 7,
216
+ motion_bucket_id: int = 127,
217
+ noise_aug_strength: float = 0.02,
218
+ decode_chunk_size: Optional[int] = None,
219
+ num_videos_per_prompt: Optional[int] = 1,
220
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
221
+ latents: Optional[torch.FloatTensor] = None,
222
+ output_type: Optional[str] = "pil",
223
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
224
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
225
+ return_dict: bool = True,
226
+ ):
227
+ r"""
228
+ The call function to the pipeline for generation.
229
+
230
+ Args:
231
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
232
+ Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
233
+ 1]`.
234
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
235
+ The height in pixels of the generated image.
236
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
237
+ The width in pixels of the generated image.
238
+ num_frames (`int`, *optional*):
239
+ The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
240
+ `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
241
+ num_inference_steps (`int`, *optional*, defaults to 25):
242
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
243
+ expense of slower inference. This parameter is modulated by `strength`.
244
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
245
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
246
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
247
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
248
+ fps (`int`, *optional*, defaults to 7):
249
+ Frames per second. The rate at which the generated images shall be exported to a video after
250
+ generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
251
+ motion_bucket_id (`int`, *optional*, defaults to 127):
252
+ Used for conditioning the amount of motion for the generation. The higher the number the more motion
253
+ will be in the video.
254
+ noise_aug_strength (`float`, *optional*, defaults to 0.02):
255
+ The amount of noise added to the init image, the higher it is the less the video will look like the
256
+ init image. Increase it for more motion.
257
+ action_type (`torch.FloatTensor`, *optional*, defaults to None):
258
+ The action type to condition the generation. These features are used by the ControlNet
259
+ to influence the generation process. The features should be of shape `[batch_size, 1]`.
260
+ decode_chunk_size (`int`, *optional*):
261
+ The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
262
+ expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
263
+ For lower memory usage, reduce `decode_chunk_size`.
264
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
265
+ The number of videos to generate per prompt.
266
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
267
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
268
+ generation deterministic.
269
+ latents (`torch.FloatTensor`, *optional*):
270
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
271
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
272
+ tensor is generated by sampling using the supplied random `generator`.
273
+ output_type (`str`, *optional*, defaults to `"pil"`):
274
+ The output format of the generated image. Choose between `pil`, `np` or `pt`.
275
+ callback_on_step_end (`Callable`, *optional*):
276
+ A function that is called at the end of each denoising step during inference. The function is called
277
+ with the following arguments:
278
+ `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
279
+ `callback_kwargs` will include a list of all tensors as specified by
280
+ `callback_on_step_end_tensor_inputs`.
281
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
282
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
283
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
284
+ `._callback_tensor_inputs` attribute of your pipeline class.
285
+ return_dict (`bool`, *optional*, defaults to `True`):
286
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
287
+ plain tuple.
288
+
289
+ Examples:
290
+
291
+ Returns:
292
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
293
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
294
+ returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`)
295
+ is returned.
296
+ """
297
+ # 0. Default height and width to unet
298
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
299
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
300
+
301
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
302
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
303
+
304
+ # 1. Check inputs. Raise error if not correct
305
+ self.check_inputs(image, cond_images, height, width)
306
+
307
+ # 2. Define call parameters
308
+ if isinstance(image, PIL.Image.Image):
309
+ batch_size = 1
310
+ elif isinstance(image, list):
311
+ batch_size = len(image)
312
+ else:
313
+ batch_size = image.shape[0]
314
+ device = self._execution_device
315
+ vae_device = self.vae.device
316
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
317
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
318
+ # corresponds to doing no classifier free guidance.
319
+ # self._guidance_scale = max_guidance_scale
320
+
321
+ # 3. Encode input image
322
+ image_embeddings, null_image_embeddings = self._encode_image(image, device, num_videos_per_prompt)
323
+
324
+ # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
325
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
326
+ fps = fps - 1
327
+
328
+ # 4. Encode input image using VAE
329
+ image = self.image_processor.preprocess(image, height=height, width=width).to(device)
330
+ noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
331
+ image = image + noise_aug_strength * noise
332
+
333
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
334
+ if needs_upcasting:
335
+ self.vae.to(dtype=torch.float32)
336
+
337
+ image_latents, null_image_latents = self._encode_vae_image(
338
+ image,
339
+ device=device,
340
+ num_videos_per_prompt=num_videos_per_prompt,
341
+ )
342
+ image_latents = image_latents.to(image_embeddings.dtype)
343
+ null_image_latents = null_image_latents.to(image_embeddings.dtype)
344
+
345
+ # Repeat the image latents for each frame so we can concatenate them with the noise
346
+ # image_latents [batch, channels, height, width] -> [batch, num_frames, channels, height, width]
347
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
348
+ null_image_latents = null_image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
349
+
350
+ # 5. Get Added Time IDs
351
+ added_time_ids = self._get_add_time_ids(
352
+ fps,
353
+ motion_bucket_id,
354
+ noise_aug_strength,
355
+ image_embeddings.dtype,
356
+ batch_size,
357
+ num_videos_per_prompt,
358
+ )
359
+ added_time_ids = added_time_ids.to(device)
360
+
361
+ # TODO: reshape time ids for factor guidance
362
+
363
+ # 6. Prepare timesteps
364
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
365
+ timesteps = self.scheduler.timesteps
366
+
367
+ # 7a. Prepare latent variables
368
+ num_channels_latents = self.unet.config.out_channels*2
369
+ latents = self.prepare_latents(
370
+ batch_size * num_videos_per_prompt,
371
+ num_frames,
372
+ num_channels_latents,
373
+ height,
374
+ width,
375
+ image_embeddings.dtype,
376
+ device,
377
+ generator,
378
+ latents,
379
+ )
380
+
381
+ # 7b. Prepare control latent embeds
382
+ if not cond_images is None:
383
+ cond_em, null_cond_em = self._encode_vae_condition(cond_images,
384
+ device,
385
+ num_videos_per_prompt,
386
+ bbox_mask_frames=bbox_mask_frames)
387
+ cond_em = cond_em.to(image_embeddings.dtype)
388
+ null_cond_em = null_cond_em.to(image_embeddings.dtype)
389
+ else:
390
+ cond_em = None
391
+ null_cond_em = None
392
+
393
+ # 7c. Prepare action features
394
+ if action_type is not None:
395
+ action_type, null_action_type = action_type.unsqueeze(0), torch.zeros_like(action_type).unsqueeze(0)
396
+ else:
397
+ action_type = None
398
+ null_action_type = None
399
+
400
+ # 8. Prepare guidance scales
401
+ guidance_scale_img = torch.linspace(min_guidance_scale_img, max_guidance_scale_img, num_frames).unsqueeze(0)
402
+ guidance_scale_img = guidance_scale_img.to(device, latents.dtype)
403
+ guidance_scale_img = guidance_scale_img.repeat(batch_size * num_videos_per_prompt, 1)
404
+ guidance_scale_img = _append_dims(guidance_scale_img, latents.ndim)
405
+
406
+ guidance_scale_action = torch.linspace(min_guidance_scale_action, max_guidance_scale_action, num_frames).unsqueeze(0)
407
+ guidance_scale_action = guidance_scale_action.to(device, latents.dtype)
408
+ guidance_scale_action = guidance_scale_action.repeat(batch_size * num_videos_per_prompt, 1)
409
+ guidance_scale_action = _append_dims(guidance_scale_action, latents.ndim)
410
+
411
+ guidance_scale_bbox = torch.linspace(min_guidance_scale_bbox, max_guidance_scale_bbox, num_frames).unsqueeze(0)
412
+ guidance_scale_bbox = guidance_scale_bbox.to(device, latents.dtype)
413
+ guidance_scale_bbox = guidance_scale_bbox.repeat(batch_size * num_videos_per_prompt, 1)
414
+ guidance_scale_bbox = _append_dims(guidance_scale_bbox, latents.ndim)
415
+
416
+ # Build the tensors to batch the different levels of conditioning (used for the factored CFG)
417
+ # [image_and_bbox_embeddings, image_and_bbox_and_action_embeddings]
418
+ image_embeddings = torch.cat([image_embeddings, image_embeddings])
419
+ image_latents = torch.cat([image_latents, image_latents])
420
+ cond_em = torch.cat([cond_em, cond_em])
421
+ action_type = torch.cat([null_action_type, action_type])
422
+ added_time_ids = torch.cat([added_time_ids] * 2)
423
+ latents = torch.cat([latents] * 2)
424
+
425
+ # 9. Denoising loop
426
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
427
+ self._num_timesteps = len(timesteps)
428
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
429
+ for i, t in enumerate(timesteps):
430
+ latent_model_input = latents
431
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
432
+
433
+ # print(latent_model_input.shape, image_latents.shape)
434
+
435
+ # Concatenate image_latents over channels dimension
436
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
437
+ down_block_additional_residuals, mid_block_additional_residuals = self.controlnet(
438
+ latent_model_input,
439
+ timestep=t,
440
+ encoder_hidden_states=image_embeddings,
441
+ added_time_ids=added_time_ids,
442
+ control_cond=cond_em,
443
+ action_type=action_type,
444
+ conditioning_scale=control_condition_scale,
445
+ return_dict=False,
446
+ )
447
+
448
+ # predict the noise residual
449
+ noise_pred = self.unet(
450
+ sample=latent_model_input,
451
+ timestep=t,
452
+ encoder_hidden_states=image_embeddings,
453
+ added_time_ids=added_time_ids,
454
+ down_block_additional_residuals=down_block_additional_residuals,
455
+ mid_block_additional_residuals=mid_block_additional_residuals,
456
+ return_dict=False,
457
+ )[0]
458
+
459
+ # Predict unconditional noise
460
+ noise_pred_cond_img = self.null_model(
461
+ latent_model_input,
462
+ t,
463
+ encoder_hidden_states=image_embeddings,
464
+ added_time_ids=added_time_ids,
465
+ return_dict=False,
466
+ )[0]
467
+
468
+ # Perform factored CFG
469
+ # NOTE: Currently discarding the unconditional noise prediction from the finetuned model
470
+ noise_pred_cond_img_bbox, noise_pred_cond_all = noise_pred.chunk(2)
471
+
472
+ # NOTE: `noise_pred_uncond` is technically the same as `noise_pred_cond_img` since they both condition on the image.
473
+ # Therefore we could probably remove `noise_pred_cond_img` and get similar performances and faster inference
474
+ noise_pred = noise_pred_cond_img \
475
+ + guidance_scale_bbox * (noise_pred_cond_img_bbox - noise_pred_cond_img) \
476
+ + guidance_scale_action * (noise_pred_cond_all - noise_pred_cond_img_bbox)
477
+
478
+ # compute the previous noisy sample x_t -> x_t-1
479
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
480
+ # print("latents", latents.shape)
481
+
482
+ if callback_on_step_end is not None:
483
+ callback_kwargs = {}
484
+ for k in callback_on_step_end_tensor_inputs:
485
+ callback_kwargs[k] = locals()[k]
486
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
487
+
488
+ latents = callback_outputs.pop("latents", latents)
489
+
490
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
491
+ progress_bar.update()
492
+
493
+ if not output_type == "latent":
494
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
495
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
496
+ else:
497
+ frames = latents
498
+
499
+ # cast back to fp16 if needed
500
+ if needs_upcasting:
501
+ self.vae.to(dtype=torch.float16)
502
+
503
+ self.maybe_free_model_hooks()
504
+
505
+ if not return_dict:
506
+ return frames
507
+
508
+ return StableVideoDiffusionPipelineOutput(frames=frames)
509
+
510
+
511
+ # TODO: Some helper functions from Stable Video Diffusion that we could move elsewhere
512
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
513
+ h, w = input.shape[-2:]
514
+ factors = (h / size[0], w / size[1])
515
+
516
+ # First, we have to determine sigma
517
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
518
+ sigmas = (
519
+ max((factors[0] - 1.0) / 2.0, 0.001),
520
+ max((factors[1] - 1.0) / 2.0, 0.001),
521
+ )
522
+
523
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
524
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
525
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
526
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
527
+
528
+ # Make sure it is odd
529
+ if (ks[0] % 2) == 0:
530
+ ks = ks[0] + 1, ks[1]
531
+
532
+ if (ks[1] % 2) == 0:
533
+ ks = ks[0], ks[1] + 1
534
+
535
+ input = _gaussian_blur2d(input, ks, sigmas)
536
+
537
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
538
+ return output
539
+
540
+
541
+ def _compute_padding(kernel_size):
542
+ """Compute padding tuple."""
543
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
544
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
545
+ if len(kernel_size) < 2:
546
+ raise AssertionError(kernel_size)
547
+ computed = [k - 1 for k in kernel_size]
548
+
549
+ # for even kernels we need to do asymmetric padding :(
550
+ out_padding = 2 * len(kernel_size) * [0]
551
+
552
+ for i in range(len(kernel_size)):
553
+ computed_tmp = computed[-(i + 1)]
554
+
555
+ pad_front = computed_tmp // 2
556
+ pad_rear = computed_tmp - pad_front
557
+
558
+ out_padding[2 * i + 0] = pad_front
559
+ out_padding[2 * i + 1] = pad_rear
560
+
561
+ return out_padding
562
+
563
+ def _filter2d(input, kernel):
564
+ # prepare kernel
565
+ b, c, h, w = input.shape
566
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
567
+
568
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
569
+
570
+ height, width = tmp_kernel.shape[-2:]
571
+
572
+ padding_shape: list[int] = _compute_padding([height, width])
573
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
574
+
575
+ # kernel and input tensor reshape to align element-wise or batch-wise params
576
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
577
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
578
+
579
+ # convolve the tensor with the kernel.
580
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
581
+
582
+ out = output.view(b, c, h, w)
583
+ return out
584
+
585
+
586
+ def _gaussian(window_size: int, sigma):
587
+ if isinstance(sigma, float):
588
+ sigma = torch.tensor([[sigma]])
589
+
590
+ batch_size = sigma.shape[0]
591
+
592
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
593
+
594
+ if window_size % 2 == 0:
595
+ x = x + 0.5
596
+
597
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
598
+
599
+ return gauss / gauss.sum(-1, keepdim=True)
600
+
601
+
602
+ def _gaussian_blur2d(input, kernel_size, sigma):
603
+ if isinstance(sigma, tuple):
604
+ sigma = torch.tensor([sigma], dtype=input.dtype)
605
+ else:
606
+ sigma = sigma.to(dtype=input.dtype)
607
+
608
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
609
+ bs = sigma.shape[0]
610
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
611
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
612
+ out_x = _filter2d(input, kernel_x[..., None, :])
613
+ out = _filter2d(out_x, kernel_y[..., None])
614
+
615
+ return out
src/pipelines/pipeline_video_control_nullmodel.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Callable, Dict, List, Optional, Union
3
+ import PIL.Image
4
+ from einops import rearrange
5
+
6
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
7
+ tensor2vid,
8
+ StableVideoDiffusionPipelineOutput,
9
+ _append_dims,
10
+ EXAMPLE_DOC_STRING
11
+ )
12
+ from diffusers import StableVideoDiffusionPipeline as StableVideoDiffusionPipeline_original
13
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
14
+ from diffusers.utils import logging, replace_example_docstring
15
+ from diffusers.utils.torch_utils import randn_tensor
16
+
17
+ from src.models import UNetSpatioTemporalConditionModel, ControlNetModel
18
+ from diffusers.models import AutoencoderKLTemporalDecoder
19
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
20
+ from diffusers import EulerDiscreteScheduler
21
+ from diffusers.image_processor import VaeImageProcessor
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+ class StableVideoControlNullModelPipeline(StableVideoDiffusionPipeline_original):
26
+
27
+ def __init__(
28
+ self,
29
+ vae: AutoencoderKLTemporalDecoder,
30
+ image_encoder: CLIPVisionModelWithProjection,
31
+ unet: UNetSpatioTemporalConditionModel,
32
+ controlnet: ControlNetModel,
33
+ scheduler: EulerDiscreteScheduler,
34
+ feature_extractor: CLIPImageProcessor,
35
+ null_model: UNetSpatioTemporalConditionModel,
36
+ ):
37
+ # calling the super class constructors without calling StableVideoDiffusionPipeline_original's
38
+ DiffusionPipeline.__init__(self)
39
+
40
+ self.register_modules(
41
+ vae=vae,
42
+ image_encoder=image_encoder,
43
+ controlnet=controlnet,
44
+ unet=unet,
45
+ scheduler=scheduler,
46
+ feature_extractor=feature_extractor,
47
+ null_model=null_model,
48
+ )
49
+
50
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
51
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
52
+
53
+ def check_inputs(self, image, cond_images, height, width):
54
+ if (
55
+ not isinstance(image, torch.Tensor)
56
+ and not isinstance(image, PIL.Image.Image)
57
+ and not isinstance(image, list)
58
+ ):
59
+ raise ValueError(
60
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
61
+ f" {type(image)}"
62
+ )
63
+ if not isinstance(cond_images, torch.Tensor):
64
+ raise ValueError(
65
+ "`cond_images` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
66
+ f" {type(cond_images)}"
67
+ )
68
+
69
+ if height % 8 != 0 or width % 8 != 0:
70
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
71
+
72
+
73
+ def _encode_vae_condition(
74
+ self,
75
+ cond_image: torch.tensor,
76
+ device: Union[str, torch.device],
77
+ num_videos_per_prompt: int,
78
+ do_classifier_free_guidance: bool,
79
+ bbox_mask_frames: List[bool] = None
80
+ ):
81
+ video_length = cond_image.shape[1]
82
+ cond_image = cond_image.to(device=device)
83
+ cond_image = cond_image.to(dtype=self.vae.dtype)
84
+
85
+ if cond_image.shape[2] == 3:
86
+ cond_image = rearrange(cond_image, "b f c h w -> (b f) c h w")
87
+ cond_em = self.vae.encode(cond_image).latent_dist.mode()
88
+ cond_em = rearrange(cond_em, "(b f) c h w -> b f c h w", f=video_length)
89
+ else:
90
+ assert cond_image.shape[2] == 4, "The input tensor should have 3 or 4 channels. 3 for frames and 4 for latents."
91
+ cond_em = cond_image
92
+
93
+ # duplicate cond_em for each generation per prompt, using mps friendly method
94
+ cond_em = cond_em.repeat(num_videos_per_prompt, 1, 1, 1, 1)
95
+
96
+ # Bbox conditioning masking during inference (requiring the model to predict behaviour instead)
97
+ if bbox_mask_frames is not None:
98
+ mask_cond = torch.tensor(bbox_mask_frames, device=cond_em.device).view(num_videos_per_prompt, video_length, 1, 1, 1)
99
+ null_embedding = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
100
+ cond_em = torch.where(mask_cond, null_embedding, cond_em)
101
+
102
+ if do_classifier_free_guidance:
103
+ # negative_cond_em = torch.zeros_like(cond_em)
104
+ negative_cond_em = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
105
+
106
+ # For classifier free guidance, we need to do two forward passes.
107
+ # Here we concatenate the unconditional and text embeddings into a single batch
108
+ # to avoid doing two forward passes
109
+ cond_em = torch.cat([negative_cond_em, cond_em])
110
+
111
+ return cond_em
112
+
113
+ @property
114
+ def do_classifier_free_guidance(self):
115
+ return False
116
+ # if isinstance(self.guidance_scale, (int, float)):
117
+ # return self.guidance_scale > 1
118
+ # return self.guidance_scale.max() > 1
119
+
120
+ @torch.no_grad()
121
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
122
+ def __call__(
123
+ self,
124
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
125
+ cond_images: torch.FloatTensor = None,
126
+ bbox_mask_frames: List[bool] = None,
127
+ action_type: torch.FloatTensor = None,
128
+ height: int = 576,
129
+ width: int = 1024,
130
+ num_frames: Optional[int] = None,
131
+ num_inference_steps: int = 25,
132
+ min_guidance_scale: float = 1.0,
133
+ max_guidance_scale: float = 3.0,
134
+ control_condition_scale: float=1.0,
135
+ fps: int = 7,
136
+ motion_bucket_id: int = 127,
137
+ noise_aug_strength: float = 0.02,
138
+ decode_chunk_size: Optional[int] = None,
139
+ num_videos_per_prompt: Optional[int] = 1,
140
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
141
+ latents: Optional[torch.FloatTensor] = None,
142
+ output_type: Optional[str] = "pil",
143
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
144
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
145
+ return_dict: bool = True,
146
+ ):
147
+ r"""
148
+ The call function to the pipeline for generation.
149
+
150
+ Args:
151
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
152
+ Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
153
+ 1]`.
154
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
155
+ The height in pixels of the generated image.
156
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
157
+ The width in pixels of the generated image.
158
+ num_frames (`int`, *optional*):
159
+ The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
160
+ `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
161
+ num_inference_steps (`int`, *optional*, defaults to 25):
162
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
163
+ expense of slower inference. This parameter is modulated by `strength`.
164
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
165
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
166
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
167
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
168
+ fps (`int`, *optional*, defaults to 7):
169
+ Frames per second. The rate at which the generated images shall be exported to a video after
170
+ generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
171
+ motion_bucket_id (`int`, *optional*, defaults to 127):
172
+ Used for conditioning the amount of motion for the generation. The higher the number the more motion
173
+ will be in the video.
174
+ noise_aug_strength (`float`, *optional*, defaults to 0.02):
175
+ The amount of noise added to the init image, the higher it is the less the video will look like the
176
+ init image. Increase it for more motion.
177
+ action_type (`torch.FloatTensor`, *optional*, defaults to None):
178
+ The action type to condition the generation. These features are used by the ControlNet
179
+ to influence the generation process. The features should be of shape `[batch_size, 1]`.
180
+ decode_chunk_size (`int`, *optional*):
181
+ The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
182
+ expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
183
+ For lower memory usage, reduce `decode_chunk_size`.
184
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
185
+ The number of videos to generate per prompt.
186
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
187
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
188
+ generation deterministic.
189
+ latents (`torch.FloatTensor`, *optional*):
190
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
191
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
192
+ tensor is generated by sampling using the supplied random `generator`.
193
+ output_type (`str`, *optional*, defaults to `"pil"`):
194
+ The output format of the generated image. Choose between `pil`, `np` or `pt`.
195
+ callback_on_step_end (`Callable`, *optional*):
196
+ A function that is called at the end of each denoising step during inference. The function is called
197
+ with the following arguments:
198
+ `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
199
+ `callback_kwargs` will include a list of all tensors as specified by
200
+ `callback_on_step_end_tensor_inputs`.
201
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
202
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
203
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
204
+ `._callback_tensor_inputs` attribute of your pipeline class.
205
+ return_dict (`bool`, *optional*, defaults to `True`):
206
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
207
+ plain tuple.
208
+
209
+ Examples:
210
+
211
+ Returns:
212
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
213
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
214
+ returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`)
215
+ is returned.
216
+ """
217
+ # 0. Default height and width to unet
218
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
219
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
220
+
221
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
222
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
223
+
224
+ # 1. Check inputs. Raise error if not correct
225
+ self.check_inputs(image, cond_images, height, width)
226
+
227
+ # 2. Define call parameters
228
+ if isinstance(image, PIL.Image.Image):
229
+ batch_size = 1
230
+ elif isinstance(image, list):
231
+ batch_size = len(image)
232
+ else:
233
+ batch_size = image.shape[0]
234
+ device = self._execution_device
235
+ vae_device = self.vae.device
236
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
237
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
238
+ # corresponds to doing no classifier free guidance.
239
+ self._guidance_scale = max_guidance_scale
240
+
241
+ # 3. Encode input image
242
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
243
+
244
+ # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
245
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
246
+ fps = fps - 1
247
+
248
+ # 4. Encode input image using VAE
249
+ image = self.image_processor.preprocess(image, height=height, width=width).to(device)
250
+ noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
251
+ image = image + noise_aug_strength * noise
252
+
253
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
254
+ if needs_upcasting:
255
+ self.vae.to(dtype=torch.float32)
256
+
257
+ image_latents = self._encode_vae_image(
258
+ image,
259
+ device=device,
260
+ num_videos_per_prompt=num_videos_per_prompt,
261
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
262
+ )
263
+ image_latents = image_latents.to(image_embeddings.dtype)
264
+
265
+ # Repeat the image latents for each frame so we can concatenate them with the noise
266
+ # image_latents [batch, channels, height, width] -> [batch, num_frames, channels, height, width]
267
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
268
+ # 5. Get Added Time IDs
269
+ added_time_ids = self._get_add_time_ids(
270
+ fps,
271
+ motion_bucket_id,
272
+ noise_aug_strength,
273
+ image_embeddings.dtype,
274
+ batch_size,
275
+ num_videos_per_prompt,
276
+ self.do_classifier_free_guidance,
277
+ )
278
+ added_time_ids = added_time_ids.to(device)
279
+
280
+ # 6. Prepare timesteps
281
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
282
+ timesteps = self.scheduler.timesteps
283
+
284
+ # 7a. Prepare latent variables
285
+ num_channels_latents = self.unet.config.out_channels*2
286
+ latents = self.prepare_latents(
287
+ batch_size * num_videos_per_prompt,
288
+ num_frames,
289
+ num_channels_latents,
290
+ height,
291
+ width,
292
+ image_embeddings.dtype,
293
+ device,
294
+ generator,
295
+ latents,
296
+ )
297
+
298
+ # 7b. Prepare control latent embeds
299
+ if not cond_images is None:
300
+ cond_em = self._encode_vae_condition(cond_images,
301
+ device,
302
+ num_videos_per_prompt,
303
+ self.do_classifier_free_guidance,
304
+ bbox_mask_frames=bbox_mask_frames)
305
+ cond_em = cond_em.to(image_embeddings.dtype)
306
+ else:
307
+ cond_em = None
308
+
309
+ # 7c. Prepare action features
310
+ if not action_type is None:
311
+ if self.do_classifier_free_guidance:
312
+ action_type = torch.cat([torch.zeros_like(action_type).unsqueeze(0), action_type.unsqueeze(0)])
313
+ else:
314
+ action_type = None
315
+
316
+ # 8. Prepare guidance scale
317
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
318
+ guidance_scale = guidance_scale.to(device, latents.dtype)
319
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
320
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
321
+
322
+ self._guidance_scale = guidance_scale
323
+
324
+ # 9. Denoising loop
325
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
326
+ self._num_timesteps = len(timesteps)
327
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
328
+ for i, t in enumerate(timesteps):
329
+ # expand the latents if we are doing classifier free guidance
330
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
331
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
332
+
333
+ # print(latent_model_input.shape, image_latents.shape, self.do_classifier_free_guidance)
334
+
335
+ # Concatenate image_latents over channels dimension
336
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
337
+ # latent_model_input_null_model = latent_model_input.clone().detach()
338
+ down_block_additional_residuals, mid_block_additional_residuals = self.controlnet(
339
+ latent_model_input,
340
+ timestep=t,
341
+ encoder_hidden_states=image_embeddings,
342
+ added_time_ids=added_time_ids,
343
+ control_cond=cond_em,
344
+ action_type=action_type,
345
+ conditioning_scale=control_condition_scale,
346
+ return_dict=False,
347
+ )
348
+
349
+ # predict the noise residual
350
+ noise_pred = self.unet(
351
+ sample=latent_model_input,
352
+ timestep=t,
353
+ encoder_hidden_states=image_embeddings,
354
+ added_time_ids=added_time_ids,
355
+ down_block_additional_residuals=down_block_additional_residuals,
356
+ mid_block_additional_residuals=mid_block_additional_residuals,
357
+ return_dict=False,
358
+ )[0]
359
+
360
+ # Predict unconditional noise
361
+ noise_pred_uncond = self.null_model(
362
+ latent_model_input,
363
+ t,
364
+ encoder_hidden_states=image_embeddings,
365
+ added_time_ids=added_time_ids,
366
+ return_dict=False,
367
+ )[0]
368
+
369
+ # perform guidance
370
+ if self.do_classifier_free_guidance:
371
+ _, noise_pred_cond = noise_pred.chunk(2) # NOTE: Currently discarding the unconditional noise prediction from the finetuned model
372
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
373
+ else:
374
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred - noise_pred_uncond)
375
+
376
+ # compute the previous noisy sample x_t -> x_t-1
377
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
378
+ # print("latents", latents.shape)
379
+
380
+ if callback_on_step_end is not None:
381
+ callback_kwargs = {}
382
+ for k in callback_on_step_end_tensor_inputs:
383
+ callback_kwargs[k] = locals()[k]
384
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
385
+
386
+ latents = callback_outputs.pop("latents", latents)
387
+
388
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
389
+ progress_bar.update()
390
+
391
+ if not output_type == "latent":
392
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
393
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
394
+ else:
395
+ frames = latents
396
+
397
+ # cast back to fp16 if needed
398
+ if needs_upcasting:
399
+ self.vae.to(dtype=torch.float16)
400
+
401
+ self.maybe_free_model_hooks()
402
+
403
+ if not return_dict:
404
+ return frames
405
+
406
+ return StableVideoDiffusionPipelineOutput(frames=frames)
src/pipelines/pipeline_video_diffusion.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableVideoDiffusionPipeline as StableVideoDiffusionPipeline_original
2
+ import torch
3
+ from einops import rearrange
4
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
5
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
6
+ from typing import Callable, Dict, List, Tuple, Optional, Union
7
+ import PIL.Image
8
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
9
+ tensor2vid,
10
+ StableVideoDiffusionPipelineOutput,
11
+ _append_dims,
12
+ EXAMPLE_DOC_STRING
13
+ )
14
+
15
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
+
17
+ class VideoDiffusionPipeline(StableVideoDiffusionPipeline_original):
18
+
19
+ def _encode_vae_condition(
20
+ self,
21
+ cond_image: torch.tensor,
22
+ device: Union[str, torch.device],
23
+ num_videos_per_prompt: int,
24
+ do_classifier_free_guidance: bool,
25
+ ):
26
+ video_length = cond_image.shape[1]
27
+ cond_image = cond_image.to(device=device)
28
+ cond_image = cond_image.to(dtype=self.vae.dtype)
29
+ cond_image = rearrange(cond_image, "b f c h w -> (b f) c h w")
30
+ cond_em = self.vae.encode(cond_image).latent_dist.mode()
31
+ cond_em = rearrange(cond_em, "(b f) c h w -> b f c h w", f=video_length)
32
+
33
+ # duplicate cond_em for each generation per prompt, using mps friendly method
34
+ cond_em = cond_em.repeat(num_videos_per_prompt, 1, 1, 1, 1)
35
+
36
+ if do_classifier_free_guidance:
37
+ negative_cond_em = torch.zeros_like(cond_em)
38
+
39
+ # For classifier free guidance, we need to do two forward passes.
40
+ # Here we concatenate the unconditional and text embeddings into a single batch
41
+ # to avoid doing two forward passes
42
+ cond_em = torch.cat([negative_cond_em, cond_em])
43
+
44
+ return cond_em
45
+
46
+ def decode_latent_to_video(self, latents,
47
+ decode_chunk_size: Optional[int] = None,
48
+ num_frames: Optional[int] = None,
49
+ output_type: Optional[str] = "pil",):
50
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
51
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
52
+
53
+ @torch.no_grad()
54
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
55
+ def __call__(
56
+ self,
57
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
58
+ bbox_images: Optional[torch.FloatTensor] = None,
59
+ bbox_conditions: Optional[Dict[str, Union[torch.FloatTensor, List[Union[float, int]]]]] = None,
60
+ original_size: Optional[Tuple[int]] = (1242, 375),
61
+ height: int = 576,
62
+ width: int = 1024,
63
+ num_frames: Optional[int] = None,
64
+ num_inference_steps: int = 25,
65
+ min_guidance_scale: float = 1.0,
66
+ max_guidance_scale: float = 3.0,
67
+ fps: int = 7,
68
+ motion_bucket_id: int = 127,
69
+ noise_aug_strength: float = 0.02,
70
+ decode_chunk_size: Optional[int] = None,
71
+ num_videos_per_prompt: Optional[int] = 1,
72
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
73
+ latents: Optional[torch.FloatTensor] = None,
74
+ output_type: Optional[str] = "pil",
75
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
76
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
77
+ return_dict: bool = True,
78
+ num_cond_bbox_frames: int=3,
79
+ ):
80
+ r"""
81
+ The call function to the pipeline for generation.
82
+
83
+ Args:
84
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
85
+ Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
86
+ 1]`.
87
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
88
+ The height in pixels of the generated image.
89
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
90
+ The width in pixels of the generated image.
91
+ num_frames (`int`, *optional*):
92
+ The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
93
+ `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
94
+ num_inference_steps (`int`, *optional*, defaults to 25):
95
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
96
+ expense of slower inference. This parameter is modulated by `strength`.
97
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
98
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
99
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
100
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
101
+ fps (`int`, *optional*, defaults to 7):
102
+ Frames per second. The rate at which the generated images shall be exported to a video after
103
+ generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
104
+ motion_bucket_id (`int`, *optional*, defaults to 127):
105
+ Used for conditioning the amount of motion for the generation. The higher the number the more motion
106
+ will be in the video.
107
+ noise_aug_strength (`float`, *optional*, defaults to 0.02):
108
+ The amount of noise added to the init image, the higher it is the less the video will look like the
109
+ init image. Increase it for more motion.
110
+ decode_chunk_size (`int`, *optional*):
111
+ The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
112
+ expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
113
+ For lower memory usage, reduce `decode_chunk_size`.
114
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
115
+ The number of videos to generate per prompt.
116
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
117
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
118
+ generation deterministic.
119
+ latents (`torch.FloatTensor`, *optional*):
120
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
121
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
122
+ tensor is generated by sampling using the supplied random `generator`.
123
+ output_type (`str`, *optional*, defaults to `"pil"`):
124
+ The output format of the generated image. Choose between `pil`, `np` or `pt`.
125
+ callback_on_step_end (`Callable`, *optional*):
126
+ A function that is called at the end of each denoising step during inference. The function is called
127
+ with the following arguments:
128
+ `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
129
+ `callback_kwargs` will include a list of all tensors as specified by
130
+ `callback_on_step_end_tensor_inputs`.
131
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
132
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
133
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
134
+ `._callback_tensor_inputs` attribute of your pipeline class.
135
+ return_dict (`bool`, *optional*, defaults to `True`):
136
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
137
+ plain tuple.
138
+
139
+ Examples:
140
+
141
+ Returns:
142
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
143
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
144
+ returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`)
145
+ is returned.
146
+ """
147
+ # 0. Default height and width to unet
148
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
149
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
150
+
151
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
152
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
153
+
154
+ # 1. Check inputs. Raise error if not correct
155
+ self.check_inputs(image, height, width)
156
+
157
+ # 2. Define call parameters
158
+ if isinstance(image, PIL.Image.Image):
159
+ batch_size = 1
160
+ elif isinstance(image, list):
161
+ batch_size = len(image)
162
+ else:
163
+ batch_size = image.shape[0]
164
+ device = self._execution_device
165
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
166
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
167
+ # corresponds to doing no classifier free guidance.
168
+ self._guidance_scale = max_guidance_scale
169
+
170
+ # 3. Encode input image
171
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
172
+
173
+ # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
174
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
175
+ fps = fps - 1
176
+
177
+ # 4. Encode input image using VAE
178
+ image = self.image_processor.preprocess(image, height=height, width=width).to(device)
179
+ noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
180
+ image = image + noise_aug_strength * noise
181
+
182
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
183
+ if needs_upcasting:
184
+ self.vae.to(dtype=torch.float32)
185
+
186
+ image_latents = self._encode_vae_image(
187
+ image,
188
+ device=device,
189
+ num_videos_per_prompt=num_videos_per_prompt,
190
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
191
+ )
192
+ image_latents = image_latents.to(image_embeddings.dtype)
193
+
194
+ # Repeat the image latents for each frame so we can concatenate them with the noise
195
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
196
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
197
+
198
+ # 7b. Prepare control latent embeds
199
+ if not bbox_images is None:
200
+ cond_latents = self._encode_vae_condition(bbox_images,
201
+ device,
202
+ num_videos_per_prompt,
203
+ self.do_classifier_free_guidance)
204
+ image_latents[:,0:num_cond_bbox_frames,::] = cond_latents[:,0:num_cond_bbox_frames,::]
205
+ image_latents[:,-1,::]=cond_latents[:,-1,::]
206
+
207
+ # 5. Get Added Time IDs
208
+ added_time_ids = self._get_add_time_ids(
209
+ fps,
210
+ motion_bucket_id,
211
+ noise_aug_strength,
212
+ image_embeddings.dtype,
213
+ batch_size,
214
+ num_videos_per_prompt,
215
+ self.do_classifier_free_guidance,
216
+ )
217
+ added_time_ids = added_time_ids.to(device)
218
+
219
+ # 6. Prepare timesteps
220
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
221
+ timesteps = self.scheduler.timesteps
222
+
223
+ # 7. Prepare latent variables
224
+ num_channels_latents = self.unet.config.out_channels*2
225
+ latents = self.prepare_latents(
226
+ batch_size * num_videos_per_prompt,
227
+ num_frames,
228
+ num_channels_latents,
229
+ height,
230
+ width,
231
+ image_embeddings.dtype,
232
+ device,
233
+ generator,
234
+ latents,
235
+ )
236
+
237
+ # 8. Prepare guidance scale
238
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
239
+ guidance_scale = guidance_scale.to(device, latents.dtype)
240
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
241
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
242
+
243
+ self._guidance_scale = guidance_scale
244
+
245
+ # 9. Denoising loop
246
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
247
+ self._num_timesteps = len(timesteps)
248
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
249
+ for i, t in enumerate(timesteps):
250
+ # expand the latents if we are doing classifier free guidance
251
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
252
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
253
+
254
+ # Concatenate image_latents over channels dimension
255
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
256
+
257
+ # predict the noise residual
258
+ noise_pred = self.unet(
259
+ latent_model_input,
260
+ t,
261
+ encoder_hidden_states=image_embeddings,
262
+ added_time_ids=added_time_ids,
263
+ return_dict=False,
264
+ )[0]
265
+
266
+ # perform guidance
267
+ if self.do_classifier_free_guidance:
268
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
269
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
270
+
271
+ # compute the previous noisy sample x_t -> x_t-1
272
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
273
+
274
+ if callback_on_step_end is not None:
275
+ callback_kwargs = {}
276
+ for k in callback_on_step_end_tensor_inputs:
277
+ callback_kwargs[k] = locals()[k]
278
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
279
+
280
+ latents = callback_outputs.pop("latents", latents)
281
+
282
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
283
+ progress_bar.update()
284
+
285
+ if not output_type == "latent":
286
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
287
+ frames = torch.clamp(frames, -1, 1)
288
+ # not sure why these codes were here
289
+ # for i in range(frames.shape[2]):
290
+ # frame = frames[:, :, i]
291
+ # if frame.min() > -0.9:
292
+ # frames[:,:,i] = torch.zeros_like(frame)
293
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
294
+ else:
295
+ frames = latents
296
+
297
+ # cast back to fp16 if needed
298
+ if needs_upcasting:
299
+ self.vae.to(dtype=torch.float16)
300
+ self.maybe_free_model_hooks()
301
+
302
+ if not return_dict:
303
+ return frames
304
+
305
+ return StableVideoDiffusionPipelineOutput(frames=frames)
src/preprocess/README.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Video Dataset Processing Tools
2
+
3
+ This directory contains tools for processing and filtering video datasets. There are two main types of tools:
4
+
5
+ 1. Dataset Preprocessing Tools (`preprocess_*.py`)
6
+ 2. Dataset Filtering Tool (`filter_dataset_tool.py`)
7
+
8
+ ## Dataset Preprocessing Tools
9
+
10
+ These scripts process raw video datasets (DADA2000, CAP, and Russia Car Crash) by:
11
+ - Extracting frames at specified FPS
12
+ - Cropping frames to desired dimensions
13
+ - Generating object detection labels
14
+ - Creating train/val splits
15
+
16
+ ### Usage
17
+
18
+ Basic usage with default settings:
19
+ ```bash
20
+ # For DADA2000 dataset
21
+ python preprocess_dada_dataset.py
22
+
23
+ # For CAP dataset
24
+ python preprocess_cap_dataset.py
25
+
26
+ # For Russia Car Crash dataset
27
+ python preprocess_russia_dataset.py
28
+ ```
29
+
30
+ Advanced usage with custom settings:
31
+ ```bash
32
+ python preprocess_dada_dataset.py \
33
+ --dataset_root /path/to/datasets \
34
+ --dataset_dir /path/to/raw/dataset \
35
+ --out_directory /path/to/output \
36
+ --out_fps 15 \
37
+ --skip_extraction \
38
+ --skip_labels \
39
+ --skip_split
40
+ ```
41
+
42
+ ### Common Arguments
43
+ - `--dataset_root`: Root directory for datasets
44
+ - `--dataset_dir`: Directory containing the raw dataset
45
+ - `--out_directory`: Output directory (defaults to {dataset_root}/dataset_name)
46
+ - `--skip_extraction`: Skip frame extraction step
47
+ - `--skip_labels`: Skip label generation step
48
+ - `--skip_split`: Skip train/val split step
49
+
50
+ ### Dataset-Specific Arguments
51
+ - DADA2000:
52
+ - `--out_fps`: Output frames per second (default: 12)
53
+ - CAP:
54
+ - `--reverse`: Process samples in reverse order
55
+ - Russia:
56
+ - `--process_train`: Process training set (default is validation set only)
57
+
58
+ ## Dataset Filtering Tool
59
+
60
+ A tool for manually reviewing and filtering video datasets. It provides an interactive interface to review video frames and mark them as high quality or rejected. The tool can also automatically detect upscaled videos and scene changes to help with the filtering process.
61
+
62
+ ### Features
63
+ - Interactive video frame review with keyboard controls
64
+ - Automatic detection of upscaled videos
65
+ - Scene change detection
66
+ - Caching support for faster processing
67
+ - Support for both single-category and multi-category datasets
68
+
69
+ ### Usage
70
+
71
+ Basic usage with default settings:
72
+ ```bash
73
+ python filter_dataset_tool.py --dataset_name my_dataset
74
+ ```
75
+
76
+ Advanced usage with all features enabled:
77
+ ```bash
78
+ python filter_dataset_tool.py \
79
+ --dataset_name my_dataset \
80
+ --start_idx 0 \
81
+ --data_dir ./custom/path/to/images \
82
+ --output_root ./custom/output/path \
83
+ --use_cache
84
+ ```
85
+
86
+ ### Keyboard Controls
87
+ - `w`: Next frame
88
+ - `s`: Previous frame
89
+ - `d`: Next video
90
+ - `a`: Previous video
91
+ - `r`: Reject video
92
+ - `h`: Mark as high quality
93
+ - `p`: Increase playback speed
94
+ - `l`: Decrease playback speed
95
+ - `ESC`: Exit
96
+
97
+ ### Command Line Arguments
98
+ - `--dataset_name`: Name of the dataset directory (required)
99
+ - `--start_idx`: Starting index for video review (default: 0)
100
+ - `--data_dir`: Custom data directory path (default: ./{dataset_name}/images)
101
+ - `--output_root`: Custom output root directory (default: ./{dataset_name})
102
+ - `--disable_sort_by_upsample`: Disable sorting by upsampling factor
103
+ - `--disable_check_scene_changes`: Disable scene change detection
104
+ - `--single_category`: Process videos from a single category directory
105
+ - `--use_cache`: Use cache to speed up processing
src/preprocess/filter_dataset_tool.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import numpy as np
5
+ from time import time
6
+ import glob
7
+ from tqdm import tqdm
8
+ import scenedetect as sd
9
+ import argparse
10
+
11
+ # Load existing JSON data if available
12
+ def load_json(filename):
13
+ if os.path.exists(filename):
14
+ with open(filename, "r") as f:
15
+ return json.load(f)
16
+ return []
17
+
18
+ def save_json(filename, data):
19
+ with open(filename, "w") as f:
20
+ json.dump(data, f, indent=4)
21
+
22
+ def estimate_upsizing_factor(image_path):
23
+ """Estimate how much an image was upsized before being resized to 720x1280"""
24
+
25
+ # Load image in grayscale
26
+ img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
27
+
28
+ if img is None:
29
+ print(f"Error loading image: {image_path}")
30
+ return None
31
+
32
+ # Compute the 2D Fourier Transform
33
+ f = np.fft.fft2(img)
34
+ fshift = np.fft.fftshift(f) # Center the low frequencies
35
+ magnitude_spectrum = np.abs(fshift)
36
+
37
+ # Compute high-frequency energy
38
+ h, w = img.shape
39
+ cx, cy = w // 2, h // 2 # Center of the image
40
+ radius = min(cx, cy) // 4 # Define a region for high frequencies
41
+
42
+ # Mask low frequencies (keep only high frequencies)
43
+ mask = np.zeros((h, w), np.uint8)
44
+ cv2.circle(mask, (cx, cy), radius, 1, thickness=-1)
45
+ high_freq_energy = np.sum(magnitude_spectrum * (1 - mask))
46
+
47
+ # Normalize energy by image size
48
+ energy_score = high_freq_energy / (h * w)
49
+
50
+ # Estimate how much the image was upscaled
51
+ upsize_factor = 1 / (1 + energy_score) # Inverse relation: lower energy → more upscaling
52
+
53
+ return upsize_factor
54
+
55
+ def check_upsample(video_paths, output_root, use_cache=True):
56
+ t = time()
57
+
58
+ if use_cache:
59
+ cache_file = f"{output_root}/upsample_scores.json"
60
+ cached_data = load_json(cache_file)
61
+
62
+ results = {}
63
+ num_frames = 5
64
+ for src_images in tqdm(video_paths, desc="Computing upscale"):
65
+ vid_name = src_images.split('/')[-1]
66
+ if use_cache and vid_name in cached_data:
67
+ results[src_images] = cached_data[vid_name]
68
+ continue
69
+
70
+ all_images = sorted(glob.glob(f"{src_images}/*.jpg"))
71
+
72
+ if len(all_images) < 5:
73
+ continue
74
+
75
+ frame_indices = np.linspace(0, len(all_images) - 1, num_frames).astype(int)
76
+
77
+ vid_scores = []
78
+ for frame_idx in frame_indices:
79
+ image_path = all_images[frame_idx]
80
+
81
+ upsize_factor = estimate_upsizing_factor(image_path)
82
+ # print(image_dir, upsize_factor)
83
+ vid_scores.append(upsize_factor)
84
+
85
+ results[src_images] = np.median(vid_scores).item()
86
+
87
+ sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
88
+ sorted_results = {k: v for k, v in sorted_results}
89
+
90
+ if use_cache:
91
+ sorted_vids_by_names = {k.split('/')[-1]: v for k, v in sorted_results.items()}
92
+ save_json(cache_file, sorted_vids_by_names)
93
+
94
+ # print(f"Done in {time()-t:.2f}s")
95
+ return sorted_results
96
+
97
+ def detect_scenes(image_folder, threshold=27.0):
98
+ """Detects scene changes in a folder of images using PySceneDetect."""
99
+ image_files = [os.path.join(image_folder, f) for f in sorted(os.listdir(image_folder)) if f.lower().endswith(('.jpg', '.jpeg'))]
100
+ detector = sd.detectors.ContentDetector(threshold=threshold)
101
+ scene_list = []
102
+ prev_frame = None
103
+ frame_num = 0
104
+
105
+ for image_idx in range(0, len(image_files), 2): # Skip frames to go faster
106
+ image_file = image_files[image_idx]
107
+ frame = cv2.imread(image_file)
108
+ if frame is None:
109
+ continue
110
+
111
+ frame_num += 1
112
+ if prev_frame is not None:
113
+ if detector.process_frame(frame_num, frame):
114
+ scene_list.append(frame_num)
115
+
116
+ prev_frame = frame
117
+
118
+ return scene_list
119
+
120
+ def scan_scene_changes(video_paths, output_root, use_cache=True):
121
+
122
+ if use_cache:
123
+ cache_file = f"{output_root}/scene_changes.json"
124
+ cached_data = load_json(cache_file)
125
+
126
+ all_scene_change_vids = []
127
+ scene_changes_by_vid_name = {}
128
+ for folder_path in tqdm(video_paths, desc="Detecting scene changes"):
129
+
130
+ vid_name = folder_path.split('/')[-1]
131
+ if use_cache and vid_name in cached_data:
132
+ scene_changes = cached_data[vid_name]
133
+ else:
134
+ scene_changes = detect_scenes(folder_path)
135
+
136
+ scene_changes_by_vid_name[vid_name] = scene_changes
137
+
138
+ if len(scene_changes) > 0:
139
+ # print(f"{folder_path.split('/')[-1]} scene changes:", scene_changes)
140
+ all_scene_change_vids.append(folder_path)
141
+
142
+ if use_cache:
143
+ save_json(cache_file, scene_changes_by_vid_name)
144
+
145
+ print("Scene change vids:", len(all_scene_change_vids))
146
+ return all_scene_change_vids
147
+
148
+
149
+ def sort_tool(video_folders, rejected_file, highquality_file, start_idx=0):
150
+
151
+ rejected_videos = load_json(rejected_file)
152
+ highquality_videos = load_json(highquality_file)
153
+
154
+ rejected_videos_count = 0
155
+ for video_path in video_folders:
156
+ video_name = video_path.split("/")[-1]
157
+ if video_name in rejected_videos:
158
+ rejected_videos_count += 1
159
+ print(f"{rejected_videos_count}/{len(video_folders)} videos already rejected in this set")
160
+
161
+ video_idx = start_idx
162
+ frame_idx = 0
163
+ fps = 12
164
+ last_action_next = True
165
+
166
+ while True:
167
+ video_path = video_folders[video_idx]
168
+ video_name = video_path.split("/")[-1]
169
+ image_files = sorted([f for f in os.listdir(video_path) if f.endswith(".jpg")])
170
+
171
+ if not image_files:
172
+ print(f"No images found in {video_name}")
173
+ if last_action_next:
174
+ video_idx = (video_idx + 1) % len(video_folders)
175
+ else:
176
+ video_idx = (video_idx - 1) % len(video_folders)
177
+ continue
178
+
179
+ if video_name in rejected_videos or video_name in highquality_videos:
180
+ print(f"{video_name} already filtered")
181
+ if last_action_next:
182
+ video_idx = (video_idx + 1) % len(video_folders)
183
+ else:
184
+ video_idx = (video_idx - 1) % len(video_folders)
185
+ continue
186
+
187
+ frame_idx = 0
188
+ playing = True
189
+ paused = False
190
+
191
+ while playing:
192
+ frame_path = os.path.join(video_path, image_files[frame_idx])
193
+ frame = cv2.imread(frame_path)
194
+
195
+ if frame is None:
196
+ print(f"Failed to load {frame_path}")
197
+ continue
198
+
199
+ display_text = f"Video: {video_name} ({video_idx}/{len(video_folders)}) | Frame: {frame_idx + 1}/{len(image_files)} | fps: {fps}"
200
+ cv2.putText(frame, display_text, (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
201
+ cv2.imshow("Video Reviewer", frame)
202
+
203
+ key = cv2.waitKey(int(1000 / fps)) # 12 FPS
204
+
205
+ if key == ord('w'): # Next frame
206
+ frame_idx = min(len(image_files)-1, frame_idx + 1)
207
+ paused = True
208
+ elif key == ord('s'): # Previous frame
209
+ frame_idx = max(0, frame_idx - 1)
210
+ paused = True
211
+ elif key == ord('d'): # Next video
212
+ video_idx = (video_idx + 1) % len(video_folders)
213
+ last_action_next = True
214
+ break
215
+ elif key == ord('a'): # Previous video
216
+ video_idx = (video_idx - 1) % len(video_folders)
217
+ last_action_next = False
218
+ break
219
+ elif key == ord('r'): # Reject video
220
+ if video_name not in rejected_videos:
221
+ rejected_videos.append(video_name)
222
+ save_json(rejected_file, rejected_videos)
223
+ print(f"Rejected: {video_name}")
224
+ video_idx = (video_idx + 1) % len(video_folders)
225
+ break
226
+ elif key == ord('h'): # Mark as high quality
227
+ if video_name not in highquality_videos:
228
+ highquality_videos.append(video_name)
229
+ save_json(highquality_file, highquality_videos)
230
+ print(f"High Quality: {video_name}")
231
+ video_idx = (video_idx + 1) % len(video_folders)
232
+ break
233
+ elif key == ord('p'): # Increase fps
234
+ fps += 1
235
+ elif key == ord('l'): # Lower fps
236
+ fps = max(1, fps - 1)
237
+ elif key == 27: # ESC to exit
238
+ playing = False
239
+ break
240
+
241
+ if not paused:
242
+ frame_idx = (frame_idx + 1) % len(image_files)
243
+
244
+ if key == 27:
245
+ print(f"Last video: {video_name} ({video_idx})")
246
+ break
247
+
248
+ cv2.destroyAllWindows()
249
+
250
+ def collect_all_videos(data_dir, single_category=False):
251
+ all_video_paths = []
252
+ if single_category:
253
+ all_video_paths = sorted(glob.glob(f"{data_dir}/*"))
254
+ else:
255
+ for category in sorted(os.listdir(data_dir)):
256
+ category_path = os.path.join(data_dir, category)
257
+ if os.path.isdir(category_path):
258
+ all_video_paths.extend(sorted(glob.glob(f"{category_path}/*")))
259
+ return all_video_paths
260
+
261
+ if __name__ == "__main__":
262
+ parser = argparse.ArgumentParser(description='Filter and sort video dataset')
263
+ parser.add_argument('--dataset_name', type=str, required=True,
264
+ help='Name of the dataset directory')
265
+ parser.add_argument('--start_idx', type=int, default=0,
266
+ help='Starting index for video review')
267
+ parser.add_argument('--data_dir', type=str, default=None,
268
+ help='Custom data directory path (defaults to ./{dataset_name}/images)')
269
+ parser.add_argument('--output_root', type=str, default=None,
270
+ help='Custom output root directory (defaults to ./{dataset_name})')
271
+ parser.add_argument('--disable_sort_by_upsample', action='store_true',
272
+ help='Disable sorting videos by upsampling factor')
273
+ parser.add_argument('--disable_check_scene_changes', action='store_true',
274
+ help='Disable checking for scene changes in videos')
275
+ parser.add_argument('--single_category', action='store_true',
276
+ help='Process videos from a single category directory')
277
+ parser.add_argument('--use_cache', action='store_true',
278
+ help='Use cache to speed up processing')
279
+ args = parser.parse_args()
280
+
281
+ # Set default paths if not specified
282
+ if args.data_dir is None:
283
+ args.data_dir = f"./{args.dataset_name}/images"
284
+ if args.output_root is None:
285
+ args.output_root = f"./{args.dataset_name}"
286
+
287
+ # Output JSON files
288
+ rejected_file = f"{args.output_root}/rejected.json"
289
+ auto_low_quality = f"{args.output_root}/auto_low_quality.json"
290
+ highquality_file = f"{args.output_root}/highquality.json"
291
+
292
+ all_video_paths = collect_all_videos(args.data_dir, args.single_category)
293
+
294
+ if not args.disable_sort_by_upsample:
295
+ sorted_vids = check_upsample(all_video_paths, args.output_root, use_cache=args.use_cache)
296
+ video_folders = list(sorted_vids.keys())
297
+
298
+ # Save the worst to file
299
+ # auto_reject_vids = [v.split('/')[-1] for v in video_folders[:2000]]
300
+ # save_json(auto_low_quality, auto_reject_vids)
301
+ else:
302
+ video_folders = all_video_paths
303
+
304
+ # Prepend scene change samples
305
+ if not args.disable_check_scene_changes:
306
+ new_video_folders = []
307
+ scene_change_vids = scan_scene_changes(all_video_paths, args.output_root, use_cache=args.use_cache)
308
+ new_video_folders.extend(scene_change_vids)
309
+ for vid_name in video_folders:
310
+ if vid_name not in new_video_folders:
311
+ new_video_folders.append(vid_name)
312
+ video_folders = new_video_folders
313
+
314
+ # Start tool
315
+ sort_tool(video_folders, rejected_file, highquality_file, start_idx=args.start_idx)
src/preprocess/preprocess_cap_dataset.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ from tqdm import tqdm
5
+ from glob import glob
6
+ import argparse
7
+
8
+ from yolo_sam import YoloSamProcessor
9
+
10
+ def load_json(filename):
11
+ if os.path.exists(filename):
12
+ with open(filename, "r") as f:
13
+ return json.load(f)
14
+ print(filename, "not found")
15
+ return []
16
+
17
+ def create_video(sample, video_path):
18
+ video_filename = f"{video_path}.mp4"
19
+ FPS = 12
20
+ frame_size = (1056, 660)#(512, 320)
21
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
22
+ video_writer_out = cv2.VideoWriter(video_filename, fourcc, FPS, frame_size)
23
+
24
+ for img in sample:
25
+ video_writer_out.write(img)
26
+
27
+ video_writer_out.release()
28
+ print(f"Video saved: {video_filename}")
29
+
30
+
31
+ def crop_images(images_dir_path, output_dir_path, crop_extents=None):
32
+ """
33
+ Crop frames
34
+ """
35
+
36
+ all_images = sorted(glob(f"{images_dir_path}/*.jpg"))
37
+ total_frames = len(all_images)
38
+ sample_image = cv2.imread(all_images[0])
39
+ sample_name = str(int(images_dir_path.split('/')[-2])).zfill(5)
40
+ sample_category = images_dir_path.split('/')[-3]
41
+ out_vid_name = f"{sample_category}_{sample_name}"
42
+ src_height, src_width = sample_image.shape[:2]
43
+
44
+ if crop_extents:
45
+ src_height, src_width = (src_height + crop_extents[1]) - crop_extents[0], (src_width + crop_extents[3]) - crop_extents[2]
46
+
47
+ # print(f"Source images '{out_vid_name}': {src_height}x{src_width}")
48
+
49
+ image_output_folder = os.path.join(output_dir_path, out_vid_name)
50
+ os.makedirs(image_output_folder, exist_ok=True)
51
+
52
+ out_frame_count = 0
53
+ # sample_test = []
54
+ for frame_idx in range(total_frames):
55
+
56
+ frame_path = all_images[frame_idx]
57
+ frame = cv2.imread(frame_path)
58
+
59
+ # Crop frame
60
+ if crop_extents:
61
+ frame = frame[crop_extents[0]:crop_extents[1], crop_extents[2]:crop_extents[3]]
62
+
63
+ # Save frame
64
+ out_image_name = f"{out_vid_name}_{str(frame_idx).zfill(4)}.jpg"
65
+ out_image_path = os.path.join(image_output_folder, out_image_name)
66
+ cv2.imwrite(out_image_path, frame)
67
+
68
+ # sample_test.append(frame)
69
+
70
+ out_frame_count += 1
71
+
72
+ print(f"Done '{out_vid_name}': {src_height}x{src_width}, {out_frame_count} frames")
73
+
74
+ # create_video(sample_test, "path/to/sample_test_vid", fps=6)
75
+
76
+
77
+ def extract_frames(dataset_dir, out_directory, crop_extents=None, specific_videos=None):
78
+
79
+ # NOTE: We are excluding all crashes that involve visible humans (pedestrians, cyclists, motorbikes...)
80
+ video_types_to_exclude = [1, 2, 3, 4, 5, 6, 37, 38, 45, 46, 47, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60]
81
+
82
+ for category_dir in sorted(os.listdir(dataset_dir)):
83
+ category_dir_path = os.path.join(dataset_dir, category_dir)
84
+ if not os.path.isdir(category_dir_path):
85
+ continue
86
+
87
+ # Let's filter the videos we want right away
88
+ vid_type = int(category_dir)
89
+ if int(vid_type) in video_types_to_exclude:
90
+ continue
91
+
92
+ for vid_name in tqdm(sorted(os.listdir(category_dir_path))):
93
+ if specific_videos is not None and vid_name not in specific_videos:
94
+ continue
95
+
96
+ images_dir_path = os.path.join(category_dir_path, vid_name, "images")
97
+ out_path = os.path.join(out_directory, "images", category_dir)
98
+ crop_images(images_dir_path, out_path, crop_extents=crop_extents)
99
+
100
+ print("Extraction complete.")
101
+
102
+
103
+ def generate_labels(out_directory, vid_names=None, subdir="", in_directory=None, reverse_order=False):
104
+ label_output_folder = os.path.join(out_directory, "labels", subdir)
105
+ os.makedirs(label_output_folder, exist_ok=True)
106
+
107
+ # Checkpoint paths
108
+ yolo_ckpt = "yolov8x.pt" # Will auto download with utltralytics
109
+ sam2_ckpt = "/network/scratch/x/xuolga/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
110
+ sam2_cfg = "./configs/sam2.1/sam2.1_hiera_b+.yaml"
111
+ yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
112
+
113
+ samples_run = 0
114
+
115
+ src_directory = in_directory if in_directory is not None else out_directory
116
+ video_dir_root = os.path.join(src_directory, "images", subdir)
117
+ for category in sorted(os.listdir(video_dir_root), reverse=reverse_order):
118
+ category_root = os.path.join(video_dir_root, category)
119
+ for video_name in tqdm(sorted(os.listdir(category_root), reverse=reverse_order)):
120
+ if vid_names is not None and video_name not in vid_names:
121
+ continue
122
+
123
+ video_dir = os.path.join(category_root, video_name)
124
+ if len(os.listdir(video_dir)) == 0:
125
+ print("Empty video dir:", video_dir)
126
+ continue
127
+
128
+ # Skip if label file already exists
129
+ out_label_path = os.path.join(label_output_folder, f"{video_name}.json")
130
+ if os.path.exists(out_label_path):
131
+ print(f"Skipping {video_name} - label file already exists")
132
+ continue
133
+
134
+ if len(os.listdir(video_dir)) > 300:
135
+ print(f"SKIPPING LONG VIDEO {video_name}")
136
+ continue
137
+
138
+ print(f"Computing bboxes for {video_name}...")
139
+ video_data = yolo_sam(video_dir, rel_bbox=True)
140
+
141
+ # Add metadata
142
+ vid_type = int(video_name.split('_')[0])
143
+ ego_involved = vid_type < 19 or vid_type == 61
144
+ final_out_data = {
145
+ "video_source": f"{video_name}.mp4",
146
+ "metadata": {
147
+ "ego_involved": ego_involved,
148
+ "accident_type": vid_type
149
+ },
150
+ "data": video_data
151
+ }
152
+
153
+ with open(out_label_path, 'w') as json_file:
154
+ json.dump(final_out_data, json_file, indent=1)
155
+
156
+ print("Saved label:", out_label_path)
157
+
158
+ samples_run += 1
159
+ if samples_run > 50:
160
+ print("Resetting Yolo_Sam in case of memory leak")
161
+ del yolo_sam
162
+ yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
163
+ samples_run = 0
164
+
165
+
166
+ def make_train_val_split(out_directory):
167
+ image_folder = os.path.join(out_directory, "images")
168
+ label_folder = os.path.join(out_directory, "labels")
169
+
170
+ all_image_folders = os.listdir(image_folder)
171
+ split_idx = int(len(all_image_folders) * 0.9)
172
+
173
+ train_split = all_image_folders[:split_idx]
174
+ val_split = all_image_folders[split_idx:]
175
+
176
+ os.makedirs(os.path.join(image_folder, "train"), exist_ok=True)
177
+ os.makedirs(os.path.join(image_folder, "val"), exist_ok=True)
178
+ os.makedirs(os.path.join(label_folder, "train"), exist_ok=True)
179
+ os.makedirs(os.path.join(label_folder, "val"), exist_ok=True)
180
+
181
+ for filename in train_split:
182
+ os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "train", filename))
183
+ os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "train", f"{filename}.json"))
184
+
185
+ for filename in val_split:
186
+ os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "val", filename))
187
+ os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "val", f"{filename}.json"))
188
+
189
+
190
+ if __name__ == "__main__":
191
+ parser = argparse.ArgumentParser(description='Process CAP dataset')
192
+ parser.add_argument('--dataset_root', type=str, required=True,
193
+ help='Root directory for datasets')
194
+ parser.add_argument('--dataset_dir', type=str, default="/network/scratch/l/luis.lara/dev/MM-AU/CAP-DATA",
195
+ help='Directory containing the CAP dataset')
196
+ parser.add_argument('--out_directory', type=str, default=None,
197
+ help='Output directory (defaults to {dataset_root}/cap_images_12fps)')
198
+ parser.add_argument('--reverse', action='store_true',
199
+ help='Process samples in reverse order')
200
+ parser.add_argument('--skip_extraction', action='store_true',
201
+ help='Skip frame extraction step')
202
+ parser.add_argument('--skip_labels', action='store_true',
203
+ help='Skip label generation step')
204
+ parser.add_argument('--skip_split', action='store_true',
205
+ help='Skip train/val split step')
206
+ args = parser.parse_args()
207
+
208
+ # Set default output directory if not specified
209
+ if args.out_directory is None:
210
+ args.out_directory = os.path.join(args.dataset_root, "cap_images_12fps")
211
+
212
+ # Extract frames from videos
213
+ if not args.skip_extraction:
214
+ cap_crop_extents = [40, -40, 128, -128] # Custom crop for CAP dataset (get ratio right)
215
+ extract_frames(args.dataset_dir, args.out_directory, crop_extents=cap_crop_extents, specific_videos=None)
216
+
217
+ # Create labels (run bbox detector)
218
+ if not args.skip_labels:
219
+ in_directory = os.path.join(args.dataset_root, "cap_images_12fps")
220
+ generate_labels(args.out_directory, vid_names=None, reverse_order=args.reverse)
221
+
222
+ # Split into train and val sets
223
+ if not args.skip_split:
224
+ make_train_val_split(args.out_directory)
src/preprocess/preprocess_dada_dataset.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ from tqdm import tqdm
5
+ import argparse
6
+
7
+ from yolo_sam import YoloSamProcessor
8
+
9
+
10
+ def create_video(sample, video_path):
11
+ video_filename = f"{video_path}.mp4"
12
+ FPS = 12
13
+ frame_size = (1056, 660)#(512, 320)
14
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
15
+ video_writer_out = cv2.VideoWriter(video_filename, fourcc, FPS, frame_size)
16
+
17
+ for img in sample:
18
+ video_writer_out.write(img)
19
+
20
+ video_writer_out.release()
21
+ print(f"Video saved: {video_filename}")
22
+
23
+
24
+ def downsample_and_crop_vid(video_path, output_dir, out_fps=12, crop_extents=None):
25
+ """
26
+ Downsample fps and crop frames
27
+ """
28
+
29
+ # Load video
30
+ cap = cv2.VideoCapture(video_path)
31
+ org_fps = int(cap.get(cv2.CAP_PROP_FPS))
32
+ src_width, src_height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
33
+ original_sample_name = video_path.split('/')[-1].split('.')[0]
34
+ category = original_sample_name.split("_")[0]
35
+ vid_num = '90'+str(int(original_sample_name.split("_")[-1])).zfill(3) # NOTE: We prepend a '90' to differentiate between DADA and CAP samples
36
+ sample_name = f"{category}_{vid_num}"
37
+ if crop_extents:
38
+ src_width, src_height = (src_width + crop_extents[3]) - crop_extents[2], (src_height + crop_extents[1]) - crop_extents[0]
39
+
40
+ print(f"Source video '{sample_name}': {src_width}x{src_height}, fps={org_fps}")
41
+
42
+ total_frames = 0
43
+ target_period = 1/out_fps
44
+ last_frame_time = target_period
45
+ out_frame_count = 0
46
+
47
+ image_output_folder = os.path.join(output_dir, "images", category, sample_name)
48
+ os.makedirs(image_output_folder, exist_ok=True)
49
+
50
+ # sample_test = []
51
+ while cap.isOpened():
52
+ success, frame = cap.read()
53
+
54
+ if not success:
55
+ break
56
+
57
+ # Extract frames according to desired fps
58
+ if last_frame_time >= target_period:
59
+ out_frame_count += 1
60
+ last_frame_time = (last_frame_time - target_period)
61
+
62
+ # Crop frame
63
+ if crop_extents:
64
+ frame = frame[:, crop_extents[2] : crop_extents[3]]
65
+
66
+ # Save frame
67
+ out_image_name = f"{sample_name}_{str(total_frames).zfill(4)}.jpg"
68
+ out_image_path = os.path.join(image_output_folder, out_image_name)
69
+ cv2.imwrite(out_image_path, frame)
70
+
71
+ # sample_test.append(frame)
72
+
73
+ total_frames += 1
74
+ last_frame_time += 1/org_fps
75
+
76
+ print(f"Done '{sample_name}': {out_frame_count} frames, fps: {out_frame_count / (total_frames*1/org_fps)}")
77
+ cap.release()
78
+ # create_video(sample_test, "/path/to/sample_test_vid")
79
+
80
+
81
+ def extract_frames(dataset_dir, out_directory, crop_extents=None, out_fps=12):
82
+ dataset_video_dir = os.path.join(dataset_dir)
83
+
84
+ # NOTE: We are excluding all crashes that involve visible humans (pedestrians, cyclists, motorbikes...)
85
+ video_types_to_exclude = [1, 2, 3, 4, 5, 6, 37, 38, 45, 46, 47, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60]
86
+
87
+ for filename in tqdm(os.listdir(dataset_video_dir)):
88
+ if filename.split('.')[-1] != "mp4":
89
+ continue
90
+
91
+ # Let's filter the videos we want right away
92
+ vid_type = filename.split('_')[0]
93
+ if int(vid_type) in video_types_to_exclude:
94
+ continue
95
+
96
+ video_path = os.path.join(dataset_video_dir, filename)
97
+ downsample_and_crop_vid(video_path, out_directory, out_fps=out_fps, crop_extents=crop_extents)
98
+ print("Extraction complete.")
99
+
100
+
101
+ def generate_labels(out_directory, vid_names=None, subdir=""):
102
+ label_output_folder = os.path.join(out_directory, "labels", subdir)
103
+ os.makedirs(label_output_folder, exist_ok=True)
104
+
105
+ # Checkpoint paths
106
+ yolo_ckpt = "yolov8x.pt" # Will auto download with utltralytics
107
+ sam2_ckpt = "/network/scratch/x/xuolga/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
108
+ sam2_cfg = "./configs/sam2.1/sam2.1_hiera_b+.yaml"
109
+ yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
110
+
111
+ samples_run = 0
112
+
113
+ video_dir_root = os.path.join(out_directory, "images", subdir)
114
+ for category in sorted(os.listdir(video_dir_root), reverse=True):
115
+ category_root = os.path.join(video_dir_root, category)
116
+ for video_name in tqdm(os.listdir(category_root)):
117
+ if vid_names is not None and video_name not in vid_names:
118
+ continue
119
+
120
+ video_dir = os.path.join(category_root, video_name)
121
+ if len(os.listdir(video_dir)) == 0:
122
+ print("Empty video dir:", video_dir)
123
+ continue
124
+
125
+ if len(os.listdir(video_dir)) > 300:
126
+ print(f"SKIPPING LONG VIDEO {video_name}")
127
+ continue
128
+
129
+ # Skip if label file already exists
130
+ out_label_path = os.path.join(label_output_folder, f"{video_name}.json")
131
+ if os.path.exists(out_label_path):
132
+ print(f"Skipping {video_name} - label file already exists")
133
+ continue
134
+
135
+ video_data = yolo_sam(video_dir, rel_bbox=True)
136
+
137
+ if video_data is None:
138
+ print("COMPUTED VIDEO DATA IS NULL for video:", video_dir)
139
+
140
+ # Add metadata
141
+ vid_type = int(video_name.split('_')[0])
142
+ ego_involved = vid_type < 19 or vid_type == 61
143
+ final_out_data = {
144
+ "video_source": f"{video_name}.mp4",
145
+ "metadata": {
146
+ "ego_involved": ego_involved,
147
+ "accident_type": vid_type
148
+ },
149
+ "data": video_data
150
+ }
151
+
152
+ with open(out_label_path, 'w') as json_file:
153
+ json.dump(final_out_data, json_file, indent=1)
154
+
155
+ print("Saved label:", out_label_path)
156
+
157
+ samples_run += 1
158
+ if samples_run > 50:
159
+ print("Resetting Yolo_Sam in case of memory leak")
160
+ del yolo_sam
161
+ yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
162
+ samples_run = 0
163
+
164
+
165
+ def make_train_val_split(out_directory):
166
+ image_folder = os.path.join(out_directory, "images")
167
+ label_folder = os.path.join(out_directory, "labels")
168
+
169
+ all_image_folders = os.listdir(image_folder)
170
+ split_idx = int(len(all_image_folders) * 0.9)
171
+
172
+ train_split = all_image_folders[:split_idx]
173
+ val_split = all_image_folders[split_idx:]
174
+
175
+ os.makedirs(os.path.join(image_folder, "train"), exist_ok=True)
176
+ os.makedirs(os.path.join(image_folder, "val"), exist_ok=True)
177
+ os.makedirs(os.path.join(label_folder, "train"), exist_ok=True)
178
+ os.makedirs(os.path.join(label_folder, "val"), exist_ok=True)
179
+
180
+ for filename in train_split:
181
+ os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "train", filename))
182
+ os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "train", f"{filename}.json"))
183
+
184
+ for filename in val_split:
185
+ os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "val", filename))
186
+ os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "val", f"{filename}.json"))
187
+
188
+
189
+ if __name__ == "__main__":
190
+ parser = argparse.ArgumentParser(description='Process DADA2000 dataset')
191
+ parser.add_argument('--dataset_root', type=str, required=True,
192
+ help='Root directory for datasets')
193
+ parser.add_argument('--dataset_dir', type=str, required=True,
194
+ help='Directory containing the DADA2000 dataset')
195
+ parser.add_argument('--out_directory', type=str, default=None,
196
+ help='Output directory (defaults to {dataset_root}/dada2000_images_12fps)')
197
+ parser.add_argument('--skip_extraction', action='store_true',
198
+ help='Skip frame extraction step')
199
+ parser.add_argument('--skip_labels', action='store_true',
200
+ help='Skip label generation step')
201
+ parser.add_argument('--skip_split', action='store_true',
202
+ help='Skip train/val split step')
203
+ parser.add_argument('--out_fps', type=int, default=12,
204
+ help='Output frames per second (default: 12)')
205
+ args = parser.parse_args()
206
+
207
+ # Set default output directory if not specified
208
+ if args.out_directory is None:
209
+ args.out_directory = os.path.join(args.dataset_root, "dada2000_images_12fps")
210
+
211
+ # Extract frames from videos
212
+ if not args.skip_extraction:
213
+ dada_crop_extents = [0, -0, 264, -264] # Custom crop for DADA2000 dataset (get ratio right)
214
+ extract_frames(args.dataset_dir, args.out_directory, crop_extents=dada_crop_extents, out_fps=args.out_fps)
215
+
216
+ # Create labels (run bbox detector)
217
+ if not args.skip_labels:
218
+ generate_labels(args.out_directory, vid_names=None)
219
+
220
+ # Split into train and val sets
221
+ if not args.skip_split:
222
+ make_train_val_split(args.out_directory)
src/preprocess/preprocess_russia_dataset.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ from tqdm import tqdm
5
+ import argparse
6
+
7
+ from yolo_sam import YoloSamProcessor
8
+
9
+ def downsample_and_crop_vid(video_path, output_dir, out_fps=7, crop_extents=None):
10
+ """
11
+ Downsample fps and crop frames
12
+ """
13
+
14
+ # Load video
15
+ cap = cv2.VideoCapture(video_path)
16
+ org_fps = int(cap.get(cv2.CAP_PROP_FPS))
17
+ src_width, src_height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
18
+ sample_name = video_path.split('/')[-1].split('.')[0]
19
+ if crop_extents:
20
+ src_width, src_height = (src_width + crop_extents[3]) - crop_extents[2], (src_height + crop_extents[1]) - crop_extents[0]
21
+
22
+ print(f"Source video '{sample_name}': {src_width}x{src_height}, fps={org_fps}")
23
+
24
+ total_frames = 0
25
+ target_period = 1/out_fps
26
+ last_frame_time = target_period
27
+ out_frame_count = 0
28
+
29
+ image_output_folder = os.path.join(output_dir, "images", sample_name)
30
+ os.makedirs(image_output_folder, exist_ok=True)
31
+
32
+ while cap.isOpened():
33
+ success, frame = cap.read()
34
+
35
+ if not success:
36
+ break
37
+
38
+ # Extract frames according to desired fps
39
+ if last_frame_time >= target_period:
40
+ out_frame_count += 1
41
+ last_frame_time = (last_frame_time - target_period)
42
+
43
+ # Crop frame
44
+ if crop_extents:
45
+ frame = frame[crop_extents[0] : crop_extents[1], crop_extents[2] : crop_extents[3]]
46
+
47
+ # Save frame
48
+ out_image_name = f"{sample_name}_{str(total_frames).zfill(4)}.jpg"
49
+ out_image_path = os.path.join(image_output_folder, out_image_name)
50
+ cv2.imwrite(out_image_path, frame)
51
+
52
+ total_frames += 1
53
+ last_frame_time += 1/org_fps
54
+
55
+ print(f"Done '{sample_name}': {out_frame_count} frames, fps: {out_frame_count / (total_frames*1/org_fps)}")
56
+ cap.release()
57
+
58
+ def extract_frames(dataset_dir, out_directory, crop_extents=None):
59
+ dataset_video_dir = os.path.join(dataset_dir, "video")
60
+ for filename in tqdm(os.listdir(dataset_video_dir)):
61
+ video_path = os.path.join(dataset_video_dir, filename)
62
+ fps = 7
63
+ downsample_and_crop_vid(video_path, out_directory, out_fps=fps, crop_extents=crop_extents)
64
+ print("Extraction complete.")
65
+
66
+
67
+ def generate_labels(dataset_dir, out_directory, video_subdir=''):
68
+ label_output_folder = os.path.join(out_directory, "labels", video_subdir)
69
+ os.makedirs(label_output_folder, exist_ok=True)
70
+
71
+ # Checkpoint paths
72
+ yolo_ckpt = "yolov8x.pt" # Will auto download with utltralytics
73
+ sam2_ckpt = "/network/scratch/x/xuolga/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
74
+ sam2_cfg = "./configs/sam2.1/sam2.1_hiera_b+.yaml"
75
+ yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
76
+
77
+ video_dir_root = os.path.join(out_directory, "images", video_subdir)
78
+ # for video_name in tqdm(os.listdir(video_dir_root)):
79
+ for video_name in tqdm(["w10_138", "w10_94", "w1_10", "w1_46", "w2_79", "w3_17", "w6_14", "w6_44", "w6_78", "w6_94", "w7_1", "w7_14"]):
80
+ video_dir = os.path.join(video_dir_root, video_name)
81
+ if len(os.listdir(video_dir)) == 0:
82
+ print("Empty video dir:", video_dir)
83
+ continue
84
+
85
+ video_data = yolo_sam(video_dir, rel_bbox=True)
86
+
87
+ # Add metadata
88
+ org_dataset_labels = os.path.join(dataset_dir, "label", "json")
89
+ orig_label_path = os.path.join(org_dataset_labels, f"{video_name}.json")
90
+ with open(orig_label_path, 'r') as json_file:
91
+ metadata = json.load(json_file)[0]['meta_data']
92
+
93
+ final_out_data = {
94
+ "video_source": f"{video_name}.mp4",
95
+ "metadata": metadata,
96
+ "data": video_data
97
+ }
98
+
99
+ out_label_path = os.path.join(label_output_folder, f"{video_name}.json")
100
+ with open(out_label_path, 'w') as json_file:
101
+ json.dump(final_out_data, json_file, indent=1)
102
+
103
+ print("Saved label:", out_label_path)
104
+
105
+
106
+ def make_train_val_split(out_directory):
107
+ image_folder = os.path.join(out_directory, "images")
108
+ label_folder = os.path.join(out_directory, "labels")
109
+
110
+ all_image_folders = os.listdir(image_folder)
111
+ split_idx = int(len(all_image_folders) * 0.9)
112
+
113
+ train_split = all_image_folders[:split_idx]
114
+ val_split = all_image_folders[split_idx:]
115
+
116
+ os.makedirs(os.path.join(image_folder, "train"), exist_ok=True)
117
+ os.makedirs(os.path.join(image_folder, "val"), exist_ok=True)
118
+ os.makedirs(os.path.join(label_folder, "train"), exist_ok=True)
119
+ os.makedirs(os.path.join(label_folder, "val"), exist_ok=True)
120
+
121
+ for filename in train_split:
122
+ os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "train", filename))
123
+ os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "train", f"{filename}.json"))
124
+
125
+ for filename in val_split:
126
+ os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "val", filename))
127
+ os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "val", f"{filename}.json"))
128
+
129
+
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser(description='Process Russia Car Crash dataset')
132
+ parser.add_argument('--dataset_root', type=str, required=True,
133
+ help='Root directory for datasets')
134
+ parser.add_argument('--dataset_dir', type=str, required=True,
135
+ help='Directory containing the Russia Car Crash dataset')
136
+ parser.add_argument('--out_directory', type=str, default=None,
137
+ help='Output directory (defaults to {dataset_root}/preprocess_russia_crash)')
138
+ parser.add_argument('--skip_extraction', action='store_true',
139
+ help='Skip frame extraction step')
140
+ parser.add_argument('--skip_labels', action='store_true',
141
+ help='Skip label generation step')
142
+ parser.add_argument('--skip_split', action='store_true',
143
+ help='Skip train/val split step')
144
+ parser.add_argument('--process_train', action='store_true',
145
+ help='Process training set (default is validation set only)')
146
+ args = parser.parse_args()
147
+
148
+ # Set default output directory if not specified
149
+ if args.out_directory is None:
150
+ args.out_directory = os.path.join(args.dataset_root, "preprocess_russia_crash")
151
+
152
+ # Custom crop for Russia dataset (hide largest watermarks)
153
+ src_height, src_width = 986, 555
154
+ russia_crop_extents = [int(0.032*src_height), -int(0.198*src_height), int(0.115*src_width), -int(0.115*src_width)]
155
+
156
+ # Extract frames from videos
157
+ if not args.skip_extraction:
158
+ extract_frames(args.dataset_dir, args.out_directory, crop_extents=russia_crop_extents)
159
+
160
+ # Create labels (run bbox detector)
161
+ if not args.skip_labels:
162
+ generate_labels(args.dataset_dir, args.out_directory, video_subdir='val')
163
+ if args.process_train:
164
+ generate_labels(args.dataset_dir, args.out_directory, video_subdir='train')
165
+
166
+ # Split into train and val sets
167
+ if not args.skip_split:
168
+ make_train_val_split(args.out_directory)
src/preprocess/yolo_sam.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from PIL import Image
5
+ import torch
6
+ import random
7
+ from itertools import combinations
8
+ from tqdm import tqdm
9
+ import json
10
+ import cv2
11
+
12
+ from ultralytics import YOLO
13
+ from sam2.build_sam import build_sam2_video_predictor
14
+
15
+ NUM_LOOK_BACK_FRAMES = 3
16
+ FPS = 12
17
+
18
+ CLASSES_TO_KEEP = { # Using YOLO ids
19
+ 0: 'person',
20
+ 1: 'bicycle',
21
+ 2: 'car',
22
+ 3: 'motorcycle',
23
+ 5: 'bus',
24
+ 6: 'train',
25
+ 7: 'truck',
26
+ }
27
+
28
+
29
+ def create_video_from_images(images_dir, output_video, out_fps, start_frame=None, end_frame=None):
30
+
31
+ images = sorted(os.listdir(images_dir))
32
+
33
+ img0_path = os.path.join(images_dir, images[0])
34
+ img0 = cv2.imread(img0_path)
35
+ height, width, _ = img0.shape
36
+
37
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
38
+ out = cv2.VideoWriter(output_video, fourcc, out_fps, (width, height))
39
+
40
+ for idx, frame_name in enumerate(images):
41
+
42
+ if start_frame is not None and idx < start_frame:
43
+ continue
44
+ if end_frame is not None and idx >= end_frame:
45
+ continue
46
+
47
+ img = cv2.imread(os.path.join(images_dir, frame_name))
48
+ out.write(img)
49
+
50
+ out.release()
51
+ print("Saved video:", output_video)
52
+
53
+
54
+ def show_mask(mask, ax, obj_id=None, random_color=False, label=None):
55
+ if random_color:
56
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
57
+ else:
58
+ cmap = plt.get_cmap("tab10")
59
+ cmap_idx = 0 if obj_id is None else obj_id
60
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
61
+ h, w = mask.shape[-2:]
62
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
63
+ if label is not None:
64
+ text_location = mask.nonzero()
65
+ if len(text_location[0]) > 0:
66
+ rand_point = random.randint(0, len(text_location[0]) - 1)
67
+ ax.text(text_location[2][rand_point], text_location[1][rand_point], label, color=(1, 1, 1))
68
+ ax.imshow(mask_image)
69
+
70
+ def show_box(box, ax, label=None, color=((1, 0.7, 0.7))):
71
+ x0, y0 = box[0], box[1]
72
+ w, h = box[2] - box[0], box[3] - box[1]
73
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0, 0, 0, 0), lw=1))
74
+
75
+ if label is not None:
76
+ ax.text(x0 + w // 2, y0 + h // 2, label, color=color)
77
+
78
+
79
+ class TrackedObject:
80
+ def __init__(self, track_id, class_id, bbox, initial_frame_idx):
81
+ self.track_id = track_id
82
+ self.bbox = bbox
83
+ self.class_pred_counts = {class_id: 1}
84
+ self.initial_frame_idx = initial_frame_idx
85
+
86
+ self._top_class = class_id
87
+
88
+ @property
89
+ def class_id(self):
90
+ """
91
+ The class for the object is whichever class was predicted the most for it
92
+ """
93
+ if self._top_class is not None:
94
+ return self._top_class
95
+
96
+ top_class = None
97
+ top_count = 0
98
+ for class_id, count in self.class_pred_counts.items():
99
+ if count >= top_count:
100
+ top_count = count
101
+ top_class = class_id
102
+
103
+ self._top_class = top_class
104
+ return top_class
105
+
106
+ def new_pred(self, class_id):
107
+ if class_id not in self.class_pred_counts:
108
+ self.class_pred_counts[class_id] = 0
109
+ self.class_pred_counts[class_id] += 1
110
+ self._top_class = None # Remove cached top_class
111
+
112
+ def __repr__(self):
113
+ return f"id:{self.track_id}, class:{self.class_id}, bbox:{self.bbox}, init_frame:{self.initial_frame_idx}"
114
+
115
+ def __str__(self):
116
+ return f"id:{self.track_id}, class:{self.class_id}, bbox:{self.bbox}, init_frame:{self.initial_frame_idx}"
117
+
118
+
119
+ class YoloSamProcessor():
120
+
121
+ def __init__(self, yolo_ckpt, sam_ckpt, sam_cfg, gpu_id=0):
122
+ # Load models
123
+
124
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
125
+ self.device = f"cuda:{gpu_id}"
126
+
127
+ self.yolo_model = YOLO(yolo_ckpt)
128
+ self.sam_model = build_sam2_video_predictor(sam_cfg, sam_ckpt, device=self.device)
129
+
130
+ def __call__(self, video_dir, rel_bbox=True):
131
+ # Renaming the videos to be compatible with SAM2
132
+ prev_frame_paths = {}
133
+ src_width, src_height = None, None
134
+ self.num_frames = len(os.listdir(video_dir))
135
+ for vid in os.listdir(video_dir):
136
+ new_name = vid.split("_")[-1]
137
+ og_path = os.path.join(video_dir, vid)
138
+ new_path = os.path.join(video_dir, new_name)
139
+ os.rename(og_path, new_path)
140
+ prev_frame_paths[new_path] = og_path
141
+
142
+ if src_width is None:
143
+ img_sample = Image.open(new_path)
144
+ src_width, src_height = img_sample.size
145
+
146
+ self.out_data = None
147
+ try:
148
+ # YOLO model
149
+ self.run_yolo(video_dir)
150
+ self.filter_yolo_preds()
151
+
152
+ # SAM2 model
153
+ self.run_sam(video_dir)
154
+ self.filter_sam_preds()
155
+
156
+ self.extract_final_bboxes()
157
+
158
+ # Format the data
159
+ self.out_data = []
160
+ for frame_idx, frame_data in enumerate(self.final_bboxes):
161
+ frame_path = self.frame_path_by_idx[frame_idx]
162
+ frame_name = prev_frame_paths[frame_path].split('/')[-1]
163
+ self.out_data.append({
164
+ "image_source": frame_name,
165
+ "labels": []
166
+ })
167
+ for obj_id, bbox in frame_data.items():
168
+
169
+ bbox = bbox.tolist()
170
+ if rel_bbox:
171
+ # Save bbox coordinates as a ratio to image size
172
+ formatted_bbox = [bbox[0]/src_width, bbox[1]/src_height, bbox[2]/src_width, bbox[3]/src_height]
173
+ else:
174
+ # Keep in absolute coordinates
175
+ formatted_bbox = bbox
176
+
177
+ tracked_obj = self.initial_bboxes[obj_id]
178
+
179
+ out_label = {
180
+ "track_id": obj_id,
181
+ "name": CLASSES_TO_KEEP[tracked_obj.class_id],
182
+ "class": tracked_obj.class_id,
183
+ "box": formatted_bbox,
184
+ }
185
+ self.out_data[frame_idx]["labels"].append(out_label)
186
+ except Exception as e:
187
+ print("Yolo_Sam processor failed:", e)
188
+ finally:
189
+ # Revert back the names of the files
190
+ print("Restoring names in video dir after exception")
191
+ for new_path, old_path in prev_frame_paths.items():
192
+ os.rename(new_path, old_path)
193
+
194
+ return self.out_data
195
+
196
+ def run_yolo(self, video_dir):
197
+ self.initial_bboxes = {} # Store the first bbox for each track id
198
+ self.id_reassigns = {}
199
+ all_preds_by_track_id = []
200
+
201
+ self.frame_path_by_idx = {}
202
+
203
+ # Reset yolo's tracker for new video
204
+ if self.yolo_model.predictor is not None:
205
+ self.yolo_model.predictor.trackers[0].reset()
206
+
207
+ new_id_counter = 1
208
+ sorted_frames = sorted(os.listdir(video_dir))
209
+ for frame_idx, frame_file in enumerate(sorted_frames):
210
+ frame_path = os.path.join(video_dir, frame_file)
211
+ self.frame_path_by_idx[frame_idx] = frame_path
212
+ img = Image.open(frame_path)
213
+
214
+ yolo_results = self.yolo_model.track(img, persist=True, conf=0.1, verbose=False, device=self.device)
215
+ yolo_boxes = yolo_results[0].boxes
216
+ all_preds_by_track_id.append({})
217
+
218
+ # If the detection is has a new track id, then we record
219
+ if yolo_boxes.is_track:
220
+ for idx in range(len(yolo_boxes)):
221
+ track_id = int(yolo_boxes.id[idx].item())
222
+ bbox = yolo_boxes.xyxy[idx].numpy()
223
+ class_id = int(yolo_boxes.cls[idx].item())
224
+
225
+ tracked_obj = self.initial_bboxes.get(track_id)
226
+ # Check if YOLO is trying to assign a id that was already predicted but was not present in the previous frame
227
+ # This means YOLO lost this obj and is attempting to reassign. We will assign a new id for this as we want to
228
+ # trust SAM with the tracking not YOLO.
229
+ prev_track_id = track_id
230
+ if tracked_obj is not None and frame_idx > 0 and track_id not in all_preds_by_track_id[frame_idx - 1]:
231
+
232
+ # Check if id has been reassigned
233
+ if self.id_reassigns.get(track_id) is not None:
234
+ track_id = self.id_reassigns.get(track_id)
235
+
236
+ if track_id not in all_preds_by_track_id[frame_idx - 1]: # Check again because id might have changed
237
+ # Assign a new track id
238
+ track_id = 100 + new_id_counter
239
+ new_id_counter += 1
240
+ self.id_reassigns[prev_track_id] = track_id
241
+ tracked_obj = None
242
+ # print(f"Frame: {frame_idx} re-assigned id {prev_track_id}->{track_id}")
243
+
244
+ if tracked_obj is None:
245
+
246
+ # Check overlap with existing bboxes to make sure this isn't a double detection
247
+ reject_detection = False
248
+ for idx2 in range(len(yolo_boxes)):
249
+ track_id2 = int(yolo_boxes.id[idx2].item())
250
+ if track_id2 in [track_id, prev_track_id] or self.initial_bboxes.get(track_id2) is None:
251
+ continue
252
+ bbox2 = yolo_boxes.xyxy[idx2].numpy()
253
+ iou = self._bbox_iou(bbox, bbox2)
254
+ if iou >= 0.8:
255
+ reject_detection = True
256
+ # print("Redetection! Frame:", frame_idx, "Iou:", iou, track_id, track_id2)
257
+ break
258
+
259
+ if not reject_detection:
260
+ tracked_obj = TrackedObject(track_id, class_id, bbox, frame_idx)
261
+ self.initial_bboxes[track_id] = tracked_obj
262
+ else:
263
+ tracked_obj.new_pred(class_id)
264
+
265
+ if tracked_obj is not None:
266
+ all_preds_by_track_id[frame_idx][track_id] = TrackedObject(track_id, class_id, bbox, tracked_obj.initial_frame_idx)
267
+
268
+
269
+ def filter_yolo_preds(self):
270
+ # Smooth classes detected + filter out unwanted classes
271
+ self.filtered_objects = []
272
+ self.filtered_objects_by_frame = {}
273
+ for _, tracked_obj in self.initial_bboxes.items():
274
+ if tracked_obj.class_id in CLASSES_TO_KEEP:
275
+ self.filtered_objects.append(tracked_obj)
276
+
277
+ if self.filtered_objects_by_frame.get(tracked_obj.initial_frame_idx) is None:
278
+ self.filtered_objects_by_frame[tracked_obj.initial_frame_idx] = []
279
+ self.filtered_objects_by_frame[tracked_obj.initial_frame_idx].append(tracked_obj)
280
+
281
+ self.initial_frame_idx_by_track_id = {obj.track_id: obj.initial_frame_idx for obj in self.filtered_objects}
282
+
283
+ # print("Filtered objects:", self.filtered_objects)
284
+
285
+
286
+ def run_sam(self, video_dir):
287
+ self.video_segments = {} # video_segments contains the per-frame segmentation results
288
+ self.track_ids_to_reject = {} # {track_id: reject_all_before_frame_idx}
289
+
290
+ if self.filtered_objects is None or len(self.filtered_objects) == 0:
291
+ # There are no objects to track
292
+ return
293
+
294
+ inference_state = self.sam_model.init_state(video_path=video_dir) # NOTE: Kind of annoying that the model requires frames to be named with numbers only...
295
+
296
+ self.sam_model.reset_state(inference_state)
297
+ for obj in self.filtered_objects:
298
+ _, out_obj_ids, out_mask_logits = self.sam_model.add_new_points_or_box(
299
+ inference_state=inference_state,
300
+ frame_idx=obj.initial_frame_idx,
301
+ obj_id=obj.track_id,
302
+ box=obj.bbox,
303
+ )
304
+
305
+ def get_last_frame_occurrence(sam_track_ids_per_frame, track_id, current_idx):
306
+ for frame_idx in range(current_idx-1, -1, -1):
307
+ if track_id in sam_track_ids_per_frame.get(frame_idx, []):
308
+ return frame_idx
309
+ return -1
310
+
311
+ # run propagation throughout the video and collect the results in a dict
312
+ sam_track_ids_per_frame = {}
313
+ long_non_existence_track_ids = []
314
+ with torch.cuda.amp.autocast(): # Need this for some reason to fix some casting issues... (BFloat16 and Float16 mismatches)
315
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.sam_model.propagate_in_video(inference_state):
316
+
317
+ sam_tracked_ids = []
318
+ for pred_idx, mask_logits in enumerate(out_mask_logits):
319
+ mask = (mask_logits > 0.0).cpu().numpy()
320
+ track_id = out_obj_ids[pred_idx]
321
+ if mask.sum() > 0 and out_frame_idx > (self.initial_frame_idx_by_track_id[track_id] - NUM_LOOK_BACK_FRAMES):
322
+ sam_tracked_ids.append(track_id)
323
+ sam_track_ids_per_frame[out_frame_idx] = sam_tracked_ids
324
+
325
+ # Compare *new* YOLO preds and make sure they don't overlap with existing SAM preds
326
+ for new_yolo_obj in self.filtered_objects_by_frame.get(out_frame_idx, []):
327
+ yolo_track_id = new_yolo_obj.track_id
328
+ yolo_bbox = new_yolo_obj.bbox
329
+
330
+ for ind, sam_track_id in enumerate(out_obj_ids):
331
+ if (sam_track_id == yolo_track_id) or (out_frame_idx < (self.initial_frame_idx_by_track_id[sam_track_id] - NUM_LOOK_BACK_FRAMES)):
332
+ continue
333
+
334
+ sam_mask = (out_mask_logits[ind] > 0.0).cpu().numpy()
335
+ sam_bbox = self._get_bbox_from_mask(sam_mask)
336
+ if sam_bbox is None:
337
+ continue
338
+
339
+ # Flag the SAM prediction only if this prediction has been lost for many frames and SAM is trying to recover it
340
+ # (in which case we should keep the YOLO pred)
341
+
342
+ last_occurrence_idx = get_last_frame_occurrence(sam_track_ids_per_frame, sam_track_id, out_frame_idx)
343
+ if (out_frame_idx - last_occurrence_idx) >= (FPS * 0.8) and last_occurrence_idx >= 0:
344
+ print(sam_track_id, "long non-existence:", out_frame_idx - last_occurrence_idx, "frames")
345
+ long_non_existence_track_ids.append(sam_track_id)
346
+
347
+ iou = self._bbox_iou(yolo_bbox, sam_bbox)
348
+ if iou > 0.8:
349
+ # Reject the SAM tracked object if it was lost for many frames
350
+ if sam_track_id in long_non_existence_track_ids:
351
+ rejected_track_id = sam_track_id
352
+ self.track_ids_to_reject[rejected_track_id] = self.initial_frame_idx_by_track_id[sam_track_id]
353
+ print(f"Frame {out_frame_idx}. {yolo_track_id} & {sam_track_id} iou: {iou:.2f}. Reject: {rejected_track_id} (all) for long non-existence")
354
+ else:
355
+ # Otherwise, choose the obj with the latest yolo initial frame detection to reject
356
+ yolo_initial_frame = self.initial_frame_idx_by_track_id[yolo_track_id] # This is just the current frame
357
+ sam_initial_frame = self.initial_frame_idx_by_track_id[sam_track_id]
358
+ yolo_error = yolo_initial_frame >= sam_initial_frame
359
+ rejected_track_id = yolo_track_id if yolo_error else sam_track_id
360
+ reject_all_before_frame_idx = self.num_frames if yolo_error else sam_initial_frame
361
+ self.track_ids_to_reject[rejected_track_id] = reject_all_before_frame_idx
362
+ print(f"Frame {out_frame_idx}. {yolo_track_id} & {sam_track_id} iou: {iou:.2f}. Reject: {rejected_track_id} ({'before frame #' + str(reject_all_before_frame_idx) if not yolo_error else 'all'})")
363
+
364
+ self.video_segments[out_frame_idx] = {
365
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
366
+ for i, out_obj_id in enumerate(out_obj_ids)
367
+ }
368
+
369
+
370
+ def filter_sam_preds(self):
371
+ self.filtered_sam_preds = []
372
+ self.filtered_yolo_preds = []
373
+ self.num_sam_existence_frames_by_track_id = {}
374
+
375
+ def check_reject_pred(obj_id, frame_idx):
376
+ return obj_id in self.track_ids_to_reject and frame_idx < self.track_ids_to_reject[obj_id]
377
+
378
+ for frame_idx in range(self.num_frames):
379
+ self.filtered_sam_preds.append({})
380
+ self.filtered_yolo_preds.append({})
381
+
382
+ if frame_idx not in self.video_segments.keys():
383
+ continue
384
+
385
+ for obj_id, mask in self.video_segments[frame_idx].items():
386
+ # Only keep mask predictions that happen after the initial frame (with a small buffer)
387
+ # (SAM will try to predict them before the prompt frame, and will often get them wrong)
388
+ if (frame_idx >= self.initial_frame_idx_by_track_id[obj_id] - NUM_LOOK_BACK_FRAMES) and not check_reject_pred(obj_id, frame_idx):
389
+ self.filtered_sam_preds[frame_idx][obj_id] = mask
390
+
391
+ if mask.sum() > 0:
392
+ if self.num_sam_existence_frames_by_track_id.get(obj_id) is None:
393
+ self.num_sam_existence_frames_by_track_id[obj_id] = 0
394
+ self.num_sam_existence_frames_by_track_id[obj_id] += 1
395
+
396
+ for obj in self.filtered_objects:
397
+ if obj.initial_frame_idx == frame_idx and not check_reject_pred(obj.track_id, frame_idx):
398
+ self.filtered_yolo_preds[frame_idx][obj.track_id] = obj.bbox
399
+
400
+
401
+ def extract_final_bboxes(self):
402
+ # Extract the bboxes from the predicted masks
403
+ # Also filter any overlapping. At this stage if there is overlapping masks this is likely a fault on SAM's side
404
+ # (id switching/collecting) and there is not much we can do for this.
405
+ self.final_bboxes = []
406
+ rejected_ids = []
407
+ for frame_idx in range(self.num_frames):
408
+ self.final_bboxes.append({})
409
+ for obj_id, mask in self.filtered_sam_preds[frame_idx].items():
410
+ mask_box = self._get_bbox_from_mask(mask)
411
+ if mask_box is not None:
412
+ self.final_bboxes[frame_idx][obj_id] = mask_box
413
+
414
+ # Compute IOU overlap and eliminate duplicates
415
+ items_to_compare = list(self.final_bboxes[frame_idx].items()) + list(self.final_bboxes[frame_idx].items())
416
+ for (id0, bbox0), (id1, bbox1) in combinations(items_to_compare, 2):
417
+ if id0 == id1:
418
+ continue
419
+
420
+ if id0 not in self.final_bboxes[frame_idx] or id1 not in self.final_bboxes[frame_idx]: # Could've been removed previously
421
+ continue
422
+
423
+ if id0 in rejected_ids:
424
+ del self.final_bboxes[frame_idx][id0]
425
+ continue
426
+ if id1 in rejected_ids:
427
+ del self.final_bboxes[frame_idx][id1]
428
+ continue
429
+
430
+ iou = self._bbox_iou(bbox0, bbox1)
431
+ if iou > 0.8:
432
+ # Rejecting the prediction that exists for the least amount of frames throughout the video
433
+ frame_count0 = self.num_sam_existence_frames_by_track_id[id0]
434
+ frame_count1 = self.num_sam_existence_frames_by_track_id[id1]
435
+ rejected_id = id0 if frame_count0 < frame_count1 else id1
436
+
437
+ del self.final_bboxes[frame_idx][rejected_id]
438
+ rejected_ids.append(rejected_id)
439
+ # print(f"Frame {frame_idx}. {id0} & {id1} iou: {iou}. Rejecting {rejected_id}")
440
+
441
+ def _bbox_iou(self, box1, box2):
442
+ """
443
+ Compute the Intersection over Union (IoU) between two bounding boxes.
444
+
445
+ Parameters:
446
+ box1: (x1, y1, x2, y2) coordinates of the first box.
447
+ box2: (x1, y1, x2, y2) coordinates of the second box.
448
+
449
+ Returns:
450
+ iou: Intersection over Union value (0 to 1).
451
+ """
452
+ # Get the coordinates of the intersection rectangle
453
+ x1 = max(box1[0], box2[0])
454
+ y1 = max(box1[1], box2[1])
455
+ x2 = min(box1[2], box2[2])
456
+ y2 = min(box1[3], box2[3])
457
+
458
+ # Compute intersection area
459
+ inter_width = max(0, x2 - x1)
460
+ inter_height = max(0, y2 - y1)
461
+ inter_area = inter_width * inter_height
462
+
463
+ # Compute areas of both bounding boxes
464
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
465
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
466
+
467
+ # Compute union area
468
+ union_area = box1_area + box2_area - inter_area
469
+
470
+ # Compute IoU
471
+ iou = inter_area / union_area if union_area > 0 else 0
472
+
473
+ return iou
474
+
475
+ def _get_bbox_from_mask(self, mask):
476
+ mask_points = mask.nonzero()
477
+ if len(mask_points[0]) == 0:
478
+ return None
479
+
480
+ x0 = min(mask_points[2])
481
+ y0 = min(mask_points[1])
482
+ x1 = max(mask_points[2])
483
+ y1 = max(mask_points[1])
484
+
485
+ return np.array([x0, y0, x1, y1])
486
+
487
+ from collections import defaultdict
488
+ class CVCOLORS:
489
+ RED = (0,0,255)
490
+ GREEN = (0,255,0)
491
+ BLUE = (255,0,0)
492
+ PURPLE = (247,44,200)
493
+ ORANGE = (44,162,247)
494
+ MINT = (239,255,66)
495
+ YELLOW = (2,255,250)
496
+ BROWN = (42,42,165)
497
+ LIME=(51,255,153)
498
+ GRAY=(128, 128, 128)
499
+ LIGHTPINK = (222,209,255)
500
+ LIGHTGREEN = (204,255,204)
501
+ LIGHTBLUE = (255,235,207)
502
+ LIGHTPURPLE = (255,153,204)
503
+ LIGHTRED = (204,204,255)
504
+ WHITE = (255,255,255)
505
+ BLACK = (0,0,0)
506
+
507
+ TRACKID_LOOKUP = defaultdict(lambda: (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255)))
508
+ TYPE_LOOKUP = [BROWN, BLUE, PURPLE, RED, ORANGE, YELLOW, GREEN, LIGHTPURPLE, LIGHTPINK, LIGHTRED, GRAY]
509
+ REVERT_CHANNEL_F = lambda x: (x[2], x[1], x[0])
510
+
511
+
512
+ if __name__ == "__main__":
513
+
514
+ # Checkpoint paths
515
+ yolo_ckpt = "yolov8x.pt" # Will auto download with utltralytics
516
+
517
+ # NOTE: Need to download beforehand from https://github.com/facebookresearch/sam2
518
+ sam2_ckpt = "/network/scratch/x/xuolga/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
519
+ sam2_cfg = "./configs/sam2.1/sam2.1_hiera_b+.yaml"
520
+
521
+ yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
522
+ output_dir = f"/path/to/output_dir"
523
+ bboxes_out_dir = os.path.join(output_dir, "bboxes")
524
+ labels_out_dir = os.path.join(output_dir, "json")
525
+ os.makedirs(bboxes_out_dir, exist_ok=True)
526
+ os.makedirs(labels_out_dir, exist_ok=True)
527
+
528
+ video_dir_root = f"/path/to/dada2000_images_12fps/images"
529
+ videos = ['8_90038', '8_90002', '10_90019', '10_90045', '10_90027', '10_90029', '10_90082', '10_90021', '10_90064', '10_90083', '10_90141', '10_90139', '10_90034', '10_90134', '10_90056', '10_90169', '10_90040', '11_90109', '11_90162', '11_90202', '11_90142', '11_90180', '11_90161', '11_90091', '11_90189', '11_90002', '11_90192', '11_90221', '11_90181', '12_90007', '12_90042', '13_90002', '13_90008', '13_90007', '14_90012', '14_90018', '14_90014', '14_90027', '24_90017', '24_90005', '24_90006', '24_90011', '42_90021', '43_90013', '48_90078', '48_90031', '48_90001', '48_90075', '49_90030', '49_90021', '61_90016', '61_90004']
530
+
531
+ for video_name in tqdm(videos):
532
+ cat = video_name.split("_")[0]
533
+ video_dir = os.path.join(video_dir_root, cat, video_name)
534
+
535
+ if len(os.listdir(video_dir)) > 300:
536
+ print("Skipping:", video_name)
537
+ continue
538
+ out_data = yolo_sam(video_dir, rel_bbox=False)
539
+
540
+ # Output format:
541
+ """
542
+ out_data = [ # List of Dicts, each inner dict represents one frame of the video
543
+ {
544
+ "image_source": img1.jpg,
545
+ "labels":
546
+ [ # List of Dicts, each dict represents one tracked object
547
+ {'track_id': 0, 'name': 'car', 'class':2, 'bbox': [x0, y0, x1, y1]}, # Obj 0
548
+ {...}, # Obj 1
549
+ ]
550
+ }, # Frame 0
551
+
552
+ {...}, # Frame 1
553
+ ]
554
+ """
555
+
556
+ # Plot final bboxes and save to file
557
+ out_json_path = os.path.join(labels_out_dir, f"{video_name}.json")
558
+ os.makedirs(os.path.dirname(out_json_path), exist_ok=True)
559
+ with open(out_json_path, 'w') as json_file:
560
+ json.dump(out_data, json_file, indent=1)
561
+
562
+ og_frames = sorted(os.listdir(video_dir))
563
+ out_bbox_path = os.path.join(bboxes_out_dir, video_name)
564
+ os.makedirs(out_bbox_path, exist_ok=True)
565
+ for frame_idx, frame_data in enumerate(out_data):
566
+ plt.figure(figsize=(9, 6))
567
+ plt.axis("off")
568
+ img = Image.open(os.path.join(video_dir, og_frames[frame_idx]))
569
+ plt.imshow(img)
570
+ for obj in frame_data["labels"]:
571
+ color = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[obj["class"]])) / 255.0
572
+ show_box(obj["box"], plt.gca(), label=str(obj["track_id"]), color=color)
573
+
574
+ frame_id_name = og_frames[frame_idx].split("_")[-1].split(".")[0]
575
+ plt.savefig(os.path.join(out_bbox_path, f"bboxes_frame_{frame_id_name}"))
576
+
577
+ # Save videos of bboxes
578
+ videos_out_dir = os.path.join(output_dir, "videos")
579
+ out_video_path = os.path.join(videos_out_dir, f"{video_name}_with_bboxes.mp4")
580
+ os.makedirs(os.path.dirname(out_video_path), exist_ok=True)
581
+ create_video_from_images(out_bbox_path, out_video_path, out_fps=12)
582
+
583
+
584
+
src/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .parser import parse_args
2
+ from .utils import encode_video_image, get_add_time_ids, get_samples, get_model_attr
src/utils/parser.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: Clean up unused args
2
+ def parse_args():
3
+ import argparse
4
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
5
+
6
+ parser.add_argument(
7
+ "--project_name",
8
+ type=str,
9
+ default="car_crash",
10
+ help="Name of the project."
11
+ )
12
+ parser.add_argument(
13
+ "--pretrained_model_name_or_path",
14
+ type=str,
15
+ default=None,
16
+ required=True,
17
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
18
+ )
19
+ parser.add_argument(
20
+ "--finetuned_svd_path",
21
+ type=str,
22
+ default=None,
23
+ required=False,
24
+ help="Path to pretrained unet model. Used to override 'pretrained_model_name_or_path', and will default to this if not provided",
25
+ )
26
+ parser.add_argument(
27
+ "--revision",
28
+ type=str,
29
+ default=None,
30
+ required=False,
31
+ help="Revision of pretrained model identifier from huggingface.co/models.",
32
+ )
33
+ parser.add_argument(
34
+ "--variant",
35
+ type=str,
36
+ default=None,
37
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
38
+ )
39
+ parser.add_argument(
40
+ "--output_dir",
41
+ type=str,
42
+ default="out",
43
+ help="The output directory where the model predictions and checkpoints will be written.",
44
+ )
45
+ parser.add_argument(
46
+ "--logging_dir",
47
+ type=str,
48
+ default="logs",
49
+ help=(
50
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
51
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
52
+ ),
53
+ )
54
+ parser.add_argument(
55
+ "--eval_dir",
56
+ type=str,
57
+ default="eval",
58
+ help=(
59
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
60
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
61
+ ),
62
+ )
63
+ parser.add_argument(
64
+ "--learning_rate",
65
+ type=float,
66
+ default=1e-4,
67
+ help="Initial learning rate (after the potential warmup period) to use.",
68
+ )
69
+ parser.add_argument(
70
+ "--object_net_lr_factor",
71
+ type=float,
72
+ default=1.0,
73
+ help="Factor to scale the learning rate of the object network.",
74
+ )
75
+ parser.add_argument(
76
+ "--gradient_accumulation_steps",
77
+ type=int,
78
+ default=1,
79
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
80
+ )
81
+ parser.add_argument(
82
+ "--scale_lr",
83
+ action="store_true",
84
+ default=False,
85
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
86
+ )
87
+ parser.add_argument(
88
+ "--lr_scheduler",
89
+ type=str,
90
+ default="constant",
91
+ help=(
92
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
93
+ ' "constant", "constant_with_warmup"]'
94
+ ),
95
+ )
96
+ parser.add_argument(
97
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
98
+ )
99
+ parser.add_argument(
100
+ "--snr_gamma",
101
+ type=float,
102
+ default=None,
103
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
104
+ "More details here: https://arxiv.org/abs/2303.09556.",
105
+ )
106
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
107
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
108
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
109
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
110
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
111
+ parser.add_argument(
112
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
113
+ )
114
+ parser.add_argument("--num_train_epochs", type=int, default=1)
115
+ parser.add_argument(
116
+ "--max_train_steps",
117
+ type=int,
118
+ default=None,
119
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
120
+ )
121
+ parser.add_argument(
122
+ "--dataloader_num_workers",
123
+ type=int,
124
+ default=0,
125
+ help=(
126
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
127
+ ),
128
+ )
129
+ parser.add_argument(
130
+ "--mixed_precision",
131
+ type=str,
132
+ default=None,
133
+ choices=["no", "fp16", "bf16"],
134
+ help=(
135
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
136
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
137
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
138
+ ),
139
+ )
140
+ parser.add_argument(
141
+ "--rank",
142
+ type=int,
143
+ default=4,
144
+ help=("The dimension of the LoRA update matrices."),
145
+ )
146
+ parser.add_argument(
147
+ "--checkpointing_steps",
148
+ type=int,
149
+ default=500,
150
+ help=(
151
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
152
+ " training using `--resume_from_checkpoint`."
153
+ ),
154
+ )
155
+ parser.add_argument(
156
+ "--checkpointing_time",
157
+ type=int,
158
+ default=-1,
159
+ help=(
160
+ "Save a checkpoint of the training state every X seconds. Useful to save checkpoint when using jobs with short timeouts. If <= 0, will not save any checkpoints based on time."
161
+ " training using `--resume_from_checkpoint`."
162
+ ),
163
+ )
164
+ parser.add_argument(
165
+ "--checkpoints_total_limit",
166
+ type=int,
167
+ default=None,
168
+ help=("Max number of checkpoints to store."),
169
+ )
170
+ parser.add_argument(
171
+ "--resume_from_checkpoint",
172
+ type=str,
173
+ default=None,
174
+ help=(
175
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
176
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
177
+ ),
178
+ )
179
+ parser.add_argument(
180
+ "--enable_gradient_checkpointing",
181
+ action="store_true",
182
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
183
+ )
184
+ parser.add_argument(
185
+ "--data_root",
186
+ type=str,
187
+ default="./data",
188
+ help="The root directory of the dataset.",
189
+ )
190
+ parser.add_argument(
191
+ "--validation_steps",
192
+ type=int,
193
+ default=50,
194
+ help=(
195
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
196
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
197
+ ),
198
+ )
199
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
200
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
201
+ parser.add_argument(
202
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
203
+ )
204
+ parser.add_argument("--disable_object_condition", action="store_true", help="Whether or not to disable object condition.")
205
+ parser.add_argument("--encoder_hid_dim_type", type=str, default=None, help="The type of unet's input hidden.")
206
+ parser.add_argument(
207
+ "--report_to",
208
+ type=str,
209
+ default="wandb",
210
+ help=(
211
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
212
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
213
+ ),
214
+ )
215
+ parser.add_argument(
216
+ "--run_name",
217
+ type=str,
218
+ default=None,
219
+ help="The name of the run log.",
220
+ )
221
+ parser.add_argument(
222
+ "--prediction_type",
223
+ type=str,
224
+ default=None,
225
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
226
+ )
227
+ parser.add_argument(
228
+ "--guidance_scale",
229
+ type=float,
230
+ default=1.0,
231
+ help="(Image only). A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`."
232
+ )
233
+ parser.add_argument(
234
+ "--guidance_rescale",
235
+ type=float,
236
+ default=0.0,
237
+ help="(Image only). Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR."
238
+ )
239
+ parser.add_argument(
240
+ "--min_guidance_scale",
241
+ type=float,
242
+ default=1.0,
243
+ help="(Video generation only). The minimum guidance scale. Used for the classifier free guidance with first frame."
244
+ )
245
+ parser.add_argument(
246
+ "--max_guidance_scale",
247
+ type=float,
248
+ default=3.0,
249
+ help="(Video generation only). The maximum guidance scale. Used for the classifier free guidance with last frame."
250
+ )
251
+ parser.add_argument(
252
+ "--conditioning_dropout_prob",
253
+ type=float,
254
+ default=0.1,
255
+ help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
256
+ )
257
+ parser.add_argument(
258
+ "--clip_length",
259
+ type=int,
260
+ default=25,
261
+ help="The number of frames in a clip.",
262
+ )
263
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
264
+ parser.add_argument(
265
+ "--non_ema_revision",
266
+ type=str,
267
+ default=None,
268
+ required=False,
269
+ help=(
270
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
271
+ " remote repository specified with --pretrained_model_name_or_path."
272
+ ),
273
+ )
274
+ parser.add_argument(
275
+ "--backprop_temporal_blocks_start_iter",
276
+ type=int,
277
+ default=-1,
278
+ help="(Video generation only). The starting iteration of only backpropagating into temporal blocks (if -1, then always backpropagate the entire network).",
279
+ )
280
+ parser.add_argument(
281
+ "--enable_lora",
282
+ action="store_true",
283
+ default=False,
284
+ help="Enable LoRA.",
285
+ )
286
+ parser.add_argument(
287
+ "--add_bbox_frame_conditioning",
288
+ action="store_true",
289
+ default=False,
290
+ help="(Video generation only). Add bbox frame conditioning.",
291
+ )
292
+ parser.add_argument(
293
+ "--bbox_dropout_prob",
294
+ type=float,
295
+ default=0.0,
296
+ help="(Video generation only). Bbox dropout probability. Drops out the bbox conditionings.",
297
+ )
298
+ parser.add_argument(
299
+ "--num_demo_samples",
300
+ type=int,
301
+ default=1,
302
+ help="Number of samples to generate during demo.",
303
+ )
304
+ parser.add_argument(
305
+ "--noise_aug_strength",
306
+ type=float,
307
+ default=0.02,
308
+ help="(Video generation only). Strength of noise augmentation.",
309
+ )
310
+ parser.add_argument(
311
+ "--num_inference_steps",
312
+ type=int,
313
+ default=25,
314
+ help="Number of inference denoising steps.",
315
+ )
316
+ parser.add_argument(
317
+ "--conditioning_scale",
318
+ type=float,
319
+ default=1.0,
320
+ help="(Controlnet only). The scale of conditioning."
321
+ )
322
+ parser.add_argument(
323
+ "--train_H",
324
+ type=int,
325
+ default=None,
326
+ help="For training, the height of the image to use. If None, the default height is 320 for video diffusion and 512 for image diffusion."
327
+ )
328
+ parser.add_argument(
329
+ "--train_W",
330
+ type=int,
331
+ default=None,
332
+ help="For training, the width of the image to use. If None, the default width is 512."
333
+ )
334
+ parser.add_argument(
335
+ "--eval_H",
336
+ type=int,
337
+ default=None,
338
+ help="For evaluation, the height of the image to use. If None, the default height is 320 for video diffusion and 512 for image diffusion."
339
+ )
340
+ parser.add_argument(
341
+ "--generate_bbox",
342
+ action="store_true",
343
+ default=False,
344
+ help="(Controlnet only). Whether to generate bbox."
345
+ )
346
+ parser.add_argument(
347
+ "--predict_bbox",
348
+ action="store_true",
349
+ default=False,
350
+ help="(Video diffusion only). Whether to predict bbox."
351
+ )
352
+ parser.add_argument(
353
+ "--evaluate_only",
354
+ action="store_true",
355
+ default=False,
356
+ help="Whether to only evaluate the model."
357
+ )
358
+ parser.add_argument(
359
+ "--demo_path",
360
+ default=None,
361
+ type=str,
362
+ help="Path where the demo samples are saved."
363
+ )
364
+ parser.add_argument(
365
+ "--pretrained_bbox_model",
366
+ default=None,
367
+ type=str,
368
+ help="Path to the pretrained bbox model."
369
+ )
370
+ parser.add_argument(
371
+ "--if_last_frame_trajectory",
372
+ action="store_true",
373
+ default=False,
374
+ help="Whether to use the last frame as the trajectory."
375
+ )
376
+ parser.add_argument(
377
+ "--fps",
378
+ type=int,
379
+ default=None,
380
+ help="FPS of the video."
381
+ )
382
+ parser.add_argument(
383
+ "--num_cond_bbox_frames",
384
+ type=int,
385
+ default=3,
386
+ help="Number of conditioning bbox frames."
387
+ )
388
+ parser.add_argument(
389
+ "--wandb_entity",
390
+ type=str,
391
+ default="",
392
+ help="Wandb entity",
393
+ )
394
+ parser.add_argument(
395
+ "--non_overlapping_clips",
396
+ action="store_true",
397
+ default=False,
398
+ help="Load clips that do not overlap in dataset",
399
+ )
400
+ parser.add_argument(
401
+ "--disable_wandb",
402
+ action="store_true",
403
+ default=False,
404
+ help="Disable wandb logging.",
405
+ )
406
+ parser.add_argument(
407
+ "--empty_cuda_cache",
408
+ action="store_true",
409
+ default=False,
410
+ help="Periodically call `torch.cuda.empty_cache` every training loop. Helps to reduce chance of memory leakage and OOM error",
411
+ )
412
+ parser.add_argument(
413
+ "--bbox_masking_prob",
414
+ type=float,
415
+ default=0.0,
416
+ help="Bbox dropout probability. Drops out the bbox conditionings, will drop certain agents in the scene.",
417
+ )
418
+
419
+ parser.add_argument(
420
+ "--dataset_name",
421
+ nargs="+",
422
+ type=str,
423
+ default="russia_crash",
424
+ choices=["russia_crash", "nuscenes", "dada2000", "mmau", "bdd100k"],
425
+ help=(
426
+ "The name of the Dataset to train on. Can specify a list of dataset names to be merged together."
427
+ ),
428
+ )
429
+ parser.add_argument(
430
+ "--use_action_conditioning",
431
+ action="store_true",
432
+ default=False,
433
+ help="Whether to use the action conditioning (ie accident type flags)"
434
+ )
435
+ parser.add_argument(
436
+ "--contiguous_bbox_masking_prob",
437
+ type=float,
438
+ default=0.0,
439
+ help="Prob to mask out bbox conditionings in a contiguous manner (ie, all bboxes after a certain frame). A random frame between [0, N] will be selected and all subsequent frames will be masked",
440
+ )
441
+ parser.add_argument(
442
+ "--contiguous_bbox_masking_start_ratio",
443
+ type=float,
444
+ default=0.0,
445
+ help="Ratio of contiguous bbox frames masking (as defined by --contiguous_bbox_masking_prob) to mask out from the start of the video instead of the end.",
446
+ )
447
+ parser.add_argument(
448
+ "--val_on_first_step",
449
+ action="store_true",
450
+ default=False,
451
+ help="Whether to run validation on the first step (useful for testing)"
452
+ )
453
+
454
+ args = parser.parse_args()
455
+ # default to using the same revision for the non-ema model if not specified
456
+ if args.non_ema_revision is None:
457
+ args.non_ema_revision = args.revision
458
+
459
+ if args.enable_lora:
460
+ args.backprop_temporal_blocks_start_iter = -1
461
+
462
+ if args.evaluate_only:
463
+ assert args.resume_from_checkpoint is not None, "Must provide a checkpoint to evaluate the model."
464
+
465
+ if args.fps is None:
466
+ if args.dataset_name == "bdd100k":
467
+ args.fps = 5
468
+ elif args.dataset_name == "nuscenes":
469
+ args.fps = 7 # NOTE: Must match the fps in nuscenes dataloader (=2 if no interpolation)
470
+ else:
471
+ args.fps = 7
472
+ return args