Spaces:
Paused
Paused
Upload 52 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- LICENSE +21 -0
- config/a100l.yaml +13 -0
- config/multi_gpu.yaml +17 -0
- etc/architecture_figure.png +3 -0
- etc/genvid_57_11_04453.gif +3 -0
- etc/genvid_64_48_08386.gif +3 -0
- etc/genvid_87_21_08924.gif +3 -0
- one_sample.ipynb +19 -0
- run_gen_videos.py +44 -0
- scripts/controlnet_train_action_multigpu.sh +58 -0
- scripts/controlnet_train_action_singlegpu.sh +58 -0
- scripts/mmau_train_video_diffusion_multigpu.sh +52 -0
- scripts/mmau_train_video_diffusion_singlegpu.sh +53 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/datasets/base_dataset.py +189 -0
- src/datasets/bbox_utils.py +68 -0
- src/datasets/bdd100k_dataset.py +185 -0
- src/datasets/dada2000_dataset.py +339 -0
- src/datasets/dataset_factory.py +36 -0
- src/datasets/dataset_utils.py +50 -0
- src/datasets/merged_dataset.py +54 -0
- src/datasets/mmau_dataset.py +549 -0
- src/datasets/nuscenes_dataset.py +298 -0
- src/datasets/russia_crash_dataset.py +173 -0
- src/eval/README.md +120 -0
- src/eval/__pycache__/generate_samples.cpython-310.pyc +0 -0
- src/eval/generate_samples.py +394 -0
- src/eval/video_dataset.py +79 -0
- src/eval/video_quality_metrics_fvd_gt_rand.py +458 -0
- src/eval/video_quality_metrics_fvd_pair.py +349 -0
- src/eval/video_quality_metrics_jedi_gt_rand.py +91 -0
- src/eval/video_quality_metrics_jedi_pair.py +92 -0
- src/models/__init__.py +2 -0
- src/models/controlnet.py +391 -0
- src/models/unet_spatio_temporal_condition.py +169 -0
- src/pipelines/__init__.py +4 -0
- src/pipelines/pipeline_video_control.py +408 -0
- src/pipelines/pipeline_video_control_factor_guidance.py +615 -0
- src/pipelines/pipeline_video_control_nullmodel.py +406 -0
- src/pipelines/pipeline_video_diffusion.py +305 -0
- src/preprocess/README.md +105 -0
- src/preprocess/filter_dataset_tool.py +315 -0
- src/preprocess/preprocess_cap_dataset.py +224 -0
- src/preprocess/preprocess_dada_dataset.py +222 -0
- src/preprocess/preprocess_russia_dataset.py +168 -0
- src/preprocess/yolo_sam.py +584 -0
- src/utils/__init__.py +2 -0
- src/utils/parser.py +472 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
etc/architecture_figure.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
etc/genvid_57_11_04453.gif filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
etc/genvid_64_48_08386.gif filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
etc/genvid_87_21_08924.gif filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
config/a100l.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
deepspeed_config: {}
|
| 3 |
+
distributed_type: NO
|
| 4 |
+
fsdp_config: {}
|
| 5 |
+
machine_rank: 0
|
| 6 |
+
main_process_ip: null
|
| 7 |
+
main_process_port: null
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: fp16
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 1
|
| 12 |
+
use_cpu: false
|
| 13 |
+
gpu_ids: all
|
config/multi_gpu.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: true
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
gpu_ids: all
|
| 7 |
+
machine_rank: 0
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: fp16
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 4
|
| 12 |
+
rdzv_backend: static
|
| 13 |
+
same_network: true
|
| 14 |
+
tpu_env: []
|
| 15 |
+
tpu_use_cluster: false
|
| 16 |
+
tpu_use_sudo: false
|
| 17 |
+
use_cpu: false
|
etc/architecture_figure.png
ADDED
|
Git LFS Details
|
etc/genvid_57_11_04453.gif
ADDED
|
Git LFS Details
|
etc/genvid_64_48_08386.gif
ADDED
|
Git LFS Details
|
etc/genvid_87_21_08924.gif
ADDED
|
Git LFS Details
|
one_sample.ipynb
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "dd192fe4",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": []
|
| 10 |
+
}
|
| 11 |
+
],
|
| 12 |
+
"metadata": {
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
}
|
| 16 |
+
},
|
| 17 |
+
"nbformat": 4,
|
| 18 |
+
"nbformat_minor": 5
|
| 19 |
+
}
|
run_gen_videos.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse #test
|
| 2 |
+
|
| 3 |
+
from src.eval.generate_samples import generate_samples
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
parser = argparse.ArgumentParser(description="Generate test samples from MMAU dataset")
|
| 7 |
+
parser.add_argument('--model_path', type=str, required=True, help='Model checkpoint used for generation')
|
| 8 |
+
parser.add_argument('--data_root', type=str, required=True, help='Dataset root path')
|
| 9 |
+
parser.add_argument('--output_path', type=str, default="./output_videos", help='Video output path')
|
| 10 |
+
parser.add_argument('--disable_null_model', action="store_true", default=False, help='For uncond noise preds, whether to use a null model')
|
| 11 |
+
parser.add_argument('--use_factor_guidance', action="store_true", default=False, help='')
|
| 12 |
+
parser.add_argument('--num_demo_samples', type=int, default=10, help='Number of samples to collect for generation')
|
| 13 |
+
parser.add_argument('--max_output_vids', type=int, default=200, help='Exit program once this many videos have been generated')
|
| 14 |
+
parser.add_argument('--num_gens_per_sample', type=int, default=1, help='Number videos to generate for each test case')
|
| 15 |
+
parser.add_argument('--eval_output', action="store_true", default=False, help='')
|
| 16 |
+
parser.add_argument('--seed', type=int, default=None, help='')
|
| 17 |
+
parser.add_argument('--dataset', type=str, default="mmau")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--bbox_mask_idx_batch",
|
| 20 |
+
nargs="+",
|
| 21 |
+
type=int,
|
| 22 |
+
default=[None],
|
| 23 |
+
choices=list(range(25+1)),
|
| 24 |
+
help="Where to start the masking, multiple values represent multiple different test cases for each sample",
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--force_action_type_batch",
|
| 28 |
+
nargs="+",
|
| 29 |
+
type=int,
|
| 30 |
+
default=[None],
|
| 31 |
+
choices=[0, 1, 2, 3, 4],
|
| 32 |
+
help="Which action type to force, multiple values represent multiple different test cases for each sample",
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--guidance_scales",
|
| 36 |
+
nargs="+",
|
| 37 |
+
type=int,
|
| 38 |
+
default=[(1, 9)],
|
| 39 |
+
help="Guidance progression to use, multiple values represent multiple different test cases for each sample",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
args = parser.parse_args()
|
| 43 |
+
|
| 44 |
+
generate_samples(args)
|
scripts/controlnet_train_action_multigpu.sh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nvidia-smi | grep 'python' | awk '{ print $5 }' | xargs -n1 kill -9
|
| 2 |
+
|
| 3 |
+
# User-specific paths and settings
|
| 4 |
+
DATASET_PATH="<path/to/datasets>" # e.g., "/home/datasets_root"
|
| 5 |
+
NAME="<experiment_name>" # e.g., "box2video_experiment1"
|
| 6 |
+
OUT_DIR="<path/to/output>/${NAME}" # e.g., "/home/results/${NAME}"
|
| 7 |
+
PROJECT_NAME='<wandb_project_name>' # e.g., 'car_crash'
|
| 8 |
+
WANDB_ENTITY='<wandb_username>' # Your Weights & Biases username
|
| 9 |
+
PRETRAINED_MODEL_PATH="<path/to/pretrained/model>" # e.g., "/home/checkpoints_root/checkpoint"
|
| 10 |
+
|
| 11 |
+
# export HF_HOME=/path/to/root # Where the SVD pretrained models are/will be downloaded
|
| 12 |
+
|
| 13 |
+
# Create output directory
|
| 14 |
+
mkdir -p $OUT_DIR
|
| 15 |
+
|
| 16 |
+
# Save training script for reference
|
| 17 |
+
SCRIPT_PATH=$0
|
| 18 |
+
SAVE_SCRIPT_PATH="${OUT_DIR}/train_scripts.sh"
|
| 19 |
+
cp $SCRIPT_PATH $SAVE_SCRIPT_PATH
|
| 20 |
+
echo "Saved script to ${SAVE_SCRIPT_PATH}"
|
| 21 |
+
|
| 22 |
+
# Training command
|
| 23 |
+
CUDA_LAUNCH_BLOCKING=1 accelerate launch --config_file config/multi_gpu.yaml train_video_controlnet.py \
|
| 24 |
+
--run_name $NAME \
|
| 25 |
+
--data_root $DATASET_PATH \
|
| 26 |
+
--project_name $PROJECT_NAME \
|
| 27 |
+
--pretrained_model_name_or_path $PRETRAINED_MODEL_PATH \
|
| 28 |
+
--output_dir $OUT_DIR \
|
| 29 |
+
--variant fp16 \
|
| 30 |
+
--dataset_name mmau \
|
| 31 |
+
--train_batch_size 1 \
|
| 32 |
+
--learning_rate 4e-5 \
|
| 33 |
+
--checkpoints_total_limit 3 \
|
| 34 |
+
--checkpointing_steps 300 \
|
| 35 |
+
--checkpointing_time 10620 \
|
| 36 |
+
--gradient_accumulation_steps 5 \
|
| 37 |
+
--validation_steps 300 \
|
| 38 |
+
--enable_gradient_checkpointing \
|
| 39 |
+
--lr_scheduler constant \
|
| 40 |
+
--report_to wandb \
|
| 41 |
+
--seed 1234 \
|
| 42 |
+
--mixed_precision fp16 \
|
| 43 |
+
--clip_length 25 \
|
| 44 |
+
--fps 6 \
|
| 45 |
+
--min_guidance_scale 1.0 \
|
| 46 |
+
--max_guidance_scale 3.0 \
|
| 47 |
+
--noise_aug_strength 0.01 \
|
| 48 |
+
--num_demo_samples 15 \
|
| 49 |
+
--num_train_epochs 10 \
|
| 50 |
+
--dataloader_num_workers 0 \
|
| 51 |
+
--resume_from_checkpoint latest \
|
| 52 |
+
--wandb_entity $WANDB_ENTITY \
|
| 53 |
+
--train_H 320 \
|
| 54 |
+
--train_W 512 \
|
| 55 |
+
--use_action_conditioning \
|
| 56 |
+
--contiguous_bbox_masking_prob 0.75 \
|
| 57 |
+
--contiguous_bbox_masking_start_ratio 0.0 \
|
| 58 |
+
--val_on_first_step
|
scripts/controlnet_train_action_singlegpu.sh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nvidia-smi | grep 'python' | awk '{ print $5 }' | xargs -n1 kill -9
|
| 2 |
+
|
| 3 |
+
# User-specific paths and settings
|
| 4 |
+
DATASET_PATH="<path/to/datasets>" # e.g., "/home/datasets_root"
|
| 5 |
+
NAME="<experiment_name>" # e.g., "box2video_experiment1"
|
| 6 |
+
OUT_DIR="<path/to/output>/${NAME}" # e.g., "/home/results/${NAME}"
|
| 7 |
+
PROJECT_NAME='<wandb_project_name>' # e.g., 'car_crash'
|
| 8 |
+
WANDB_ENTITY='<wandb_username>' # Your Weights & Biases username
|
| 9 |
+
PRETRAINED_MODEL_PATH="<path/to/pretrained/model>" # e.g., "/path/to/pretrained/checkpoint"
|
| 10 |
+
|
| 11 |
+
# export HF_HOME=/path/to/root # Where the SVD pretrained models are/will be downloaded
|
| 12 |
+
|
| 13 |
+
# Create output directory
|
| 14 |
+
mkdir -p $OUT_DIR
|
| 15 |
+
|
| 16 |
+
# Save training script for reference
|
| 17 |
+
SCRIPT_PATH=$0
|
| 18 |
+
SAVE_SCRIPT_PATH="${OUT_DIR}/train_scripts.sh"
|
| 19 |
+
cp $SCRIPT_PATH $SAVE_SCRIPT_PATH
|
| 20 |
+
echo "Saved script to ${SAVE_SCRIPT_PATH}"
|
| 21 |
+
|
| 22 |
+
# Training command
|
| 23 |
+
CUDA_LAUNCH_BLOCKING=1 accelerate launch --config_file config/a100l.yaml train_video_controlnet.py \
|
| 24 |
+
--run_name $NAME \
|
| 25 |
+
--data_root $DATASET_PATH \
|
| 26 |
+
--project_name $PROJECT_NAME \
|
| 27 |
+
--pretrained_model_name_or_path $PRETRAINED_MODEL_PATH \
|
| 28 |
+
--output_dir $OUT_DIR \
|
| 29 |
+
--variant fp16 \
|
| 30 |
+
--dataset_name mmau \
|
| 31 |
+
--train_batch_size 1 \
|
| 32 |
+
--learning_rate 4e-5 \
|
| 33 |
+
--checkpoints_total_limit 3 \
|
| 34 |
+
--checkpointing_steps 300 \
|
| 35 |
+
--checkpointing_time 10620 \
|
| 36 |
+
--gradient_accumulation_steps 5 \
|
| 37 |
+
--validation_steps 300 \
|
| 38 |
+
--enable_gradient_checkpointing \
|
| 39 |
+
--lr_scheduler constant \
|
| 40 |
+
--report_to wandb \
|
| 41 |
+
--seed 1234 \
|
| 42 |
+
--mixed_precision fp16 \
|
| 43 |
+
--clip_length 25 \
|
| 44 |
+
--fps 6 \
|
| 45 |
+
--min_guidance_scale 1.0 \
|
| 46 |
+
--max_guidance_scale 3.0 \
|
| 47 |
+
--noise_aug_strength 0.01 \
|
| 48 |
+
--num_demo_samples 15 \
|
| 49 |
+
--num_train_epochs 10 \
|
| 50 |
+
--dataloader_num_workers 0 \
|
| 51 |
+
--resume_from_checkpoint latest \
|
| 52 |
+
--wandb_entity $WANDB_ENTITY \
|
| 53 |
+
--train_H 320 \
|
| 54 |
+
--train_W 512 \
|
| 55 |
+
--use_action_conditioning \
|
| 56 |
+
--contiguous_bbox_masking_prob 0.75 \
|
| 57 |
+
--contiguous_bbox_masking_start_ratio 0.0 \
|
| 58 |
+
--val_on_first_step
|
scripts/mmau_train_video_diffusion_multigpu.sh
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nvidia-smi | grep 'python' | awk '{ print $5 }' | xargs -n1 kill -9
|
| 2 |
+
|
| 3 |
+
# User-specific paths and settings
|
| 4 |
+
DATASET_PATH="<path/to/datasets>" # e.g., "/home/datasets_root"
|
| 5 |
+
NAME="<experiment_name>" # e.g., "box2video_experiment1"
|
| 6 |
+
OUT_DIR="<path/to/output>/${NAME}" # e.g., "/home/results/${NAME}"
|
| 7 |
+
PROJECT_NAME='<wandb_project_name>' # e.g., 'car_crash'
|
| 8 |
+
WANDB_ENTITY='<wandb_username>' # Your Weights & Biases username
|
| 9 |
+
PRETRAINED_MODEL_PATH="stabilityai/stable-video-diffusion-img2vid-xt" # HuggingFace model ID
|
| 10 |
+
|
| 11 |
+
# export HF_HOME=/path/to/root # Where the SVD pretrained models are/will be downloaded
|
| 12 |
+
|
| 13 |
+
# Create output directory
|
| 14 |
+
mkdir -p $OUT_DIR
|
| 15 |
+
|
| 16 |
+
# Save training script for reference
|
| 17 |
+
SCRIPT_PATH=$0
|
| 18 |
+
SAVE_SCRIPT_PATH="${OUT_DIR}/train_scripts.sh"
|
| 19 |
+
cp $SCRIPT_PATH $SAVE_SCRIPT_PATH
|
| 20 |
+
echo "Saved script to ${SAVE_SCRIPT_PATH}"
|
| 21 |
+
|
| 22 |
+
# Training command
|
| 23 |
+
CUDA_LAUNCH_BLOCKING=1 accelerate launch --config_file config/multi_gpu.yaml train_video_diffusion.py \
|
| 24 |
+
--run_name $NAME \
|
| 25 |
+
--data_root $DATASET_PATH \
|
| 26 |
+
--project_name $PROJECT_NAME \
|
| 27 |
+
--pretrained_model_name_or_path $PRETRAINED_MODEL_PATH \
|
| 28 |
+
--output_dir $OUT_DIR \
|
| 29 |
+
--variant fp16 \
|
| 30 |
+
--dataset_name mmau \
|
| 31 |
+
--train_batch_size 1 \
|
| 32 |
+
--learning_rate 1e-5 \
|
| 33 |
+
--checkpoints_total_limit 3 \
|
| 34 |
+
--checkpointing_steps 300 \
|
| 35 |
+
--gradient_accumulation_steps 5 \
|
| 36 |
+
--validation_steps 300 \
|
| 37 |
+
--enable_gradient_checkpointing \
|
| 38 |
+
--lr_scheduler constant \
|
| 39 |
+
--report_to wandb \
|
| 40 |
+
--seed 1234 \
|
| 41 |
+
--mixed_precision fp16 \
|
| 42 |
+
--clip_length 25 \
|
| 43 |
+
--min_guidance_scale 1.0 \
|
| 44 |
+
--max_guidance_scale 3.0 \
|
| 45 |
+
--noise_aug_strength 0.01 \
|
| 46 |
+
--num_demo_samples 15 \
|
| 47 |
+
--backprop_temporal_blocks_start_iter -1 \
|
| 48 |
+
--num_train_epochs 30 \
|
| 49 |
+
--train_H 320 \
|
| 50 |
+
--train_W 512 \
|
| 51 |
+
--resume_from_checkpoint latest \
|
| 52 |
+
--wandb_entity $WANDB_ENTITY
|
scripts/mmau_train_video_diffusion_singlegpu.sh
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nvidia-smi | grep 'python' | awk '{ print $5 }' | xargs -n1 kill -9
|
| 2 |
+
|
| 3 |
+
# User-specific paths and settings
|
| 4 |
+
DATASET_PATH="<path/to/datasets>" # e.g., "/home/datasets_root"
|
| 5 |
+
NAME="<experiment_name>" # e.g., "box2video_experiment1"
|
| 6 |
+
OUT_DIR="<path/to/output>/${NAME}" # e.g., "/home/results/${NAME}"
|
| 7 |
+
PROJECT_NAME='<wandb_project_name>' # e.g., 'car_crash'
|
| 8 |
+
WANDB_ENTITY='<wandb_username>' # Your Weights & Biases usernames
|
| 9 |
+
PRETRAINED_MODEL_PATH="stabilityai/stable-video-diffusion-img2vid-xt" # HuggingFace model ID
|
| 10 |
+
|
| 11 |
+
# export HF_HOME=/path/to/root # Where the SVD pretrained models are/will be downloaded
|
| 12 |
+
|
| 13 |
+
# Create output directory
|
| 14 |
+
mkdir -p $OUT_DIR
|
| 15 |
+
|
| 16 |
+
# Save training script for reference
|
| 17 |
+
SCRIPT_PATH=$0
|
| 18 |
+
SAVE_SCRIPT_PATH="${OUT_DIR}/train_scripts.sh"
|
| 19 |
+
cp $SCRIPT_PATH $SAVE_SCRIPT_PATH
|
| 20 |
+
echo "Saved script to ${SAVE_SCRIPT_PATH}"
|
| 21 |
+
|
| 22 |
+
# Training command
|
| 23 |
+
CUDA_LAUNCH_BLOCKING=1 accelerate launch --config_file config/a100l.yaml train_video_diffusion.py \
|
| 24 |
+
--run_name $NAME \
|
| 25 |
+
--data_root $DATASET_PATH \
|
| 26 |
+
--project_name $PROJECT_NAME \
|
| 27 |
+
--pretrained_model_name_or_path $PRETRAINED_MODEL_PATH \
|
| 28 |
+
--output_dir $OUT_DIR \
|
| 29 |
+
--variant fp16 \
|
| 30 |
+
--dataset_name mmau \
|
| 31 |
+
--train_batch_size 1 \
|
| 32 |
+
--learning_rate 1e-5 \
|
| 33 |
+
--checkpoints_total_limit 3 \
|
| 34 |
+
--checkpointing_steps 300 \
|
| 35 |
+
--gradient_accumulation_steps 5 \
|
| 36 |
+
--validation_steps 300 \
|
| 37 |
+
--enable_gradient_checkpointing \
|
| 38 |
+
--lr_scheduler constant \
|
| 39 |
+
--report_to wandb \
|
| 40 |
+
--seed 1234 \
|
| 41 |
+
--mixed_precision fp16 \
|
| 42 |
+
--clip_length 25 \
|
| 43 |
+
--min_guidance_scale 1.0 \
|
| 44 |
+
--max_guidance_scale 3.0 \
|
| 45 |
+
--noise_aug_strength 0.01 \
|
| 46 |
+
--bbox_dropout_prob 0.1 \
|
| 47 |
+
--num_demo_samples 15 \
|
| 48 |
+
--backprop_temporal_blocks_start_iter -1 \
|
| 49 |
+
--num_train_epochs 30 \
|
| 50 |
+
--train_H 320 \
|
| 51 |
+
--train_W 512 \
|
| 52 |
+
--resume_from_checkpoint latest \
|
| 53 |
+
--wandb_entity $WANDB_ENTITY
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (153 Bytes). View file
|
|
|
src/datasets/base_dataset.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
from src.datasets.bbox_utils import plot_2d_bbox
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BaseDataset:
|
| 11 |
+
|
| 12 |
+
def __init__(self,
|
| 13 |
+
root='./datasets',
|
| 14 |
+
train=True,
|
| 15 |
+
clip_length=25,
|
| 16 |
+
# orig_width=None, orig_height=None,
|
| 17 |
+
resize_width=512, resize_height=320,
|
| 18 |
+
non_overlapping_clips=False,
|
| 19 |
+
bbox_masking_prob=0.0,
|
| 20 |
+
sample_clip_from_end=True,
|
| 21 |
+
ego_only=False,
|
| 22 |
+
ignore_labels=False):
|
| 23 |
+
|
| 24 |
+
self.root = root
|
| 25 |
+
self.train = train
|
| 26 |
+
self.clip_length = clip_length
|
| 27 |
+
# self.orig_width = orig_width
|
| 28 |
+
# self.orig_height = orig_height
|
| 29 |
+
self.resize_width = resize_width
|
| 30 |
+
self.resize_height = resize_height
|
| 31 |
+
|
| 32 |
+
self.non_overlapping_clips = non_overlapping_clips
|
| 33 |
+
self.bbox_masking_prob = bbox_masking_prob
|
| 34 |
+
self.sample_clip_from_end = sample_clip_from_end
|
| 35 |
+
self.ego_only = ego_only
|
| 36 |
+
self.ignore_labels = ignore_labels
|
| 37 |
+
|
| 38 |
+
self.data_split = 'train' if self.train else 'val'
|
| 39 |
+
|
| 40 |
+
# Image transforms
|
| 41 |
+
self.transform = transforms.Compose([
|
| 42 |
+
transforms.Resize((self.resize_height, self.resize_width)),
|
| 43 |
+
transforms.ToTensor(),
|
| 44 |
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # map from [0,1] to [-1,1]
|
| 45 |
+
])
|
| 46 |
+
self.revert_transform = transforms.Compose([
|
| 47 |
+
transforms.Normalize(mean=(-1, -1, -1), std=(2, 2, 2)),
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
self.image_files = [] # Contains the paths of all the images in the dataset
|
| 51 |
+
self.clip_list = [] # Contains a list of image indices for each clip
|
| 52 |
+
self.frame_labels = [] # For each image file, contains a list of dicts (labels of each object in the frame)
|
| 53 |
+
|
| 54 |
+
self.disable_cache = True
|
| 55 |
+
if self.disable_cache:
|
| 56 |
+
print("Bbox image caching disabled")
|
| 57 |
+
|
| 58 |
+
def __len__(self):
|
| 59 |
+
return len(self.clip_list)
|
| 60 |
+
|
| 61 |
+
def __getitem__(self, index):
|
| 62 |
+
return self._getclipitem(index)
|
| 63 |
+
|
| 64 |
+
def _getclipitem(self, index):
|
| 65 |
+
frames_indices = self.clip_list[index]
|
| 66 |
+
|
| 67 |
+
images, labels, bboxes, image_paths = [], [], [], []
|
| 68 |
+
masked_track_ids = self._get_masked_track_ids(frames_indices)
|
| 69 |
+
for frame_idx in frames_indices:
|
| 70 |
+
ret_dict = self._getimageitem(frame_idx, masked_track_ids=masked_track_ids)
|
| 71 |
+
images.append(ret_dict["image"])
|
| 72 |
+
labels.append(ret_dict["labels"])
|
| 73 |
+
bboxes.append(ret_dict["bbox_image"])
|
| 74 |
+
image_paths.append(ret_dict["image_path"])
|
| 75 |
+
images = torch.stack(images)
|
| 76 |
+
prompt = "" # NOTE: Currently not supporting prompts
|
| 77 |
+
|
| 78 |
+
action_type = 0 # Assume "normal" driving when unspecified
|
| 79 |
+
if hasattr(self, "action_type_list"):
|
| 80 |
+
action_type = self.action_type_list[index]
|
| 81 |
+
|
| 82 |
+
vid_name = self.image_files[frames_indices[0]].split("/")[-1].split(".")[0][:-5]
|
| 83 |
+
|
| 84 |
+
if not self.ignore_labels:
|
| 85 |
+
bboxes = torch.stack(bboxes)
|
| 86 |
+
|
| 87 |
+
# NOTE: Keys are plural because this makes more sense when batches get collated
|
| 88 |
+
ret_dict = {"clips": images,
|
| 89 |
+
"prompts": prompt,
|
| 90 |
+
"indices": index,
|
| 91 |
+
"bbox_images": bboxes,
|
| 92 |
+
"action_type": action_type,
|
| 93 |
+
"vid_name": vid_name,
|
| 94 |
+
"image_paths": image_paths
|
| 95 |
+
}
|
| 96 |
+
else:
|
| 97 |
+
ret_dict = {"clips": images,
|
| 98 |
+
"prompts": prompt,
|
| 99 |
+
"indices": index}
|
| 100 |
+
|
| 101 |
+
return ret_dict
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _getimageitem(self, frame_index, masked_track_ids=None):
|
| 105 |
+
# Get the image
|
| 106 |
+
image_file = self.image_files[frame_index]
|
| 107 |
+
image = Image.open(image_file)
|
| 108 |
+
image = self.transform(image)
|
| 109 |
+
|
| 110 |
+
if not self.ignore_labels:
|
| 111 |
+
|
| 112 |
+
# Get the labels
|
| 113 |
+
labels = self.frame_labels[frame_index]
|
| 114 |
+
|
| 115 |
+
# Get the bbox image (from cache or draw new one)
|
| 116 |
+
image_filename = image_file.split('/')[-1].split('.')[0]
|
| 117 |
+
cache_filename = f"{image_filename}_bboxes"
|
| 118 |
+
cache_file = os.path.join(self.bbox_image_dir, f"{cache_filename}.jpg")
|
| 119 |
+
redraw_for_masked_agents = masked_track_ids is not None and len(masked_track_ids) > 0
|
| 120 |
+
if not os.path.exists(cache_file) or redraw_for_masked_agents or self.disable_cache:
|
| 121 |
+
bbox_im = self._draw_bbox(labels, cache_img_name=cache_filename, masked_track_ids=masked_track_ids, disable_cache=redraw_for_masked_agents or self.disable_cache)
|
| 122 |
+
else:
|
| 123 |
+
bbox_im = Image.open(cache_file)
|
| 124 |
+
bbox_im = self.transform(bbox_im)
|
| 125 |
+
else:
|
| 126 |
+
labels = None
|
| 127 |
+
bbox_im = None
|
| 128 |
+
|
| 129 |
+
ret_dict = {"image": image,
|
| 130 |
+
"image_path": image_file,
|
| 131 |
+
"labels": labels,
|
| 132 |
+
"frame_index": frame_index,
|
| 133 |
+
"bbox_image": bbox_im}
|
| 134 |
+
|
| 135 |
+
return ret_dict
|
| 136 |
+
|
| 137 |
+
def _draw_bbox(self, frame_labels, cache_img_name=None, masked_track_ids=None, disable_cache=False):
|
| 138 |
+
canvas = torch.zeros((3, self.orig_height, self.orig_width))
|
| 139 |
+
bbox_im = plot_2d_bbox(canvas, frame_labels, show_track_color=True, masked_track_ids=masked_track_ids)
|
| 140 |
+
transform = transforms.Compose([transforms.ToPILImage()])
|
| 141 |
+
bbox_pil = transform(bbox_im)
|
| 142 |
+
|
| 143 |
+
if cache_img_name is not None and not disable_cache:
|
| 144 |
+
if not os.path.exists(self.bbox_image_dir):
|
| 145 |
+
os.makedirs(self.bbox_image_dir, exist_ok=True)
|
| 146 |
+
image_path = os.path.join(self.bbox_image_dir, f"{cache_img_name}.jpg")
|
| 147 |
+
bbox_pil.save(image_path)
|
| 148 |
+
print("Cached bbox file:", image_path)
|
| 149 |
+
|
| 150 |
+
bbox_im = self.transform(bbox_pil)
|
| 151 |
+
return bbox_im
|
| 152 |
+
|
| 153 |
+
def _get_masked_track_ids(self, frames_indices):
|
| 154 |
+
masked_track_ids = []
|
| 155 |
+
if self.bbox_masking_prob > 0:
|
| 156 |
+
# Find all the trackIDs in the clip, randomly select some to mask and exclude from the bbox rendering
|
| 157 |
+
all_track_ids = set()
|
| 158 |
+
for frame_idx in frames_indices:
|
| 159 |
+
frame_labels = self.frame_labels[frame_idx] #self._parse_label(self.image_files[frame])
|
| 160 |
+
for label in frame_labels:
|
| 161 |
+
track_id = label['track_id']
|
| 162 |
+
if track_id not in all_track_ids and random.random() <= self.bbox_masking_prob:
|
| 163 |
+
# Mask out this agent
|
| 164 |
+
masked_track_ids.append(track_id)
|
| 165 |
+
all_track_ids.add(label['track_id'])
|
| 166 |
+
|
| 167 |
+
return masked_track_ids
|
| 168 |
+
|
| 169 |
+
def get_frame_file_by_index(self, index, timestep=0):
|
| 170 |
+
frames = self.clip_list[index]
|
| 171 |
+
if timestep is None:
|
| 172 |
+
ret = []
|
| 173 |
+
for frame in frames:
|
| 174 |
+
ret.append(self.image_files[frame])
|
| 175 |
+
return ret
|
| 176 |
+
return self.image_files[frames[timestep]]
|
| 177 |
+
|
| 178 |
+
def get_bbox_image_file_by_index(self, index=None, image_file=None):
|
| 179 |
+
if image_file is None:
|
| 180 |
+
image_file = self.get_frame_file_by_index(index)
|
| 181 |
+
|
| 182 |
+
clip_name = image_file.split("/")[-2]
|
| 183 |
+
return image_file.replace(self.image_dir, self.bbox_image_dir).replace('/'+clip_name+'/', '/').replace(".jpg", "_bboxes.jpg")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
src/datasets/bbox_utils.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
class CVCOLORS:
|
| 6 |
+
RED = (0,0,255)
|
| 7 |
+
GREEN = (0,255,0)
|
| 8 |
+
BLUE = (255,0,0)
|
| 9 |
+
PURPLE = (247,44,200)
|
| 10 |
+
ORANGE = (44,162,247)
|
| 11 |
+
MINT = (239,255,66)
|
| 12 |
+
YELLOW = (2,255,250)
|
| 13 |
+
BROWN = (42,42,165)
|
| 14 |
+
LIME=(51,255,153)
|
| 15 |
+
GRAY=(128, 128, 128)
|
| 16 |
+
LIGHTPINK = (222,209,255)
|
| 17 |
+
LIGHTGREEN = (204,255,204)
|
| 18 |
+
LIGHTBLUE = (255,235,207)
|
| 19 |
+
LIGHTPURPLE = (255,153,204)
|
| 20 |
+
LIGHTRED = (204,204,255)
|
| 21 |
+
WHITE = (255,255,255)
|
| 22 |
+
BLACK = (0,0,0)
|
| 23 |
+
|
| 24 |
+
TRACKID_LOOKUP = defaultdict(lambda: (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255)))
|
| 25 |
+
TYPE_LOOKUP = [BROWN, BLUE, PURPLE, RED, ORANGE, YELLOW, LIGHTPINK, LIGHTPURPLE, GRAY, LIGHTRED, GREEN]
|
| 26 |
+
REVERT_CHANNEL_F = lambda x: (x[2], x[1], x[0])
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# TODO: This could be moved to base dataset class (?)
|
| 30 |
+
def plot_2d_bbox(img, labels, show_track_color=False, channel_first=True, rgb2bgr=False, box_color=None, masked_track_ids=None, crash_border=False):
|
| 31 |
+
|
| 32 |
+
if channel_first:
|
| 33 |
+
img = img.permute((1, 2, 0)).detach().cpu().numpy().copy()*255
|
| 34 |
+
else:
|
| 35 |
+
img = img.detach().cpu().numpy().copy()*255
|
| 36 |
+
|
| 37 |
+
if rgb2bgr:
|
| 38 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 39 |
+
|
| 40 |
+
masked_track_ids = masked_track_ids or []
|
| 41 |
+
|
| 42 |
+
for i, label_info in enumerate(labels):
|
| 43 |
+
track_id = label_info['track_id']
|
| 44 |
+
if track_id in masked_track_ids:
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
box_2d = label_info['bbox']
|
| 48 |
+
|
| 49 |
+
if not show_track_color:
|
| 50 |
+
type_color_i = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[label_info['class_id']])) / 255 if box_color is None else box_color
|
| 51 |
+
track_color_i = CVCOLORS.REVERT_CHANNEL_F((1, 1, 1))
|
| 52 |
+
|
| 53 |
+
cv2.rectangle(img, (int(box_2d[0]), int(box_2d[1])), (int(box_2d[2]), int(box_2d[3])), type_color_i, cv2.FILLED)
|
| 54 |
+
cv2.rectangle(img, (int(box_2d[0]), int(box_2d[1])), (int(box_2d[2]), int(box_2d[3])), track_color_i, 2)
|
| 55 |
+
else:
|
| 56 |
+
type_color_i = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[label_info['class_id']])) / 255 if box_color is None else box_color
|
| 57 |
+
track_color_i = CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TRACKID_LOOKUP[label_info['track_id']])
|
| 58 |
+
|
| 59 |
+
dim = min(box_2d[2] - box_2d[0], box_2d[3] - box_2d[1])
|
| 60 |
+
b_thick = min(max(dim * 0.1, 2), 8)
|
| 61 |
+
cv2.rectangle(img, (int(box_2d[0]), int(box_2d[1])), (int(box_2d[2]), int(box_2d[3])), type_color_i, cv2.FILLED)
|
| 62 |
+
cv2.rectangle(img, (int(box_2d[0] + b_thick), int(box_2d[1] + b_thick)), (int(box_2d[2] - b_thick), int(box_2d[3] - b_thick)), track_color_i, cv2.FILLED)
|
| 63 |
+
|
| 64 |
+
if crash_border:
|
| 65 |
+
thickness = 20
|
| 66 |
+
cv2.rectangle(img, (0, 0), (img.shape[1], img.shape[0]), color=(0, 1, 0), thickness=thickness, lineType=cv2.LINE_8)
|
| 67 |
+
|
| 68 |
+
return img
|
src/datasets/bdd100k_dataset.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_dataset import BaseDataset
|
| 2 |
+
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BDD100KDataset(BaseDataset):
|
| 9 |
+
CLASS_NAME_TO_ID = {
|
| 10 |
+
'pedestrian': 1,
|
| 11 |
+
'rider': 2,
|
| 12 |
+
'car': 3,
|
| 13 |
+
'truck': 4,
|
| 14 |
+
'bus': 5,
|
| 15 |
+
'train': 6,
|
| 16 |
+
'motorcycle': 7,
|
| 17 |
+
'bicycle': 8,
|
| 18 |
+
'traffic light': 9,
|
| 19 |
+
'traffic sign': 10,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
TO_COCO_LABELS = {
|
| 23 |
+
1: 0,
|
| 24 |
+
2: 0,
|
| 25 |
+
3: 2,
|
| 26 |
+
4: 7,
|
| 27 |
+
5: 5,
|
| 28 |
+
6: 6,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
TO_IMAGE_DIR = 'images/track'
|
| 32 |
+
TO_BBOX_DIR = 'bboxes/track'
|
| 33 |
+
TO_LABEL_DIR = 'labels'
|
| 34 |
+
TO_BBOX_LABELS = 'labels/box_track_20'
|
| 35 |
+
TO_SEG_LABELS = 'labels/seg_track_20/colormaps'
|
| 36 |
+
TO_POSE_LABELS = 'labels/pose_21'
|
| 37 |
+
|
| 38 |
+
def __init__(self,
|
| 39 |
+
root='./datasets',
|
| 40 |
+
train=True,
|
| 41 |
+
clip_length=25,
|
| 42 |
+
#orig_height=720, orig_width=1280, # TODO: Define this (and use it)
|
| 43 |
+
resize_height=320, resize_width=512,
|
| 44 |
+
non_overlapping_clips=False,
|
| 45 |
+
bbox_masking_prob=0.0,
|
| 46 |
+
sample_clip_from_end=True,
|
| 47 |
+
ego_only=False,
|
| 48 |
+
ignore_labels=False,
|
| 49 |
+
use_preplotted_bbox=True,
|
| 50 |
+
specific_samples=None,
|
| 51 |
+
specific_categories=None,
|
| 52 |
+
force_clip_type=None):
|
| 53 |
+
|
| 54 |
+
super(BDD100KDataset, self).__init__(root=root,
|
| 55 |
+
train=train,
|
| 56 |
+
clip_length=clip_length,
|
| 57 |
+
resize_height=resize_height,
|
| 58 |
+
resize_width=resize_width,
|
| 59 |
+
non_overlapping_clips=non_overlapping_clips,
|
| 60 |
+
bbox_masking_prob=bbox_masking_prob,
|
| 61 |
+
sample_clip_from_end=sample_clip_from_end,
|
| 62 |
+
ego_only=ego_only,
|
| 63 |
+
ignore_labels=ignore_labels)
|
| 64 |
+
|
| 65 |
+
self.MAX_BOXES_PER_DATA = 30
|
| 66 |
+
self._location = 'train' if self.train else 'val'
|
| 67 |
+
self.version = 'bdd100k'
|
| 68 |
+
self.use_preplotted_bbox = use_preplotted_bbox
|
| 69 |
+
|
| 70 |
+
self.image_dir = os.path.join(self.root, self.version, BDD100KDataset.TO_IMAGE_DIR, self._location)
|
| 71 |
+
self.bbox_label_dir = os.path.join(self.root, self.version, BDD100KDataset.TO_BBOX_LABELS, self._location)
|
| 72 |
+
self.bbox_image_dir = os.path.join(self.root, self.version, BDD100KDataset.TO_BBOX_DIR, self._location)
|
| 73 |
+
|
| 74 |
+
if specific_categories is not None:
|
| 75 |
+
print("BDD100k does not support `specific_categories`")
|
| 76 |
+
if force_clip_type is not None:
|
| 77 |
+
print("BDD100k does not support `force_clip_type`")
|
| 78 |
+
self.specific_samples = specific_samples
|
| 79 |
+
if self.specific_samples is not None:
|
| 80 |
+
print("Only loading specific samples:", self.specific_samples)
|
| 81 |
+
|
| 82 |
+
listed_image_dir = os.listdir(self.image_dir)
|
| 83 |
+
try:
|
| 84 |
+
listed_image_dir.remove('pred')
|
| 85 |
+
except:
|
| 86 |
+
pass
|
| 87 |
+
self.clip_folders = sorted(listed_image_dir)
|
| 88 |
+
self.clip_folder_lengths = {k:len(os.listdir(os.path.join(self.image_dir, k))) for k in self.clip_folders}
|
| 89 |
+
|
| 90 |
+
for l in self.clip_folder_lengths.values():
|
| 91 |
+
assert l >= self.clip_length, f'clip length {self.clip_length} is too long for clip folder length {l}'
|
| 92 |
+
|
| 93 |
+
self._collect_clips()
|
| 94 |
+
|
| 95 |
+
def _collect_clips(self):
|
| 96 |
+
print("Collecting dataset clips...")
|
| 97 |
+
|
| 98 |
+
for clip_folder in self.clip_folders:
|
| 99 |
+
clip_path = os.path.join(self.image_dir, clip_folder)
|
| 100 |
+
clip_frames = sorted(os.listdir(clip_path))
|
| 101 |
+
|
| 102 |
+
if self.specific_samples is not None and clip_folder not in self.specific_samples:
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
# Add all images to image_files
|
| 106 |
+
image_indices = []
|
| 107 |
+
for frame in clip_frames:
|
| 108 |
+
self.image_files.append(os.path.join(clip_path, frame))
|
| 109 |
+
image_indices.append(len(self.image_files)-1)
|
| 110 |
+
|
| 111 |
+
# Create clips of length clip_length
|
| 112 |
+
if self.clip_length is not None:
|
| 113 |
+
# Collect clips as overlapping clips (i.e. A video with 30 frames will yield 5 25-frame clips)
|
| 114 |
+
for start_image_idx in range(0, len(clip_frames) - self.clip_length + 1):
|
| 115 |
+
end_image_idx = start_image_idx + self.clip_length
|
| 116 |
+
clip_indices = image_indices[start_image_idx:end_image_idx]
|
| 117 |
+
self.clip_list.append(clip_indices)
|
| 118 |
+
|
| 119 |
+
def _parse_label(self, label_file, frame_id):
|
| 120 |
+
target = []
|
| 121 |
+
with open(label_file, 'r') as f:
|
| 122 |
+
label = json.load(f)
|
| 123 |
+
frame_i = int(frame_id[-11:-4])-1
|
| 124 |
+
assert frame_id == label[frame_i]['name']
|
| 125 |
+
for obj in label[frame_i-1]['labels']:
|
| 126 |
+
if obj['category'] not in BDD100KDataset.CLASS_NAME_TO_ID:
|
| 127 |
+
continue
|
| 128 |
+
target.append({
|
| 129 |
+
'frame_name': frame_id,
|
| 130 |
+
'track_id': int(obj['id']),
|
| 131 |
+
'bbox': [obj['box2d']['x1'], obj['box2d']['y1'], obj['box2d']['x2'], obj['box2d']['y2']],
|
| 132 |
+
'class_id': BDD100KDataset.CLASS_NAME_TO_ID[obj['category']],
|
| 133 |
+
'class_name': obj['category'],
|
| 134 |
+
})
|
| 135 |
+
if len(target) >= self.MAX_BOXES_PER_DATA:
|
| 136 |
+
break
|
| 137 |
+
return target
|
| 138 |
+
|
| 139 |
+
def _getimageitem(self, frame_index, masked_track_ids=None):
|
| 140 |
+
# Get the image
|
| 141 |
+
image_file = self.image_files[frame_index]
|
| 142 |
+
image = Image.open(image_file)
|
| 143 |
+
image = self.transform(image)
|
| 144 |
+
|
| 145 |
+
if not self.ignore_labels:
|
| 146 |
+
# Get the labels
|
| 147 |
+
clip_id = image_file[:image_file.rfind('/')]
|
| 148 |
+
clip_id = clip_id[clip_id.rfind('/')+1:]
|
| 149 |
+
label_file = os.path.join(self.bbox_label_dir, f'{clip_id}.json')
|
| 150 |
+
frame_id = image_file[image_file.rfind('/')+1:]
|
| 151 |
+
labels = self._parse_label(label_file, frame_id)
|
| 152 |
+
|
| 153 |
+
# Get the bbox image
|
| 154 |
+
if self.use_preplotted_bbox:
|
| 155 |
+
bbox_file = self.get_bbox_image_file_by_index(image_file=image_file)
|
| 156 |
+
bbox_im = Image.open(bbox_file)
|
| 157 |
+
bbox_im = self.transform(bbox_im)
|
| 158 |
+
else:
|
| 159 |
+
bbox_im = self._draw_bbox(labels, masked_track_ids=masked_track_ids)
|
| 160 |
+
else:
|
| 161 |
+
labels = None
|
| 162 |
+
bbox_im = None
|
| 163 |
+
|
| 164 |
+
ret_dict = {"image": image,
|
| 165 |
+
"image_path": image_file,
|
| 166 |
+
"labels": labels,
|
| 167 |
+
"frame_index": frame_index,
|
| 168 |
+
"bbox_image": bbox_im}
|
| 169 |
+
|
| 170 |
+
return ret_dict
|
| 171 |
+
|
| 172 |
+
def get_bbox_image_file_by_index(self, index=None, image_file=None):
|
| 173 |
+
if image_file is None:
|
| 174 |
+
image_file = self.get_image_file_by_index(index)
|
| 175 |
+
|
| 176 |
+
return image_file.replace(BDD100KDataset.TO_IMAGE_DIR, BDD100KDataset.TO_BBOX_DIR)
|
| 177 |
+
|
| 178 |
+
def get_image_file_by_index(self, index):
|
| 179 |
+
return self.image_files[index]
|
| 180 |
+
|
| 181 |
+
def __len__(self):
|
| 182 |
+
return len(self.clip_list) if self.clip_length is not None else len(self.image_files)
|
| 183 |
+
|
| 184 |
+
if __name__ == "__init__":
|
| 185 |
+
dataset = BDD100KDataset()
|
src/datasets/dada2000_dataset.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import csv
|
| 5 |
+
|
| 6 |
+
from src.datasets.base_dataset import BaseDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DADA2000Dataset(BaseDataset):
|
| 10 |
+
|
| 11 |
+
CLASS_NAME_TO_ID = {
|
| 12 |
+
'person': 1,
|
| 13 |
+
'car': 3,
|
| 14 |
+
'truck': 4,
|
| 15 |
+
'bus': 5,
|
| 16 |
+
'train': 6,
|
| 17 |
+
'motorcycle': 7,
|
| 18 |
+
'bicycle': 8,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
def __init__(self,
|
| 22 |
+
root='./datasets',
|
| 23 |
+
train=True,
|
| 24 |
+
clip_length=25,
|
| 25 |
+
orig_height=660, orig_width=1056,
|
| 26 |
+
resize_height=320, resize_width=512,
|
| 27 |
+
non_overlapping_clips=False,
|
| 28 |
+
bbox_masking_prob=0.0,
|
| 29 |
+
sample_clip_from_end=True,
|
| 30 |
+
ego_only=False,
|
| 31 |
+
specific_samples=None):
|
| 32 |
+
|
| 33 |
+
super(DADA2000Dataset, self).__init__(root=root,
|
| 34 |
+
train=train,
|
| 35 |
+
clip_length=clip_length,
|
| 36 |
+
resize_height=resize_height,
|
| 37 |
+
resize_width=resize_width,
|
| 38 |
+
non_overlapping_clips=non_overlapping_clips,
|
| 39 |
+
bbox_masking_prob=bbox_masking_prob,
|
| 40 |
+
sample_clip_from_end=sample_clip_from_end,
|
| 41 |
+
ego_only=ego_only)
|
| 42 |
+
|
| 43 |
+
self.dataset_name = "preprocess_dada2000"
|
| 44 |
+
|
| 45 |
+
self.orig_width = orig_width
|
| 46 |
+
self.orig_height = orig_height
|
| 47 |
+
self.image_dir = os.path.join(self.root, self.dataset_name, "images", self.data_split)
|
| 48 |
+
self.label_dir = os.path.join(self.root, self.dataset_name, "labels", self.data_split)
|
| 49 |
+
self.bbox_image_dir = os.path.join(self.root, self.dataset_name, "bbox_images", self.data_split)
|
| 50 |
+
self.metadata_csv_path = os.path.join(self.root, self.dataset_name, "metadata.csv") # TODO: This information could be transfered into each individual label file
|
| 51 |
+
|
| 52 |
+
self.strict_collision_filter = True
|
| 53 |
+
if self.strict_collision_filter:
|
| 54 |
+
print("Strict collision filter set for DADA2000")
|
| 55 |
+
|
| 56 |
+
self.specific_samples = specific_samples
|
| 57 |
+
if self.specific_samples is not None:
|
| 58 |
+
print("Only loading specific samples:", self.specific_samples)
|
| 59 |
+
|
| 60 |
+
self._collect_clips()
|
| 61 |
+
|
| 62 |
+
def _collect_clips(self):
|
| 63 |
+
|
| 64 |
+
accident_frame_metadata = {}
|
| 65 |
+
with open(self.metadata_csv_path) as csv_file:
|
| 66 |
+
csv_reader = csv.reader(csv_file)
|
| 67 |
+
for i, row in enumerate(csv_reader):
|
| 68 |
+
if i == 0:
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
video_num = row[0]
|
| 72 |
+
video_type = row[5]
|
| 73 |
+
abnormal_start_frame_idx = int(row[7])
|
| 74 |
+
accident_frame_idx = int(row[8])
|
| 75 |
+
abnormal_end_frame_idx = int(row[9])
|
| 76 |
+
|
| 77 |
+
video_name = f"{video_type}_{video_num.rjust(3, '0')}"
|
| 78 |
+
|
| 79 |
+
if accident_frame_idx == "-1":
|
| 80 |
+
# print("Skipping video:", video_name)
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
# Need to convert the original frame idx to closest downsampled frame index
|
| 84 |
+
downsample_factor = 30/7 # Because we downsampled from 30fps to 7fps
|
| 85 |
+
accident_frame_metadata[video_name] = (int(abnormal_start_frame_idx / downsample_factor), int(accident_frame_idx / downsample_factor + 0.5), int(abnormal_end_frame_idx / downsample_factor + 0.5))
|
| 86 |
+
|
| 87 |
+
self.clip_type_list = [] # crash or normal or abnormal (abnormal is a scene that has abnormal driving but doesn't contain the actual crash moment)
|
| 88 |
+
image_indices_by_clip = {}
|
| 89 |
+
for label_file in sorted(os.listdir(self.label_dir)):
|
| 90 |
+
if not label_file.endswith('.json'):
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
full_filename = os.path.join(self.label_dir, label_file)
|
| 94 |
+
with open(full_filename) as json_file:
|
| 95 |
+
all_data = json.load(json_file)
|
| 96 |
+
metadata = all_data['metadata']
|
| 97 |
+
|
| 98 |
+
if self.ego_only:
|
| 99 |
+
print("Ego collisions only activated!")
|
| 100 |
+
if metadata['ego_involved'] == False:
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
if self.strict_collision_filter and metadata["accident_type"] in [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 25, 26, 28, 29, 31, 32, 34, 35, 36]:
|
| 104 |
+
# Reject video where the collision is with "static" agents
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
# Some rejected clips
|
| 108 |
+
if all_data["video_source"] in ["10_001.mp4"]:
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
if self.specific_samples is not None and all_data["video_source"].split(".")[0] not in self.specific_samples:
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
clip_filename = label_file.split('.')[0]
|
| 116 |
+
clip_file = os.path.join(self.image_dir, clip_filename)
|
| 117 |
+
clip_data = all_data["data"]
|
| 118 |
+
clip_frames = sorted(os.listdir(clip_file))
|
| 119 |
+
|
| 120 |
+
num_frames = len(clip_frames)
|
| 121 |
+
if num_frames < self.clip_length:
|
| 122 |
+
# print(f"{clip_filename} does not have enough frames: has {num_frames} expected at least {self.clip_length}")
|
| 123 |
+
continue
|
| 124 |
+
|
| 125 |
+
accident_metadata = accident_frame_metadata.get(clip_filename)
|
| 126 |
+
if accident_metadata is None:
|
| 127 |
+
print(clip_filename, "no accident metadata found")
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
# Frames within the abnormal range are considered accidents and outside are considered normal driving
|
| 131 |
+
# ab_start_idx, acc_idx, ab_end_idx = accident_metadata
|
| 132 |
+
# if ab_end_idx - ab_start_idx >= self.clip_length:
|
| 133 |
+
# # We can just feed abnormal clip frames
|
| 134 |
+
# clip_data = clip_data[ab_start_idx:ab_end_idx]
|
| 135 |
+
# clip_frames = clip_frames[ab_start_idx:ab_end_idx]
|
| 136 |
+
# num_frames = len(clip_frames)
|
| 137 |
+
# else:
|
| 138 |
+
# # print(clip_filename, "no enough abnormal frames:", ab_end_idx - ab_start_idx)
|
| 139 |
+
# continue
|
| 140 |
+
|
| 141 |
+
# #NOTE: Let's trim really long videos: videos over 75 frames will get the first frames trimmed
|
| 142 |
+
# if num_frames > 75:
|
| 143 |
+
# # print("Long video:", clip_filename, len(num_frames), "frames")
|
| 144 |
+
# frames_to_trim = num_frames - 75
|
| 145 |
+
# clip_data = clip_data[frames_to_trim:]
|
| 146 |
+
# clip_frames = clip_frames[frames_to_trim:]
|
| 147 |
+
|
| 148 |
+
clip_label_data = self._parse_clip_labels(clip_data)
|
| 149 |
+
self.frame_labels.extend(clip_label_data) # In this case labels are already sorted so they will match up to the image indices
|
| 150 |
+
|
| 151 |
+
image_indices_by_clip[clip_filename] = []
|
| 152 |
+
for image_file in clip_frames:
|
| 153 |
+
self.image_files.append(os.path.join(clip_file, image_file))
|
| 154 |
+
image_indices_by_clip[clip_filename].append(len(self.image_files)-1)
|
| 155 |
+
|
| 156 |
+
assert len(self.frame_labels) == len(self.image_files) # We assume a one-to-one association between images and labels
|
| 157 |
+
|
| 158 |
+
ab_start_idx, acc_idx, ab_end_idx = accident_metadata
|
| 159 |
+
def get_clip_type(image_idx, end_image_idx):
|
| 160 |
+
clip_type = "normal"
|
| 161 |
+
if image_idx <= acc_idx and end_image_idx > acc_idx:
|
| 162 |
+
# Contains accident frame
|
| 163 |
+
clip_type = "crash"
|
| 164 |
+
elif (image_idx >= ab_start_idx and image_idx <= ab_end_idx) or (end_image_idx > ab_start_idx and end_image_idx < ab_end_idx):
|
| 165 |
+
# Does not contain accident frame, but contains abnormal driving (moment before and after accident)
|
| 166 |
+
clip_type = "abnormal"
|
| 167 |
+
|
| 168 |
+
return clip_type
|
| 169 |
+
|
| 170 |
+
# Cut the videos in clips of the correct length according to the strategies chosen
|
| 171 |
+
if not self.non_overlapping_clips:
|
| 172 |
+
for image_idx in range(len(image_indices_by_clip[clip_filename]) - self.clip_length + 1):
|
| 173 |
+
end_image_idx = image_idx+self.clip_length
|
| 174 |
+
|
| 175 |
+
clip_type = get_clip_type(image_idx, end_image_idx)
|
| 176 |
+
if clip_type == "abnormal":
|
| 177 |
+
# Let's just reject the abnormal clips
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
self.clip_list.append(image_indices_by_clip[clip_filename][image_idx:end_image_idx])
|
| 181 |
+
self.clip_type_list.append(clip_type)
|
| 182 |
+
|
| 183 |
+
else:
|
| 184 |
+
if self.sample_clip_from_end:
|
| 185 |
+
# In case self.clip_length << actual video sample length, we can create multiple non-overlapping clips for each sample
|
| 186 |
+
# Prioritize selecting clips from the end, to make sur the accident is included (which tends to be at the end of the videos)
|
| 187 |
+
total_frames = len(image_indices_by_clip[clip_filename])
|
| 188 |
+
for clip_i in range(total_frames // self.clip_length):
|
| 189 |
+
start_image_idx = total_frames - (self.clip_length * (clip_i + 1))
|
| 190 |
+
end_image_idx = total_frames - (self.clip_length * clip_i)
|
| 191 |
+
|
| 192 |
+
clip_type = get_clip_type(start_image_idx, end_image_idx)
|
| 193 |
+
if clip_type == "abnormal":
|
| 194 |
+
# Let's just reject the abnormal clips
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
self.clip_list.append(image_indices_by_clip[clip_filename][start_image_idx:end_image_idx])
|
| 198 |
+
self.clip_type_list.append(clip_type)
|
| 199 |
+
else:
|
| 200 |
+
total_frames = len(image_indices_by_clip[clip_filename])
|
| 201 |
+
for clip_i in range(total_frames // self.clip_length):
|
| 202 |
+
start_image_idx = clip_i * self.clip_length
|
| 203 |
+
end_image_idx = start_image_idx + self.clip_length
|
| 204 |
+
|
| 205 |
+
clip_type = get_clip_type(start_image_idx, end_image_idx)
|
| 206 |
+
if clip_type == "abnormal":
|
| 207 |
+
# Let's just reject the abnormal clips
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
self.clip_list.append(image_indices_by_clip[clip_filename][start_image_idx:end_image_idx])
|
| 211 |
+
self.clip_type_list.append(clip_type)
|
| 212 |
+
|
| 213 |
+
print("Number of clips DADA2000:", len(self.clip_list), f"({self.data_split})")
|
| 214 |
+
crash_clip_count = 0
|
| 215 |
+
normal_clip_count = 0
|
| 216 |
+
for clip_type in self.clip_type_list:
|
| 217 |
+
if clip_type == "crash":
|
| 218 |
+
crash_clip_count += 1
|
| 219 |
+
elif clip_type == "normal":
|
| 220 |
+
normal_clip_count += 1
|
| 221 |
+
print(crash_clip_count, "crash clips", normal_clip_count, "normal clips")
|
| 222 |
+
|
| 223 |
+
def _parse_clip_labels(self, clip_data):
|
| 224 |
+
frame_labels = []
|
| 225 |
+
for frame_data in clip_data:
|
| 226 |
+
obj_data = frame_data['labels']
|
| 227 |
+
|
| 228 |
+
object_labels = []
|
| 229 |
+
for label in obj_data:
|
| 230 |
+
# Only keep the classes of interest
|
| 231 |
+
class_id = DADA2000Dataset.CLASS_NAME_TO_ID.get(label['name'])
|
| 232 |
+
if class_id is None:
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
# Convert bbox coordinates to pixel space wrt to image size
|
| 236 |
+
bbox = label['box']
|
| 237 |
+
bbox_coords_pixel = [int(bbox[0] * self.orig_width), # x1
|
| 238 |
+
int(bbox[1] * self.orig_height), # y1
|
| 239 |
+
int(bbox[2] * self.orig_width), # x2
|
| 240 |
+
int(bbox[3] * self.orig_height)] # y2
|
| 241 |
+
|
| 242 |
+
object_labels.append({
|
| 243 |
+
'frame_name': frame_data["image_source"],
|
| 244 |
+
'track_id': int(label['track_id']),
|
| 245 |
+
'bbox': bbox_coords_pixel,
|
| 246 |
+
'class_id': class_id,
|
| 247 |
+
'class_name': label['name'], # Class name of the object
|
| 248 |
+
})
|
| 249 |
+
|
| 250 |
+
frame_labels.append(object_labels)
|
| 251 |
+
|
| 252 |
+
return frame_labels
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def pre_cache_dataset(dataset_root):
|
| 256 |
+
# Trigger label and bbox image cache generation
|
| 257 |
+
from time import time
|
| 258 |
+
dataset_train = DADA2000Dataset(root=dataset_root, train=True, clip_length=25, non_overlapping_clips=False)
|
| 259 |
+
t = time()
|
| 260 |
+
for i in tqdm(range(len(dataset_train))):
|
| 261 |
+
d = dataset_train[i]
|
| 262 |
+
if i >= 100:
|
| 263 |
+
print("Time:", time() - t)
|
| 264 |
+
print("break")
|
| 265 |
+
|
| 266 |
+
dataset_val = DADA2000Dataset(root=dataset_root, train=False, clip_length=25, non_overlapping_clips=True)
|
| 267 |
+
for i in tqdm(range(len(dataset_val))):
|
| 268 |
+
d = dataset_val[i]
|
| 269 |
+
|
| 270 |
+
print("Done.")
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
dataset_root = "/path/to/Datasets"
|
| 274 |
+
pre_cache_dataset(dataset_root)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
"""
|
| 278 |
+
ACCIDENT TYPES
|
| 279 |
+
{
|
| 280 |
+
"ego_car_involved": {
|
| 281 |
+
"self_initiated": {
|
| 282 |
+
"out_of_control": [61]
|
| 283 |
+
},
|
| 284 |
+
"dynamic_participants": {
|
| 285 |
+
"person_centric": {
|
| 286 |
+
"pedestrian": [1, 2],
|
| 287 |
+
"cyclist": [3, 4]
|
| 288 |
+
},
|
| 289 |
+
"vehicle_centric": {
|
| 290 |
+
"motorbike": [5, 6],
|
| 291 |
+
"truck": [7, 8, 9],
|
| 292 |
+
"car": [10, 11, 12]
|
| 293 |
+
}
|
| 294 |
+
},
|
| 295 |
+
"static_participants": {
|
| 296 |
+
"road_crentric": {
|
| 297 |
+
"large_roadblocks": [13],
|
| 298 |
+
"curb": [14],
|
| 299 |
+
"small_roadblocks_potholes": [15]
|
| 300 |
+
},
|
| 301 |
+
"other_semantics_centric": {
|
| 302 |
+
"trees": [16],
|
| 303 |
+
"telegraph_poles": [17],
|
| 304 |
+
"other_road_facilities": [18]
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
},
|
| 308 |
+
"ego_car_uninvolved": {
|
| 309 |
+
"dynamic_participants": {
|
| 310 |
+
"vehicle_centric": {
|
| 311 |
+
"motorbike_motorbike": [37, 38],
|
| 312 |
+
"truck_truck": [39, 40, 41],
|
| 313 |
+
"car_car": [42, 43, 44],
|
| 314 |
+
"motorbike_truck": [45, 46, 47],
|
| 315 |
+
"truck_car": [48, 49],
|
| 316 |
+
"car_motorbike": [50, 51]
|
| 317 |
+
},
|
| 318 |
+
"person_centric": [52, 53, 54, 55, 56, 57, 58, 59, 60]
|
| 319 |
+
},
|
| 320 |
+
"static_participants" : [19, 20, 21, 22, 25, 26, 28, 29, 31, 32, 34, 35, 36]
|
| 321 |
+
},
|
| 322 |
+
"summary": {
|
| 323 |
+
"ego_car_involved": {
|
| 324 |
+
"person_centric": [1, 2, 3, 4],
|
| 325 |
+
"vehicle_centric": [5, 6, 7, 8, 9, 10, 11, 12],
|
| 326 |
+
"static_participants": [13, 14, 15, 16, 17, 18],
|
| 327 |
+
"out_of_control": [61]
|
| 328 |
+
},
|
| 329 |
+
"ego_car_uninvolved": {
|
| 330 |
+
"static_participants": [19, 20, 21, 22, 25, 26, 28, 29, 31, 32, 34, 35, 36],
|
| 331 |
+
"vehicle_centric": [37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 50, 51],
|
| 332 |
+
"person_centric": [52, 53, 54, 55, 56, 57, 58, 59, 60]
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
|
src/datasets/dataset_factory.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.datasets.merged_dataset import MergedDataset
|
| 2 |
+
from src.datasets.russia_crash_dataset import RussiaCrashDataset
|
| 3 |
+
from src.datasets.dada2000_dataset import DADA2000Dataset
|
| 4 |
+
from src.datasets.mmau_dataset import MMAUDataset
|
| 5 |
+
from src.datasets.bdd100k_dataset import BDD100KDataset
|
| 6 |
+
# from src.datasets.nuscenes_dataset import NuScenesDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_dataset(dataset_name, **kwargs):
|
| 10 |
+
|
| 11 |
+
if str.lower(dataset_name) == "russia_crash":
|
| 12 |
+
dataset = RussiaCrashDataset(**kwargs)
|
| 13 |
+
elif str.lower(dataset_name) == "nuscenes":
|
| 14 |
+
dataset = NuScenesDataset(**kwargs)
|
| 15 |
+
elif str.lower(dataset_name) == "dada2000":
|
| 16 |
+
dataset = DADA2000Dataset(**kwargs)
|
| 17 |
+
elif str.lower(dataset_name) == "mmau":
|
| 18 |
+
dataset = MMAUDataset(**kwargs)
|
| 19 |
+
elif str.lower(dataset_name) == "bdd100k":
|
| 20 |
+
dataset = BDD100KDataset(**kwargs)
|
| 21 |
+
else:
|
| 22 |
+
raise NotImplementedError(f"Dataset '{dataset_name}' not implemented")
|
| 23 |
+
|
| 24 |
+
return dataset
|
| 25 |
+
|
| 26 |
+
def dataset_factory(dataset_names, **kwargs):
|
| 27 |
+
if isinstance(dataset_names, str) or (isinstance(dataset_names, list) and len(dataset_names) == 1):
|
| 28 |
+
dataset_name = dataset_names[0] if isinstance(dataset_names, list) else dataset_names
|
| 29 |
+
# Init the single dataset
|
| 30 |
+
return create_dataset(dataset_name, **kwargs)
|
| 31 |
+
elif isinstance(dataset_names, list):
|
| 32 |
+
all_datasets = []
|
| 33 |
+
for dataset_name in dataset_names:
|
| 34 |
+
all_datasets.append(create_dataset(dataset_name, **kwargs))
|
| 35 |
+
return MergedDataset(all_datasets)
|
| 36 |
+
|
src/datasets/dataset_utils.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from .dataset_factory import dataset_factory
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def worker_init_fn(worker_id):
|
| 8 |
+
os.sched_setaffinity(0, range(os.cpu_count()))
|
| 9 |
+
|
| 10 |
+
def get_dataloader(dset_root,
|
| 11 |
+
dset_names,
|
| 12 |
+
if_train,
|
| 13 |
+
batch_size,
|
| 14 |
+
num_workers,
|
| 15 |
+
clip_length=25,
|
| 16 |
+
shuffle=True,
|
| 17 |
+
image_height=None,
|
| 18 |
+
image_width=None,
|
| 19 |
+
non_overlapping_clips=False,
|
| 20 |
+
ego_only=False,
|
| 21 |
+
bbox_masking_prob=0.0,
|
| 22 |
+
specific_samples=None,
|
| 23 |
+
specific_categories=None,
|
| 24 |
+
force_clip_type=None):
|
| 25 |
+
|
| 26 |
+
dataset = dataset_factory(dset_names,
|
| 27 |
+
root=dset_root,
|
| 28 |
+
train=if_train,
|
| 29 |
+
clip_length=clip_length,
|
| 30 |
+
resize_height=image_height,
|
| 31 |
+
resize_width=image_width,
|
| 32 |
+
non_overlapping_clips=non_overlapping_clips,
|
| 33 |
+
bbox_masking_prob=bbox_masking_prob,
|
| 34 |
+
ego_only=ego_only,
|
| 35 |
+
specific_samples=specific_samples,
|
| 36 |
+
specific_categories=specific_categories,
|
| 37 |
+
force_clip_type=force_clip_type)
|
| 38 |
+
|
| 39 |
+
dataloader = torch.utils.data.DataLoader(
|
| 40 |
+
dataset,
|
| 41 |
+
batch_size=batch_size,
|
| 42 |
+
num_workers=num_workers,
|
| 43 |
+
shuffle=shuffle,
|
| 44 |
+
pin_memory=True,
|
| 45 |
+
drop_last=True,
|
| 46 |
+
worker_init_fn=worker_init_fn
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
return dataset, dataloader
|
| 50 |
+
|
src/datasets/merged_dataset.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import transforms
|
| 2 |
+
|
| 3 |
+
class MergedDataset:
|
| 4 |
+
"""
|
| 5 |
+
Dataset wrapper to access many datasets as one
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
def __init__(self, dataset_list):
|
| 9 |
+
|
| 10 |
+
self.dataset_list = dataset_list
|
| 11 |
+
|
| 12 |
+
# TODO: Make sure this matches all datasets
|
| 13 |
+
self.resize_width = self.dataset_list[0].resize_width
|
| 14 |
+
self.resize_height = self.dataset_list[0].resize_height
|
| 15 |
+
self.revert_transform = self.dataset_list[0].revert_transform
|
| 16 |
+
|
| 17 |
+
print("TOTAL number of clips in merged dataset:", self.__len__(), f"({self.dataset_list[0].data_split})")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def __len__(self):
|
| 21 |
+
return sum([len(dset) for dset in self.dataset_list])
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def __getitem__(self, global_index):
|
| 25 |
+
|
| 26 |
+
target_dset, rel_index = self.get_dataset_by_sample_index(global_index)
|
| 27 |
+
ret_dict = target_dset.__getitem__(rel_index)
|
| 28 |
+
|
| 29 |
+
# Overwrite returned index with the global index
|
| 30 |
+
ret_dict["indices"] = global_index
|
| 31 |
+
|
| 32 |
+
return ret_dict
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_dataset_by_sample_index(self, index):
|
| 36 |
+
total_idx = 0
|
| 37 |
+
target_dset = None
|
| 38 |
+
for dset in self.dataset_list:
|
| 39 |
+
total_idx += len(dset)
|
| 40 |
+
if index < total_idx:
|
| 41 |
+
target_dset = dset
|
| 42 |
+
break
|
| 43 |
+
|
| 44 |
+
return target_dset, (index - (total_idx - len(target_dset)))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_frame_file_by_index(self, index, timestep=None):
|
| 48 |
+
target_dset, rel_index = self.get_dataset_by_sample_index(index)
|
| 49 |
+
return target_dset.get_frame_file_by_index(rel_index, timestep=timestep)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_bbox_image_file_by_index(self, index, image_file=None):
|
| 53 |
+
target_dset, rel_index = self.get_dataset_by_sample_index(index)
|
| 54 |
+
return target_dset.get_bbox_image_file_by_index(index=rel_index)
|
src/datasets/mmau_dataset.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import csv
|
| 5 |
+
import json
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
from src.datasets.base_dataset import BaseDataset
|
| 9 |
+
|
| 10 |
+
def load_json(filename):
|
| 11 |
+
if os.path.exists(filename):
|
| 12 |
+
with open(filename, "r") as f:
|
| 13 |
+
return json.load(f)
|
| 14 |
+
print(filename, "not found")
|
| 15 |
+
return []
|
| 16 |
+
|
| 17 |
+
def create_video_from_images(images_list, output_video, out_fps, start_frame=None, end_frame=None):
|
| 18 |
+
|
| 19 |
+
img0_path = images_list[0]
|
| 20 |
+
img0 = cv2.imread(img0_path)
|
| 21 |
+
height, width, _ = img0.shape
|
| 22 |
+
|
| 23 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 24 |
+
out = cv2.VideoWriter(output_video, fourcc, out_fps, (width, height))
|
| 25 |
+
|
| 26 |
+
for idx, frame_name in enumerate(images_list):
|
| 27 |
+
|
| 28 |
+
if start_frame is not None and idx < start_frame:
|
| 29 |
+
continue
|
| 30 |
+
if end_frame is not None and idx >= end_frame:
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
img = cv2.imread(frame_name)
|
| 34 |
+
out.write(img)
|
| 35 |
+
|
| 36 |
+
out.release()
|
| 37 |
+
print("Saved video:", output_video)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MMAUDataset(BaseDataset):
|
| 41 |
+
|
| 42 |
+
CLASS_NAME_TO_ID = {
|
| 43 |
+
'person': 1,
|
| 44 |
+
'car': 3,
|
| 45 |
+
'truck': 4,
|
| 46 |
+
'bus': 5,
|
| 47 |
+
'train': 6,
|
| 48 |
+
'motorcycle': 7,
|
| 49 |
+
'bicycle': 8,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
def __init__(self,
|
| 53 |
+
root='./datasets',
|
| 54 |
+
train=True,
|
| 55 |
+
clip_length=25,
|
| 56 |
+
orig_height=640, orig_width=1024,
|
| 57 |
+
resize_height=320, resize_width=512,
|
| 58 |
+
non_overlapping_clips=False,
|
| 59 |
+
bbox_masking_prob=0.0,
|
| 60 |
+
sample_clip_from_end=True,
|
| 61 |
+
ego_only=False,
|
| 62 |
+
specific_samples=None,
|
| 63 |
+
specific_categories=None,
|
| 64 |
+
dada_only=False,
|
| 65 |
+
cleanup_dataset=False,
|
| 66 |
+
force_clip_type=None
|
| 67 |
+
):
|
| 68 |
+
|
| 69 |
+
self.ignore_labels = False
|
| 70 |
+
if self.ignore_labels:
|
| 71 |
+
print("IGNORING LABELS in MMAU dataset")
|
| 72 |
+
|
| 73 |
+
super(MMAUDataset, self).__init__(root=root,
|
| 74 |
+
train=train,
|
| 75 |
+
clip_length=clip_length,
|
| 76 |
+
resize_height=resize_height,
|
| 77 |
+
resize_width=resize_width,
|
| 78 |
+
non_overlapping_clips=non_overlapping_clips,
|
| 79 |
+
bbox_masking_prob=bbox_masking_prob,
|
| 80 |
+
sample_clip_from_end=sample_clip_from_end,
|
| 81 |
+
ego_only=ego_only,
|
| 82 |
+
ignore_labels=self.ignore_labels) # NOTE: Ignoring labels currently
|
| 83 |
+
|
| 84 |
+
self.dada_only = dada_only
|
| 85 |
+
self.cleanup_dataset = cleanup_dataset
|
| 86 |
+
self.dataset_name = "mmau_images_12fps" if not dada_only else "dada2000_images_12fps"
|
| 87 |
+
|
| 88 |
+
self.orig_width = orig_width
|
| 89 |
+
self.orig_height = orig_height
|
| 90 |
+
self.split = "train" if train else "val"
|
| 91 |
+
|
| 92 |
+
self.image_dir = os.path.join(self.root, self.dataset_name, "images")
|
| 93 |
+
self.label_dir = os.path.join(self.root, self.dataset_name, "labels")
|
| 94 |
+
self.bbox_image_dir = os.path.join(self.root, self.dataset_name, "bbox_images")
|
| 95 |
+
|
| 96 |
+
self.downsample_6fps = True
|
| 97 |
+
if self.downsample_6fps:
|
| 98 |
+
print("Downsampling MMAU clips to 6 fps")
|
| 99 |
+
|
| 100 |
+
self.ego_only = ego_only
|
| 101 |
+
if self.ego_only:
|
| 102 |
+
print("Ego collisions only filter set for MMAU dataset")
|
| 103 |
+
|
| 104 |
+
self.strict_collision_filter = False
|
| 105 |
+
if self.strict_collision_filter:
|
| 106 |
+
print("Strict collision filter set for MMAU dataset")
|
| 107 |
+
|
| 108 |
+
self.specific_samples = specific_samples
|
| 109 |
+
if self.specific_samples is not None:
|
| 110 |
+
print("Only loading specific samples:", self.specific_samples)
|
| 111 |
+
|
| 112 |
+
self.specific_categories = specific_categories
|
| 113 |
+
if self.specific_categories is not None:
|
| 114 |
+
print("Only loading specific categories:", self.specific_categories)
|
| 115 |
+
|
| 116 |
+
self.force_clip_type = force_clip_type
|
| 117 |
+
if self.force_clip_type is not None:
|
| 118 |
+
print("Only loading samples with type:", force_clip_type)
|
| 119 |
+
|
| 120 |
+
self._collect_clips()
|
| 121 |
+
|
| 122 |
+
def _collect_metadata_csv(self, metadata_csv_path):
|
| 123 |
+
accident_frame_metadata = {}
|
| 124 |
+
with open(metadata_csv_path) as csv_file:
|
| 125 |
+
csv_reader = csv.reader(csv_file)
|
| 126 |
+
for i, row in enumerate(csv_reader):
|
| 127 |
+
if i == 0:
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
video_num = str(int(row[0]))
|
| 131 |
+
video_type = str(int(row[5]))
|
| 132 |
+
abnormal_start_frame_idx = int(row[7])
|
| 133 |
+
accident_frame_idx = int(row[9])
|
| 134 |
+
abnormal_end_frame_idx = int(row[8])
|
| 135 |
+
|
| 136 |
+
video_name = f"{video_type}_{video_num.rjust(5, '0')}"
|
| 137 |
+
|
| 138 |
+
if accident_frame_idx == "-1":
|
| 139 |
+
# print("Skipping video:", video_name)
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
downsample_factor = 30/12 if video_num.startswith("90") else 1 # Downsample for DADA (30fps) and CAP (12 fps) are not the same
|
| 143 |
+
if self.downsample_6fps:
|
| 144 |
+
downsample_factor *= 2
|
| 145 |
+
accident_frame_metadata[video_name] = (int(abnormal_start_frame_idx / downsample_factor),
|
| 146 |
+
int(accident_frame_idx / downsample_factor + 0.5),
|
| 147 |
+
int(abnormal_end_frame_idx / downsample_factor + 0.5))
|
| 148 |
+
|
| 149 |
+
return accident_frame_metadata
|
| 150 |
+
|
| 151 |
+
def _collect_clips(self):
|
| 152 |
+
print("Collecting dataset clips...")
|
| 153 |
+
|
| 154 |
+
mmau_dataset = os.path.join(self.root, self.dataset_name)
|
| 155 |
+
|
| 156 |
+
# Load data split
|
| 157 |
+
datasplit_data = load_json(os.path.join(mmau_dataset, "mmau_datasplit.json"))
|
| 158 |
+
|
| 159 |
+
# Compile reject videos
|
| 160 |
+
auto_filtered_vids = load_json(os.path.join(mmau_dataset, "auto_low_quality.json"))
|
| 161 |
+
rejected_vids = load_json(os.path.join(mmau_dataset, "rejected.json"))
|
| 162 |
+
all_rejected_vids = auto_filtered_vids + rejected_vids
|
| 163 |
+
|
| 164 |
+
# Collect the accident moment information
|
| 165 |
+
accident_frame_metadata = self._collect_metadata_csv(os.path.join(mmau_dataset, "mmau_metadata.csv"))
|
| 166 |
+
|
| 167 |
+
self.clip_type_list = [] # crash or normal or abnormal (abnormal is a scene that has abnormal driving but doesn't contain the actual crash moment)
|
| 168 |
+
self.action_type_list = [] # 0-4, 0 is normal, 1-4 are different types of crashes
|
| 169 |
+
image_indices_by_clip = {}
|
| 170 |
+
|
| 171 |
+
null_labels = []
|
| 172 |
+
# Iterate datasplit file
|
| 173 |
+
count_vid = 0
|
| 174 |
+
for category, split in datasplit_data.items():
|
| 175 |
+
for split_name, vid_names in split.items():
|
| 176 |
+
if split_name != self.split:
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
for vid_name in vid_names:
|
| 180 |
+
if vid_name in all_rejected_vids:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
if self.dada_only and not vid_name.split("_")[-1].startswith("90"): # NOTE: REMOVE THIS
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
# Read image files
|
| 187 |
+
image_dir = os.path.join(mmau_dataset, "images")
|
| 188 |
+
clip_file = os.path.join(image_dir, category, vid_name)
|
| 189 |
+
clip_frames = sorted(os.listdir(clip_file))
|
| 190 |
+
|
| 191 |
+
if self.cleanup_dataset:
|
| 192 |
+
# NOTE: For renaming frames (can remove this later)
|
| 193 |
+
fix_label = False
|
| 194 |
+
for frame_name in clip_frames:
|
| 195 |
+
if vid_name not in frame_name:
|
| 196 |
+
fix_label = True
|
| 197 |
+
new_frame_name = f"{vid_name}_{frame_name}"
|
| 198 |
+
root_path = os.path.join(mmau_dataset, "images", category, vid_name)
|
| 199 |
+
os.rename(os.path.join(root_path, frame_name), os.path.join(root_path, new_frame_name))
|
| 200 |
+
|
| 201 |
+
image_dir = os.path.join(mmau_dataset, "images")
|
| 202 |
+
clip_file = os.path.join(image_dir, category, vid_name)
|
| 203 |
+
clip_frames = sorted(os.listdir(clip_file))
|
| 204 |
+
|
| 205 |
+
# Also rename in label file
|
| 206 |
+
label_file_path = os.path.join(self.label_dir, f"{vid_name}.json")
|
| 207 |
+
if os.path.exists(label_file_path) and fix_label:
|
| 208 |
+
with open(label_file_path, "r") as f:
|
| 209 |
+
data = json.load(f)
|
| 210 |
+
data_field = data["data"]
|
| 211 |
+
if data_field is None:
|
| 212 |
+
print(f"{vid_name}.json CLIP DATA IS NULL 2")
|
| 213 |
+
null_labels.append(vid_name)
|
| 214 |
+
else:
|
| 215 |
+
for i, frame_data in enumerate(data_field):
|
| 216 |
+
current_frame_name = frame_data["image_source"]
|
| 217 |
+
if vid_name not in current_frame_name:
|
| 218 |
+
new_frame_name = f"{vid_name}_{current_frame_name}"
|
| 219 |
+
data["data"][i]["image_source"] = new_frame_name
|
| 220 |
+
|
| 221 |
+
with open(label_file_path, "w") as f:
|
| 222 |
+
json.dump(data, f, indent=1)
|
| 223 |
+
|
| 224 |
+
num_frames = len(clip_frames) if not self.downsample_6fps else len(clip_frames) // 2
|
| 225 |
+
if num_frames < self.clip_length:
|
| 226 |
+
print(f"{vid_name} does not have enough frames: has {num_frames}, expected at least {self.clip_length}")
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
accident_metadata = accident_frame_metadata.get(vid_name)
|
| 230 |
+
if accident_metadata is None:
|
| 231 |
+
print(vid_name, "no accident metadata found")
|
| 232 |
+
continue
|
| 233 |
+
|
| 234 |
+
step = 2 if self.downsample_6fps else 1
|
| 235 |
+
clip_frame_names = []
|
| 236 |
+
for image_idx in range(0, len(clip_frames), step):
|
| 237 |
+
image_file = clip_frames[image_idx]
|
| 238 |
+
clip_frame_names.append(image_file)
|
| 239 |
+
|
| 240 |
+
count_vid += 1
|
| 241 |
+
# Read label file
|
| 242 |
+
if not self.ignore_labels:
|
| 243 |
+
label_file_path = os.path.join(self.label_dir, f"{vid_name}.json")
|
| 244 |
+
if not os.path.exists(label_file_path):
|
| 245 |
+
if num_frames <= 300:
|
| 246 |
+
# Because a lot of the long videos were rejected because they were too long to process
|
| 247 |
+
# print(f"{label_file_path} does not exist")
|
| 248 |
+
pass
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
with open(label_file_path) as json_file:
|
| 252 |
+
all_data = json.load(json_file)
|
| 253 |
+
metadata = all_data['metadata']
|
| 254 |
+
|
| 255 |
+
if self.ego_only:
|
| 256 |
+
if metadata['ego_involved'] == False:
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
if self.strict_collision_filter and metadata["accident_type"] in [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 25, 26, 28, 29, 31, 32, 34, 35, 36]:
|
| 260 |
+
# Reject video where the collision is with "static" agents
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
# Some post-hoc rejected clips
|
| 264 |
+
if all_data["video_source"] in ["10_90001.mp4"]:
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
if self.specific_samples is not None and all_data["video_source"].split(".")[0] not in self.specific_samples:
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
if self.specific_categories is not None and metadata["accident_type"] not in self.specific_categories:
|
| 271 |
+
continue
|
| 272 |
+
|
| 273 |
+
clip_data = all_data["data"]
|
| 274 |
+
if clip_data is None:
|
| 275 |
+
print(f"{vid_name}.json CLIP DATA IS NULL")
|
| 276 |
+
null_labels.append(vid_name)
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
clip_label_data = self._parse_clip_labels(clip_data, clip_frame_names)
|
| 280 |
+
self.frame_labels.extend(clip_label_data) # In this case, labels are already sorted so they will match up to the image indices
|
| 281 |
+
|
| 282 |
+
image_indices_by_clip[vid_name] = []
|
| 283 |
+
for image_file in clip_frame_names:
|
| 284 |
+
self.image_files.append(os.path.join(clip_file, image_file))
|
| 285 |
+
image_indices_by_clip[vid_name].append(len(self.image_files)-1)
|
| 286 |
+
|
| 287 |
+
if not self.ignore_labels:
|
| 288 |
+
assert len(self.frame_labels) == len(self.image_files), f"{len(self.frame_labels)} frame labels != {len(self.image_files)} image files" # We assume a one-to-one association between images and labels
|
| 289 |
+
|
| 290 |
+
ab_start_idx, acc_idx, ab_end_idx = accident_metadata
|
| 291 |
+
def get_clip_type(image_idx, end_image_idx):
|
| 292 |
+
clip_type = "normal"
|
| 293 |
+
if image_idx <= acc_idx and end_image_idx >= acc_idx:
|
| 294 |
+
# Contains accident frame
|
| 295 |
+
clip_type = "crash"
|
| 296 |
+
elif (image_idx >= ab_start_idx and image_idx <= ab_end_idx) \
|
| 297 |
+
or (end_image_idx >= ab_start_idx and end_image_idx <= ab_end_idx) \
|
| 298 |
+
or image_idx > acc_idx: # Let's also consider "normal" driving clip that happen after the accident to be "abnormal" as they might show the aftermath (e.g. car damage)
|
| 299 |
+
# Does not contain accident frame, but contains abnormal driving (moment before and after accident)
|
| 300 |
+
clip_type = "abnormal"
|
| 301 |
+
|
| 302 |
+
return clip_type
|
| 303 |
+
|
| 304 |
+
# Cut the videos in clips of the correct length
|
| 305 |
+
# NOTE: Only implementing strategy of selecting two clips per video: 1 with normal driving and 1 with crash
|
| 306 |
+
# Select normal driving from beginning preferably and crash clip try to center it on the accident instant
|
| 307 |
+
|
| 308 |
+
# Find crash clip
|
| 309 |
+
crash_found = False
|
| 310 |
+
if self.force_clip_type is None or self.force_clip_type == "crash":
|
| 311 |
+
start_image_idx, end_image_idx = None, None
|
| 312 |
+
total_frames = len(image_indices_by_clip[vid_name])
|
| 313 |
+
if acc_idx is not None and self.clip_length is not None:
|
| 314 |
+
# Keep frame_count frames around accident frame
|
| 315 |
+
start_image_idx = acc_idx - int(self.clip_length/2 + 0.5)
|
| 316 |
+
end_image_idx = acc_idx + int(self.clip_length/2)
|
| 317 |
+
|
| 318 |
+
if total_frames < self.clip_length:
|
| 319 |
+
print(f"Not enough frames in '{vid_name}': {total_frames}, skipping")
|
| 320 |
+
else:
|
| 321 |
+
if start_image_idx < 0:
|
| 322 |
+
end_image_idx += -(start_image_idx)
|
| 323 |
+
start_image_idx = 0
|
| 324 |
+
|
| 325 |
+
if end_image_idx > total_frames:
|
| 326 |
+
start_image_idx -= (end_image_idx - total_frames)
|
| 327 |
+
end_image_idx = total_frames
|
| 328 |
+
|
| 329 |
+
self.clip_list.append(image_indices_by_clip[vid_name][start_image_idx:end_image_idx])
|
| 330 |
+
self.clip_type_list.append("crash")
|
| 331 |
+
action_type = self._get_action_type(metadata["accident_type"])
|
| 332 |
+
self.action_type_list.append(action_type)
|
| 333 |
+
crash_found = True
|
| 334 |
+
|
| 335 |
+
# Debug: #############
|
| 336 |
+
# frame_path_list = [self.image_files[i] for i in image_indices_by_clip[vid_name][start_image_idx:end_image_idx]]
|
| 337 |
+
# create_video_from_images(frame_path_list, f"outputs/sample_clip_{vid_name}_crash.mp4", out_fps=6 if self.downsample_6fps else 12)
|
| 338 |
+
|
| 339 |
+
# # Debug plot bboxes:
|
| 340 |
+
# out_bbox_path = os.path.join("outputs", f"{vid_name}_bboxes_crash")
|
| 341 |
+
# os.makedirs(out_bbox_path, exist_ok=True)
|
| 342 |
+
# for frame_path, label_data in zip(frame_path_list, clip_label_data[start_image_idx:end_image_idx]):
|
| 343 |
+
# plt.figure(figsize=(9, 6))
|
| 344 |
+
# plt.axis("off")
|
| 345 |
+
# img = Image.open(frame_path)
|
| 346 |
+
# plt.imshow(img)
|
| 347 |
+
# for obj in label_data:
|
| 348 |
+
# color = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[obj["class_id"]])) / 255.0
|
| 349 |
+
# show_box(obj["bbox"], plt.gca(), label=str(obj["track_id"]), color=color)
|
| 350 |
+
|
| 351 |
+
# frame_id_name = frame_path.split("_")[-1].split(".")[0]
|
| 352 |
+
# plt.savefig(os.path.join(out_bbox_path, f"bboxes_frame_{frame_id_name}.jpg"))
|
| 353 |
+
#######################3
|
| 354 |
+
|
| 355 |
+
if not crash_found:
|
| 356 |
+
print("Crash not found for", vid_name)
|
| 357 |
+
|
| 358 |
+
assert end_image_idx > start_image_idx
|
| 359 |
+
|
| 360 |
+
if self.force_clip_type is None or self.force_clip_type == "normal":
|
| 361 |
+
normal_found = False
|
| 362 |
+
for start_image_idx in range(len(image_indices_by_clip[vid_name]) - self.clip_length + 1):
|
| 363 |
+
end_image_idx = start_image_idx+self.clip_length
|
| 364 |
+
|
| 365 |
+
clip_type = get_clip_type(start_image_idx, end_image_idx)
|
| 366 |
+
if clip_type == "abnormal" or clip_type == "crash":
|
| 367 |
+
# Let's just reject the abnormal clips
|
| 368 |
+
continue
|
| 369 |
+
|
| 370 |
+
self.clip_list.append(image_indices_by_clip[vid_name][start_image_idx:end_image_idx])
|
| 371 |
+
self.clip_type_list.append(clip_type)
|
| 372 |
+
self.action_type_list.append(0)
|
| 373 |
+
normal_found = True
|
| 374 |
+
|
| 375 |
+
# Debug: ########
|
| 376 |
+
# frame_path_list = [self.image_files[i] for i in image_indices_by_clip[vid_name][start_image_idx:end_image_idx]]
|
| 377 |
+
# create_video_from_images(frame_path_list, f"outputs/sample_clip_{vid_name}_normal.mp4", out_fps=6 if self.downsample_6fps else 12)
|
| 378 |
+
|
| 379 |
+
# out_bbox_path = os.path.join("outputs", f"{vid_name}_bboxes_normal")
|
| 380 |
+
# os.makedirs(out_bbox_path, exist_ok=True)
|
| 381 |
+
# for frame_path, label_data in zip(frame_path_list, clip_label_data[start_image_idx:end_image_idx]):
|
| 382 |
+
# plt.figure(figsize=(9, 6))
|
| 383 |
+
# plt.axis("off")
|
| 384 |
+
# img = Image.open(frame_path)
|
| 385 |
+
# plt.imshow(img)
|
| 386 |
+
# for obj in label_data:
|
| 387 |
+
# color = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[obj["class_id"]])) / 255.0
|
| 388 |
+
# show_box(obj["bbox"], plt.gca(), label=str(obj["track_id"]), color=color)
|
| 389 |
+
|
| 390 |
+
# frame_id_name = frame_path.split("_")[-1].split(".")[0]
|
| 391 |
+
# plt.savefig(os.path.join(out_bbox_path, f"bboxes_frame_{frame_id_name}.jpg"))
|
| 392 |
+
#################
|
| 393 |
+
|
| 394 |
+
break
|
| 395 |
+
|
| 396 |
+
# if not normal_found:
|
| 397 |
+
# print("Normal not found for", vid_name)
|
| 398 |
+
|
| 399 |
+
assert len(self.clip_list) == len(self.clip_type_list) == len(self.action_type_list)
|
| 400 |
+
|
| 401 |
+
print("Number of clips MMAU:", len(self.clip_list), f"({self.data_split})", f"(from {count_vid} original videos)")
|
| 402 |
+
crash_clip_count = 0
|
| 403 |
+
normal_clip_count = 0
|
| 404 |
+
for clip_type in self.clip_type_list:
|
| 405 |
+
if clip_type == "crash":
|
| 406 |
+
crash_clip_count += 1
|
| 407 |
+
elif clip_type == "normal":
|
| 408 |
+
normal_clip_count += 1
|
| 409 |
+
print(crash_clip_count, "crash clips", normal_clip_count, "normal clips")
|
| 410 |
+
|
| 411 |
+
if self.cleanup_dataset and len(null_labels) > 0:
|
| 412 |
+
print("Null labels:", null_labels)
|
| 413 |
+
for label_name in null_labels:
|
| 414 |
+
label_file_path = os.path.join(self.label_dir, f"{label_name}.json")
|
| 415 |
+
if os.path.exists(label_file_path):
|
| 416 |
+
os.remove(label_file_path)
|
| 417 |
+
print("Removed label file:", label_file_path)
|
| 418 |
+
|
| 419 |
+
def _parse_clip_labels(self, clip_data, clip_frame_names):
|
| 420 |
+
frame_labels = []
|
| 421 |
+
for frame_data in clip_data:
|
| 422 |
+
obj_data = frame_data['labels']
|
| 423 |
+
image_source = frame_data["image_source"]
|
| 424 |
+
|
| 425 |
+
if self.downsample_6fps and image_source not in clip_frame_names:
|
| 426 |
+
# Only preserve even numbered frames
|
| 427 |
+
continue
|
| 428 |
+
|
| 429 |
+
object_labels = []
|
| 430 |
+
for label in obj_data:
|
| 431 |
+
# Only keep the classes of interest
|
| 432 |
+
class_id = MMAUDataset.CLASS_NAME_TO_ID.get(label['name'])
|
| 433 |
+
if class_id is None:
|
| 434 |
+
continue
|
| 435 |
+
|
| 436 |
+
# Convert bbox coordinates to pixel space wrt to image size
|
| 437 |
+
bbox = label['box']
|
| 438 |
+
bbox_coords_pixel = [int(bbox[0] * self.orig_width), # x1
|
| 439 |
+
int(bbox[1] * self.orig_height), # y1
|
| 440 |
+
int(bbox[2] * self.orig_width), # x2
|
| 441 |
+
int(bbox[3] * self.orig_height)] # y2
|
| 442 |
+
|
| 443 |
+
object_labels.append({
|
| 444 |
+
'frame_name': image_source,
|
| 445 |
+
'track_id': int(label['track_id']),
|
| 446 |
+
'bbox': bbox_coords_pixel,
|
| 447 |
+
'class_id': class_id,
|
| 448 |
+
'class_name': label['name'], # Class name of the object
|
| 449 |
+
})
|
| 450 |
+
|
| 451 |
+
frame_labels.append(object_labels)
|
| 452 |
+
|
| 453 |
+
return frame_labels
|
| 454 |
+
|
| 455 |
+
def _get_action_type(self, accident_type):
|
| 456 |
+
# [0: normal, 1: ego, 2: ego/veh, 3: veh, 4: veh/veh]
|
| 457 |
+
accident_type = int(accident_type)
|
| 458 |
+
if accident_type in [61, 62, 13, 14, 15, 16, 17, 18]:
|
| 459 |
+
return 1
|
| 460 |
+
elif accident_type in range(1, 12 + 1):
|
| 461 |
+
return 2
|
| 462 |
+
elif accident_type in [37, 39, 41, 42, 44] + list(range(19, 36 + 1)) + list(range(52, 60 + 1)):
|
| 463 |
+
return 3
|
| 464 |
+
elif accident_type in [38, 40, 43, 45, 46, 47, 48, 49, 50, 51]:
|
| 465 |
+
return 4
|
| 466 |
+
else:
|
| 467 |
+
raise ValueError(f"Unknown accident type: {accident_type}")
|
| 468 |
+
|
| 469 |
+
def pre_cache_dataset(dataset_root):
|
| 470 |
+
# dset = MMAUDataset(dataset_root, train=False, cleanup_dataset=True, specific_categories=["42"])
|
| 471 |
+
|
| 472 |
+
dset = MMAUDataset(dataset_root, train=False, cleanup_dataset=True)
|
| 473 |
+
# s = dset.__getitem__(0)
|
| 474 |
+
|
| 475 |
+
# dset = MMAUDataset(dataset_root, train=False, cleanup_dataset=True)
|
| 476 |
+
# s = dset.__getitem__(0)
|
| 477 |
+
# Trigger label and bbox image cache generation
|
| 478 |
+
# from time import time
|
| 479 |
+
# dataset_train = DADA2000Dataset(root=dataset_root, train=True, clip_length=25, non_overlapping_clips=False)
|
| 480 |
+
# t = time()
|
| 481 |
+
# for i in tqdm(range(len(dataset_train))):
|
| 482 |
+
# d = dataset_train[i]
|
| 483 |
+
# if i >= 100:
|
| 484 |
+
# print("Time:", time() - t)
|
| 485 |
+
# print("break")
|
| 486 |
+
|
| 487 |
+
# dataset_val = DADA2000Dataset(root=dataset_root, train=False, clip_length=25, non_overlapping_clips=True)
|
| 488 |
+
# for i in tqdm(range(len(dataset_val))):
|
| 489 |
+
# d = dataset_val[i]
|
| 490 |
+
|
| 491 |
+
# print("Done.")
|
| 492 |
+
|
| 493 |
+
if __name__ == "__main__":
|
| 494 |
+
dataset_root = "/path/to/Datasets"
|
| 495 |
+
pre_cache_dataset(dataset_root)
|
| 496 |
+
|
| 497 |
+
MMAUDataset(dataset_root, train=True)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
"""
|
| 501 |
+
ACCIDENT TYPES
|
| 502 |
+
{
|
| 503 |
+
"ego_car_involved": {
|
| 504 |
+
"self_initiated": {
|
| 505 |
+
"out_of_control": [61]
|
| 506 |
+
},
|
| 507 |
+
"dynamic_participants": {
|
| 508 |
+
"person_centric": {
|
| 509 |
+
"pedestrian": [1, 2],
|
| 510 |
+
"cyclist": [3, 4]
|
| 511 |
+
},
|
| 512 |
+
"vehicle_centric": {
|
| 513 |
+
"motorbike": [5, 6],
|
| 514 |
+
"truck": [7, 8, 9],
|
| 515 |
+
"car": [10, 11, 12]
|
| 516 |
+
}
|
| 517 |
+
},
|
| 518 |
+
"static_participants": {
|
| 519 |
+
"road_crentric": {
|
| 520 |
+
"large_roadblocks": [13],
|
| 521 |
+
"curb": [14],
|
| 522 |
+
"small_roadblocks_potholes": [15]
|
| 523 |
+
},
|
| 524 |
+
"other_semantics_centric": {
|
| 525 |
+
"trees": [16],
|
| 526 |
+
"telegraph_poles": [17],
|
| 527 |
+
"other_road_facilities": [18]
|
| 528 |
+
}
|
| 529 |
+
}
|
| 530 |
+
},
|
| 531 |
+
"ego_car_uninvolved": {
|
| 532 |
+
"dynamic_participants": {
|
| 533 |
+
"vehicle_centric": {
|
| 534 |
+
"motorbike_motorbike": [37, 38],
|
| 535 |
+
"truck_truck": [39, 40, 41],
|
| 536 |
+
"car_car": [42, 43, 44],
|
| 537 |
+
"motorbike_truck": [45, 46, 47],
|
| 538 |
+
"truck_car": [48, 49],
|
| 539 |
+
"car_motorbike": [50, 51]
|
| 540 |
+
},
|
| 541 |
+
"person_centric": [52, 53, 54, 55, 56, 57, 58, 59, 60]
|
| 542 |
+
},
|
| 543 |
+
"static_participants" : [19, 20, 21, 22, 25, 26, 28, 29, 31, 32, 34, 35, 36]
|
| 544 |
+
},
|
| 545 |
+
}
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
|
src/datasets/nuscenes_dataset.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nuscenes.nuscenes import NuScenes
|
| 2 |
+
from nuscenes.utils.geometry_utils import view_points
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from pyquaternion import Quaternion
|
| 6 |
+
import os
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
from nuscenes.utils.splits import create_splits_scenes
|
| 9 |
+
from shapely.geometry import MultiPoint, box
|
| 10 |
+
from typing import List, Tuple, Union
|
| 11 |
+
import json
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from src.datasets.base_dataset import BaseDataset
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# "Singleton" that holds the data so we only have to load once for training & validation
|
| 18 |
+
nusc_data = None
|
| 19 |
+
|
| 20 |
+
class NuScenesDataset(BaseDataset):
|
| 21 |
+
|
| 22 |
+
CLASS_NAME_TO_ID = {
|
| 23 |
+
"animal": 1,
|
| 24 |
+
"human.pedestrian.adult": 1,
|
| 25 |
+
"human.pedestrian.child": 1,
|
| 26 |
+
"human.pedestrian.construction_worker": 1,
|
| 27 |
+
"human.pedestrian.personal_mobility": 1,
|
| 28 |
+
"human.pedestrian.police_officer": 1,
|
| 29 |
+
"human.pedestrian.stroller": 1,
|
| 30 |
+
"human.pedestrian.wheelchair": 1,
|
| 31 |
+
|
| 32 |
+
# "movable_object.barrier": 10,
|
| 33 |
+
# "movable_object.debris": 10,
|
| 34 |
+
# "movable_object.pushable_pullable": 10,
|
| 35 |
+
# "movable_object.trafficcone": 10,
|
| 36 |
+
# "static_object.bicycle_rack": 10,
|
| 37 |
+
|
| 38 |
+
"vehicle.bicycle": 8,
|
| 39 |
+
|
| 40 |
+
"vehicle.bus.bendy": 5,
|
| 41 |
+
"vehicle.bus.rigid": 5,
|
| 42 |
+
|
| 43 |
+
"vehicle.car": 3,
|
| 44 |
+
"vehicle.emergency.police": 3,
|
| 45 |
+
|
| 46 |
+
"vehicle.construction": 4,
|
| 47 |
+
"vehicle.emergency.ambulance": 4,
|
| 48 |
+
"vehicle.trailer": 4,
|
| 49 |
+
"vehicle.truck": 4,
|
| 50 |
+
|
| 51 |
+
"vehicle.motorcycle": 7,
|
| 52 |
+
|
| 53 |
+
"None": 10,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
def __init__(self,
|
| 57 |
+
root='./datasets',
|
| 58 |
+
train=True,
|
| 59 |
+
clip_length=25,
|
| 60 |
+
orig_height=900, orig_width=1600,
|
| 61 |
+
resize_height=320, resize_width=512,
|
| 62 |
+
non_overlapping_clips=False,
|
| 63 |
+
bbox_masking_prob=0.0,
|
| 64 |
+
test_split=False,
|
| 65 |
+
ego_only=False):
|
| 66 |
+
|
| 67 |
+
super(NuScenesDataset, self).__init__(root=root,
|
| 68 |
+
train=train,
|
| 69 |
+
clip_length=clip_length,
|
| 70 |
+
resize_height=resize_height,
|
| 71 |
+
resize_width=resize_width,
|
| 72 |
+
non_overlapping_clips=non_overlapping_clips,
|
| 73 |
+
bbox_masking_prob=bbox_masking_prob)
|
| 74 |
+
|
| 75 |
+
self.dataset_name = 'nuscenes'
|
| 76 |
+
self.train = train
|
| 77 |
+
self.orig_width = orig_width
|
| 78 |
+
self.orig_height = orig_height
|
| 79 |
+
self.non_overlapping_clips = non_overlapping_clips
|
| 80 |
+
self.bbox_image_dir = os.path.join(self.root, self.dataset_name, "bbox_images", self.data_split)
|
| 81 |
+
self.label_dir = os.path.join(self.root, self.dataset_name, "labels", self.data_split)
|
| 82 |
+
os.makedirs(self.label_dir, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
self.inst_token_to_track_id = {}
|
| 85 |
+
|
| 86 |
+
split_scenes = self._load_nusc(test_split)
|
| 87 |
+
self._collect_clips(split_scenes)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _load_nusc(self, test_split):
|
| 91 |
+
global nusc_data
|
| 92 |
+
if nusc_data is None:
|
| 93 |
+
data_split = 'v1.0-trainval' if not test_split else 'v1.0-test'
|
| 94 |
+
# data_split = 'v1.0-mini' # Or: 'v1.0-mini' for testing
|
| 95 |
+
nusc_data = NuScenes(version=data_split,
|
| 96 |
+
dataroot=os.path.join(self.root, self.dataset_name),
|
| 97 |
+
verbose=True)
|
| 98 |
+
self.nusc = nusc_data
|
| 99 |
+
|
| 100 |
+
dataset_split = 'train' if self.train else 'val'
|
| 101 |
+
if test_split:
|
| 102 |
+
dataset_split = 'test'
|
| 103 |
+
|
| 104 |
+
split_scene_names = create_splits_scenes()[dataset_split] # [train: 700, val: 150, test: 150]
|
| 105 |
+
split_scenes = [scene for scene in nusc_data.scene if scene['name'] in split_scene_names]
|
| 106 |
+
|
| 107 |
+
return split_scenes
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _collect_clips(self, split_scenes):
|
| 111 |
+
image_indices_by_scene = {}
|
| 112 |
+
|
| 113 |
+
def collect_frame(scene_idx, sample_data):
|
| 114 |
+
# Get image
|
| 115 |
+
image_path = os.path.join(self.root, self.dataset_name, sample_data['filename'])
|
| 116 |
+
self.image_files.append(image_path)
|
| 117 |
+
if image_indices_by_scene.get(scene_idx) is None:
|
| 118 |
+
image_indices_by_scene[scene_idx] = []
|
| 119 |
+
image_indices_by_scene[scene_idx].append(len(self.image_files) - 1)
|
| 120 |
+
|
| 121 |
+
# Parse label
|
| 122 |
+
labels = self._parse_label(sample_data["token"])
|
| 123 |
+
self.frame_labels.append(labels)
|
| 124 |
+
|
| 125 |
+
# Interpolating annotations to increase the frame rate (nuscenes annotation fps=2Hz, video data fps=12Hz)
|
| 126 |
+
self.fps = 7
|
| 127 |
+
target_period = 1/self.fps # For fps downsampling
|
| 128 |
+
max_frames_per_scene = 75
|
| 129 |
+
print("Collecting nuscenes clips...")
|
| 130 |
+
for scene_i, scene in enumerate(split_scenes):
|
| 131 |
+
|
| 132 |
+
curr_data_token = self.nusc.get('sample', scene['first_sample_token'])['data']["CAM_FRONT"]
|
| 133 |
+
curr_sample_data = self.nusc.get('sample_data', curr_data_token)
|
| 134 |
+
collect_frame(scene_i, curr_sample_data)
|
| 135 |
+
|
| 136 |
+
cumul_delta = 0
|
| 137 |
+
total_delta = 0
|
| 138 |
+
t = 0
|
| 139 |
+
while curr_data_token:
|
| 140 |
+
curr_sample_data = self.nusc.get('sample_data', curr_data_token)
|
| 141 |
+
|
| 142 |
+
next_sample_data_token = curr_sample_data['next']
|
| 143 |
+
if not next_sample_data_token:
|
| 144 |
+
break
|
| 145 |
+
next_sample_data = self.nusc.get('sample_data', next_sample_data_token)
|
| 146 |
+
|
| 147 |
+
# FPS downsampling: only select certain frames based on elapsed times
|
| 148 |
+
delta = (next_sample_data['timestamp'] - curr_sample_data['timestamp']) / 1e6
|
| 149 |
+
cumul_delta += delta
|
| 150 |
+
total_delta += delta
|
| 151 |
+
if cumul_delta >= target_period:
|
| 152 |
+
collect_frame(scene_i, next_sample_data)
|
| 153 |
+
t += 1
|
| 154 |
+
cumul_delta = cumul_delta - target_period
|
| 155 |
+
|
| 156 |
+
curr_data_token = next_sample_data_token
|
| 157 |
+
|
| 158 |
+
if len(image_indices_by_scene[scene_i]) > max_frames_per_scene:
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
# print(f"Fps: {len(image_indices_by_scene[scene_i]) / total_delta:.4f}")
|
| 162 |
+
|
| 163 |
+
if not self.non_overlapping_clips:
|
| 164 |
+
for image_idx in range(len(image_indices_by_scene[scene_i]) - self.clip_length + 1):
|
| 165 |
+
self.clip_list.append(image_indices_by_scene[scene_i][image_idx:image_idx+self.clip_length])
|
| 166 |
+
else:
|
| 167 |
+
# In case self.clip_length << actual video sample length (~20s), we can create multiple non-overlapping clips for each sample
|
| 168 |
+
total_frames = len(image_indices_by_scene[scene_i])
|
| 169 |
+
for clip_i in range(total_frames // self.clip_length):
|
| 170 |
+
start_image_idx = clip_i * self.clip_length
|
| 171 |
+
self.clip_list.append(image_indices_by_scene[scene_i][start_image_idx:start_image_idx+self.clip_length])
|
| 172 |
+
|
| 173 |
+
print("Number of nuScenes clips:", len(self.clip_list), f"({'train' if self.train else 'val'})")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _parse_label(self, token):
|
| 177 |
+
|
| 178 |
+
cam_front_data = self.nusc.get('sample_data', token)
|
| 179 |
+
|
| 180 |
+
# Check cache, if it doesn't exist, then create label file
|
| 181 |
+
filename = cam_front_data["filename"].split('/')[-1].split('.')[0]
|
| 182 |
+
label_file_path = os.path.join(self.label_dir, f"{filename}.json")
|
| 183 |
+
if os.path.exists(label_file_path):
|
| 184 |
+
with open(label_file_path, 'r') as json_file:
|
| 185 |
+
object_labels = json.load(json_file)
|
| 186 |
+
|
| 187 |
+
return object_labels
|
| 188 |
+
else:
|
| 189 |
+
front_camera_sensor = self.nusc.get('calibrated_sensor', cam_front_data['calibrated_sensor_token'])
|
| 190 |
+
camera_intrinsic = np.array(front_camera_sensor['camera_intrinsic'])
|
| 191 |
+
ego_pose = self.nusc.get('ego_pose', cam_front_data['ego_pose_token'])
|
| 192 |
+
|
| 193 |
+
object_labels = []
|
| 194 |
+
bbox_center_by_track_id = {}
|
| 195 |
+
for bbox_3d in self.nusc.get_boxes(token):
|
| 196 |
+
|
| 197 |
+
class_name = bbox_3d.name
|
| 198 |
+
if class_name not in NuScenesDataset.CLASS_NAME_TO_ID:
|
| 199 |
+
continue
|
| 200 |
+
class_id = NuScenesDataset.CLASS_NAME_TO_ID[class_name]
|
| 201 |
+
|
| 202 |
+
instance_token = self.nusc.get('sample_annotation', bbox_3d.token)['instance_token']
|
| 203 |
+
if instance_token not in self.inst_token_to_track_id:
|
| 204 |
+
self.inst_token_to_track_id[instance_token] = len(self.inst_token_to_track_id)
|
| 205 |
+
|
| 206 |
+
# Project 3D bboxes to 2D
|
| 207 |
+
# (Code adapted from: https://github.com/nutonomy/nuscenes-devkit/blob/master/python-sdk/nuscenes/scripts/export_2d_annotations_as_json.py)
|
| 208 |
+
|
| 209 |
+
# Move them to the ego-pose frame.
|
| 210 |
+
bbox_3d.translate(-np.array(ego_pose['translation']))
|
| 211 |
+
bbox_3d.rotate(Quaternion(ego_pose['rotation']).inverse)
|
| 212 |
+
|
| 213 |
+
# Move them to the calibrated sensor frame.
|
| 214 |
+
bbox_3d.translate(-np.array(front_camera_sensor['translation']))
|
| 215 |
+
bbox_3d.rotate(Quaternion(front_camera_sensor['rotation']).inverse)
|
| 216 |
+
|
| 217 |
+
# Filter out the corners that are not in front of the calibrated sensor.
|
| 218 |
+
corners_3d = bbox_3d.corners()
|
| 219 |
+
in_front = np.argwhere(corners_3d[2, :] > 0).flatten()
|
| 220 |
+
corners_3d = corners_3d[:, in_front]
|
| 221 |
+
|
| 222 |
+
# Project 3d box to 2d.
|
| 223 |
+
corner_coords = view_points(corners_3d, camera_intrinsic, True).T[:, :2].tolist()
|
| 224 |
+
|
| 225 |
+
# Keep only corners that fall within the image.
|
| 226 |
+
final_coords = self._post_process_coords(corner_coords)
|
| 227 |
+
|
| 228 |
+
# Skip if the convex hull of the re-projected corners does not intersect the image canvas.
|
| 229 |
+
if final_coords is None:
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
min_x, min_y, max_x, max_y = final_coords
|
| 233 |
+
track_id = self.inst_token_to_track_id[instance_token]
|
| 234 |
+
|
| 235 |
+
bbox_center_by_track_id[track_id] = bbox_3d.center
|
| 236 |
+
|
| 237 |
+
obj_label = {
|
| 238 |
+
'frame_name': cam_front_data["filename"],
|
| 239 |
+
'track_id': track_id,
|
| 240 |
+
'bbox': [min_x, min_y, max_x, max_y],
|
| 241 |
+
'class_id': class_id,
|
| 242 |
+
'class_name': class_name,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
object_labels.append(obj_label)
|
| 246 |
+
|
| 247 |
+
# Render the furthest bboxes first (closer ones should be on top)
|
| 248 |
+
object_labels.sort(key=lambda label: np.linalg.norm(bbox_center_by_track_id[label["track_id"]]), reverse=True)
|
| 249 |
+
|
| 250 |
+
# Cache file
|
| 251 |
+
with open(label_file_path, 'w') as json_file:
|
| 252 |
+
json.dump(object_labels, json_file)
|
| 253 |
+
print("Cached labels:", label_file_path)
|
| 254 |
+
|
| 255 |
+
return object_labels
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _post_process_coords(self, corner_coords: List) -> Union[Tuple[float, float, float, float], None]:
|
| 259 |
+
"""
|
| 260 |
+
Get the intersection of the convex hull of the reprojected bbox corners and the image canvas, return None if no intersection.
|
| 261 |
+
:param corner_coords: Corner coordinates of reprojected bounding box.
|
| 262 |
+
:param imsize: Size of the image canvas.
|
| 263 |
+
:return: Intersection of the convex hull of the 2D box corners and the image canvas.
|
| 264 |
+
"""
|
| 265 |
+
imsize = (self.orig_width, self.orig_height)
|
| 266 |
+
|
| 267 |
+
polygon_from_2d_box = MultiPoint(corner_coords).convex_hull
|
| 268 |
+
img_canvas = box(0, 0, imsize[0], imsize[1])
|
| 269 |
+
|
| 270 |
+
if polygon_from_2d_box.intersects(img_canvas):
|
| 271 |
+
img_intersection = polygon_from_2d_box.intersection(img_canvas)
|
| 272 |
+
intersection_coords = np.array([coord for coord in img_intersection.exterior.coords])
|
| 273 |
+
|
| 274 |
+
min_x = min(intersection_coords[:, 0])
|
| 275 |
+
min_y = min(intersection_coords[:, 1])
|
| 276 |
+
max_x = max(intersection_coords[:, 0])
|
| 277 |
+
max_y = max(intersection_coords[:, 1])
|
| 278 |
+
|
| 279 |
+
return min_x, min_y, max_x, max_y
|
| 280 |
+
else:
|
| 281 |
+
return None
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def pre_cache_dataset(dataset_root):
|
| 285 |
+
# Trigger label and bbox image cache generation
|
| 286 |
+
dataset_val = NuScenesDataset(root=dataset_root, train=False, clip_length=25, non_overlapping_clips=True)
|
| 287 |
+
for i in tqdm(range(len(dataset_val))):
|
| 288 |
+
d = dataset_val[i]
|
| 289 |
+
|
| 290 |
+
dataset_train = NuScenesDataset(root=dataset_root, train=True, clip_length=25, non_overlapping_clips=True)
|
| 291 |
+
for i in tqdm(range(len(dataset_train))):
|
| 292 |
+
d = dataset_train[i]
|
| 293 |
+
|
| 294 |
+
print("Done.")
|
| 295 |
+
|
| 296 |
+
if __name__ == "__main__":
|
| 297 |
+
dataset_root = "/path/to/Datasets"
|
| 298 |
+
pre_cache_dataset(dataset_root)
|
src/datasets/russia_crash_dataset.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
from src.datasets.base_dataset import BaseDataset
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RussiaCrashDataset(BaseDataset):
|
| 8 |
+
|
| 9 |
+
CLASS_NAME_TO_ID = {
|
| 10 |
+
'person': 1,
|
| 11 |
+
'car': 3,
|
| 12 |
+
'truck': 4,
|
| 13 |
+
'bus': 5,
|
| 14 |
+
'train': 6,
|
| 15 |
+
'motorcycle': 7,
|
| 16 |
+
'bicycle': 8,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
def __init__(self,
|
| 20 |
+
root='./datasets',
|
| 21 |
+
train=True,
|
| 22 |
+
clip_length=25,
|
| 23 |
+
orig_height=555, orig_width=986,
|
| 24 |
+
resize_height=320, resize_width=512,
|
| 25 |
+
non_overlapping_clips=False,
|
| 26 |
+
bbox_masking_prob=0.0,
|
| 27 |
+
sample_clip_from_end=True,
|
| 28 |
+
ego_only=False,
|
| 29 |
+
specific_samples=None):
|
| 30 |
+
|
| 31 |
+
super(RussiaCrashDataset, self).__init__(root=root,
|
| 32 |
+
train=train,
|
| 33 |
+
clip_length=clip_length,
|
| 34 |
+
resize_height=resize_height,
|
| 35 |
+
resize_width=resize_width,
|
| 36 |
+
non_overlapping_clips=non_overlapping_clips,
|
| 37 |
+
bbox_masking_prob=bbox_masking_prob,
|
| 38 |
+
sample_clip_from_end=sample_clip_from_end,
|
| 39 |
+
ego_only=ego_only)
|
| 40 |
+
|
| 41 |
+
self.dataset_name = "preprocess_russia_crash"
|
| 42 |
+
|
| 43 |
+
self.orig_width = orig_width
|
| 44 |
+
self.orig_height = orig_height
|
| 45 |
+
self.image_dir = os.path.join(self.root, self.dataset_name, "images", self.data_split)
|
| 46 |
+
self.label_dir = os.path.join(self.root, self.dataset_name, "labels", self.data_split)
|
| 47 |
+
self.bbox_image_dir = os.path.join(self.root, self.dataset_name, "bbox_images", self.data_split)
|
| 48 |
+
|
| 49 |
+
self.specific_samples = specific_samples
|
| 50 |
+
|
| 51 |
+
self._collect_clips()
|
| 52 |
+
|
| 53 |
+
def _collect_clips(self):
|
| 54 |
+
image_indices_by_clip = {}
|
| 55 |
+
for label_file in sorted(os.listdir(self.label_dir)):
|
| 56 |
+
if not label_file.endswith('.json'):
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
full_filename = os.path.join(self.label_dir, label_file)
|
| 60 |
+
with open(full_filename) as json_file:
|
| 61 |
+
all_data = json.load(json_file)
|
| 62 |
+
metadata = all_data['metadata']
|
| 63 |
+
|
| 64 |
+
# Only include dashcam samples
|
| 65 |
+
if metadata['camera'] != "Dashcam":
|
| 66 |
+
continue
|
| 67 |
+
# Exclude animal and "other" accidents
|
| 68 |
+
if metadata['accident_type'] == "Risk of collision/collision with an animal":
|
| 69 |
+
continue
|
| 70 |
+
if metadata['accident_type'] == 'Other types of traffic accidents':
|
| 71 |
+
continue
|
| 72 |
+
# NOTE uncomment to only include actual car collision (no close misses and dangerous events)
|
| 73 |
+
# if metadata['collision_type'] == "No Collision":
|
| 74 |
+
# continue
|
| 75 |
+
if self.ego_only:
|
| 76 |
+
print("Ego collisions only activated!")
|
| 77 |
+
if metadata['collision_type'] == "No Collision" or metadata["ego_car_involved"] != "Yes":
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
clip_filename = label_file.split('.')[0]
|
| 81 |
+
clip_file = os.path.join(self.image_dir, clip_filename)
|
| 82 |
+
|
| 83 |
+
if self.specific_samples is not None and clip_filename not in self.specific_samples:
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
if len(os.listdir(clip_file)) < self.clip_length:
|
| 87 |
+
# print(f"{clip_filename} does not have enough frames: has {len(os.listdir(clip_file))} expected at least {self.clip_length}")
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
clip_label_data = self._parse_clip_labels(all_data["data"])
|
| 91 |
+
self.frame_labels.extend(clip_label_data) # In this case labels are already sorted so they will match up to the image indices
|
| 92 |
+
|
| 93 |
+
image_indices_by_clip[clip_filename] = []
|
| 94 |
+
for image_file in sorted(os.listdir(clip_file)):
|
| 95 |
+
self.image_files.append(os.path.join(clip_file, image_file))
|
| 96 |
+
image_indices_by_clip[clip_filename].append(len(self.image_files)-1)
|
| 97 |
+
|
| 98 |
+
assert len(self.frame_labels) == len(self.image_files) # We assume a one-to-one association between images and labels
|
| 99 |
+
|
| 100 |
+
# Cut the videos in clips of the correct length according to the strategies chosen
|
| 101 |
+
if not self.non_overlapping_clips:
|
| 102 |
+
for image_idx in range(len(image_indices_by_clip[clip_filename]) - self.clip_length + 1):
|
| 103 |
+
self.clip_list.append(image_indices_by_clip[clip_filename][image_idx:image_idx+self.clip_length])
|
| 104 |
+
else:
|
| 105 |
+
if self.sample_clip_from_end:
|
| 106 |
+
# In case self.clip_length << actual video sample length, we can create multiple non-overlapping clips for each sample
|
| 107 |
+
# Prioritize selecting clips from the end, to make sur the accident is included (which tends to be at the end of the videos)
|
| 108 |
+
total_frames = len(image_indices_by_clip[clip_filename])
|
| 109 |
+
for clip_i in range(total_frames // self.clip_length):
|
| 110 |
+
start_image_idx = total_frames - (self.clip_length * (clip_i + 1))
|
| 111 |
+
end_image_idx = total_frames - (self.clip_length * clip_i)
|
| 112 |
+
self.clip_list.append(image_indices_by_clip[clip_filename][start_image_idx:end_image_idx])
|
| 113 |
+
else:
|
| 114 |
+
total_frames = len(image_indices_by_clip[clip_filename])
|
| 115 |
+
for clip_i in range(total_frames // self.clip_length):
|
| 116 |
+
start_image_idx = clip_i * self.clip_length
|
| 117 |
+
end_image_idx = start_image_idx + self.clip_length
|
| 118 |
+
self.clip_list.append(image_indices_by_clip[clip_filename][start_image_idx:end_image_idx])
|
| 119 |
+
|
| 120 |
+
print("Number of clips Russia_crash:", len(self.clip_list), f"({self.data_split})")
|
| 121 |
+
|
| 122 |
+
def _parse_clip_labels(self, clip_data):
|
| 123 |
+
frame_labels = []
|
| 124 |
+
for frame_data in clip_data:
|
| 125 |
+
obj_data = frame_data['labels']
|
| 126 |
+
|
| 127 |
+
object_labels = []
|
| 128 |
+
for label in obj_data:
|
| 129 |
+
# Only keep the classes of interest
|
| 130 |
+
class_id = RussiaCrashDataset.CLASS_NAME_TO_ID.get(label['name'])
|
| 131 |
+
if class_id is None:
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
# Convert bbox coordinates to pixel space wrt to image size
|
| 135 |
+
bbox = label['box']
|
| 136 |
+
bbox_coords_pixel = [int(bbox[0] * self.orig_width), # x1
|
| 137 |
+
int(bbox[1] * self.orig_height), # y1
|
| 138 |
+
int(bbox[2] * self.orig_width), # x2
|
| 139 |
+
int(bbox[3] * self.orig_height)] # y2
|
| 140 |
+
|
| 141 |
+
object_labels.append({
|
| 142 |
+
'frame_name': frame_data["image_source"],
|
| 143 |
+
'track_id': int(label['track_id']),
|
| 144 |
+
'bbox': bbox_coords_pixel,
|
| 145 |
+
'class_id': class_id,
|
| 146 |
+
'class_name': label['name'], # Class name of the object
|
| 147 |
+
})
|
| 148 |
+
|
| 149 |
+
frame_labels.append(object_labels)
|
| 150 |
+
|
| 151 |
+
return frame_labels
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def pre_cache_dataset(dataset_root):
|
| 155 |
+
# Trigger label and bbox image cache generation
|
| 156 |
+
dataset_val = RussiaCrashDataset(root=dataset_root, train=False, clip_length=25, non_overlapping_clips=True)
|
| 157 |
+
for i in tqdm(range(len(dataset_val))):
|
| 158 |
+
d = dataset_val[i]
|
| 159 |
+
|
| 160 |
+
dataset_train = RussiaCrashDataset(root=dataset_root, train=True, clip_length=25, non_overlapping_clips=True)
|
| 161 |
+
for i in tqdm(range(len(dataset_train))):
|
| 162 |
+
d = dataset_train[i]
|
| 163 |
+
|
| 164 |
+
print("Done.")
|
| 165 |
+
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
from tqdm import tqdm
|
| 168 |
+
|
| 169 |
+
dataset_root = "/path/to/Datasets"
|
| 170 |
+
pre_cache_dataset(dataset_root)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
|
src/eval/README.md
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Video Quality Evaluation Tools
|
| 2 |
+
|
| 3 |
+
This directory contains scripts for evaluating video quality metrics between generated and ground truth videos. There are four main evaluation scripts:
|
| 4 |
+
|
| 5 |
+
1. `video_quality_metrics_fvd_pair.py`: Evaluates FVD (Fréchet Video Distance) between paired generated and ground truth videos
|
| 6 |
+
2. `video_quality_metrics_fvd_gt_rand.py`: Evaluates FVD using pre-computed ground truth statistics
|
| 7 |
+
3. `video_quality_metrics_jedi_pair.py`: Evaluates JEDi metric between paired generated and ground truth videos
|
| 8 |
+
4. `video_quality_metrics_jedi_gt_rand.py`: Evaluates JEDi metric using random ground truth samples
|
| 9 |
+
|
| 10 |
+
## Video Generation
|
| 11 |
+
|
| 12 |
+
Before running the evaluation scripts, you'll need to generate video samples using the `run_gen_videos.py` script:
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
python run_gen_videos.py \
|
| 16 |
+
--model_path /path/to/model/checkpoint \
|
| 17 |
+
--output_path /path/to/output/videos \
|
| 18 |
+
--data_root /path/to/dataset_root \
|
| 19 |
+
--num_demo_samples 10 \
|
| 20 |
+
--max_output_vids 200 \
|
| 21 |
+
--num_gens_per_sample 1 \
|
| 22 |
+
--eval_output
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### Key Generation Arguments
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
--model_path PATH # Path to model checkpoint (required)
|
| 29 |
+
--data_root PATH # Dataset root path
|
| 30 |
+
--output_path PATH # Where to save generated videos
|
| 31 |
+
--num_demo_samples N # Number of samples to collect for generation
|
| 32 |
+
--max_output_vids N # Maximum number of videos to generate
|
| 33 |
+
--num_gens_per_sample N # Videos to generate per test case
|
| 34 |
+
|
| 35 |
+
# Optional arguments for controlling generation
|
| 36 |
+
--bbox_mask_idx_batch N1 N2 ... # Where to start masking (0-25)
|
| 37 |
+
--force_action_type_batch N1 N2 ... # Force specific action types (0-4)
|
| 38 |
+
--guidance_scales N1 N2 ... # Guidance scales to use
|
| 39 |
+
--seed N # Random seed for reproducibility
|
| 40 |
+
--disable_null_model # Disable null model for unconditional noise
|
| 41 |
+
--use_factor_guidance # Use factor guidance during generation
|
| 42 |
+
--eval_output # Enable evaluation output
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### Action Types
|
| 46 |
+
- 0: Normal driving
|
| 47 |
+
- 1-4: Different types of crash scenarios
|
| 48 |
+
|
| 49 |
+
## Common Arguments for Evaluation
|
| 50 |
+
|
| 51 |
+
All evaluation scripts share some common command line arguments:
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
--vid_root PATH # Root directory containing generated videos (required)
|
| 55 |
+
--samples N # Number of samples to evaluate (default: 200)
|
| 56 |
+
--num_frames N # Number of frames per video (default: 25)
|
| 57 |
+
--downsample_int N # Downsample interval for frames (default: 1)
|
| 58 |
+
--action_type N # Action type to filter videos (0: normal, 1-4: crash types)
|
| 59 |
+
--shuffle # Shuffle videos before evaluation
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## FVD Evaluation
|
| 63 |
+
|
| 64 |
+
### Paired Evaluation
|
| 65 |
+
```bash
|
| 66 |
+
python video_quality_metrics_fvd_pair.py \
|
| 67 |
+
--vid_root /path/to/videos \
|
| 68 |
+
--samples 200 \
|
| 69 |
+
--num_frames 25 \
|
| 70 |
+
--downsample
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### Ground Truth Statistics Evaluation
|
| 74 |
+
```bash
|
| 75 |
+
# First, collect ground truth statistics
|
| 76 |
+
python video_quality_metrics_fvd_gt_rand.py \
|
| 77 |
+
--vid_root /path/to/videos \
|
| 78 |
+
--collect_stats \
|
| 79 |
+
--samples 500 \
|
| 80 |
+
--action_type 1
|
| 81 |
+
|
| 82 |
+
# Then evaluate using the collected statistics
|
| 83 |
+
python video_quality_metrics_fvd_gt_rand.py \
|
| 84 |
+
--vid_root /path/to/videos \
|
| 85 |
+
--gt_stats /path/to/stats.npz \
|
| 86 |
+
--samples 200 \
|
| 87 |
+
--shuffle
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## JEDi Evaluation
|
| 91 |
+
|
| 92 |
+
### Paired Evaluation
|
| 93 |
+
```bash
|
| 94 |
+
python video_quality_metrics_jedi_pair.py \
|
| 95 |
+
--vid_root /path/to/videos \
|
| 96 |
+
--samples 200 \
|
| 97 |
+
--num_frames 25 \
|
| 98 |
+
--test_feature_path /path/to/features
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
### Ground Truth Random Evaluation
|
| 102 |
+
```bash
|
| 103 |
+
python video_quality_metrics_jedi_gt_rand.py \
|
| 104 |
+
--vid_root /path/to/videos \
|
| 105 |
+
--samples 200 \
|
| 106 |
+
--gt_samples 500 \
|
| 107 |
+
--test_feature_path /path/to/features \
|
| 108 |
+
--action_type 1 \
|
| 109 |
+
--shuffle
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## Additional Notes
|
| 113 |
+
|
| 114 |
+
- The `--action_type` argument can be used to filter videos by category:
|
| 115 |
+
- 0: Normal driving videos
|
| 116 |
+
- 1-4: Different types of crash videos
|
| 117 |
+
- For FVD evaluation with ground truth statistics, you can collect statistics once and reuse them for multiple evaluations
|
| 118 |
+
- The JEDi metric requires a test feature path for model loading
|
| 119 |
+
- All scripts support shuffling of videos before evaluation for more robust results
|
| 120 |
+
- The default resolution for videos is 320x512 pixels
|
src/eval/__pycache__/generate_samples.cpython-310.pyc
ADDED
|
Binary file (9.13 kB). View file
|
|
|
src/eval/generate_samples.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image, ImageDraw
|
| 3 |
+
import cv2
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import json
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
import warnings
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
torch.cuda.empty_cache()
|
| 12 |
+
import torch.utils.checkpoint
|
| 13 |
+
from accelerate.utils import set_seed
|
| 14 |
+
|
| 15 |
+
with warnings.catch_warnings():
|
| 16 |
+
warnings.simplefilter("ignore")
|
| 17 |
+
from src.pipelines import StableVideoControlPipeline
|
| 18 |
+
from src.pipelines import StableVideoControlNullModelPipeline
|
| 19 |
+
from src.pipelines import StableVideoControlFactorGuidancePipeline
|
| 20 |
+
|
| 21 |
+
from src.models import UNetSpatioTemporalConditionModel, ControlNetModel
|
| 22 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 23 |
+
from diffusers.models import AutoencoderKLTemporalDecoder
|
| 24 |
+
|
| 25 |
+
from src.datasets.dataset_utils import get_dataloader
|
| 26 |
+
from src.utils import get_samples
|
| 27 |
+
|
| 28 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 29 |
+
print(f"Device: {device}")
|
| 30 |
+
|
| 31 |
+
generator = None #torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 32 |
+
CLIP_LENGTH = 25
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def create_video_from_np(sample, video_path, fps=6):
|
| 36 |
+
video_filename = f"{video_path}.mp4"
|
| 37 |
+
frame_size = (512, 320)
|
| 38 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 39 |
+
video_writer_out = cv2.VideoWriter(video_filename, fourcc, fps, frame_size)
|
| 40 |
+
|
| 41 |
+
for img in sample:
|
| 42 |
+
img = np.transpose(img, (1, 2, 0))
|
| 43 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 44 |
+
video_writer_out.write(img)
|
| 45 |
+
|
| 46 |
+
video_writer_out.release()
|
| 47 |
+
print(f"Video saved: {video_filename}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def export_to_video(video_frames, output_video_path=None, fps=6):
|
| 51 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 52 |
+
h, w, c = video_frames[0].shape
|
| 53 |
+
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
|
| 54 |
+
for i in range(len(video_frames)):
|
| 55 |
+
img = cv2.cvtColor(video_frames[i].astype(np.uint8), cv2.COLOR_RGB2BGR)
|
| 56 |
+
video_writer.write(img)
|
| 57 |
+
return output_video_path
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def label_frames_with_action_id(bbox_frames, action_id, masked_idx=None):
|
| 61 |
+
action_name = {0: "Normal", 1: "Ego", 2: "Ego/Veh", 3: "Veh", 4: "Veh/Veh"}
|
| 62 |
+
action_text = f"Action: {action_name[action_id]} ({action_id})"
|
| 63 |
+
for i in range(bbox_frames.shape[0]):
|
| 64 |
+
# Convert numpy array to PIL Image
|
| 65 |
+
frame = Image.fromarray(bbox_frames[i].transpose(1, 2, 0))
|
| 66 |
+
draw = ImageDraw.Draw(frame)
|
| 67 |
+
|
| 68 |
+
# Add text in top right corner
|
| 69 |
+
text_position = (frame.width - 10, 10) # 10 pixels from top, 10 pixels from right
|
| 70 |
+
if masked_idx is not None and masked_idx <= i:
|
| 71 |
+
text_color = (0, 0, 0)
|
| 72 |
+
action_text = f"Action: {action_name[action_id]} ({action_id}) [masked]"
|
| 73 |
+
else:
|
| 74 |
+
text_color = (255, 255, 255)
|
| 75 |
+
action_text = action_text
|
| 76 |
+
|
| 77 |
+
draw.text(text_position, action_text, fill=text_color, anchor="ra")
|
| 78 |
+
|
| 79 |
+
# Convert back to numpy array
|
| 80 |
+
bbox_frames[i] = np.array(frame).transpose(2, 0, 1)
|
| 81 |
+
|
| 82 |
+
return bbox_frames
|
| 83 |
+
|
| 84 |
+
def load_ctrlv_pipelines(model_dir, use_null_model=False, use_factor_guidance=False):
|
| 85 |
+
unet_variant = "fp16" if "stabilityai" in model_dir else None
|
| 86 |
+
|
| 87 |
+
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
| 88 |
+
model_dir,
|
| 89 |
+
subfolder="unet",
|
| 90 |
+
variant=unet_variant,
|
| 91 |
+
low_cpu_mem_usage=True,
|
| 92 |
+
num_frames=CLIP_LENGTH
|
| 93 |
+
)
|
| 94 |
+
ctrlnet = ControlNetModel.from_pretrained(
|
| 95 |
+
model_dir,
|
| 96 |
+
subfolder="control_net",
|
| 97 |
+
variant=unet_variant,
|
| 98 |
+
num_frames=25
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if not use_null_model and not use_factor_guidance:
|
| 102 |
+
pipeline = StableVideoControlPipeline.from_pretrained(
|
| 103 |
+
"stabilityai/stable-video-diffusion-img2vid-xt",
|
| 104 |
+
controlnet=ctrlnet,
|
| 105 |
+
unet=unet,
|
| 106 |
+
variant=unet_variant
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
else:
|
| 110 |
+
|
| 111 |
+
# For null model prediction of uncond noise
|
| 112 |
+
null_model_path = "stabilityai/stable-video-diffusion-img2vid-xt"
|
| 113 |
+
null_model_unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
| 114 |
+
null_model_path,
|
| 115 |
+
subfolder="unet",
|
| 116 |
+
variant=None,
|
| 117 |
+
low_cpu_mem_usage=True,
|
| 118 |
+
num_frames=CLIP_LENGTH
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
if use_null_model and not use_factor_guidance:
|
| 122 |
+
pipeline = StableVideoControlNullModelPipeline.from_pretrained(
|
| 123 |
+
"stabilityai/stable-video-diffusion-img2vid-xt",
|
| 124 |
+
controlnet=ctrlnet,
|
| 125 |
+
unet=unet,
|
| 126 |
+
null_model=null_model_unet,
|
| 127 |
+
variant=unet_variant
|
| 128 |
+
)
|
| 129 |
+
elif use_factor_guidance:
|
| 130 |
+
pipeline = StableVideoControlFactorGuidancePipeline.from_pretrained(
|
| 131 |
+
"stabilityai/stable-video-diffusion-img2vid-xt",
|
| 132 |
+
controlnet=ctrlnet,
|
| 133 |
+
unet=unet,
|
| 134 |
+
null_model=null_model_unet,
|
| 135 |
+
variant=unet_variant
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
pipeline = pipeline.to(device)
|
| 139 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 140 |
+
|
| 141 |
+
unet.eval()
|
| 142 |
+
ctrlnet.eval()
|
| 143 |
+
|
| 144 |
+
return pipeline
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def generate_video_ctrlv(sample, pipeline, video_path="video_out/genvid", json_path="video_out/gt_frames", bbox_mask_frames=None, action_type=None, use_factor_guidance=False, guidance=[1.0, 3.0], video_path2=None):
|
| 148 |
+
frame_size = (512, 320)
|
| 149 |
+
FPS = 6
|
| 150 |
+
CLIP_LENGTH = sample['bbox_images'].shape[0]
|
| 151 |
+
|
| 152 |
+
init_image = sample['image_init']
|
| 153 |
+
bbox_images = sample['bbox_images'].unsqueeze(0)
|
| 154 |
+
action_type = sample['action_type'].unsqueeze(0) if action_type is None else action_type
|
| 155 |
+
|
| 156 |
+
sample['bbox_images'].to(device)
|
| 157 |
+
|
| 158 |
+
# Save GT frame paths to json file
|
| 159 |
+
gt_frame_paths = [file_path[0] for file_path in sample['image_paths']]
|
| 160 |
+
with open(json_path, "w") as file:
|
| 161 |
+
json.dump(gt_frame_paths, file, indent=1)
|
| 162 |
+
print("Saved GT frames json file:", json_path)
|
| 163 |
+
|
| 164 |
+
if not use_factor_guidance:
|
| 165 |
+
frames = pipeline(init_image,
|
| 166 |
+
cond_images=bbox_images,
|
| 167 |
+
bbox_mask_frames=bbox_mask_frames,
|
| 168 |
+
action_type=action_type,
|
| 169 |
+
height=frame_size[1], width=frame_size[0],
|
| 170 |
+
decode_chunk_size=8, motion_bucket_id=127, fps=FPS,
|
| 171 |
+
num_inference_steps=30,
|
| 172 |
+
num_frames=CLIP_LENGTH,
|
| 173 |
+
control_condition_scale=1.0,
|
| 174 |
+
min_guidance_scale=guidance[0],
|
| 175 |
+
max_guidance_scale=guidance[1],
|
| 176 |
+
noise_aug_strength=0.01,
|
| 177 |
+
generator=generator, output_type='pt').frames[0]
|
| 178 |
+
else:
|
| 179 |
+
frames = pipeline(init_image,
|
| 180 |
+
cond_images=bbox_images,
|
| 181 |
+
bbox_mask_frames=bbox_mask_frames,
|
| 182 |
+
action_type=action_type,
|
| 183 |
+
height=frame_size[1], width=frame_size[0],
|
| 184 |
+
decode_chunk_size=8, motion_bucket_id=127, fps=FPS,
|
| 185 |
+
num_inference_steps=30,
|
| 186 |
+
num_frames=CLIP_LENGTH,
|
| 187 |
+
control_condition_scale=1.0,
|
| 188 |
+
min_guidance_scale_img=1.0,
|
| 189 |
+
max_guidance_scale_img=3.0,
|
| 190 |
+
min_guidance_scale_action=6.0,
|
| 191 |
+
max_guidance_scale_action=12.0,
|
| 192 |
+
min_guidance_scale_bbox=1.0,
|
| 193 |
+
max_guidance_scale_bbox=3.0,
|
| 194 |
+
noise_aug_strength=0.01,
|
| 195 |
+
generator=generator, output_type='pt').frames[0]
|
| 196 |
+
|
| 197 |
+
frames = frames.detach().cpu().numpy()*255
|
| 198 |
+
frames = frames.astype(np.uint8)
|
| 199 |
+
|
| 200 |
+
tmp = np.moveaxis(np.transpose(frames, (0, 2, 3, 1)), 0, 0)
|
| 201 |
+
output_video_path = f"{video_path}.mp4"
|
| 202 |
+
export_to_video(tmp, output_video_path, fps=FPS)
|
| 203 |
+
print(f"Video saved:", output_video_path)
|
| 204 |
+
|
| 205 |
+
if video_path2 is not None:
|
| 206 |
+
output_video_path2 = f"{video_path2}.mp4"
|
| 207 |
+
export_to_video(tmp, output_video_path2, fps=FPS)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def generate_samples(args):
|
| 211 |
+
model_path = args.model_path
|
| 212 |
+
print("Model path:", model_path)
|
| 213 |
+
|
| 214 |
+
if args.seed is not None:
|
| 215 |
+
set_seed(args.seed)
|
| 216 |
+
print("Set seed:", args.seed)
|
| 217 |
+
|
| 218 |
+
# LOAD PIPELINE
|
| 219 |
+
use_null_model = not args.disable_null_model
|
| 220 |
+
use_factor_guidance = args.use_factor_guidance
|
| 221 |
+
pipeline = load_ctrlv_pipelines(model_path, use_null_model=use_null_model, use_factor_guidance=use_factor_guidance)
|
| 222 |
+
|
| 223 |
+
# LOAD DATASET
|
| 224 |
+
data_root = args.data_root
|
| 225 |
+
dataset_name = args.dataset
|
| 226 |
+
train_set = False
|
| 227 |
+
val_dataset, val_loader = get_dataloader(
|
| 228 |
+
data_root, dataset_name, if_train=train_set, clip_length=CLIP_LENGTH,
|
| 229 |
+
batch_size=1, num_workers=0, shuffle=True,
|
| 230 |
+
image_height=320, image_width=512,
|
| 231 |
+
non_overlapping_clips=True, #specific_samples=specific_samples
|
| 232 |
+
)
|
| 233 |
+
if train_set:
|
| 234 |
+
print("WARNING: Currently using training split")
|
| 235 |
+
|
| 236 |
+
# COLLECT SAMPLES
|
| 237 |
+
num_demo_samples = args.num_demo_samples
|
| 238 |
+
demo_samples = get_samples(val_loader, num_demo_samples, show_progress=True)
|
| 239 |
+
|
| 240 |
+
sample_range = range(0, num_demo_samples)
|
| 241 |
+
num_samples = len(sample_range)
|
| 242 |
+
|
| 243 |
+
# video_dir_path = os.path.join(os.getcwd(), "video_out", "video_out_box2video_may1_eval_test")
|
| 244 |
+
video_dir_path = args.output_path
|
| 245 |
+
os.makedirs(video_dir_path, exist_ok=True)
|
| 246 |
+
video_counter = 0
|
| 247 |
+
|
| 248 |
+
# GENERATION PARAMETERS
|
| 249 |
+
|
| 250 |
+
# Set the bbox masking
|
| 251 |
+
bbox_mask_idx_batch = args.bbox_mask_idx_batch
|
| 252 |
+
condition_on_last_bbox = False
|
| 253 |
+
|
| 254 |
+
# Set the action type
|
| 255 |
+
force_action_type = None #1 # 0: Normal, 1: Ego, 2: Ego/Veh, 3: Veh, 4: Veh/Veh
|
| 256 |
+
force_action_type_batch = args.force_action_type_batch
|
| 257 |
+
|
| 258 |
+
num_gens_per_sample = args.num_gens_per_sample
|
| 259 |
+
guidance_scales = args.guidance_scales
|
| 260 |
+
eval_output = args.eval_output
|
| 261 |
+
|
| 262 |
+
# GENERATE VIDEOS
|
| 263 |
+
|
| 264 |
+
# Check for samples that were already done and do not compute them again
|
| 265 |
+
skip_samples = {}
|
| 266 |
+
out_video_path = f"{video_dir_path}/gt_ref"
|
| 267 |
+
if os.path.exists(out_video_path):
|
| 268 |
+
all_videos = os.listdir(out_video_path)
|
| 269 |
+
video_counter = len(all_videos)
|
| 270 |
+
for sample_name in all_videos:
|
| 271 |
+
vid_name = "_".join(sample_name.split("_")[1:])
|
| 272 |
+
skip_samples[vid_name] = True
|
| 273 |
+
|
| 274 |
+
print("SKIP SAMPLES:", skip_samples)
|
| 275 |
+
|
| 276 |
+
for guidance in guidance_scales or [-1]:
|
| 277 |
+
|
| 278 |
+
if guidance != -1:
|
| 279 |
+
print("Guidance:", force_action_type)
|
| 280 |
+
else:
|
| 281 |
+
guidance = [1, 3]
|
| 282 |
+
|
| 283 |
+
for _ in range(num_gens_per_sample):
|
| 284 |
+
for force_action_type in force_action_type_batch or [-1]:
|
| 285 |
+
|
| 286 |
+
if force_action_type != -1:
|
| 287 |
+
print("Force action type:", force_action_type)
|
| 288 |
+
else:
|
| 289 |
+
force_action_type = None
|
| 290 |
+
|
| 291 |
+
for bbox_mask_idx in bbox_mask_idx_batch or [-1]:
|
| 292 |
+
|
| 293 |
+
if bbox_mask_idx != -1:
|
| 294 |
+
print("Bbox masking:", bbox_mask_idx)
|
| 295 |
+
else:
|
| 296 |
+
bbox_mask_idx = None
|
| 297 |
+
|
| 298 |
+
for i, sample in tqdm(enumerate(demo_samples), desc="Generating samples", total=num_samples):
|
| 299 |
+
if i >= list(sample_range)[-1] + 1:
|
| 300 |
+
break
|
| 301 |
+
if i not in sample_range:
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
if video_counter > args.max_output_vids:
|
| 305 |
+
print(f"MAX OUTPUT VIDS REACHED: {video_counter} >= {args.max_output_vids}")
|
| 306 |
+
exit()
|
| 307 |
+
|
| 308 |
+
vid_name = sample["vid_name"]
|
| 309 |
+
|
| 310 |
+
mask_hint = "" if bbox_mask_idx is None else f"_bframes:{str(bbox_mask_idx)}"
|
| 311 |
+
action_hint = "" if force_action_type is None else f"_action:{str(force_action_type)}"
|
| 312 |
+
guidance_hint = "" if guidance_scales is None else f"_guide{guidance[0]}:{guidance[1]}"
|
| 313 |
+
scene_name = f"{video_counter}_{vid_name}{mask_hint}{action_hint}{guidance_hint}"
|
| 314 |
+
|
| 315 |
+
scene_name_no_counter = "_".join(scene_name.split("_")[1:])
|
| 316 |
+
if scene_name_no_counter in skip_samples:
|
| 317 |
+
print(f"Skipping sample that was already computed: {vid_name}")
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
print("Generating video for:", scene_name)
|
| 321 |
+
|
| 322 |
+
if eval_output:
|
| 323 |
+
os.makedirs(f"{video_dir_path}/gen_videos", exist_ok=True)
|
| 324 |
+
os.makedirs(f"{video_dir_path}/gt_frames", exist_ok=True)
|
| 325 |
+
os.makedirs(f"{video_dir_path}/gt_ref", exist_ok=True)
|
| 326 |
+
|
| 327 |
+
gt_vid_path = f"{video_dir_path}/gt_ref/{scene_name}/(1)gt_video_{scene_name}"
|
| 328 |
+
bbox_out_path_root = f"{video_dir_path}/gt_ref/{scene_name}"
|
| 329 |
+
out_video_path = f"{video_dir_path}/gen_videos/genvid_{video_counter}_{vid_name}"
|
| 330 |
+
out_json_path = os.path.join(video_dir_path, "gt_frames", f"gt_frames_{video_counter}_{vid_name}.json")
|
| 331 |
+
|
| 332 |
+
out_video_path2 = f"{bbox_out_path_root}/(3)genvid_adv_{scene_name}"
|
| 333 |
+
|
| 334 |
+
os.makedirs(bbox_out_path_root, exist_ok=True)
|
| 335 |
+
else:
|
| 336 |
+
os.makedirs(f"{video_dir_path}/{scene_name}", exist_ok=True)
|
| 337 |
+
|
| 338 |
+
gt_vid_path = f"{video_dir_path}/{scene_name}/(1)gt_video_{scene_name}"
|
| 339 |
+
bbox_out_path_root = f"{video_dir_path}/{scene_name}"
|
| 340 |
+
out_video_path = f"{video_dir_path}/{scene_name}/(3)genvid_adv_{scene_name}"
|
| 341 |
+
out_json_path = os.path.join(video_dir_path, scene_name, f"gt_frames_{sample['vid_name']}.json")
|
| 342 |
+
out_video_path2 = None
|
| 343 |
+
|
| 344 |
+
create_video_from_np(sample['gt_clip_np'], video_path=gt_vid_path)
|
| 345 |
+
|
| 346 |
+
# Add action type text to ground truth bounding box frames # TODO: Make sure the action type aligns if we change it for generation
|
| 347 |
+
action_type = sample['action_type'].unsqueeze(0)
|
| 348 |
+
og_action_type = action_type.item()
|
| 349 |
+
if force_action_type is not None:
|
| 350 |
+
action_type = torch.ones_like(action_type) * force_action_type
|
| 351 |
+
|
| 352 |
+
action_id = action_type.item()
|
| 353 |
+
bbox_frames = sample['bbox_images_np'].copy()
|
| 354 |
+
if bbox_mask_idx is not None:
|
| 355 |
+
# print(f"Masking bboxes after index {bbox_mask_idx}")
|
| 356 |
+
|
| 357 |
+
# Let's save a copy of the original bboxes for reference
|
| 358 |
+
bbox_frames_ref = sample['bbox_images_np'].copy()
|
| 359 |
+
label_frames_with_action_id(bbox_frames_ref, og_action_type)
|
| 360 |
+
create_video_from_np(bbox_frames_ref, video_path=f"{bbox_out_path_root}/(2)video_2dbboxes_{scene_name}_nomask")
|
| 361 |
+
|
| 362 |
+
# For display, let's mask with white
|
| 363 |
+
mask_cond = bbox_mask_idx <= np.arange(CLIP_LENGTH).reshape(CLIP_LENGTH, 1, 1, 1)
|
| 364 |
+
if condition_on_last_bbox:
|
| 365 |
+
mask_cond[-1, 0, 0, 0] = False
|
| 366 |
+
bbox_frames = np.where(mask_cond, np.ones_like(bbox_frames)*255, bbox_frames)
|
| 367 |
+
label_frames_with_action_id(bbox_frames, action_id, masked_idx=bbox_mask_idx)
|
| 368 |
+
else:
|
| 369 |
+
label_frames_with_action_id(bbox_frames, action_id)
|
| 370 |
+
|
| 371 |
+
create_video_from_np(bbox_frames, video_path=f"{bbox_out_path_root}/(2)video_2dbboxes_{scene_name}")
|
| 372 |
+
|
| 373 |
+
bbox_mask_frames = [False] * CLIP_LENGTH
|
| 374 |
+
if bbox_mask_idx is not None:
|
| 375 |
+
bbox_mask_frames[bbox_mask_idx:] = [True] * (len(bbox_mask_frames) - bbox_mask_idx)
|
| 376 |
+
if condition_on_last_bbox:
|
| 377 |
+
bbox_mask_frames[-1] = False
|
| 378 |
+
|
| 379 |
+
generate_video_ctrlv(
|
| 380 |
+
sample,
|
| 381 |
+
pipeline,
|
| 382 |
+
video_path=out_video_path,
|
| 383 |
+
json_path=out_json_path,
|
| 384 |
+
bbox_mask_frames=bbox_mask_frames,
|
| 385 |
+
action_type=action_type,
|
| 386 |
+
use_factor_guidance=use_factor_guidance,
|
| 387 |
+
guidance=guidance,
|
| 388 |
+
video_path2=out_video_path2
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
video_counter += 1
|
| 392 |
+
|
| 393 |
+
print("DONE")
|
| 394 |
+
|
src/eval/video_dataset.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
class VideoDataset(Dataset):
|
| 8 |
+
def __init__(self, video_root, num_frames=25, downsample_int=1, transform=None):
|
| 9 |
+
"""
|
| 10 |
+
Args:
|
| 11 |
+
video_root (str): Directory with all the video files
|
| 12 |
+
num_frames (int): Number of frames to extract from each video
|
| 13 |
+
downsample_int (int): Interval between frames to extract
|
| 14 |
+
transform (callable, optional): Optional transform to be applied on frames
|
| 15 |
+
"""
|
| 16 |
+
self.video_root = video_root
|
| 17 |
+
self.num_frames = num_frames
|
| 18 |
+
self.downsample_int = downsample_int
|
| 19 |
+
self.transform = transform
|
| 20 |
+
|
| 21 |
+
# Get list of video files
|
| 22 |
+
self.video_files = []
|
| 23 |
+
gen_videos = os.path.join(video_root, "gen_videos") if os.path.exists(os.path.join(video_root, "gen_videos")) else video_root
|
| 24 |
+
for fname in os.listdir(gen_videos):
|
| 25 |
+
if fname.endswith('.mp4'):
|
| 26 |
+
self.video_files.append(os.path.join(gen_videos, fname))
|
| 27 |
+
|
| 28 |
+
self.video_files.sort()
|
| 29 |
+
|
| 30 |
+
def __len__(self):
|
| 31 |
+
return len(self.video_files)
|
| 32 |
+
|
| 33 |
+
def get_frames_mp4(self, video_path):
|
| 34 |
+
"""Extract frames from video file"""
|
| 35 |
+
cap = cv2.VideoCapture(video_path)
|
| 36 |
+
if not cap.isOpened():
|
| 37 |
+
raise ValueError(f"Could not open video file: {video_path}")
|
| 38 |
+
|
| 39 |
+
frames = []
|
| 40 |
+
frame_count = 0
|
| 41 |
+
|
| 42 |
+
while True:
|
| 43 |
+
ret, frame = cap.read()
|
| 44 |
+
if not ret:
|
| 45 |
+
break
|
| 46 |
+
|
| 47 |
+
frame = cv2.resize(frame, (512, 320))
|
| 48 |
+
|
| 49 |
+
if frame_count % self.downsample_int == 0:
|
| 50 |
+
frames.append(frame)
|
| 51 |
+
|
| 52 |
+
frame_count += 1
|
| 53 |
+
|
| 54 |
+
if len(frames) >= self.num_frames:
|
| 55 |
+
break
|
| 56 |
+
|
| 57 |
+
cap.release()
|
| 58 |
+
|
| 59 |
+
if len(frames) < self.num_frames:
|
| 60 |
+
# Pad with last frame if we don't have enough frames
|
| 61 |
+
last_frame = frames[-1] if frames else np.zeros((320, 512, 3), dtype=np.uint8)
|
| 62 |
+
while len(frames) < self.num_frames:
|
| 63 |
+
frames.append(last_frame)
|
| 64 |
+
|
| 65 |
+
return np.array(frames[:self.num_frames])
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, idx):
|
| 68 |
+
video_path = self.video_files[idx]
|
| 69 |
+
frames = self.get_frames_mp4(video_path)
|
| 70 |
+
|
| 71 |
+
# Convert to torch tensor and normalize
|
| 72 |
+
frames = torch.from_numpy(frames).float()
|
| 73 |
+
frames = frames.permute(0, 3, 1, 2) # Change from (T, H, W, C) to (T, C, H, W)
|
| 74 |
+
frames = frames / (255/2.0) - 1.0 # Normalize to [-1, 1]
|
| 75 |
+
|
| 76 |
+
if self.transform:
|
| 77 |
+
frames = self.transform(frames)
|
| 78 |
+
|
| 79 |
+
return frames, []
|
src/eval/video_quality_metrics_fvd_gt_rand.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import scipy.linalg
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import math
|
| 7 |
+
import cv2
|
| 8 |
+
import json
|
| 9 |
+
import random
|
| 10 |
+
import os
|
| 11 |
+
import argparse
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import io
|
| 16 |
+
import re
|
| 17 |
+
import requests
|
| 18 |
+
import html
|
| 19 |
+
import hashlib
|
| 20 |
+
import urllib
|
| 21 |
+
import urllib.request
|
| 22 |
+
import uuid
|
| 23 |
+
|
| 24 |
+
from distutils.util import strtobool
|
| 25 |
+
from typing import Any, List, Tuple, Union, Dict
|
| 26 |
+
|
| 27 |
+
from src.datasets.dataset_utils import get_dataloader
|
| 28 |
+
from src.utils import get_samples
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_frames_from_path_list(path_list):
|
| 32 |
+
frames = []
|
| 33 |
+
for path in path_list:
|
| 34 |
+
img = cv2.imread(path)
|
| 35 |
+
img = cv2.resize(img, [512, 320])
|
| 36 |
+
frames.append(img)
|
| 37 |
+
return np.array(frames)
|
| 38 |
+
|
| 39 |
+
def get_frames_mp4(video_path: str, frame_interval: int = 1) -> None:
|
| 40 |
+
|
| 41 |
+
# Open the video file
|
| 42 |
+
cap = cv2.VideoCapture(video_path)
|
| 43 |
+
if not cap.isOpened():
|
| 44 |
+
raise ValueError(f"Could not open video file: {video_path}")
|
| 45 |
+
|
| 46 |
+
frame_count = 0
|
| 47 |
+
saved_count = 0
|
| 48 |
+
|
| 49 |
+
frames = []
|
| 50 |
+
while True:
|
| 51 |
+
ret, frame = cap.read()
|
| 52 |
+
if not ret:
|
| 53 |
+
break
|
| 54 |
+
|
| 55 |
+
frame = cv2.resize(frame, (512, 320))
|
| 56 |
+
|
| 57 |
+
# Save frame if it's the right interval
|
| 58 |
+
if frame_count % frame_interval == 0:
|
| 59 |
+
frames.append(frame)
|
| 60 |
+
saved_count += 1
|
| 61 |
+
|
| 62 |
+
frame_count += 1
|
| 63 |
+
|
| 64 |
+
cap.release()
|
| 65 |
+
return np.array(frames)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def load_json(filename):
|
| 69 |
+
if os.path.exists(filename):
|
| 70 |
+
with open(filename, "r") as f:
|
| 71 |
+
return json.load(f)
|
| 72 |
+
print(filename, "not found")
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any:
|
| 77 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 78 |
+
assert num_attempts >= 1
|
| 79 |
+
|
| 80 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 81 |
+
if not re.match('^[a-z]+://', url):
|
| 82 |
+
return url if return_filename else open(url, "rb")
|
| 83 |
+
|
| 84 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 85 |
+
# arise on Windows:
|
| 86 |
+
#
|
| 87 |
+
# file:///c:/foo.txt
|
| 88 |
+
#
|
| 89 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 90 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 91 |
+
#
|
| 92 |
+
# If you touch this code path, you should test it on both Linux and
|
| 93 |
+
# Windows.
|
| 94 |
+
#
|
| 95 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
| 96 |
+
# but that converts forward slashes to backslashes and this causes
|
| 97 |
+
# its own set of problems.
|
| 98 |
+
if url.startswith('file://'):
|
| 99 |
+
filename = urllib.parse.urlparse(url).path
|
| 100 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
| 101 |
+
filename = filename[1:]
|
| 102 |
+
return filename if return_filename else open(filename, "rb")
|
| 103 |
+
|
| 104 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 105 |
+
|
| 106 |
+
# Download.
|
| 107 |
+
url_name = None
|
| 108 |
+
url_data = None
|
| 109 |
+
with requests.Session() as session:
|
| 110 |
+
if verbose:
|
| 111 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 112 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 113 |
+
try:
|
| 114 |
+
with session.get(url) as res:
|
| 115 |
+
res.raise_for_status()
|
| 116 |
+
if len(res.content) == 0:
|
| 117 |
+
raise IOError("No data received")
|
| 118 |
+
|
| 119 |
+
if len(res.content) < 8192:
|
| 120 |
+
content_str = res.content.decode("utf-8")
|
| 121 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 122 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
| 123 |
+
if len(links) == 1:
|
| 124 |
+
url = requests.compat.urljoin(url, links[0])
|
| 125 |
+
raise IOError("Google Drive virus checker nag")
|
| 126 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 127 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
| 128 |
+
|
| 129 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
| 130 |
+
url_name = match[1] if match else url
|
| 131 |
+
url_data = res.content
|
| 132 |
+
if verbose:
|
| 133 |
+
print(" done")
|
| 134 |
+
break
|
| 135 |
+
except KeyboardInterrupt:
|
| 136 |
+
raise
|
| 137 |
+
except:
|
| 138 |
+
if not attempts_left:
|
| 139 |
+
if verbose:
|
| 140 |
+
print(" failed")
|
| 141 |
+
raise
|
| 142 |
+
if verbose:
|
| 143 |
+
print(".", end="", flush=True)
|
| 144 |
+
|
| 145 |
+
# Return data as file object.
|
| 146 |
+
assert not return_filename
|
| 147 |
+
return io.BytesIO(url_data)
|
| 148 |
+
|
| 149 |
+
"""
|
| 150 |
+
Modified from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
|
| 151 |
+
"""
|
| 152 |
+
class FVD:
|
| 153 |
+
def __init__(self, device,
|
| 154 |
+
detector_url='https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1',
|
| 155 |
+
rescale=False, resize=False, return_features=True):
|
| 156 |
+
|
| 157 |
+
self.device = device
|
| 158 |
+
self.detector_kwargs = dict(rescale=False, resize=False, return_features=True)
|
| 159 |
+
|
| 160 |
+
with open_url(detector_url, verbose=False) as f:
|
| 161 |
+
self.detector = torch.jit.load(f).eval().to(device)
|
| 162 |
+
|
| 163 |
+
# Initialize ground truth statistics
|
| 164 |
+
self.mu_real = None
|
| 165 |
+
self.sigma_real = None
|
| 166 |
+
|
| 167 |
+
def to_device(self, device):
|
| 168 |
+
self.device = device
|
| 169 |
+
self.detector = self.detector.to(self.device)
|
| 170 |
+
|
| 171 |
+
def _compute_stats(self, feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 172 |
+
mu = feats.mean(axis=0) # [d]
|
| 173 |
+
sigma = np.cov(feats, rowvar=False) # [d, d]
|
| 174 |
+
return mu, sigma
|
| 175 |
+
|
| 176 |
+
def save_gt_stats(self, save_path: str):
|
| 177 |
+
"""Save ground truth statistics to a file."""
|
| 178 |
+
if self.mu_real is None or self.sigma_real is None:
|
| 179 |
+
raise ValueError("Ground truth statistics not computed yet")
|
| 180 |
+
|
| 181 |
+
stats = {
|
| 182 |
+
'mu_real': self.mu_real,
|
| 183 |
+
'sigma_real': self.sigma_real
|
| 184 |
+
}
|
| 185 |
+
np.savez(save_path, **stats)
|
| 186 |
+
|
| 187 |
+
def load_gt_stats(self, load_path: str):
|
| 188 |
+
"""Load ground truth statistics from a file."""
|
| 189 |
+
stats = np.load(load_path)
|
| 190 |
+
self.mu_real = stats['mu_real']
|
| 191 |
+
self.sigma_real = stats['sigma_real']
|
| 192 |
+
|
| 193 |
+
def preprocess_videos(self, videos, resolution=224, sequence_length=None):
|
| 194 |
+
|
| 195 |
+
b, t, c, h, w = videos.shape
|
| 196 |
+
|
| 197 |
+
# temporal crop
|
| 198 |
+
if sequence_length is not None:
|
| 199 |
+
assert sequence_length <= t
|
| 200 |
+
videos = videos[:, :sequence_length, ::]
|
| 201 |
+
|
| 202 |
+
# b*t x c x h x w
|
| 203 |
+
videos = videos.reshape(-1, c, h, w)
|
| 204 |
+
if c == 1:
|
| 205 |
+
videos = torch.cat([videos, videos, videos], 1)
|
| 206 |
+
c = 3
|
| 207 |
+
|
| 208 |
+
# scale shorter side to resolution
|
| 209 |
+
scale = resolution / min(h, w)
|
| 210 |
+
# import pdb; pdb.set_trace()
|
| 211 |
+
if h < w:
|
| 212 |
+
target_size = (resolution, math.ceil(w * scale))
|
| 213 |
+
else:
|
| 214 |
+
target_size = (math.ceil(h * scale), resolution)
|
| 215 |
+
|
| 216 |
+
videos = F.interpolate(videos, size=target_size).clamp(min=-1, max=1)
|
| 217 |
+
|
| 218 |
+
# center crop
|
| 219 |
+
_, c, h, w = videos.shape
|
| 220 |
+
|
| 221 |
+
h_start = (h - resolution) // 2
|
| 222 |
+
w_start = (w - resolution) // 2
|
| 223 |
+
videos = videos[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
|
| 224 |
+
|
| 225 |
+
# b, c, t, w, h
|
| 226 |
+
videos = videos.reshape(b, t, c, resolution, resolution).permute(0, 2, 1, 3, 4)
|
| 227 |
+
|
| 228 |
+
return videos.contiguous()
|
| 229 |
+
|
| 230 |
+
@torch.no_grad()
|
| 231 |
+
def evaluate(self, video_fake, video_real=None, res=224, use_saved_stats=False, save_stats_path=None):
|
| 232 |
+
"""Evaluate FVD score.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
video_fake: Generated videos
|
| 236 |
+
video_real: Ground truth videos (optional if use_saved_stats=True)
|
| 237 |
+
res: Resolution for preprocessing
|
| 238 |
+
use_saved_stats: Whether to use saved ground truth statistics
|
| 239 |
+
"""
|
| 240 |
+
video_fake = self.preprocess_videos(video_fake, resolution=res)
|
| 241 |
+
feats_fake = self.detector(video_fake, **self.detector_kwargs).cpu().numpy()
|
| 242 |
+
|
| 243 |
+
if use_saved_stats:
|
| 244 |
+
if self.mu_real is None or self.sigma_real is None:
|
| 245 |
+
raise ValueError("Ground truth statistics not loaded. Call load_gt_stats() first.")
|
| 246 |
+
mu_real = self.mu_real
|
| 247 |
+
sigma_real = self.sigma_real
|
| 248 |
+
else:
|
| 249 |
+
if video_real is None:
|
| 250 |
+
raise ValueError("video_real must be provided when use_saved_stats=False")
|
| 251 |
+
video_real = self.preprocess_videos(video_real, resolution=res)
|
| 252 |
+
feats_real = self.detector(video_real, **self.detector_kwargs).cpu().numpy()
|
| 253 |
+
mu_real, sigma_real = self._compute_stats(feats_real)
|
| 254 |
+
# Save the computed statistics
|
| 255 |
+
self.mu_real = mu_real
|
| 256 |
+
self.sigma_real = sigma_real
|
| 257 |
+
if save_stats_path is not None:
|
| 258 |
+
self.save_gt_stats(save_stats_path)
|
| 259 |
+
|
| 260 |
+
mu_gen, sigma_gen = self._compute_stats(feats_fake)
|
| 261 |
+
|
| 262 |
+
m = np.square(mu_gen - mu_real).sum()
|
| 263 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False)
|
| 264 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
| 265 |
+
return fid
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def collect_fvd_stats(data_root, samples=200, downsample_int=1, num_frames=25, save_path=None, action_type=None):
|
| 269 |
+
"""Collect and save ground truth statistics for FVD evaluation."""
|
| 270 |
+
|
| 271 |
+
if save_path is None:
|
| 272 |
+
save_path = os.path.join(data_root, "gt_fvd_stats.npz")
|
| 273 |
+
|
| 274 |
+
# Set up category filtering if specified
|
| 275 |
+
specific_categories = None
|
| 276 |
+
force_clip_type = None
|
| 277 |
+
if action_type is not None:
|
| 278 |
+
if action_type == 0:
|
| 279 |
+
force_clip_type = "normal"
|
| 280 |
+
print("Collecting normal samples only")
|
| 281 |
+
else:
|
| 282 |
+
classes_by_action_type = {
|
| 283 |
+
1: [61, 62, 13, 14, 15, 16, 17, 18],
|
| 284 |
+
2: list(range(1, 12 + 1)),
|
| 285 |
+
3: [37, 39, 41, 42, 44] + list(range(19, 36 + 1)) + list(range(52, 60 + 1)),
|
| 286 |
+
4: [38, 40, 43, 45, 46, 47, 48, 49, 50, 51]
|
| 287 |
+
}
|
| 288 |
+
specific_categories = classes_by_action_type[action_type]
|
| 289 |
+
force_clip_type = "crash"
|
| 290 |
+
print("Collecting crash samples from categories:", specific_categories)
|
| 291 |
+
|
| 292 |
+
# Create dataset and dataloader
|
| 293 |
+
dataset_name = "mmau"
|
| 294 |
+
train_set = True
|
| 295 |
+
val_dataset, val_loader = get_dataloader(data_root, dataset_name,
|
| 296 |
+
if_train=train_set, clip_length=num_frames,
|
| 297 |
+
batch_size=1, num_workers=0, shuffle=True,
|
| 298 |
+
image_height=320, image_width=512,
|
| 299 |
+
non_overlapping_clips=True,
|
| 300 |
+
specific_categories=specific_categories,
|
| 301 |
+
force_clip_type=force_clip_type)
|
| 302 |
+
|
| 303 |
+
# Collect video paths
|
| 304 |
+
gt_videos = []
|
| 305 |
+
for sample in tqdm(val_loader, desc="Collecting samples", total=samples):
|
| 306 |
+
vid_path = os.path.dirname(sample["image_paths"][0][0])
|
| 307 |
+
gt_videos.append(vid_path)
|
| 308 |
+
if len(gt_videos) >= samples:
|
| 309 |
+
break
|
| 310 |
+
|
| 311 |
+
random.shuffle(gt_videos)
|
| 312 |
+
|
| 313 |
+
num_found_samples = len(gt_videos)
|
| 314 |
+
print(f"Found {num_found_samples} ground truth video directories")
|
| 315 |
+
|
| 316 |
+
# Initialize array for all videos
|
| 317 |
+
all_videos = torch.zeros((num_found_samples, num_frames, 3, 320, 512), device="cuda")
|
| 318 |
+
|
| 319 |
+
# Load and process videos
|
| 320 |
+
valid = 0
|
| 321 |
+
for idx, video_path in tqdm(enumerate(gt_videos), desc="Processing videos", total=num_found_samples):
|
| 322 |
+
if valid == num_found_samples:
|
| 323 |
+
break
|
| 324 |
+
|
| 325 |
+
# Get list of jpg files in directory
|
| 326 |
+
frame_files = sorted([f for f in os.listdir(video_path) if f.endswith('.jpg')])
|
| 327 |
+
|
| 328 |
+
if len(frame_files) < num_frames:
|
| 329 |
+
print(f"Skipping {video_path.split('/')[-1]}, insufficient frames: {len(frame_files)}")
|
| 330 |
+
continue
|
| 331 |
+
|
| 332 |
+
# Load frames
|
| 333 |
+
frames = []
|
| 334 |
+
for frame_file in frame_files[0:num_frames:downsample_int]:
|
| 335 |
+
frame_path = os.path.join(video_path, frame_file)
|
| 336 |
+
img = cv2.imread(frame_path)
|
| 337 |
+
img = cv2.resize(img, (512, 320))
|
| 338 |
+
frames.append(img)
|
| 339 |
+
|
| 340 |
+
frames = torch.tensor(np.array(frames), device="cuda")
|
| 341 |
+
|
| 342 |
+
# Process frames
|
| 343 |
+
frames = frames.unsqueeze(0).permute(0, 1, 4, 2, 3)
|
| 344 |
+
all_videos[valid] = frames[:, :num_frames, ::]
|
| 345 |
+
valid += 1
|
| 346 |
+
|
| 347 |
+
if valid == 0:
|
| 348 |
+
raise ValueError("No valid videos found")
|
| 349 |
+
|
| 350 |
+
# Convert to torch tensor and normalize
|
| 351 |
+
all_videos = all_videos.float()
|
| 352 |
+
all_videos.div_(255/2.0).sub_(1.0)
|
| 353 |
+
|
| 354 |
+
# Initialize FVD and compute statistics
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
fvd = FVD(device='cuda')
|
| 357 |
+
video_real = fvd.preprocess_videos(all_videos)
|
| 358 |
+
feats_real = fvd.detector(video_real, **fvd.detector_kwargs).cpu().numpy()
|
| 359 |
+
mu_real, sigma_real = fvd._compute_stats(feats_real)
|
| 360 |
+
|
| 361 |
+
# Save statistics
|
| 362 |
+
stats = {
|
| 363 |
+
'mu_real': mu_real,
|
| 364 |
+
'sigma_real': sigma_real,
|
| 365 |
+
'num_videos': valid,
|
| 366 |
+
'num_frames': num_frames,
|
| 367 |
+
'resolution': 320
|
| 368 |
+
}
|
| 369 |
+
np.savez(save_path, **stats)
|
| 370 |
+
print(f"Saved ground truth statistics to {save_path}")
|
| 371 |
+
|
| 372 |
+
# Clean up
|
| 373 |
+
del fvd, all_videos, video_real, feats_real
|
| 374 |
+
torch.cuda.empty_cache()
|
| 375 |
+
|
| 376 |
+
return save_path
|
| 377 |
+
|
| 378 |
+
def evaluate_vids(vid_root, samples=200, downsample_int=1, num_frames=25, gt_stats=None, shuffle=False):
|
| 379 |
+
"""Evaluate FVD score for generated videos using pre-computed ground truth statistics."""
|
| 380 |
+
|
| 381 |
+
# Initialize FVD and load ground truth statistics
|
| 382 |
+
fvd = FVD(device='cuda')
|
| 383 |
+
if gt_stats is not None:
|
| 384 |
+
fvd.load_gt_stats(gt_stats)
|
| 385 |
+
|
| 386 |
+
# Collect generated video paths
|
| 387 |
+
f_gen_vid = []
|
| 388 |
+
gen_videos = os.path.join(vid_root, "gen_videos") if os.path.exists(os.path.join(vid_root, "gen_videos")) else vid_root
|
| 389 |
+
for fname in os.listdir(gen_videos):
|
| 390 |
+
f_gen_vid.append(fname)
|
| 391 |
+
|
| 392 |
+
print(f"Number of generated videos: {len(f_gen_vid)}")
|
| 393 |
+
|
| 394 |
+
if not shuffle:
|
| 395 |
+
f_gen_vid.sort()
|
| 396 |
+
else:
|
| 397 |
+
random.shuffle(f_gen_vid)
|
| 398 |
+
|
| 399 |
+
f_gen_vid = f_gen_vid[:samples]
|
| 400 |
+
|
| 401 |
+
# Initialize array for all videos
|
| 402 |
+
all_gen = np.zeros((samples, num_frames, 3, 320, 512))
|
| 403 |
+
|
| 404 |
+
# Load and process videos
|
| 405 |
+
valid = 0
|
| 406 |
+
for idx, fgen in tqdm(enumerate(f_gen_vid)):
|
| 407 |
+
if valid == samples:
|
| 408 |
+
break
|
| 409 |
+
|
| 410 |
+
gen_vid_path = os.path.join(gen_videos, fgen)
|
| 411 |
+
gen_vid = get_frames_mp4(gen_vid_path, frame_interval=downsample_int)
|
| 412 |
+
|
| 413 |
+
if gen_vid.shape[0] < num_frames:
|
| 414 |
+
print("Skipping, wrong size:", gen_vid.shape[0])
|
| 415 |
+
continue
|
| 416 |
+
|
| 417 |
+
gen_vid = np.expand_dims(gen_vid, 0).transpose(0, 1, 4, 2, 3)
|
| 418 |
+
all_gen[valid] = gen_vid[:, :num_frames, ::]
|
| 419 |
+
valid += 1
|
| 420 |
+
|
| 421 |
+
# Convert to torch tensor and normalize
|
| 422 |
+
all_gen = torch.from_numpy(all_gen).cuda().float()
|
| 423 |
+
all_gen /= 255/2.0
|
| 424 |
+
all_gen -= 1.0
|
| 425 |
+
|
| 426 |
+
# Compute FVD score
|
| 427 |
+
fvd_score = fvd.evaluate(all_gen, video_real=None, use_saved_stats=True)
|
| 428 |
+
del fvd
|
| 429 |
+
|
| 430 |
+
print(f'FVD Score: {fvd_score}')
|
| 431 |
+
|
| 432 |
+
if __name__ == '__main__':
|
| 433 |
+
parser = argparse.ArgumentParser(description='Evaluate FVD score using pre-computed ground truth statistics')
|
| 434 |
+
parser.add_argument('--vid_root', type=str, required=True,
|
| 435 |
+
help='Root directory containing generated videos')
|
| 436 |
+
parser.add_argument('--samples', type=int, default=200,
|
| 437 |
+
help='Number of samples to evaluate (default: 200)')
|
| 438 |
+
parser.add_argument('--num_frames', type=int, default=25,
|
| 439 |
+
help='Number of frames per video (default: 25)')
|
| 440 |
+
parser.add_argument('--downsample_int', type=int, default=1,
|
| 441 |
+
help='Downsample interval for frames (default: 1)')
|
| 442 |
+
parser.add_argument('--gt_stats', type=str, default=None,
|
| 443 |
+
help='Path to ground truth statistics file (optional)')
|
| 444 |
+
parser.add_argument('--shuffle', action='store_true',
|
| 445 |
+
help='Shuffle videos before evaluation')
|
| 446 |
+
parser.add_argument('--collect_stats', action='store_true',
|
| 447 |
+
help='Collect and save ground truth statistics')
|
| 448 |
+
parser.add_argument('--data_root', type=str, required=True,
|
| 449 |
+
help='Root directory for datasets')
|
| 450 |
+
parser.add_argument('--action_type', type=int, default=None,
|
| 451 |
+
help='Action type to filter videos (0: normal, 1-4: crash types)')
|
| 452 |
+
args = parser.parse_args()
|
| 453 |
+
|
| 454 |
+
if args.collect_stats:
|
| 455 |
+
stats_path = collect_fvd_stats(args.data_root, args.samples, args.downsample_int, args.num_frames, args.gt_stats, args.action_type)
|
| 456 |
+
args.gt_stats = stats_path
|
| 457 |
+
|
| 458 |
+
evaluate_vids(args.vid_root, args.samples, args.downsample_int, args.num_frames, args.gt_stats, args.shuffle)
|
src/eval/video_quality_metrics_fvd_pair.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import scipy.linalg
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import math
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
import cv2
|
| 9 |
+
import json
|
| 10 |
+
import argparse
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import lpips
|
| 13 |
+
from skimage.metrics import structural_similarity as ssim
|
| 14 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Copy-pasted from Copy-pasted from https://github.com/NVlabs/stylegan2-ada-pytorch
|
| 18 |
+
"""
|
| 19 |
+
import ctypes
|
| 20 |
+
import fnmatch
|
| 21 |
+
import importlib
|
| 22 |
+
import inspect
|
| 23 |
+
import numpy as np
|
| 24 |
+
import os
|
| 25 |
+
import shutil
|
| 26 |
+
import sys
|
| 27 |
+
import types
|
| 28 |
+
import io
|
| 29 |
+
import pickle
|
| 30 |
+
import re
|
| 31 |
+
import requests
|
| 32 |
+
import html
|
| 33 |
+
import hashlib
|
| 34 |
+
import glob
|
| 35 |
+
import tempfile
|
| 36 |
+
import urllib
|
| 37 |
+
import urllib.request
|
| 38 |
+
import uuid
|
| 39 |
+
from tqdm import tqdm
|
| 40 |
+
|
| 41 |
+
from distutils.util import strtobool
|
| 42 |
+
from typing import Any, List, Tuple, Union, Dict
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any:
|
| 46 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 47 |
+
assert num_attempts >= 1
|
| 48 |
+
|
| 49 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 50 |
+
if not re.match('^[a-z]+://', url):
|
| 51 |
+
return url if return_filename else open(url, "rb")
|
| 52 |
+
|
| 53 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 54 |
+
# arise on Windows:
|
| 55 |
+
#
|
| 56 |
+
# file:///c:/foo.txt
|
| 57 |
+
#
|
| 58 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 59 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 60 |
+
#
|
| 61 |
+
# If you touch this code path, you should test it on both Linux and
|
| 62 |
+
# Windows.
|
| 63 |
+
#
|
| 64 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
| 65 |
+
# but that converts forward slashes to backslashes and this causes
|
| 66 |
+
# its own set of problems.
|
| 67 |
+
if url.startswith('file://'):
|
| 68 |
+
filename = urllib.parse.urlparse(url).path
|
| 69 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
| 70 |
+
filename = filename[1:]
|
| 71 |
+
return filename if return_filename else open(filename, "rb")
|
| 72 |
+
|
| 73 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 74 |
+
|
| 75 |
+
# Download.
|
| 76 |
+
url_name = None
|
| 77 |
+
url_data = None
|
| 78 |
+
with requests.Session() as session:
|
| 79 |
+
if verbose:
|
| 80 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 81 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 82 |
+
try:
|
| 83 |
+
with session.get(url) as res:
|
| 84 |
+
res.raise_for_status()
|
| 85 |
+
if len(res.content) == 0:
|
| 86 |
+
raise IOError("No data received")
|
| 87 |
+
|
| 88 |
+
if len(res.content) < 8192:
|
| 89 |
+
content_str = res.content.decode("utf-8")
|
| 90 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 91 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
| 92 |
+
if len(links) == 1:
|
| 93 |
+
url = requests.compat.urljoin(url, links[0])
|
| 94 |
+
raise IOError("Google Drive virus checker nag")
|
| 95 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 96 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
| 97 |
+
|
| 98 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
| 99 |
+
url_name = match[1] if match else url
|
| 100 |
+
url_data = res.content
|
| 101 |
+
if verbose:
|
| 102 |
+
print(" done")
|
| 103 |
+
break
|
| 104 |
+
except KeyboardInterrupt:
|
| 105 |
+
raise
|
| 106 |
+
except:
|
| 107 |
+
if not attempts_left:
|
| 108 |
+
if verbose:
|
| 109 |
+
print(" failed")
|
| 110 |
+
raise
|
| 111 |
+
if verbose:
|
| 112 |
+
print(".", end="", flush=True)
|
| 113 |
+
|
| 114 |
+
# Return data as file object.
|
| 115 |
+
assert not return_filename
|
| 116 |
+
return io.BytesIO(url_data)
|
| 117 |
+
|
| 118 |
+
"""
|
| 119 |
+
Modified from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
|
| 120 |
+
"""
|
| 121 |
+
class FVD:
|
| 122 |
+
def __init__(self, device,
|
| 123 |
+
detector_url='https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1',
|
| 124 |
+
rescale=False, resize=False, return_features=True):
|
| 125 |
+
|
| 126 |
+
self.device = device
|
| 127 |
+
self.detector_kwargs = dict(rescale=False, resize=False, return_features=True)
|
| 128 |
+
|
| 129 |
+
with open_url(detector_url, verbose=False) as f:
|
| 130 |
+
self.detector = torch.jit.load(f).eval().to(device)
|
| 131 |
+
|
| 132 |
+
def to_device(self, device):
|
| 133 |
+
self.device = device
|
| 134 |
+
self.detector = self.detector.to(self.device)
|
| 135 |
+
|
| 136 |
+
def _compute_stats(self, feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 137 |
+
mu = feats.mean(axis=0) # [d]
|
| 138 |
+
sigma = np.cov(feats, rowvar=False) # [d, d]
|
| 139 |
+
return mu, sigma
|
| 140 |
+
|
| 141 |
+
def preprocess_videos(self, videos, resolution=224, sequence_length=None):
|
| 142 |
+
|
| 143 |
+
b, t, c, h, w = videos.shape
|
| 144 |
+
|
| 145 |
+
# temporal crop
|
| 146 |
+
if sequence_length is not None:
|
| 147 |
+
assert sequence_length <= t
|
| 148 |
+
videos = videos[:, :sequence_length, ::]
|
| 149 |
+
|
| 150 |
+
# b*t x c x h x w
|
| 151 |
+
videos = videos.reshape(-1, c, h, w)
|
| 152 |
+
if c == 1:
|
| 153 |
+
videos = torch.cat([videos, videos, videos], 1)
|
| 154 |
+
c = 3
|
| 155 |
+
|
| 156 |
+
# scale shorter side to resolution
|
| 157 |
+
scale = resolution / min(h, w)
|
| 158 |
+
# import pdb; pdb.set_trace()
|
| 159 |
+
if h < w:
|
| 160 |
+
target_size = (resolution, math.ceil(w * scale))
|
| 161 |
+
else:
|
| 162 |
+
target_size = (math.ceil(h * scale), resolution)
|
| 163 |
+
|
| 164 |
+
videos = F.interpolate(videos, size=target_size).clamp(min=-1, max=1)
|
| 165 |
+
|
| 166 |
+
# center crop
|
| 167 |
+
_, c, h, w = videos.shape
|
| 168 |
+
|
| 169 |
+
h_start = (h - resolution) // 2
|
| 170 |
+
w_start = (w - resolution) // 2
|
| 171 |
+
videos = videos[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
|
| 172 |
+
|
| 173 |
+
# b, c, t, w, h
|
| 174 |
+
videos = videos.reshape(b, t, c, resolution, resolution).permute(0, 2, 1, 3, 4)
|
| 175 |
+
|
| 176 |
+
return videos.contiguous()
|
| 177 |
+
|
| 178 |
+
@torch.no_grad()
|
| 179 |
+
def evaluate(self, video_fake, video_real, res=224):
|
| 180 |
+
|
| 181 |
+
video_fake = self.preprocess_videos(video_fake,resolution=res)
|
| 182 |
+
video_real = self.preprocess_videos(video_real,resolution=res)
|
| 183 |
+
feats_fake = self.detector(video_fake, **self.detector_kwargs).cpu().numpy()
|
| 184 |
+
feats_real = self.detector(video_real, **self.detector_kwargs).cpu().numpy()
|
| 185 |
+
|
| 186 |
+
mu_gen, sigma_gen = self._compute_stats(feats_fake)
|
| 187 |
+
mu_real, sigma_real = self._compute_stats(feats_real)
|
| 188 |
+
|
| 189 |
+
m = np.square(mu_gen - mu_real).sum()
|
| 190 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
| 191 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
| 192 |
+
return fid
|
| 193 |
+
|
| 194 |
+
def evaluate_vids(vid_root, samples=200, downsample=False, num_frames=25):
|
| 195 |
+
"""Evaluate video quality metrics between generated and ground truth videos."""
|
| 196 |
+
|
| 197 |
+
# Collect video paths
|
| 198 |
+
vid_name_to_gt_frames = {}
|
| 199 |
+
gt_videos_refs = os.path.join(vid_root, "gt_frames")
|
| 200 |
+
for fname in os.listdir(gt_videos_refs):
|
| 201 |
+
vid_name = fname.strip("gt_frames_").split(".")[0]
|
| 202 |
+
vid_name_to_gt_frames[vid_name] = fname
|
| 203 |
+
|
| 204 |
+
f_gen_vid = []
|
| 205 |
+
gen_videos = os.path.join(vid_root, "gen_videos")
|
| 206 |
+
for fname in os.listdir(gen_videos):
|
| 207 |
+
f_gen_vid.append(fname)
|
| 208 |
+
vid_name = fname.strip("genvid_").split(".")[0]
|
| 209 |
+
assert vid_name_to_gt_frames.get(vid_name) is not None, f"{fname} has no matching gt frames"
|
| 210 |
+
|
| 211 |
+
print(f"Number of generated videos: {len(f_gen_vid)}")
|
| 212 |
+
|
| 213 |
+
# Initialize arrays for all videos
|
| 214 |
+
all_gt = np.zeros((samples, num_frames, 3, 320, 512))
|
| 215 |
+
all_gen = np.zeros((samples, num_frames, 3, 320, 512))
|
| 216 |
+
|
| 217 |
+
# Load and process videos
|
| 218 |
+
valid = 0
|
| 219 |
+
for idx, fgen in tqdm(enumerate(f_gen_vid), desc="Collecting video frames"):
|
| 220 |
+
if valid == samples:
|
| 221 |
+
break
|
| 222 |
+
|
| 223 |
+
vid_name = fgen.strip("genvid_").split(".")[0]
|
| 224 |
+
fgt = vid_name_to_gt_frames[vid_name]
|
| 225 |
+
|
| 226 |
+
gen_vid_path = os.path.join(gen_videos, fgen)
|
| 227 |
+
gen_vid = get_frames_mp4(gen_vid_path)
|
| 228 |
+
|
| 229 |
+
with open(os.path.join(gt_videos_refs, fgt)) as gt_json:
|
| 230 |
+
gt_vid = get_frames_from_path_list(json.load(gt_json))
|
| 231 |
+
|
| 232 |
+
if gt_vid.shape[0] < num_frames or gen_vid.shape[0] < num_frames:
|
| 233 |
+
print("Skipping, wrong size:", gt_vid.shape[0], gen_vid.shape[0])
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
gt_vid = np.expand_dims(gt_vid, 0).transpose(0, 1, 4, 2, 3)
|
| 237 |
+
gen_vid = np.expand_dims(gen_vid, 0).transpose(0, 1, 4, 2, 3)
|
| 238 |
+
|
| 239 |
+
all_gt[valid] = gt_vid[:, :num_frames, ::]
|
| 240 |
+
all_gen[valid] = gen_vid[:, :num_frames, ::]
|
| 241 |
+
valid += 1
|
| 242 |
+
|
| 243 |
+
# Convert to torch tensors and normalize
|
| 244 |
+
all_gt = torch.from_numpy(all_gt).cuda().float()
|
| 245 |
+
all_gt /= 255/2.0
|
| 246 |
+
all_gt -= 1.0
|
| 247 |
+
all_gen = torch.from_numpy(all_gen).cuda().float()
|
| 248 |
+
all_gen /= 255/2.0
|
| 249 |
+
all_gen -= 1.0
|
| 250 |
+
|
| 251 |
+
# Compute FVD score
|
| 252 |
+
fvd = FVD(device='cuda')
|
| 253 |
+
fvd_score = fvd.evaluate(all_gt, all_gen)
|
| 254 |
+
del fvd
|
| 255 |
+
|
| 256 |
+
# Compute LPIPS score
|
| 257 |
+
loss_fn_alex = lpips.LPIPS(net='alex').cuda()
|
| 258 |
+
lpips_score = 0
|
| 259 |
+
for idx in range(all_gen.shape[0]):
|
| 260 |
+
lpips_score += loss_fn_alex(all_gt[idx], all_gen[idx])/all_gen.shape[0]
|
| 261 |
+
lpips_score = lpips_score.mean().item()
|
| 262 |
+
del loss_fn_alex
|
| 263 |
+
|
| 264 |
+
# Compute SSIM and PSNR scores
|
| 265 |
+
all_gen = all_gen.detach().cpu().numpy()
|
| 266 |
+
all_gt = all_gt.detach().cpu().numpy()
|
| 267 |
+
|
| 268 |
+
ssim_score_vid = np.zeros(samples)
|
| 269 |
+
ssim_score_image = np.zeros((samples, num_frames))
|
| 270 |
+
psnr_score_vid = np.zeros(samples)
|
| 271 |
+
psnr_score_image = np.zeros((samples, num_frames))
|
| 272 |
+
psnr_score_all = psnr(all_gt, all_gen)
|
| 273 |
+
|
| 274 |
+
for vid_idx in tqdm(range(all_gen.shape[0]), desc="Computing SSIM and PSNR"):
|
| 275 |
+
for f_idx in range(all_gen.shape[1]):
|
| 276 |
+
img_gt = all_gt[vid_idx, f_idx]
|
| 277 |
+
img_gen = all_gen[vid_idx, f_idx]
|
| 278 |
+
data_range = max(img_gt.max(), img_gen.max()) - min(img_gt.min(), img_gen.min())
|
| 279 |
+
ssim_score_image[vid_idx, f_idx] = ssim(img_gt, img_gen, channel_axis=0, data_range=data_range, gaussian_weights=True, sigma=1.5)
|
| 280 |
+
psnr_score_image[vid_idx, f_idx] = psnr(img_gt, img_gen, data_range=data_range)
|
| 281 |
+
|
| 282 |
+
vid_gt = all_gt[vid_idx]
|
| 283 |
+
vid_gen = all_gen[vid_idx]
|
| 284 |
+
data_range = max(vid_gt.max(), vid_gen.max()) - min(vid_gt.min(), vid_gen.min())
|
| 285 |
+
ssim_score_vid[vid_idx] = ssim(vid_gt, vid_gen, channel_axis=1, data_range=data_range, gaussian_weights=True, sigma=1.5)
|
| 286 |
+
psnr_score_vid[vid_idx] = psnr(vid_gt, vid_gen, data_range=data_range)
|
| 287 |
+
|
| 288 |
+
ssim_score_image_error = np.sqrt(((ssim_score_image - ssim_score_image.mean())**2).sum()/200)
|
| 289 |
+
psnr_score_image_error = np.sqrt(((psnr_score_image - psnr_score_image.mean())**2).sum()/200)
|
| 290 |
+
|
| 291 |
+
# Print results
|
| 292 |
+
print(f'FVD Score: {fvd_score}')
|
| 293 |
+
print(f'LPIPS Score: {lpips_score}')
|
| 294 |
+
print(f'SSIM Score (per image): {ssim_score_image.mean()}')
|
| 295 |
+
print(f'SSIM Score Error: {ssim_score_image_error}')
|
| 296 |
+
print(f'PSNR Score (per image): {psnr_score_image.mean()}')
|
| 297 |
+
print(f'PSNR Score Error: {psnr_score_image_error}')
|
| 298 |
+
|
| 299 |
+
# Print copy-friendly format
|
| 300 |
+
print("\nCopy friendly format:")
|
| 301 |
+
print(f"{fvd_score}, {lpips_score}, {ssim_score_image.mean()}, {psnr_score_image.mean()}")
|
| 302 |
+
|
| 303 |
+
if __name__ == '__main__':
|
| 304 |
+
parser = argparse.ArgumentParser(description='Evaluate video quality metrics between generated and ground truth videos')
|
| 305 |
+
parser.add_argument('--vid_root', type=str, required=True,
|
| 306 |
+
help='Root directory containing generated and ground truth videos')
|
| 307 |
+
parser.add_argument('--samples', type=int, default=200,
|
| 308 |
+
help='Number of samples to evaluate (default: 200)')
|
| 309 |
+
parser.add_argument('--num_frames', type=int, default=25,
|
| 310 |
+
help='Number of frames per video (default: 25)')
|
| 311 |
+
parser.add_argument('--downsample', action='store_true',
|
| 312 |
+
help='Downsample videos during evaluation')
|
| 313 |
+
args = parser.parse_args()
|
| 314 |
+
|
| 315 |
+
evaluate_vids(args.vid_root, args.samples, args.downsample, args.num_frames)
|
| 316 |
+
|
| 317 |
+
def get_frames_from_path_list(path_list):
|
| 318 |
+
frames = []
|
| 319 |
+
for path in path_list:
|
| 320 |
+
img = cv2.imread(path)
|
| 321 |
+
img = cv2.resize(img, [512, 320])
|
| 322 |
+
frames.append(img)
|
| 323 |
+
return np.array(frames)
|
| 324 |
+
|
| 325 |
+
def get_frames_mp4(video_path: str, frame_interval: int = 1) -> None:
|
| 326 |
+
|
| 327 |
+
# Open the video file
|
| 328 |
+
cap = cv2.VideoCapture(video_path)
|
| 329 |
+
if not cap.isOpened():
|
| 330 |
+
raise ValueError(f"Could not open video file: {video_path}")
|
| 331 |
+
|
| 332 |
+
frame_count = 0
|
| 333 |
+
saved_count = 0
|
| 334 |
+
|
| 335 |
+
frames = []
|
| 336 |
+
while True:
|
| 337 |
+
ret, frame = cap.read()
|
| 338 |
+
if not ret:
|
| 339 |
+
break
|
| 340 |
+
|
| 341 |
+
# Save frame if it's the right interval
|
| 342 |
+
if frame_count % frame_interval == 0:
|
| 343 |
+
frames.append(frame)
|
| 344 |
+
saved_count += 1
|
| 345 |
+
|
| 346 |
+
frame_count += 1
|
| 347 |
+
|
| 348 |
+
cap.release()
|
| 349 |
+
return np.array(frames)
|
src/eval/video_quality_metrics_jedi_gt_rand.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import cv2
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import argparse
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
from videojedi import JEDiMetric
|
| 12 |
+
from .video_dataset import VideoDataset
|
| 13 |
+
from src.datasets.dataset_utils import get_dataloader
|
| 14 |
+
|
| 15 |
+
def custom_collate(batch):
|
| 16 |
+
"""Custom collate function for DataLoader to handle video clips."""
|
| 17 |
+
videos, targets = [], []
|
| 18 |
+
for sample in batch:
|
| 19 |
+
clips = sample["clips"]
|
| 20 |
+
videos.append(clips)
|
| 21 |
+
return torch.utils.data.dataloader.default_collate(videos), targets
|
| 22 |
+
|
| 23 |
+
def evaluate_vids(vid_root, samples=200, downsample_int=1, num_frames=25, gt_samples=500, test_feature_path=None, action_type=None, shuffle=False):
|
| 24 |
+
"""Evaluate JEDi metric between generated and ground truth videos."""
|
| 25 |
+
|
| 26 |
+
# Initialize JEDi metric
|
| 27 |
+
jedi = JEDiMetric(feature_path=vid_root,
|
| 28 |
+
test_feature_path=test_feature_path,
|
| 29 |
+
model_dir="/path/to/Models")
|
| 30 |
+
|
| 31 |
+
# Create dataset and dataloader for generated videos
|
| 32 |
+
gen_dataset = VideoDataset(vid_root, num_frames=num_frames, downsample_int=downsample_int)
|
| 33 |
+
gen_loader = DataLoader(gen_dataset, batch_size=1, shuffle=shuffle, num_workers=4)
|
| 34 |
+
|
| 35 |
+
# Set up category filtering if specified
|
| 36 |
+
specific_categories = None
|
| 37 |
+
force_clip_type = None
|
| 38 |
+
if action_type is not None:
|
| 39 |
+
if action_type == 0:
|
| 40 |
+
force_clip_type = "normal"
|
| 41 |
+
print("Collecting normal samples only")
|
| 42 |
+
else:
|
| 43 |
+
classes_by_action_type = {
|
| 44 |
+
1: [61, 62, 13, 14, 15, 16, 17, 18],
|
| 45 |
+
2: list(range(1, 12 + 1)),
|
| 46 |
+
3: [37, 39, 41, 42, 44] + list(range(19, 36 + 1)) + list(range(52, 60 + 1)),
|
| 47 |
+
4: [38, 40, 43, 45, 46, 47, 48, 49, 50, 51]
|
| 48 |
+
}
|
| 49 |
+
specific_categories = classes_by_action_type[action_type]
|
| 50 |
+
force_clip_type = "crash"
|
| 51 |
+
print("Collecting crash samples from categories:", specific_categories)
|
| 52 |
+
|
| 53 |
+
# Create dataset and dataloader for ground truth videos
|
| 54 |
+
dataset_name = "mmau"
|
| 55 |
+
train_set = True
|
| 56 |
+
val_dataset, _ = get_dataloader("path/to/Datasets", dataset_name,
|
| 57 |
+
if_train=train_set, clip_length=num_frames,
|
| 58 |
+
batch_size=1, num_workers=0, shuffle=True,
|
| 59 |
+
image_height=320, image_width=512,
|
| 60 |
+
non_overlapping_clips=True,
|
| 61 |
+
specific_categories=specific_categories,
|
| 62 |
+
force_clip_type=force_clip_type)
|
| 63 |
+
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, collate_fn=custom_collate)
|
| 64 |
+
|
| 65 |
+
# Compute JEDi metric
|
| 66 |
+
jedi.load_features(train_loader=gen_loader, test_loader=val_loader,
|
| 67 |
+
num_samples=samples, num_test_samples=gt_samples)
|
| 68 |
+
jedi_metric = jedi.compute_metric()
|
| 69 |
+
print(f"JEDi Metric: {jedi_metric}")
|
| 70 |
+
|
| 71 |
+
if __name__ == '__main__':
|
| 72 |
+
parser = argparse.ArgumentParser(description='Evaluate JEDi metric between generated and ground truth videos')
|
| 73 |
+
parser.add_argument('--vid_root', type=str, required=True,
|
| 74 |
+
help='Root directory containing generated videos')
|
| 75 |
+
parser.add_argument('--samples', type=int, default=200,
|
| 76 |
+
help='Number of samples to evaluate (default: 200)')
|
| 77 |
+
parser.add_argument('--gt_samples', type=int, default=500,
|
| 78 |
+
help='Number of ground truth samples to use (default: 500)')
|
| 79 |
+
parser.add_argument('--num_frames', type=int, default=25,
|
| 80 |
+
help='Number of frames per video (default: 25)')
|
| 81 |
+
parser.add_argument('--downsample_int', type=int, default=1,
|
| 82 |
+
help='Downsample interval for frames (default: 1)')
|
| 83 |
+
parser.add_argument('--test_feature_path', type=str, default=None,
|
| 84 |
+
help='Path to test features (optional)')
|
| 85 |
+
parser.add_argument('--action_type', type=int, default=None,
|
| 86 |
+
help='Action type to filter videos (0: normal, 1-4: crash types)')
|
| 87 |
+
parser.add_argument('--shuffle', action='store_true',
|
| 88 |
+
help='Shuffle videos before evaluation')
|
| 89 |
+
args = parser.parse_args()
|
| 90 |
+
|
| 91 |
+
evaluate_vids(args.vid_root, args.samples, args.downsample_int, args.num_frames, args.gt_samples, args.test_feature_path, args.action_type, args.shuffle)
|
src/eval/video_quality_metrics_jedi_pair.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
from videojedi import JEDiMetric
|
| 7 |
+
from .video_dataset import VideoDataset
|
| 8 |
+
from src.datasets.dataset_utils import get_dataloader
|
| 9 |
+
|
| 10 |
+
def custom_collate(batch):
|
| 11 |
+
videos, targets = [], []
|
| 12 |
+
for sample in batch:
|
| 13 |
+
clips = sample["clips"]
|
| 14 |
+
videos.append(clips)
|
| 15 |
+
return torch.utils.data.dataloader.default_collate(videos), targets
|
| 16 |
+
|
| 17 |
+
def evaluate_vids(vid_root, samples=200, downsample_int=1, num_frames=25, gt_samples=500, test_feature_path=None, action_type=None):
|
| 18 |
+
"""Evaluate JEDi metric between generated and ground truth videos."""
|
| 19 |
+
|
| 20 |
+
# Initialize JEDi metric
|
| 21 |
+
jedi = JEDiMetric(feature_path=vid_root,
|
| 22 |
+
test_feature_path=test_feature_path,
|
| 23 |
+
model_dir="/path/to/Models")
|
| 24 |
+
|
| 25 |
+
# Create dataset and dataloader for generated videos
|
| 26 |
+
gen_dataset = VideoDataset(vid_root, num_frames=num_frames, downsample_int=downsample_int)
|
| 27 |
+
gen_loader = DataLoader(gen_dataset, batch_size=1, shuffle=False, num_workers=4)
|
| 28 |
+
|
| 29 |
+
# Set up category filtering if specified
|
| 30 |
+
specific_categories = None
|
| 31 |
+
force_clip_type = None
|
| 32 |
+
if action_type is not None:
|
| 33 |
+
if action_type == 0:
|
| 34 |
+
force_clip_type = "normal"
|
| 35 |
+
print("Collecting normal samples only")
|
| 36 |
+
else:
|
| 37 |
+
classes_by_action_type = {
|
| 38 |
+
1: [61, 62, 13, 14, 15, 16, 17, 18],
|
| 39 |
+
2: list(range(1, 12 + 1)),
|
| 40 |
+
3: [37, 39, 41, 42, 44] + list(range(19, 36 + 1)) + list(range(52, 60 + 1)),
|
| 41 |
+
4: [38, 40, 43, 45, 46, 47, 48, 49, 50, 51]
|
| 42 |
+
}
|
| 43 |
+
specific_categories = classes_by_action_type[action_type]
|
| 44 |
+
force_clip_type = "crash"
|
| 45 |
+
print("Collecting crash samples from categories:", specific_categories)
|
| 46 |
+
|
| 47 |
+
# Get specific samples to evaluate
|
| 48 |
+
specific_samples = []
|
| 49 |
+
gen_videos = os.path.join(vid_root, "gen_videos") if os.path.exists(f"{vid_root}/gen_videos") else vid_root
|
| 50 |
+
for fname in os.listdir(gen_videos):
|
| 51 |
+
vid_name = fname.strip("genvid_").split(".")[0]
|
| 52 |
+
gt_vid_name = "_".join(vid_name.split("_")[1:])
|
| 53 |
+
specific_samples.append(gt_vid_name)
|
| 54 |
+
|
| 55 |
+
# Create dataset and dataloader for ground truth videos
|
| 56 |
+
dataset_name = "mmau"
|
| 57 |
+
train_set = False
|
| 58 |
+
val_dataset, _ = get_dataloader("/path/to/Datasets", dataset_name,
|
| 59 |
+
if_train=train_set, clip_length=num_frames,
|
| 60 |
+
batch_size=1, num_workers=0, shuffle=True,
|
| 61 |
+
image_height=320, image_width=512,
|
| 62 |
+
non_overlapping_clips=True,
|
| 63 |
+
specific_categories=specific_categories,
|
| 64 |
+
force_clip_type=force_clip_type,
|
| 65 |
+
specific_samples=specific_samples)
|
| 66 |
+
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, collate_fn=custom_collate)
|
| 67 |
+
|
| 68 |
+
# Compute JEDi metric
|
| 69 |
+
jedi.load_features(train_loader=gen_loader, test_loader=val_loader,
|
| 70 |
+
num_samples=samples, num_test_samples=gt_samples)
|
| 71 |
+
jedi_metric = jedi.compute_metric()
|
| 72 |
+
print(f"JEDi Metric: {jedi_metric}")
|
| 73 |
+
|
| 74 |
+
if __name__ == '__main__':
|
| 75 |
+
parser = argparse.ArgumentParser(description='Evaluate JEDi metric between generated and ground truth videos')
|
| 76 |
+
parser.add_argument('--vid_root', type=str, required=True,
|
| 77 |
+
help='Root directory containing generated videos')
|
| 78 |
+
parser.add_argument('--samples', type=int, default=200,
|
| 79 |
+
help='Number of samples to evaluate (default: 200)')
|
| 80 |
+
parser.add_argument('--gt_samples', type=int, default=500,
|
| 81 |
+
help='Number of ground truth samples to use (default: 500)')
|
| 82 |
+
parser.add_argument('--num_frames', type=int, default=25,
|
| 83 |
+
help='Number of frames per video (default: 25)')
|
| 84 |
+
parser.add_argument('--downsample_int', type=int, default=1,
|
| 85 |
+
help='Downsample interval for frames (default: 1)')
|
| 86 |
+
parser.add_argument('--test_feature_path', type=str, default=None,
|
| 87 |
+
help='Path to test features (optional)')
|
| 88 |
+
parser.add_argument('--action_type', type=int, default=None,
|
| 89 |
+
help='Action type to filter videos (0: normal, 1-4: crash types)')
|
| 90 |
+
args = parser.parse_args()
|
| 91 |
+
|
| 92 |
+
evaluate_vids(args.vid_root, args.samples, args.downsample_int, args.num_frames, args.gt_samples, args.test_feature_path, args.action_type)
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.models.controlnet import ControlNetModel
|
| 2 |
+
from src.models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
src/models/controlnet.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 8 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 9 |
+
from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block
|
| 10 |
+
from diffusers.loaders import FromOriginalControlNetMixin
|
| 11 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 12 |
+
from diffusers.utils import logging
|
| 13 |
+
from diffusers.models import ControlNetModel as ControlNetModel_original
|
| 14 |
+
from diffusers.models.controlnet import ControlNetOutput, zero_module
|
| 15 |
+
|
| 16 |
+
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
| 17 |
+
|
| 18 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 19 |
+
|
| 20 |
+
class ControlNetModel(ControlNetModel_original): # (ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
| 21 |
+
|
| 22 |
+
r"""
|
| 23 |
+
A controlnet for conditional Spatio-Temporal UNet model.
|
| 24 |
+
|
| 25 |
+
Parameters:
|
| 26 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
| 27 |
+
Height and width of input/output sample.
|
| 28 |
+
in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
|
| 29 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
| 30 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
|
| 31 |
+
The tuple of downsample blocks to use.
|
| 32 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
| 33 |
+
The tuple of output channels for each block.
|
| 34 |
+
addition_time_embed_dim: (`int`, defaults to 256):
|
| 35 |
+
Dimension to to encode the additional time ids.
|
| 36 |
+
projection_class_embeddings_input_dim (`int`, defaults to 768):
|
| 37 |
+
The dimension of the projection of encoded `added_time_ids`.
|
| 38 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
| 39 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
| 40 |
+
The dimension of the cross attention features.
|
| 41 |
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
| 42 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
| 43 |
+
[`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
|
| 44 |
+
[`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
|
| 45 |
+
num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
|
| 46 |
+
The number of attention heads.
|
| 47 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 48 |
+
action_dim: (`int`, defaults to 256):
|
| 49 |
+
Dimension of the action features.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
_supports_gradient_checkpointing = True
|
| 53 |
+
|
| 54 |
+
@register_to_config
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
sample_size: Optional[int] = None,
|
| 58 |
+
in_channels: int = 8,
|
| 59 |
+
down_block_types: Tuple[str] = (
|
| 60 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 61 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 62 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 63 |
+
"DownBlockSpatioTemporal",
|
| 64 |
+
),
|
| 65 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| 66 |
+
addition_time_embed_dim: int = 256,
|
| 67 |
+
projection_class_embeddings_input_dim: int = 768,
|
| 68 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
| 69 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
| 70 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
| 71 |
+
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
|
| 72 |
+
num_frames: int = 25,
|
| 73 |
+
action_dim: int = 5, # Dimension of the action features
|
| 74 |
+
bbox_embedding_shape: Tuple[int] = (4, 128, 128),
|
| 75 |
+
):
|
| 76 |
+
# calling the super class constructors without calling ControlNetModel_original's
|
| 77 |
+
ModelMixin.__init__(self)
|
| 78 |
+
ConfigMixin.__init__(self)
|
| 79 |
+
FromOriginalControlNetMixin.__init__(self)
|
| 80 |
+
|
| 81 |
+
self.sample_size = sample_size
|
| 82 |
+
|
| 83 |
+
# Check inputs
|
| 84 |
+
if len(block_out_channels) != len(down_block_types):
|
| 85 |
+
raise ValueError(
|
| 86 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
| 90 |
+
raise ValueError(
|
| 91 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
| 100 |
+
raise ValueError(
|
| 101 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# input
|
| 105 |
+
self.conv_in = nn.Conv2d(
|
| 106 |
+
in_channels,
|
| 107 |
+
block_out_channels[0],
|
| 108 |
+
kernel_size=3,
|
| 109 |
+
padding=1,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# time
|
| 113 |
+
time_embed_dim = block_out_channels[0] * 4
|
| 114 |
+
|
| 115 |
+
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
|
| 116 |
+
timestep_input_dim = block_out_channels[0]
|
| 117 |
+
|
| 118 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 119 |
+
|
| 120 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
|
| 121 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 122 |
+
|
| 123 |
+
# Action projection layer
|
| 124 |
+
hidden_action_dim = 256
|
| 125 |
+
self.action_embedding = nn.Embedding(action_dim, hidden_action_dim)
|
| 126 |
+
self.action_proj = nn.Linear(hidden_action_dim, cross_attention_dim)
|
| 127 |
+
|
| 128 |
+
# Learnable null embedding for bbox masking
|
| 129 |
+
self.bbox_null_embedding = nn.Parameter(torch.randn(bbox_embedding_shape))
|
| 130 |
+
|
| 131 |
+
self.down_blocks = nn.ModuleList([])
|
| 132 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
| 133 |
+
|
| 134 |
+
if isinstance(num_attention_heads, int):
|
| 135 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
| 136 |
+
|
| 137 |
+
if isinstance(cross_attention_dim, int):
|
| 138 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
| 139 |
+
|
| 140 |
+
if isinstance(layers_per_block, int):
|
| 141 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
| 142 |
+
|
| 143 |
+
if isinstance(transformer_layers_per_block, int):
|
| 144 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
| 145 |
+
|
| 146 |
+
blocks_time_embed_dim = time_embed_dim
|
| 147 |
+
|
| 148 |
+
self.control_conv_in = nn.Conv2d(
|
| 149 |
+
in_channels//2,
|
| 150 |
+
block_out_channels[0],
|
| 151 |
+
kernel_size=3,
|
| 152 |
+
padding=1,
|
| 153 |
+
)
|
| 154 |
+
# # Initialize the re-zero parameter
|
| 155 |
+
# self.rz_weight = nn.Parameter(torch.Tensor([0]))
|
| 156 |
+
|
| 157 |
+
# down
|
| 158 |
+
output_channel = block_out_channels[0]
|
| 159 |
+
|
| 160 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 161 |
+
controlnet_block = zero_module(controlnet_block)
|
| 162 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
| 163 |
+
|
| 164 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 165 |
+
input_channel = output_channel
|
| 166 |
+
output_channel = block_out_channels[i]
|
| 167 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 168 |
+
|
| 169 |
+
down_block = get_down_block(
|
| 170 |
+
down_block_type,
|
| 171 |
+
num_layers=layers_per_block[i],
|
| 172 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
| 173 |
+
in_channels=input_channel,
|
| 174 |
+
out_channels=output_channel,
|
| 175 |
+
temb_channels=blocks_time_embed_dim,
|
| 176 |
+
add_downsample=not is_final_block,
|
| 177 |
+
resnet_eps=1e-5,
|
| 178 |
+
cross_attention_dim=cross_attention_dim[i],
|
| 179 |
+
num_attention_heads=num_attention_heads[i],
|
| 180 |
+
resnet_act_fn="silu",
|
| 181 |
+
)
|
| 182 |
+
self.down_blocks.append(down_block)
|
| 183 |
+
|
| 184 |
+
for _ in range(layers_per_block[i]):
|
| 185 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 186 |
+
controlnet_block = zero_module(controlnet_block)
|
| 187 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
| 188 |
+
|
| 189 |
+
if not is_final_block:
|
| 190 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 191 |
+
controlnet_block = zero_module(controlnet_block)
|
| 192 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
| 193 |
+
|
| 194 |
+
# mid
|
| 195 |
+
controlnet_block = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], kernel_size=1)
|
| 196 |
+
controlnet_block = zero_module(controlnet_block)
|
| 197 |
+
self.controlnet_mid_block = controlnet_block
|
| 198 |
+
self.mid_block = UNetMidBlockSpatioTemporal(
|
| 199 |
+
block_out_channels[-1],
|
| 200 |
+
temb_channels=blocks_time_embed_dim,
|
| 201 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
| 202 |
+
cross_attention_dim=cross_attention_dim[-1],
|
| 203 |
+
num_attention_heads=num_attention_heads[-1],
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# count how many layers upsample the images
|
| 207 |
+
self.num_upsamplers = 0
|
| 208 |
+
|
| 209 |
+
@classmethod
|
| 210 |
+
def from_unet(cls,
|
| 211 |
+
unet: UNetSpatioTemporalConditionModel,
|
| 212 |
+
load_weights_from_unet: bool = True,
|
| 213 |
+
action_dim: int = 5,
|
| 214 |
+
bbox_embedding_shape: Tuple[int] = (4, 128, 128)):
|
| 215 |
+
|
| 216 |
+
ctrlnet = cls(
|
| 217 |
+
in_channels=unet.config.in_channels,
|
| 218 |
+
down_block_types=unet.config.down_block_types,
|
| 219 |
+
block_out_channels=unet.config.block_out_channels,
|
| 220 |
+
addition_time_embed_dim=unet.config.addition_time_embed_dim,
|
| 221 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
| 222 |
+
layers_per_block=unet.config.layers_per_block,
|
| 223 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
| 224 |
+
transformer_layers_per_block=unet.config.transformer_layers_per_block,
|
| 225 |
+
num_attention_heads=unet.config.num_attention_heads,
|
| 226 |
+
num_frames=unet.config.num_frames,
|
| 227 |
+
action_dim=action_dim,
|
| 228 |
+
bbox_embedding_shape=bbox_embedding_shape,
|
| 229 |
+
)
|
| 230 |
+
unet_keys = set(unet.state_dict().keys())
|
| 231 |
+
ctrl_keys = set(ctrlnet.state_dict().keys())
|
| 232 |
+
intersection_keys = ctrl_keys.intersection(unet_keys)
|
| 233 |
+
for key in ctrl_keys:
|
| 234 |
+
if key in intersection_keys:
|
| 235 |
+
if load_weights_from_unet:
|
| 236 |
+
ctrlnet.state_dict()[key].copy_(unet.state_dict()[key])
|
| 237 |
+
# else:
|
| 238 |
+
# logger.warning(f"Key {key} not found in UNet model, initializing it randomly.")
|
| 239 |
+
|
| 240 |
+
return ctrlnet
|
| 241 |
+
|
| 242 |
+
def forward(
|
| 243 |
+
self,
|
| 244 |
+
sample: torch.FloatTensor,
|
| 245 |
+
timestep: Union[torch.Tensor, float, int],
|
| 246 |
+
encoder_hidden_states: torch.Tensor,
|
| 247 |
+
added_time_ids: torch.Tensor,
|
| 248 |
+
control_cond: torch.FloatTensor = None,
|
| 249 |
+
action_type: torch.LongTensor = None,
|
| 250 |
+
conditioning_scale: float = 1.0,
|
| 251 |
+
return_dict: bool = True,
|
| 252 |
+
) -> Union[ControlNetOutput, Tuple]:
|
| 253 |
+
r"""
|
| 254 |
+
This approach effectively integrates the forward method of the UNetSpatioTemporalConditionModel with the forward
|
| 255 |
+
method of the ControlNetModel
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
sample (`torch.FloatTensor`):
|
| 259 |
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
|
| 260 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
| 261 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
| 262 |
+
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
|
| 263 |
+
added_time_ids: (`torch.FloatTensor`):
|
| 264 |
+
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
|
| 265 |
+
embeddings and added to the time embeddings.
|
| 266 |
+
controlnet_cond (`torch.FloatTensor`):
|
| 267 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
| 268 |
+
action_type (`torch.LongTensor`):
|
| 269 |
+
The action type with shape `(batch_size)`.
|
| 270 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
| 271 |
+
The scale factor for ControlNet outputs.
|
| 272 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 273 |
+
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
|
| 274 |
+
of a plain tuple.
|
| 275 |
+
Returns:
|
| 276 |
+
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
| 277 |
+
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
| 278 |
+
returned where the first element is the sample tensor.
|
| 279 |
+
"""
|
| 280 |
+
# 1. time
|
| 281 |
+
timesteps = timestep
|
| 282 |
+
if len(timesteps.shape) == 0:
|
| 283 |
+
timesteps = timesteps[None].to(sample.device)
|
| 284 |
+
|
| 285 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 286 |
+
batch_size, num_frames = sample.shape[:2]
|
| 287 |
+
timesteps = timesteps.expand(batch_size)
|
| 288 |
+
|
| 289 |
+
t_emb = self.time_proj(timesteps)
|
| 290 |
+
|
| 291 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 292 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 293 |
+
# there might be better ways to encapsulate this.
|
| 294 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 295 |
+
|
| 296 |
+
emb = self.time_embedding(t_emb)
|
| 297 |
+
|
| 298 |
+
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
| 299 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
| 300 |
+
time_embeds = time_embeds.to(emb.dtype)
|
| 301 |
+
aug_emb = self.add_embedding(time_embeds)
|
| 302 |
+
emb = emb + aug_emb
|
| 303 |
+
|
| 304 |
+
# Flatten the batch and frames dimensions
|
| 305 |
+
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
| 306 |
+
sample = sample.flatten(0, 1)
|
| 307 |
+
# control_cond: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
| 308 |
+
control_cond = control_cond.flatten(0, 1)
|
| 309 |
+
# Repeat the embeddings num_video_frames times
|
| 310 |
+
# emb: [batch, channels] -> [batch * frames, channels]
|
| 311 |
+
emb = emb.repeat_interleave(num_frames, dim=0)
|
| 312 |
+
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
| 313 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
|
| 314 |
+
|
| 315 |
+
# Process action features if provided
|
| 316 |
+
if action_type is not None:
|
| 317 |
+
|
| 318 |
+
# Embed action features
|
| 319 |
+
action_type = action_type.to(encoder_hidden_states.device, dtype=torch.long)
|
| 320 |
+
|
| 321 |
+
# Flatten action features to match the batch*frames dimension
|
| 322 |
+
# action_features: [batch, action_dim] -> [batch * frames, action_dim]
|
| 323 |
+
if action_type.dim() == 1:
|
| 324 |
+
action_type = action_type.unsqueeze(0)
|
| 325 |
+
action_type = action_type.repeat_interleave(num_frames, dim=0)
|
| 326 |
+
|
| 327 |
+
# Project action features to match the embedding dimension
|
| 328 |
+
action_features = self.action_embedding(action_type)
|
| 329 |
+
action_emb = self.action_proj(action_features)
|
| 330 |
+
|
| 331 |
+
# Add action embeddings to the encoder_hidden_states
|
| 332 |
+
# Make sure not to add action embeddings to masked hidden states
|
| 333 |
+
is_masked_cond = (encoder_hidden_states == 0).all(dim=2).unsqueeze(-1)
|
| 334 |
+
encoder_hidden_states = torch.where(is_masked_cond, encoder_hidden_states, encoder_hidden_states + action_emb)
|
| 335 |
+
|
| 336 |
+
# 2. pre-process
|
| 337 |
+
sample = self.conv_in(sample)
|
| 338 |
+
control_cond = self.control_conv_in(control_cond)
|
| 339 |
+
sample = sample + control_cond# * self.rz_weight
|
| 340 |
+
|
| 341 |
+
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
|
| 342 |
+
|
| 343 |
+
down_block_res_samples = (sample,)
|
| 344 |
+
for downsample_block in self.down_blocks:
|
| 345 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 346 |
+
sample, res_samples = downsample_block(
|
| 347 |
+
hidden_states=sample,
|
| 348 |
+
temb=emb,
|
| 349 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 350 |
+
image_only_indicator=image_only_indicator,
|
| 351 |
+
)
|
| 352 |
+
else:
|
| 353 |
+
sample, res_samples = downsample_block(
|
| 354 |
+
hidden_states=sample,
|
| 355 |
+
temb=emb,
|
| 356 |
+
image_only_indicator=image_only_indicator,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
down_block_res_samples += res_samples
|
| 360 |
+
|
| 361 |
+
# 4. mid
|
| 362 |
+
sample = self.mid_block(
|
| 363 |
+
hidden_states=sample,
|
| 364 |
+
temb=emb,
|
| 365 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 366 |
+
image_only_indicator=image_only_indicator,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# 5. Control net blocks
|
| 370 |
+
|
| 371 |
+
controlnet_down_block_res_samples = ()
|
| 372 |
+
|
| 373 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
| 374 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
| 375 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
| 376 |
+
|
| 377 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
| 378 |
+
|
| 379 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
| 380 |
+
|
| 381 |
+
# 6. scaling
|
| 382 |
+
|
| 383 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
| 384 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
| 385 |
+
|
| 386 |
+
if not return_dict:
|
| 387 |
+
return (down_block_res_samples, mid_block_res_sample)
|
| 388 |
+
|
| 389 |
+
return ControlNetOutput(
|
| 390 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
| 391 |
+
)
|
src/models/unet_spatio_temporal_condition.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 3 |
+
from diffusers import ModelMixin
|
| 4 |
+
|
| 5 |
+
from diffusers import UNetSpatioTemporalConditionModel as UNetSpatioTemporalConditionModel_orig
|
| 6 |
+
import torch
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 9 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 10 |
+
from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
|
| 11 |
+
|
| 12 |
+
# NOTE: Only added ModelMixin to make it compatible with some older version of diffusers when using from_pretrained
|
| 13 |
+
class UNetSpatioTemporalConditionModel(UNetSpatioTemporalConditionModel_orig, PeftAdapterMixin, ModelMixin):
|
| 14 |
+
|
| 15 |
+
def enable_grad(self, temporal_transformer_block=True, all=False):
|
| 16 |
+
parameters_list = []
|
| 17 |
+
for name, param in self.named_parameters():
|
| 18 |
+
if bool('temporal_transformer_block' in name and temporal_transformer_block) or all:
|
| 19 |
+
parameters_list.append(param)
|
| 20 |
+
param.requires_grad = True
|
| 21 |
+
else:
|
| 22 |
+
param.requires_grad = False
|
| 23 |
+
return parameters_list
|
| 24 |
+
|
| 25 |
+
def get_parameters_with_grad(self):
|
| 26 |
+
return [param for param in self.parameters() if param.requires_grad]
|
| 27 |
+
|
| 28 |
+
def forward(
|
| 29 |
+
self,
|
| 30 |
+
sample: torch.FloatTensor,
|
| 31 |
+
timestep: Union[torch.Tensor, float, int],
|
| 32 |
+
encoder_hidden_states: torch.Tensor,
|
| 33 |
+
added_time_ids: torch.Tensor,
|
| 34 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 35 |
+
mid_block_additional_residuals: Optional[torch.Tensor] = None,
|
| 36 |
+
return_dict: bool = True,
|
| 37 |
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
| 38 |
+
r"""
|
| 39 |
+
The [`UNetSpatioTemporalConditionModel`] forward method.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
sample (`torch.FloatTensor`):
|
| 43 |
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
|
| 44 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
| 45 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
| 46 |
+
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
|
| 47 |
+
added_time_ids: (`torch.FloatTensor`):
|
| 48 |
+
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
|
| 49 |
+
embeddings and added to the time embeddings.
|
| 50 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 51 |
+
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
|
| 52 |
+
of a plain tuple.
|
| 53 |
+
Returns:
|
| 54 |
+
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
|
| 55 |
+
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
|
| 56 |
+
returned, otherwise a `tuple` is returned where the first element is the sample tensor.
|
| 57 |
+
"""
|
| 58 |
+
is_controlnet = mid_block_additional_residuals is not None and down_block_additional_residuals is not None
|
| 59 |
+
|
| 60 |
+
# 1. time
|
| 61 |
+
timesteps = timestep
|
| 62 |
+
if len(timesteps.shape) == 0:
|
| 63 |
+
timesteps = timesteps[None].to(sample.device)
|
| 64 |
+
|
| 65 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 66 |
+
batch_size, num_frames = sample.shape[:2]
|
| 67 |
+
timesteps = timesteps.expand(batch_size)
|
| 68 |
+
|
| 69 |
+
t_emb = self.time_proj(timesteps)
|
| 70 |
+
|
| 71 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 72 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 73 |
+
# there might be better ways to encapsulate this.
|
| 74 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 75 |
+
|
| 76 |
+
emb = self.time_embedding(t_emb)
|
| 77 |
+
|
| 78 |
+
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
| 79 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
| 80 |
+
time_embeds = time_embeds.to(emb.dtype)
|
| 81 |
+
aug_emb = self.add_embedding(time_embeds)
|
| 82 |
+
emb = emb + aug_emb
|
| 83 |
+
|
| 84 |
+
# Flatten the batch and frames dimensions
|
| 85 |
+
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
| 86 |
+
sample = sample.flatten(0, 1)
|
| 87 |
+
# Repeat the embeddings num_video_frames times
|
| 88 |
+
# emb: [batch, channels] -> [batch * frames, channels]
|
| 89 |
+
emb = emb.repeat_interleave(num_frames, dim=0)
|
| 90 |
+
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
| 91 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
|
| 92 |
+
|
| 93 |
+
# 2. pre-process
|
| 94 |
+
sample = self.conv_in(sample)
|
| 95 |
+
|
| 96 |
+
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
|
| 97 |
+
|
| 98 |
+
down_block_res_samples = (sample,)
|
| 99 |
+
for downsample_block in self.down_blocks:
|
| 100 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 101 |
+
sample, res_samples = downsample_block(
|
| 102 |
+
hidden_states=sample,
|
| 103 |
+
temb=emb,
|
| 104 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 105 |
+
image_only_indicator=image_only_indicator,
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
sample, res_samples = downsample_block(
|
| 109 |
+
hidden_states=sample,
|
| 110 |
+
temb=emb,
|
| 111 |
+
image_only_indicator=image_only_indicator,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
down_block_res_samples += res_samples
|
| 115 |
+
|
| 116 |
+
if is_controlnet:
|
| 117 |
+
new_down_block_res_samples = ()
|
| 118 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 119 |
+
down_block_res_samples, down_block_additional_residuals
|
| 120 |
+
):
|
| 121 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 122 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
| 123 |
+
|
| 124 |
+
down_block_res_samples = new_down_block_res_samples
|
| 125 |
+
|
| 126 |
+
# 4. mid
|
| 127 |
+
sample = self.mid_block(
|
| 128 |
+
hidden_states=sample,
|
| 129 |
+
temb=emb,
|
| 130 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 131 |
+
image_only_indicator=image_only_indicator,
|
| 132 |
+
)
|
| 133 |
+
if is_controlnet:
|
| 134 |
+
sample = sample + mid_block_additional_residuals
|
| 135 |
+
|
| 136 |
+
# 5. up
|
| 137 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 138 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 139 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 140 |
+
|
| 141 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 142 |
+
sample = upsample_block(
|
| 143 |
+
hidden_states=sample,
|
| 144 |
+
temb=emb,
|
| 145 |
+
res_hidden_states_tuple=res_samples,
|
| 146 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 147 |
+
image_only_indicator=image_only_indicator,
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
sample = upsample_block(
|
| 151 |
+
hidden_states=sample,
|
| 152 |
+
temb=emb,
|
| 153 |
+
res_hidden_states_tuple=res_samples,
|
| 154 |
+
image_only_indicator=image_only_indicator,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# 6. post-process
|
| 158 |
+
sample = self.conv_norm_out(sample)
|
| 159 |
+
sample = self.conv_act(sample)
|
| 160 |
+
sample = self.conv_out(sample)
|
| 161 |
+
|
| 162 |
+
# 7. Reshape back to original shape
|
| 163 |
+
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
| 164 |
+
|
| 165 |
+
if not return_dict:
|
| 166 |
+
return (sample,)
|
| 167 |
+
|
| 168 |
+
return UNetSpatioTemporalConditionOutput(sample=sample)
|
| 169 |
+
|
src/pipelines/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipeline_video_control import StableVideoControlPipeline
|
| 2 |
+
from .pipeline_video_diffusion import VideoDiffusionPipeline
|
| 3 |
+
from .pipeline_video_control_nullmodel import StableVideoControlNullModelPipeline
|
| 4 |
+
from .pipeline_video_control_factor_guidance import StableVideoControlFactorGuidancePipeline
|
src/pipelines/pipeline_video_control.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 3 |
+
import PIL.Image
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
| 8 |
+
tensor2vid,
|
| 9 |
+
StableVideoDiffusionPipelineOutput,
|
| 10 |
+
_append_dims,
|
| 11 |
+
EXAMPLE_DOC_STRING
|
| 12 |
+
)
|
| 13 |
+
from diffusers import StableVideoDiffusionPipeline as StableVideoDiffusionPipeline_original
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.utils import logging, replace_example_docstring
|
| 16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 17 |
+
|
| 18 |
+
from src.models import UNetSpatioTemporalConditionModel, ControlNetModel
|
| 19 |
+
from diffusers.models import AutoencoderKLTemporalDecoder
|
| 20 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 21 |
+
from diffusers import EulerDiscreteScheduler
|
| 22 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 25 |
+
|
| 26 |
+
class StableVideoControlPipeline(StableVideoDiffusionPipeline_original):
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
vae: AutoencoderKLTemporalDecoder,
|
| 31 |
+
image_encoder: CLIPVisionModelWithProjection,
|
| 32 |
+
unet: UNetSpatioTemporalConditionModel,
|
| 33 |
+
controlnet: ControlNetModel,
|
| 34 |
+
scheduler: EulerDiscreteScheduler,
|
| 35 |
+
feature_extractor: CLIPImageProcessor,
|
| 36 |
+
null_model: UNetSpatioTemporalConditionModel = None
|
| 37 |
+
):
|
| 38 |
+
# calling the super class constructors without calling StableVideoDiffusionPipeline_original's
|
| 39 |
+
DiffusionPipeline.__init__(self)
|
| 40 |
+
|
| 41 |
+
self.register_modules(
|
| 42 |
+
vae=vae,
|
| 43 |
+
image_encoder=image_encoder,
|
| 44 |
+
controlnet=controlnet,
|
| 45 |
+
unet=unet,
|
| 46 |
+
scheduler=scheduler,
|
| 47 |
+
feature_extractor=feature_extractor,
|
| 48 |
+
null_model=null_model,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 52 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 53 |
+
|
| 54 |
+
def check_inputs(self, image, cond_images, height, width):
|
| 55 |
+
if (
|
| 56 |
+
not isinstance(image, torch.Tensor)
|
| 57 |
+
and not isinstance(image, PIL.Image.Image)
|
| 58 |
+
and not isinstance(image, list)
|
| 59 |
+
):
|
| 60 |
+
raise ValueError(
|
| 61 |
+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 62 |
+
f" {type(image)}"
|
| 63 |
+
)
|
| 64 |
+
if not isinstance(cond_images, torch.Tensor):
|
| 65 |
+
raise ValueError(
|
| 66 |
+
"`cond_images` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 67 |
+
f" {type(cond_images)}"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 71 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _encode_vae_condition(
|
| 75 |
+
self,
|
| 76 |
+
cond_image: torch.tensor,
|
| 77 |
+
device: Union[str, torch.device],
|
| 78 |
+
num_videos_per_prompt: int,
|
| 79 |
+
do_classifier_free_guidance: bool,
|
| 80 |
+
bbox_mask_frames: List[bool] = None
|
| 81 |
+
):
|
| 82 |
+
video_length = cond_image.shape[1]
|
| 83 |
+
cond_image = cond_image.to(device=device)
|
| 84 |
+
cond_image = cond_image.to(dtype=self.vae.dtype)
|
| 85 |
+
|
| 86 |
+
if cond_image.shape[2] == 3:
|
| 87 |
+
cond_image = rearrange(cond_image, "b f c h w -> (b f) c h w")
|
| 88 |
+
cond_em = self.vae.encode(cond_image).latent_dist.mode()
|
| 89 |
+
cond_em = rearrange(cond_em, "(b f) c h w -> b f c h w", f=video_length)
|
| 90 |
+
else:
|
| 91 |
+
assert cond_image.shape[2] == 4, "The input tensor should have 3 or 4 channels. 3 for frames and 4 for latents."
|
| 92 |
+
cond_em = cond_image
|
| 93 |
+
|
| 94 |
+
# duplicate cond_em for each generation per prompt, using mps friendly method
|
| 95 |
+
cond_em = cond_em.repeat(num_videos_per_prompt, 1, 1, 1, 1)
|
| 96 |
+
|
| 97 |
+
# Bbox conditioning masking during inference (requiring the model to predict behaviour instead)
|
| 98 |
+
if bbox_mask_frames is not None:
|
| 99 |
+
mask_cond = torch.tensor(bbox_mask_frames, device=cond_em.device).view(num_videos_per_prompt, video_length, 1, 1, 1)
|
| 100 |
+
null_embedding = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
|
| 101 |
+
cond_em = torch.where(mask_cond, null_embedding, cond_em)
|
| 102 |
+
|
| 103 |
+
if do_classifier_free_guidance:
|
| 104 |
+
# negative_cond_em = torch.zeros_like(cond_em)
|
| 105 |
+
negative_cond_em = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
|
| 106 |
+
|
| 107 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 108 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 109 |
+
# to avoid doing two forward passes
|
| 110 |
+
cond_em = torch.cat([negative_cond_em, cond_em])
|
| 111 |
+
|
| 112 |
+
return cond_em
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def do_classifier_free_guidance(self):
|
| 116 |
+
# Don't do the normal CFG when using null model. The null model will take care of computing the unconditional noise
|
| 117 |
+
if self.null_model is not None:
|
| 118 |
+
return False
|
| 119 |
+
if isinstance(self.guidance_scale, (int, float)):
|
| 120 |
+
return self.guidance_scale > 1
|
| 121 |
+
return self.guidance_scale.max() > 1
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 125 |
+
def __call__(
|
| 126 |
+
self,
|
| 127 |
+
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
| 128 |
+
cond_images: torch.FloatTensor = None,
|
| 129 |
+
bbox_mask_frames: List[bool] = None,
|
| 130 |
+
action_type: torch.FloatTensor = None,
|
| 131 |
+
height: int = 576,
|
| 132 |
+
width: int = 1024,
|
| 133 |
+
num_frames: Optional[int] = None,
|
| 134 |
+
num_inference_steps: int = 25,
|
| 135 |
+
min_guidance_scale: float = 1.0,
|
| 136 |
+
max_guidance_scale: float = 3.0,
|
| 137 |
+
control_condition_scale: float=1.0,
|
| 138 |
+
fps: int = 7,
|
| 139 |
+
motion_bucket_id: int = 127,
|
| 140 |
+
noise_aug_strength: float = 0.02,
|
| 141 |
+
decode_chunk_size: Optional[int] = None,
|
| 142 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 143 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 144 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 145 |
+
output_type: Optional[str] = "pil",
|
| 146 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 147 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 148 |
+
return_dict: bool = True,
|
| 149 |
+
):
|
| 150 |
+
r"""
|
| 151 |
+
The call function to the pipeline for generation.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
| 155 |
+
Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
|
| 156 |
+
1]`.
|
| 157 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 158 |
+
The height in pixels of the generated image.
|
| 159 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 160 |
+
The width in pixels of the generated image.
|
| 161 |
+
num_frames (`int`, *optional*):
|
| 162 |
+
The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
|
| 163 |
+
`stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
|
| 164 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
| 165 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
| 166 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 167 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
| 168 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
| 169 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
| 170 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
| 171 |
+
fps (`int`, *optional*, defaults to 7):
|
| 172 |
+
Frames per second. The rate at which the generated images shall be exported to a video after
|
| 173 |
+
generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
| 174 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
| 175 |
+
Used for conditioning the amount of motion for the generation. The higher the number the more motion
|
| 176 |
+
will be in the video.
|
| 177 |
+
noise_aug_strength (`float`, *optional*, defaults to 0.02):
|
| 178 |
+
The amount of noise added to the init image, the higher it is the less the video will look like the
|
| 179 |
+
init image. Increase it for more motion.
|
| 180 |
+
action_type (`torch.FloatTensor`, *optional*, defaults to None):
|
| 181 |
+
The action type to condition the generation. These features are used by the ControlNet
|
| 182 |
+
to influence the generation process. The features should be of shape `[batch_size, 1]`.
|
| 183 |
+
decode_chunk_size (`int`, *optional*):
|
| 184 |
+
The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
|
| 185 |
+
expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
|
| 186 |
+
For lower memory usage, reduce `decode_chunk_size`.
|
| 187 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 188 |
+
The number of videos to generate per prompt.
|
| 189 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 190 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 191 |
+
generation deterministic.
|
| 192 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 193 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
| 194 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 195 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 196 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 197 |
+
The output format of the generated image. Choose between `pil`, `np` or `pt`.
|
| 198 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 199 |
+
A function that is called at the end of each denoising step during inference. The function is called
|
| 200 |
+
with the following arguments:
|
| 201 |
+
`callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
|
| 202 |
+
`callback_kwargs` will include a list of all tensors as specified by
|
| 203 |
+
`callback_on_step_end_tensor_inputs`.
|
| 204 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 205 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 206 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 207 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 208 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 209 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 210 |
+
plain tuple.
|
| 211 |
+
|
| 212 |
+
Examples:
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
| 216 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
|
| 217 |
+
returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`)
|
| 218 |
+
is returned.
|
| 219 |
+
"""
|
| 220 |
+
# 0. Default height and width to unet
|
| 221 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 222 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 223 |
+
|
| 224 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
| 225 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
| 226 |
+
|
| 227 |
+
# 1. Check inputs. Raise error if not correct
|
| 228 |
+
self.check_inputs(image, cond_images, height, width)
|
| 229 |
+
|
| 230 |
+
# 2. Define call parameters
|
| 231 |
+
if isinstance(image, PIL.Image.Image):
|
| 232 |
+
batch_size = 1
|
| 233 |
+
elif isinstance(image, list):
|
| 234 |
+
batch_size = len(image)
|
| 235 |
+
else:
|
| 236 |
+
batch_size = image.shape[0]
|
| 237 |
+
device = self._execution_device
|
| 238 |
+
vae_device = self.vae.device
|
| 239 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 240 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 241 |
+
# corresponds to doing no classifier free guidance.
|
| 242 |
+
self._guidance_scale = max_guidance_scale
|
| 243 |
+
|
| 244 |
+
# 3. Encode input image
|
| 245 |
+
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
| 246 |
+
|
| 247 |
+
# NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
|
| 248 |
+
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
|
| 249 |
+
fps = fps - 1
|
| 250 |
+
|
| 251 |
+
# 4. Encode input image using VAE
|
| 252 |
+
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
|
| 253 |
+
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
|
| 254 |
+
image = image + noise_aug_strength * noise
|
| 255 |
+
|
| 256 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 257 |
+
if needs_upcasting:
|
| 258 |
+
self.vae.to(dtype=torch.float32)
|
| 259 |
+
|
| 260 |
+
image_latents = self._encode_vae_image(
|
| 261 |
+
image,
|
| 262 |
+
device=device,
|
| 263 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 264 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 265 |
+
)
|
| 266 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 267 |
+
|
| 268 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
| 269 |
+
# image_latents [batch, channels, height, width] -> [batch, num_frames, channels, height, width]
|
| 270 |
+
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
| 271 |
+
# 5. Get Added Time IDs
|
| 272 |
+
added_time_ids = self._get_add_time_ids(
|
| 273 |
+
fps,
|
| 274 |
+
motion_bucket_id,
|
| 275 |
+
noise_aug_strength,
|
| 276 |
+
image_embeddings.dtype,
|
| 277 |
+
batch_size,
|
| 278 |
+
num_videos_per_prompt,
|
| 279 |
+
self.do_classifier_free_guidance,
|
| 280 |
+
)
|
| 281 |
+
added_time_ids = added_time_ids.to(device)
|
| 282 |
+
|
| 283 |
+
# 6. Prepare timesteps
|
| 284 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 285 |
+
timesteps = self.scheduler.timesteps
|
| 286 |
+
|
| 287 |
+
# 7a. Prepare latent variables
|
| 288 |
+
num_channels_latents = self.unet.config.out_channels*2
|
| 289 |
+
latents = self.prepare_latents(
|
| 290 |
+
batch_size * num_videos_per_prompt,
|
| 291 |
+
num_frames,
|
| 292 |
+
num_channels_latents,
|
| 293 |
+
height,
|
| 294 |
+
width,
|
| 295 |
+
image_embeddings.dtype,
|
| 296 |
+
device,
|
| 297 |
+
generator,
|
| 298 |
+
latents,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# 7b. Prepare control latent embeds
|
| 302 |
+
if not cond_images is None:
|
| 303 |
+
cond_em = self._encode_vae_condition(cond_images,
|
| 304 |
+
device,
|
| 305 |
+
num_videos_per_prompt,
|
| 306 |
+
self.do_classifier_free_guidance,
|
| 307 |
+
bbox_mask_frames=bbox_mask_frames)
|
| 308 |
+
cond_em = cond_em.to(image_embeddings.dtype)
|
| 309 |
+
else:
|
| 310 |
+
cond_em = None
|
| 311 |
+
|
| 312 |
+
# 7c. Prepare action features
|
| 313 |
+
if not action_type is None:
|
| 314 |
+
if self.do_classifier_free_guidance:
|
| 315 |
+
action_type = torch.cat([torch.zeros_like(action_type).unsqueeze(0), action_type.unsqueeze(0)])
|
| 316 |
+
else:
|
| 317 |
+
action_type = None
|
| 318 |
+
|
| 319 |
+
# 8. Prepare guidance scale
|
| 320 |
+
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
| 321 |
+
guidance_scale = guidance_scale.to(device, latents.dtype)
|
| 322 |
+
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
| 323 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
| 324 |
+
|
| 325 |
+
self._guidance_scale = guidance_scale
|
| 326 |
+
|
| 327 |
+
# 9. Denoising loop
|
| 328 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 329 |
+
self._num_timesteps = len(timesteps)
|
| 330 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 331 |
+
for i, t in enumerate(timesteps):
|
| 332 |
+
# expand the latents if we are doing classifier free guidance
|
| 333 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 334 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 335 |
+
|
| 336 |
+
# Concatenate image_latents over channels dimension
|
| 337 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 338 |
+
down_block_additional_residuals, mid_block_additional_residuals = self.controlnet(
|
| 339 |
+
latent_model_input,
|
| 340 |
+
timestep=t,
|
| 341 |
+
encoder_hidden_states=image_embeddings,
|
| 342 |
+
added_time_ids=added_time_ids,
|
| 343 |
+
control_cond=cond_em,
|
| 344 |
+
action_type=action_type,
|
| 345 |
+
conditioning_scale=control_condition_scale,
|
| 346 |
+
return_dict=False,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# predict the noise residual
|
| 350 |
+
noise_pred = self.unet(
|
| 351 |
+
sample=latent_model_input,
|
| 352 |
+
timestep=t,
|
| 353 |
+
encoder_hidden_states=image_embeddings,
|
| 354 |
+
added_time_ids=added_time_ids,
|
| 355 |
+
down_block_additional_residuals=down_block_additional_residuals,
|
| 356 |
+
mid_block_additional_residuals=mid_block_additional_residuals,
|
| 357 |
+
return_dict=False,
|
| 358 |
+
)[0]
|
| 359 |
+
|
| 360 |
+
# Predict unconditional noise
|
| 361 |
+
if self.null_model is not None:
|
| 362 |
+
t = time.time()
|
| 363 |
+
noise_pred_uncond = self.null_model(
|
| 364 |
+
latent_model_input,
|
| 365 |
+
t,
|
| 366 |
+
encoder_hidden_states=image_embeddings,
|
| 367 |
+
added_time_ids=added_time_ids,
|
| 368 |
+
return_dict=False,
|
| 369 |
+
)[0]
|
| 370 |
+
print(f"Computed uncond noise in: {time.time()-t:.4f}s")
|
| 371 |
+
|
| 372 |
+
# perform guidance
|
| 373 |
+
if self.do_classifier_free_guidance:
|
| 374 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 375 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 376 |
+
elif self.null_model is not None:
|
| 377 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred - noise_pred_uncond)
|
| 378 |
+
|
| 379 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 380 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 381 |
+
|
| 382 |
+
if callback_on_step_end is not None:
|
| 383 |
+
callback_kwargs = {}
|
| 384 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 385 |
+
callback_kwargs[k] = locals()[k]
|
| 386 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 387 |
+
|
| 388 |
+
latents = callback_outputs.pop("latents", latents)
|
| 389 |
+
|
| 390 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 391 |
+
progress_bar.update()
|
| 392 |
+
|
| 393 |
+
if not output_type == "latent":
|
| 394 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
| 395 |
+
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
| 396 |
+
else:
|
| 397 |
+
frames = latents
|
| 398 |
+
|
| 399 |
+
# cast back to fp16 if needed
|
| 400 |
+
if needs_upcasting:
|
| 401 |
+
self.vae.to(dtype=torch.float16)
|
| 402 |
+
|
| 403 |
+
self.maybe_free_model_hooks()
|
| 404 |
+
|
| 405 |
+
if not return_dict:
|
| 406 |
+
return frames
|
| 407 |
+
|
| 408 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
src/pipelines/pipeline_video_control_factor_guidance.py
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 3 |
+
import PIL.Image
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
| 8 |
+
tensor2vid,
|
| 9 |
+
StableVideoDiffusionPipelineOutput,
|
| 10 |
+
_append_dims,
|
| 11 |
+
EXAMPLE_DOC_STRING
|
| 12 |
+
)
|
| 13 |
+
from diffusers import StableVideoDiffusionPipeline as StableVideoDiffusionPipeline_original
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.utils import logging, replace_example_docstring
|
| 16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 17 |
+
|
| 18 |
+
from src.models import UNetSpatioTemporalConditionModel, ControlNetModel
|
| 19 |
+
from diffusers.models import AutoencoderKLTemporalDecoder
|
| 20 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 21 |
+
from diffusers import EulerDiscreteScheduler
|
| 22 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
PipelineImageInput = Union[
|
| 28 |
+
PIL.Image.Image,
|
| 29 |
+
np.ndarray,
|
| 30 |
+
torch.FloatTensor,
|
| 31 |
+
List[PIL.Image.Image],
|
| 32 |
+
List[np.ndarray],
|
| 33 |
+
List[torch.FloatTensor],
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
class StableVideoControlFactorGuidancePipeline(StableVideoDiffusionPipeline_original):
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
vae: AutoencoderKLTemporalDecoder,
|
| 41 |
+
image_encoder: CLIPVisionModelWithProjection,
|
| 42 |
+
unet: UNetSpatioTemporalConditionModel,
|
| 43 |
+
controlnet: ControlNetModel,
|
| 44 |
+
scheduler: EulerDiscreteScheduler,
|
| 45 |
+
feature_extractor: CLIPImageProcessor,
|
| 46 |
+
null_model: UNetSpatioTemporalConditionModel,
|
| 47 |
+
):
|
| 48 |
+
# calling the super class constructors without calling StableVideoDiffusionPipeline_original's
|
| 49 |
+
DiffusionPipeline.__init__(self)
|
| 50 |
+
|
| 51 |
+
self.register_modules(
|
| 52 |
+
vae=vae,
|
| 53 |
+
image_encoder=image_encoder,
|
| 54 |
+
controlnet=controlnet,
|
| 55 |
+
unet=unet,
|
| 56 |
+
scheduler=scheduler,
|
| 57 |
+
feature_extractor=feature_extractor,
|
| 58 |
+
null_model=null_model,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 62 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 63 |
+
|
| 64 |
+
def check_inputs(self, image, cond_images, height, width):
|
| 65 |
+
if (
|
| 66 |
+
not isinstance(image, torch.Tensor)
|
| 67 |
+
and not isinstance(image, PIL.Image.Image)
|
| 68 |
+
and not isinstance(image, list)
|
| 69 |
+
):
|
| 70 |
+
raise ValueError(
|
| 71 |
+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 72 |
+
f" {type(image)}"
|
| 73 |
+
)
|
| 74 |
+
if not isinstance(cond_images, torch.Tensor):
|
| 75 |
+
raise ValueError(
|
| 76 |
+
"`cond_images` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 77 |
+
f" {type(cond_images)}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 81 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 82 |
+
|
| 83 |
+
def _get_add_time_ids(
|
| 84 |
+
self,
|
| 85 |
+
fps: int,
|
| 86 |
+
motion_bucket_id: int,
|
| 87 |
+
noise_aug_strength: float,
|
| 88 |
+
dtype: torch.dtype,
|
| 89 |
+
batch_size: int,
|
| 90 |
+
num_videos_per_prompt: int,
|
| 91 |
+
):
|
| 92 |
+
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
|
| 93 |
+
|
| 94 |
+
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
|
| 95 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
| 96 |
+
|
| 97 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 98 |
+
raise ValueError(
|
| 99 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
| 103 |
+
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
|
| 104 |
+
|
| 105 |
+
return add_time_ids
|
| 106 |
+
|
| 107 |
+
def _encode_image(
|
| 108 |
+
self,
|
| 109 |
+
image: PipelineImageInput,
|
| 110 |
+
device: Union[str, torch.device],
|
| 111 |
+
num_videos_per_prompt: int,
|
| 112 |
+
) -> torch.FloatTensor:
|
| 113 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 114 |
+
|
| 115 |
+
if not isinstance(image, torch.Tensor):
|
| 116 |
+
image = self.image_processor.pil_to_numpy(image)
|
| 117 |
+
image = self.image_processor.numpy_to_pt(image)
|
| 118 |
+
|
| 119 |
+
# We normalize the image before resizing to match with the original implementation.
|
| 120 |
+
# Then we unnormalize it after resizing.
|
| 121 |
+
image = image * 2.0 - 1.0
|
| 122 |
+
image = _resize_with_antialiasing(image, (224, 224))
|
| 123 |
+
image = (image + 1.0) / 2.0
|
| 124 |
+
|
| 125 |
+
# Normalize the image with for CLIP input
|
| 126 |
+
image = self.feature_extractor(
|
| 127 |
+
images=image,
|
| 128 |
+
do_normalize=True,
|
| 129 |
+
do_center_crop=False,
|
| 130 |
+
do_resize=False,
|
| 131 |
+
do_rescale=False,
|
| 132 |
+
return_tensors="pt",
|
| 133 |
+
).pixel_values
|
| 134 |
+
|
| 135 |
+
image = image.to(device=device, dtype=dtype)
|
| 136 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
| 137 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
| 138 |
+
|
| 139 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
| 140 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
| 141 |
+
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
|
| 142 |
+
image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
| 143 |
+
|
| 144 |
+
null_image_embeddings = torch.zeros_like(image_embeddings)
|
| 145 |
+
|
| 146 |
+
return image_embeddings, null_image_embeddings
|
| 147 |
+
|
| 148 |
+
def _encode_vae_image(
|
| 149 |
+
self,
|
| 150 |
+
image: torch.Tensor,
|
| 151 |
+
device: Union[str, torch.device],
|
| 152 |
+
num_videos_per_prompt: int,
|
| 153 |
+
):
|
| 154 |
+
image = image.to(device=device)
|
| 155 |
+
image_latents = self.vae.encode(image).latent_dist.mode()
|
| 156 |
+
|
| 157 |
+
# duplicate image_latents for each generation per prompt, using mps friendly method
|
| 158 |
+
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
|
| 159 |
+
null_image_latents = torch.zeros_like(image_latents)
|
| 160 |
+
|
| 161 |
+
return image_latents, null_image_latents
|
| 162 |
+
|
| 163 |
+
def _encode_vae_condition(
|
| 164 |
+
self,
|
| 165 |
+
cond_image: torch.tensor,
|
| 166 |
+
device: Union[str, torch.device],
|
| 167 |
+
num_videos_per_prompt: int,
|
| 168 |
+
bbox_mask_frames: List[bool] = None
|
| 169 |
+
):
|
| 170 |
+
video_length = cond_image.shape[1]
|
| 171 |
+
cond_image = cond_image.to(device=device)
|
| 172 |
+
cond_image = cond_image.to(dtype=self.vae.dtype)
|
| 173 |
+
|
| 174 |
+
if cond_image.shape[2] == 3:
|
| 175 |
+
cond_image = rearrange(cond_image, "b f c h w -> (b f) c h w")
|
| 176 |
+
cond_em = self.vae.encode(cond_image).latent_dist.mode()
|
| 177 |
+
cond_em = rearrange(cond_em, "(b f) c h w -> b f c h w", f=video_length)
|
| 178 |
+
else:
|
| 179 |
+
assert cond_image.shape[2] == 4, "The input tensor should have 3 or 4 channels. 3 for frames and 4 for latents."
|
| 180 |
+
cond_em = cond_image
|
| 181 |
+
|
| 182 |
+
# duplicate cond_em for each generation per prompt, using mps friendly method
|
| 183 |
+
cond_em = cond_em.repeat(num_videos_per_prompt, 1, 1, 1, 1)
|
| 184 |
+
|
| 185 |
+
# Bbox conditioning masking during inference (requiring the model to predict behaviour instead)
|
| 186 |
+
if bbox_mask_frames is not None:
|
| 187 |
+
mask_cond = torch.tensor(bbox_mask_frames, device=cond_em.device).view(num_videos_per_prompt, video_length, 1, 1, 1)
|
| 188 |
+
null_embedding = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
|
| 189 |
+
cond_em = torch.where(mask_cond, null_embedding, cond_em)
|
| 190 |
+
|
| 191 |
+
null_cond_em = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
|
| 192 |
+
|
| 193 |
+
return cond_em, null_cond_em
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@torch.no_grad()
|
| 197 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 198 |
+
def __call__(
|
| 199 |
+
self,
|
| 200 |
+
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
| 201 |
+
cond_images: torch.FloatTensor = None,
|
| 202 |
+
bbox_mask_frames: List[bool] = None,
|
| 203 |
+
action_type: torch.FloatTensor = None,
|
| 204 |
+
height: int = 576,
|
| 205 |
+
width: int = 1024,
|
| 206 |
+
num_frames: Optional[int] = None,
|
| 207 |
+
num_inference_steps: int = 25,
|
| 208 |
+
min_guidance_scale_img: float = 1.0,
|
| 209 |
+
max_guidance_scale_img: float = 3.0,
|
| 210 |
+
min_guidance_scale_action: float = 1.0,
|
| 211 |
+
max_guidance_scale_action: float = 3.0,
|
| 212 |
+
min_guidance_scale_bbox: float = 1.0,
|
| 213 |
+
max_guidance_scale_bbox: float = 3.0,
|
| 214 |
+
control_condition_scale: float=1.0,
|
| 215 |
+
fps: int = 7,
|
| 216 |
+
motion_bucket_id: int = 127,
|
| 217 |
+
noise_aug_strength: float = 0.02,
|
| 218 |
+
decode_chunk_size: Optional[int] = None,
|
| 219 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 220 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 221 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 222 |
+
output_type: Optional[str] = "pil",
|
| 223 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 224 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 225 |
+
return_dict: bool = True,
|
| 226 |
+
):
|
| 227 |
+
r"""
|
| 228 |
+
The call function to the pipeline for generation.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
| 232 |
+
Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
|
| 233 |
+
1]`.
|
| 234 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 235 |
+
The height in pixels of the generated image.
|
| 236 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 237 |
+
The width in pixels of the generated image.
|
| 238 |
+
num_frames (`int`, *optional*):
|
| 239 |
+
The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
|
| 240 |
+
`stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
|
| 241 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
| 242 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
| 243 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 244 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
| 245 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
| 246 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
| 247 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
| 248 |
+
fps (`int`, *optional*, defaults to 7):
|
| 249 |
+
Frames per second. The rate at which the generated images shall be exported to a video after
|
| 250 |
+
generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
| 251 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
| 252 |
+
Used for conditioning the amount of motion for the generation. The higher the number the more motion
|
| 253 |
+
will be in the video.
|
| 254 |
+
noise_aug_strength (`float`, *optional*, defaults to 0.02):
|
| 255 |
+
The amount of noise added to the init image, the higher it is the less the video will look like the
|
| 256 |
+
init image. Increase it for more motion.
|
| 257 |
+
action_type (`torch.FloatTensor`, *optional*, defaults to None):
|
| 258 |
+
The action type to condition the generation. These features are used by the ControlNet
|
| 259 |
+
to influence the generation process. The features should be of shape `[batch_size, 1]`.
|
| 260 |
+
decode_chunk_size (`int`, *optional*):
|
| 261 |
+
The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
|
| 262 |
+
expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
|
| 263 |
+
For lower memory usage, reduce `decode_chunk_size`.
|
| 264 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 265 |
+
The number of videos to generate per prompt.
|
| 266 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 267 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 268 |
+
generation deterministic.
|
| 269 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 270 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
| 271 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 272 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 273 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 274 |
+
The output format of the generated image. Choose between `pil`, `np` or `pt`.
|
| 275 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 276 |
+
A function that is called at the end of each denoising step during inference. The function is called
|
| 277 |
+
with the following arguments:
|
| 278 |
+
`callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
|
| 279 |
+
`callback_kwargs` will include a list of all tensors as specified by
|
| 280 |
+
`callback_on_step_end_tensor_inputs`.
|
| 281 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 282 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 283 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 284 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 285 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 286 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 287 |
+
plain tuple.
|
| 288 |
+
|
| 289 |
+
Examples:
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
| 293 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
|
| 294 |
+
returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`)
|
| 295 |
+
is returned.
|
| 296 |
+
"""
|
| 297 |
+
# 0. Default height and width to unet
|
| 298 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 299 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 300 |
+
|
| 301 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
| 302 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
| 303 |
+
|
| 304 |
+
# 1. Check inputs. Raise error if not correct
|
| 305 |
+
self.check_inputs(image, cond_images, height, width)
|
| 306 |
+
|
| 307 |
+
# 2. Define call parameters
|
| 308 |
+
if isinstance(image, PIL.Image.Image):
|
| 309 |
+
batch_size = 1
|
| 310 |
+
elif isinstance(image, list):
|
| 311 |
+
batch_size = len(image)
|
| 312 |
+
else:
|
| 313 |
+
batch_size = image.shape[0]
|
| 314 |
+
device = self._execution_device
|
| 315 |
+
vae_device = self.vae.device
|
| 316 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 317 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 318 |
+
# corresponds to doing no classifier free guidance.
|
| 319 |
+
# self._guidance_scale = max_guidance_scale
|
| 320 |
+
|
| 321 |
+
# 3. Encode input image
|
| 322 |
+
image_embeddings, null_image_embeddings = self._encode_image(image, device, num_videos_per_prompt)
|
| 323 |
+
|
| 324 |
+
# NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
|
| 325 |
+
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
|
| 326 |
+
fps = fps - 1
|
| 327 |
+
|
| 328 |
+
# 4. Encode input image using VAE
|
| 329 |
+
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
|
| 330 |
+
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
|
| 331 |
+
image = image + noise_aug_strength * noise
|
| 332 |
+
|
| 333 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 334 |
+
if needs_upcasting:
|
| 335 |
+
self.vae.to(dtype=torch.float32)
|
| 336 |
+
|
| 337 |
+
image_latents, null_image_latents = self._encode_vae_image(
|
| 338 |
+
image,
|
| 339 |
+
device=device,
|
| 340 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 341 |
+
)
|
| 342 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 343 |
+
null_image_latents = null_image_latents.to(image_embeddings.dtype)
|
| 344 |
+
|
| 345 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
| 346 |
+
# image_latents [batch, channels, height, width] -> [batch, num_frames, channels, height, width]
|
| 347 |
+
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
| 348 |
+
null_image_latents = null_image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
| 349 |
+
|
| 350 |
+
# 5. Get Added Time IDs
|
| 351 |
+
added_time_ids = self._get_add_time_ids(
|
| 352 |
+
fps,
|
| 353 |
+
motion_bucket_id,
|
| 354 |
+
noise_aug_strength,
|
| 355 |
+
image_embeddings.dtype,
|
| 356 |
+
batch_size,
|
| 357 |
+
num_videos_per_prompt,
|
| 358 |
+
)
|
| 359 |
+
added_time_ids = added_time_ids.to(device)
|
| 360 |
+
|
| 361 |
+
# TODO: reshape time ids for factor guidance
|
| 362 |
+
|
| 363 |
+
# 6. Prepare timesteps
|
| 364 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 365 |
+
timesteps = self.scheduler.timesteps
|
| 366 |
+
|
| 367 |
+
# 7a. Prepare latent variables
|
| 368 |
+
num_channels_latents = self.unet.config.out_channels*2
|
| 369 |
+
latents = self.prepare_latents(
|
| 370 |
+
batch_size * num_videos_per_prompt,
|
| 371 |
+
num_frames,
|
| 372 |
+
num_channels_latents,
|
| 373 |
+
height,
|
| 374 |
+
width,
|
| 375 |
+
image_embeddings.dtype,
|
| 376 |
+
device,
|
| 377 |
+
generator,
|
| 378 |
+
latents,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# 7b. Prepare control latent embeds
|
| 382 |
+
if not cond_images is None:
|
| 383 |
+
cond_em, null_cond_em = self._encode_vae_condition(cond_images,
|
| 384 |
+
device,
|
| 385 |
+
num_videos_per_prompt,
|
| 386 |
+
bbox_mask_frames=bbox_mask_frames)
|
| 387 |
+
cond_em = cond_em.to(image_embeddings.dtype)
|
| 388 |
+
null_cond_em = null_cond_em.to(image_embeddings.dtype)
|
| 389 |
+
else:
|
| 390 |
+
cond_em = None
|
| 391 |
+
null_cond_em = None
|
| 392 |
+
|
| 393 |
+
# 7c. Prepare action features
|
| 394 |
+
if action_type is not None:
|
| 395 |
+
action_type, null_action_type = action_type.unsqueeze(0), torch.zeros_like(action_type).unsqueeze(0)
|
| 396 |
+
else:
|
| 397 |
+
action_type = None
|
| 398 |
+
null_action_type = None
|
| 399 |
+
|
| 400 |
+
# 8. Prepare guidance scales
|
| 401 |
+
guidance_scale_img = torch.linspace(min_guidance_scale_img, max_guidance_scale_img, num_frames).unsqueeze(0)
|
| 402 |
+
guidance_scale_img = guidance_scale_img.to(device, latents.dtype)
|
| 403 |
+
guidance_scale_img = guidance_scale_img.repeat(batch_size * num_videos_per_prompt, 1)
|
| 404 |
+
guidance_scale_img = _append_dims(guidance_scale_img, latents.ndim)
|
| 405 |
+
|
| 406 |
+
guidance_scale_action = torch.linspace(min_guidance_scale_action, max_guidance_scale_action, num_frames).unsqueeze(0)
|
| 407 |
+
guidance_scale_action = guidance_scale_action.to(device, latents.dtype)
|
| 408 |
+
guidance_scale_action = guidance_scale_action.repeat(batch_size * num_videos_per_prompt, 1)
|
| 409 |
+
guidance_scale_action = _append_dims(guidance_scale_action, latents.ndim)
|
| 410 |
+
|
| 411 |
+
guidance_scale_bbox = torch.linspace(min_guidance_scale_bbox, max_guidance_scale_bbox, num_frames).unsqueeze(0)
|
| 412 |
+
guidance_scale_bbox = guidance_scale_bbox.to(device, latents.dtype)
|
| 413 |
+
guidance_scale_bbox = guidance_scale_bbox.repeat(batch_size * num_videos_per_prompt, 1)
|
| 414 |
+
guidance_scale_bbox = _append_dims(guidance_scale_bbox, latents.ndim)
|
| 415 |
+
|
| 416 |
+
# Build the tensors to batch the different levels of conditioning (used for the factored CFG)
|
| 417 |
+
# [image_and_bbox_embeddings, image_and_bbox_and_action_embeddings]
|
| 418 |
+
image_embeddings = torch.cat([image_embeddings, image_embeddings])
|
| 419 |
+
image_latents = torch.cat([image_latents, image_latents])
|
| 420 |
+
cond_em = torch.cat([cond_em, cond_em])
|
| 421 |
+
action_type = torch.cat([null_action_type, action_type])
|
| 422 |
+
added_time_ids = torch.cat([added_time_ids] * 2)
|
| 423 |
+
latents = torch.cat([latents] * 2)
|
| 424 |
+
|
| 425 |
+
# 9. Denoising loop
|
| 426 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 427 |
+
self._num_timesteps = len(timesteps)
|
| 428 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 429 |
+
for i, t in enumerate(timesteps):
|
| 430 |
+
latent_model_input = latents
|
| 431 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 432 |
+
|
| 433 |
+
# print(latent_model_input.shape, image_latents.shape)
|
| 434 |
+
|
| 435 |
+
# Concatenate image_latents over channels dimension
|
| 436 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 437 |
+
down_block_additional_residuals, mid_block_additional_residuals = self.controlnet(
|
| 438 |
+
latent_model_input,
|
| 439 |
+
timestep=t,
|
| 440 |
+
encoder_hidden_states=image_embeddings,
|
| 441 |
+
added_time_ids=added_time_ids,
|
| 442 |
+
control_cond=cond_em,
|
| 443 |
+
action_type=action_type,
|
| 444 |
+
conditioning_scale=control_condition_scale,
|
| 445 |
+
return_dict=False,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# predict the noise residual
|
| 449 |
+
noise_pred = self.unet(
|
| 450 |
+
sample=latent_model_input,
|
| 451 |
+
timestep=t,
|
| 452 |
+
encoder_hidden_states=image_embeddings,
|
| 453 |
+
added_time_ids=added_time_ids,
|
| 454 |
+
down_block_additional_residuals=down_block_additional_residuals,
|
| 455 |
+
mid_block_additional_residuals=mid_block_additional_residuals,
|
| 456 |
+
return_dict=False,
|
| 457 |
+
)[0]
|
| 458 |
+
|
| 459 |
+
# Predict unconditional noise
|
| 460 |
+
noise_pred_cond_img = self.null_model(
|
| 461 |
+
latent_model_input,
|
| 462 |
+
t,
|
| 463 |
+
encoder_hidden_states=image_embeddings,
|
| 464 |
+
added_time_ids=added_time_ids,
|
| 465 |
+
return_dict=False,
|
| 466 |
+
)[0]
|
| 467 |
+
|
| 468 |
+
# Perform factored CFG
|
| 469 |
+
# NOTE: Currently discarding the unconditional noise prediction from the finetuned model
|
| 470 |
+
noise_pred_cond_img_bbox, noise_pred_cond_all = noise_pred.chunk(2)
|
| 471 |
+
|
| 472 |
+
# NOTE: `noise_pred_uncond` is technically the same as `noise_pred_cond_img` since they both condition on the image.
|
| 473 |
+
# Therefore we could probably remove `noise_pred_cond_img` and get similar performances and faster inference
|
| 474 |
+
noise_pred = noise_pred_cond_img \
|
| 475 |
+
+ guidance_scale_bbox * (noise_pred_cond_img_bbox - noise_pred_cond_img) \
|
| 476 |
+
+ guidance_scale_action * (noise_pred_cond_all - noise_pred_cond_img_bbox)
|
| 477 |
+
|
| 478 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 479 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 480 |
+
# print("latents", latents.shape)
|
| 481 |
+
|
| 482 |
+
if callback_on_step_end is not None:
|
| 483 |
+
callback_kwargs = {}
|
| 484 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 485 |
+
callback_kwargs[k] = locals()[k]
|
| 486 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 487 |
+
|
| 488 |
+
latents = callback_outputs.pop("latents", latents)
|
| 489 |
+
|
| 490 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 491 |
+
progress_bar.update()
|
| 492 |
+
|
| 493 |
+
if not output_type == "latent":
|
| 494 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
| 495 |
+
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
| 496 |
+
else:
|
| 497 |
+
frames = latents
|
| 498 |
+
|
| 499 |
+
# cast back to fp16 if needed
|
| 500 |
+
if needs_upcasting:
|
| 501 |
+
self.vae.to(dtype=torch.float16)
|
| 502 |
+
|
| 503 |
+
self.maybe_free_model_hooks()
|
| 504 |
+
|
| 505 |
+
if not return_dict:
|
| 506 |
+
return frames
|
| 507 |
+
|
| 508 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
# TODO: Some helper functions from Stable Video Diffusion that we could move elsewhere
|
| 512 |
+
def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
|
| 513 |
+
h, w = input.shape[-2:]
|
| 514 |
+
factors = (h / size[0], w / size[1])
|
| 515 |
+
|
| 516 |
+
# First, we have to determine sigma
|
| 517 |
+
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
|
| 518 |
+
sigmas = (
|
| 519 |
+
max((factors[0] - 1.0) / 2.0, 0.001),
|
| 520 |
+
max((factors[1] - 1.0) / 2.0, 0.001),
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
|
| 524 |
+
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
|
| 525 |
+
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
|
| 526 |
+
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
|
| 527 |
+
|
| 528 |
+
# Make sure it is odd
|
| 529 |
+
if (ks[0] % 2) == 0:
|
| 530 |
+
ks = ks[0] + 1, ks[1]
|
| 531 |
+
|
| 532 |
+
if (ks[1] % 2) == 0:
|
| 533 |
+
ks = ks[0], ks[1] + 1
|
| 534 |
+
|
| 535 |
+
input = _gaussian_blur2d(input, ks, sigmas)
|
| 536 |
+
|
| 537 |
+
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
|
| 538 |
+
return output
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def _compute_padding(kernel_size):
|
| 542 |
+
"""Compute padding tuple."""
|
| 543 |
+
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
|
| 544 |
+
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
|
| 545 |
+
if len(kernel_size) < 2:
|
| 546 |
+
raise AssertionError(kernel_size)
|
| 547 |
+
computed = [k - 1 for k in kernel_size]
|
| 548 |
+
|
| 549 |
+
# for even kernels we need to do asymmetric padding :(
|
| 550 |
+
out_padding = 2 * len(kernel_size) * [0]
|
| 551 |
+
|
| 552 |
+
for i in range(len(kernel_size)):
|
| 553 |
+
computed_tmp = computed[-(i + 1)]
|
| 554 |
+
|
| 555 |
+
pad_front = computed_tmp // 2
|
| 556 |
+
pad_rear = computed_tmp - pad_front
|
| 557 |
+
|
| 558 |
+
out_padding[2 * i + 0] = pad_front
|
| 559 |
+
out_padding[2 * i + 1] = pad_rear
|
| 560 |
+
|
| 561 |
+
return out_padding
|
| 562 |
+
|
| 563 |
+
def _filter2d(input, kernel):
|
| 564 |
+
# prepare kernel
|
| 565 |
+
b, c, h, w = input.shape
|
| 566 |
+
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
|
| 567 |
+
|
| 568 |
+
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
|
| 569 |
+
|
| 570 |
+
height, width = tmp_kernel.shape[-2:]
|
| 571 |
+
|
| 572 |
+
padding_shape: list[int] = _compute_padding([height, width])
|
| 573 |
+
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
|
| 574 |
+
|
| 575 |
+
# kernel and input tensor reshape to align element-wise or batch-wise params
|
| 576 |
+
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
|
| 577 |
+
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
|
| 578 |
+
|
| 579 |
+
# convolve the tensor with the kernel.
|
| 580 |
+
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
|
| 581 |
+
|
| 582 |
+
out = output.view(b, c, h, w)
|
| 583 |
+
return out
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def _gaussian(window_size: int, sigma):
|
| 587 |
+
if isinstance(sigma, float):
|
| 588 |
+
sigma = torch.tensor([[sigma]])
|
| 589 |
+
|
| 590 |
+
batch_size = sigma.shape[0]
|
| 591 |
+
|
| 592 |
+
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
|
| 593 |
+
|
| 594 |
+
if window_size % 2 == 0:
|
| 595 |
+
x = x + 0.5
|
| 596 |
+
|
| 597 |
+
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
|
| 598 |
+
|
| 599 |
+
return gauss / gauss.sum(-1, keepdim=True)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def _gaussian_blur2d(input, kernel_size, sigma):
|
| 603 |
+
if isinstance(sigma, tuple):
|
| 604 |
+
sigma = torch.tensor([sigma], dtype=input.dtype)
|
| 605 |
+
else:
|
| 606 |
+
sigma = sigma.to(dtype=input.dtype)
|
| 607 |
+
|
| 608 |
+
ky, kx = int(kernel_size[0]), int(kernel_size[1])
|
| 609 |
+
bs = sigma.shape[0]
|
| 610 |
+
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
|
| 611 |
+
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
|
| 612 |
+
out_x = _filter2d(input, kernel_x[..., None, :])
|
| 613 |
+
out = _filter2d(out_x, kernel_y[..., None])
|
| 614 |
+
|
| 615 |
+
return out
|
src/pipelines/pipeline_video_control_nullmodel.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 3 |
+
import PIL.Image
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
| 7 |
+
tensor2vid,
|
| 8 |
+
StableVideoDiffusionPipelineOutput,
|
| 9 |
+
_append_dims,
|
| 10 |
+
EXAMPLE_DOC_STRING
|
| 11 |
+
)
|
| 12 |
+
from diffusers import StableVideoDiffusionPipeline as StableVideoDiffusionPipeline_original
|
| 13 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 14 |
+
from diffusers.utils import logging, replace_example_docstring
|
| 15 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 16 |
+
|
| 17 |
+
from src.models import UNetSpatioTemporalConditionModel, ControlNetModel
|
| 18 |
+
from diffusers.models import AutoencoderKLTemporalDecoder
|
| 19 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 20 |
+
from diffusers import EulerDiscreteScheduler
|
| 21 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 24 |
+
|
| 25 |
+
class StableVideoControlNullModelPipeline(StableVideoDiffusionPipeline_original):
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
vae: AutoencoderKLTemporalDecoder,
|
| 30 |
+
image_encoder: CLIPVisionModelWithProjection,
|
| 31 |
+
unet: UNetSpatioTemporalConditionModel,
|
| 32 |
+
controlnet: ControlNetModel,
|
| 33 |
+
scheduler: EulerDiscreteScheduler,
|
| 34 |
+
feature_extractor: CLIPImageProcessor,
|
| 35 |
+
null_model: UNetSpatioTemporalConditionModel,
|
| 36 |
+
):
|
| 37 |
+
# calling the super class constructors without calling StableVideoDiffusionPipeline_original's
|
| 38 |
+
DiffusionPipeline.__init__(self)
|
| 39 |
+
|
| 40 |
+
self.register_modules(
|
| 41 |
+
vae=vae,
|
| 42 |
+
image_encoder=image_encoder,
|
| 43 |
+
controlnet=controlnet,
|
| 44 |
+
unet=unet,
|
| 45 |
+
scheduler=scheduler,
|
| 46 |
+
feature_extractor=feature_extractor,
|
| 47 |
+
null_model=null_model,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 51 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 52 |
+
|
| 53 |
+
def check_inputs(self, image, cond_images, height, width):
|
| 54 |
+
if (
|
| 55 |
+
not isinstance(image, torch.Tensor)
|
| 56 |
+
and not isinstance(image, PIL.Image.Image)
|
| 57 |
+
and not isinstance(image, list)
|
| 58 |
+
):
|
| 59 |
+
raise ValueError(
|
| 60 |
+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 61 |
+
f" {type(image)}"
|
| 62 |
+
)
|
| 63 |
+
if not isinstance(cond_images, torch.Tensor):
|
| 64 |
+
raise ValueError(
|
| 65 |
+
"`cond_images` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 66 |
+
f" {type(cond_images)}"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 70 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _encode_vae_condition(
|
| 74 |
+
self,
|
| 75 |
+
cond_image: torch.tensor,
|
| 76 |
+
device: Union[str, torch.device],
|
| 77 |
+
num_videos_per_prompt: int,
|
| 78 |
+
do_classifier_free_guidance: bool,
|
| 79 |
+
bbox_mask_frames: List[bool] = None
|
| 80 |
+
):
|
| 81 |
+
video_length = cond_image.shape[1]
|
| 82 |
+
cond_image = cond_image.to(device=device)
|
| 83 |
+
cond_image = cond_image.to(dtype=self.vae.dtype)
|
| 84 |
+
|
| 85 |
+
if cond_image.shape[2] == 3:
|
| 86 |
+
cond_image = rearrange(cond_image, "b f c h w -> (b f) c h w")
|
| 87 |
+
cond_em = self.vae.encode(cond_image).latent_dist.mode()
|
| 88 |
+
cond_em = rearrange(cond_em, "(b f) c h w -> b f c h w", f=video_length)
|
| 89 |
+
else:
|
| 90 |
+
assert cond_image.shape[2] == 4, "The input tensor should have 3 or 4 channels. 3 for frames and 4 for latents."
|
| 91 |
+
cond_em = cond_image
|
| 92 |
+
|
| 93 |
+
# duplicate cond_em for each generation per prompt, using mps friendly method
|
| 94 |
+
cond_em = cond_em.repeat(num_videos_per_prompt, 1, 1, 1, 1)
|
| 95 |
+
|
| 96 |
+
# Bbox conditioning masking during inference (requiring the model to predict behaviour instead)
|
| 97 |
+
if bbox_mask_frames is not None:
|
| 98 |
+
mask_cond = torch.tensor(bbox_mask_frames, device=cond_em.device).view(num_videos_per_prompt, video_length, 1, 1, 1)
|
| 99 |
+
null_embedding = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
|
| 100 |
+
cond_em = torch.where(mask_cond, null_embedding, cond_em)
|
| 101 |
+
|
| 102 |
+
if do_classifier_free_guidance:
|
| 103 |
+
# negative_cond_em = torch.zeros_like(cond_em)
|
| 104 |
+
negative_cond_em = self.controlnet.bbox_null_embedding.repeat(num_videos_per_prompt, video_length, 1, 1, 1)
|
| 105 |
+
|
| 106 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 107 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 108 |
+
# to avoid doing two forward passes
|
| 109 |
+
cond_em = torch.cat([negative_cond_em, cond_em])
|
| 110 |
+
|
| 111 |
+
return cond_em
|
| 112 |
+
|
| 113 |
+
@property
|
| 114 |
+
def do_classifier_free_guidance(self):
|
| 115 |
+
return False
|
| 116 |
+
# if isinstance(self.guidance_scale, (int, float)):
|
| 117 |
+
# return self.guidance_scale > 1
|
| 118 |
+
# return self.guidance_scale.max() > 1
|
| 119 |
+
|
| 120 |
+
@torch.no_grad()
|
| 121 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 122 |
+
def __call__(
|
| 123 |
+
self,
|
| 124 |
+
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
| 125 |
+
cond_images: torch.FloatTensor = None,
|
| 126 |
+
bbox_mask_frames: List[bool] = None,
|
| 127 |
+
action_type: torch.FloatTensor = None,
|
| 128 |
+
height: int = 576,
|
| 129 |
+
width: int = 1024,
|
| 130 |
+
num_frames: Optional[int] = None,
|
| 131 |
+
num_inference_steps: int = 25,
|
| 132 |
+
min_guidance_scale: float = 1.0,
|
| 133 |
+
max_guidance_scale: float = 3.0,
|
| 134 |
+
control_condition_scale: float=1.0,
|
| 135 |
+
fps: int = 7,
|
| 136 |
+
motion_bucket_id: int = 127,
|
| 137 |
+
noise_aug_strength: float = 0.02,
|
| 138 |
+
decode_chunk_size: Optional[int] = None,
|
| 139 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 140 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 141 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 142 |
+
output_type: Optional[str] = "pil",
|
| 143 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 144 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 145 |
+
return_dict: bool = True,
|
| 146 |
+
):
|
| 147 |
+
r"""
|
| 148 |
+
The call function to the pipeline for generation.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
| 152 |
+
Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
|
| 153 |
+
1]`.
|
| 154 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 155 |
+
The height in pixels of the generated image.
|
| 156 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 157 |
+
The width in pixels of the generated image.
|
| 158 |
+
num_frames (`int`, *optional*):
|
| 159 |
+
The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
|
| 160 |
+
`stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
|
| 161 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
| 162 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
| 163 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 164 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
| 165 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
| 166 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
| 167 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
| 168 |
+
fps (`int`, *optional*, defaults to 7):
|
| 169 |
+
Frames per second. The rate at which the generated images shall be exported to a video after
|
| 170 |
+
generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
| 171 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
| 172 |
+
Used for conditioning the amount of motion for the generation. The higher the number the more motion
|
| 173 |
+
will be in the video.
|
| 174 |
+
noise_aug_strength (`float`, *optional*, defaults to 0.02):
|
| 175 |
+
The amount of noise added to the init image, the higher it is the less the video will look like the
|
| 176 |
+
init image. Increase it for more motion.
|
| 177 |
+
action_type (`torch.FloatTensor`, *optional*, defaults to None):
|
| 178 |
+
The action type to condition the generation. These features are used by the ControlNet
|
| 179 |
+
to influence the generation process. The features should be of shape `[batch_size, 1]`.
|
| 180 |
+
decode_chunk_size (`int`, *optional*):
|
| 181 |
+
The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
|
| 182 |
+
expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
|
| 183 |
+
For lower memory usage, reduce `decode_chunk_size`.
|
| 184 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 185 |
+
The number of videos to generate per prompt.
|
| 186 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 187 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 188 |
+
generation deterministic.
|
| 189 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 190 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
| 191 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 192 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 193 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 194 |
+
The output format of the generated image. Choose between `pil`, `np` or `pt`.
|
| 195 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 196 |
+
A function that is called at the end of each denoising step during inference. The function is called
|
| 197 |
+
with the following arguments:
|
| 198 |
+
`callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
|
| 199 |
+
`callback_kwargs` will include a list of all tensors as specified by
|
| 200 |
+
`callback_on_step_end_tensor_inputs`.
|
| 201 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 202 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 203 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 204 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 205 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 206 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 207 |
+
plain tuple.
|
| 208 |
+
|
| 209 |
+
Examples:
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
| 213 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
|
| 214 |
+
returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`)
|
| 215 |
+
is returned.
|
| 216 |
+
"""
|
| 217 |
+
# 0. Default height and width to unet
|
| 218 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 219 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 220 |
+
|
| 221 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
| 222 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
| 223 |
+
|
| 224 |
+
# 1. Check inputs. Raise error if not correct
|
| 225 |
+
self.check_inputs(image, cond_images, height, width)
|
| 226 |
+
|
| 227 |
+
# 2. Define call parameters
|
| 228 |
+
if isinstance(image, PIL.Image.Image):
|
| 229 |
+
batch_size = 1
|
| 230 |
+
elif isinstance(image, list):
|
| 231 |
+
batch_size = len(image)
|
| 232 |
+
else:
|
| 233 |
+
batch_size = image.shape[0]
|
| 234 |
+
device = self._execution_device
|
| 235 |
+
vae_device = self.vae.device
|
| 236 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 237 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 238 |
+
# corresponds to doing no classifier free guidance.
|
| 239 |
+
self._guidance_scale = max_guidance_scale
|
| 240 |
+
|
| 241 |
+
# 3. Encode input image
|
| 242 |
+
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
| 243 |
+
|
| 244 |
+
# NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
|
| 245 |
+
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
|
| 246 |
+
fps = fps - 1
|
| 247 |
+
|
| 248 |
+
# 4. Encode input image using VAE
|
| 249 |
+
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
|
| 250 |
+
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
|
| 251 |
+
image = image + noise_aug_strength * noise
|
| 252 |
+
|
| 253 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 254 |
+
if needs_upcasting:
|
| 255 |
+
self.vae.to(dtype=torch.float32)
|
| 256 |
+
|
| 257 |
+
image_latents = self._encode_vae_image(
|
| 258 |
+
image,
|
| 259 |
+
device=device,
|
| 260 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 261 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 262 |
+
)
|
| 263 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 264 |
+
|
| 265 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
| 266 |
+
# image_latents [batch, channels, height, width] -> [batch, num_frames, channels, height, width]
|
| 267 |
+
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
| 268 |
+
# 5. Get Added Time IDs
|
| 269 |
+
added_time_ids = self._get_add_time_ids(
|
| 270 |
+
fps,
|
| 271 |
+
motion_bucket_id,
|
| 272 |
+
noise_aug_strength,
|
| 273 |
+
image_embeddings.dtype,
|
| 274 |
+
batch_size,
|
| 275 |
+
num_videos_per_prompt,
|
| 276 |
+
self.do_classifier_free_guidance,
|
| 277 |
+
)
|
| 278 |
+
added_time_ids = added_time_ids.to(device)
|
| 279 |
+
|
| 280 |
+
# 6. Prepare timesteps
|
| 281 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 282 |
+
timesteps = self.scheduler.timesteps
|
| 283 |
+
|
| 284 |
+
# 7a. Prepare latent variables
|
| 285 |
+
num_channels_latents = self.unet.config.out_channels*2
|
| 286 |
+
latents = self.prepare_latents(
|
| 287 |
+
batch_size * num_videos_per_prompt,
|
| 288 |
+
num_frames,
|
| 289 |
+
num_channels_latents,
|
| 290 |
+
height,
|
| 291 |
+
width,
|
| 292 |
+
image_embeddings.dtype,
|
| 293 |
+
device,
|
| 294 |
+
generator,
|
| 295 |
+
latents,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# 7b. Prepare control latent embeds
|
| 299 |
+
if not cond_images is None:
|
| 300 |
+
cond_em = self._encode_vae_condition(cond_images,
|
| 301 |
+
device,
|
| 302 |
+
num_videos_per_prompt,
|
| 303 |
+
self.do_classifier_free_guidance,
|
| 304 |
+
bbox_mask_frames=bbox_mask_frames)
|
| 305 |
+
cond_em = cond_em.to(image_embeddings.dtype)
|
| 306 |
+
else:
|
| 307 |
+
cond_em = None
|
| 308 |
+
|
| 309 |
+
# 7c. Prepare action features
|
| 310 |
+
if not action_type is None:
|
| 311 |
+
if self.do_classifier_free_guidance:
|
| 312 |
+
action_type = torch.cat([torch.zeros_like(action_type).unsqueeze(0), action_type.unsqueeze(0)])
|
| 313 |
+
else:
|
| 314 |
+
action_type = None
|
| 315 |
+
|
| 316 |
+
# 8. Prepare guidance scale
|
| 317 |
+
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
| 318 |
+
guidance_scale = guidance_scale.to(device, latents.dtype)
|
| 319 |
+
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
| 320 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
| 321 |
+
|
| 322 |
+
self._guidance_scale = guidance_scale
|
| 323 |
+
|
| 324 |
+
# 9. Denoising loop
|
| 325 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 326 |
+
self._num_timesteps = len(timesteps)
|
| 327 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 328 |
+
for i, t in enumerate(timesteps):
|
| 329 |
+
# expand the latents if we are doing classifier free guidance
|
| 330 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 331 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 332 |
+
|
| 333 |
+
# print(latent_model_input.shape, image_latents.shape, self.do_classifier_free_guidance)
|
| 334 |
+
|
| 335 |
+
# Concatenate image_latents over channels dimension
|
| 336 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 337 |
+
# latent_model_input_null_model = latent_model_input.clone().detach()
|
| 338 |
+
down_block_additional_residuals, mid_block_additional_residuals = self.controlnet(
|
| 339 |
+
latent_model_input,
|
| 340 |
+
timestep=t,
|
| 341 |
+
encoder_hidden_states=image_embeddings,
|
| 342 |
+
added_time_ids=added_time_ids,
|
| 343 |
+
control_cond=cond_em,
|
| 344 |
+
action_type=action_type,
|
| 345 |
+
conditioning_scale=control_condition_scale,
|
| 346 |
+
return_dict=False,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# predict the noise residual
|
| 350 |
+
noise_pred = self.unet(
|
| 351 |
+
sample=latent_model_input,
|
| 352 |
+
timestep=t,
|
| 353 |
+
encoder_hidden_states=image_embeddings,
|
| 354 |
+
added_time_ids=added_time_ids,
|
| 355 |
+
down_block_additional_residuals=down_block_additional_residuals,
|
| 356 |
+
mid_block_additional_residuals=mid_block_additional_residuals,
|
| 357 |
+
return_dict=False,
|
| 358 |
+
)[0]
|
| 359 |
+
|
| 360 |
+
# Predict unconditional noise
|
| 361 |
+
noise_pred_uncond = self.null_model(
|
| 362 |
+
latent_model_input,
|
| 363 |
+
t,
|
| 364 |
+
encoder_hidden_states=image_embeddings,
|
| 365 |
+
added_time_ids=added_time_ids,
|
| 366 |
+
return_dict=False,
|
| 367 |
+
)[0]
|
| 368 |
+
|
| 369 |
+
# perform guidance
|
| 370 |
+
if self.do_classifier_free_guidance:
|
| 371 |
+
_, noise_pred_cond = noise_pred.chunk(2) # NOTE: Currently discarding the unconditional noise prediction from the finetuned model
|
| 372 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 373 |
+
else:
|
| 374 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred - noise_pred_uncond)
|
| 375 |
+
|
| 376 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 377 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 378 |
+
# print("latents", latents.shape)
|
| 379 |
+
|
| 380 |
+
if callback_on_step_end is not None:
|
| 381 |
+
callback_kwargs = {}
|
| 382 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 383 |
+
callback_kwargs[k] = locals()[k]
|
| 384 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 385 |
+
|
| 386 |
+
latents = callback_outputs.pop("latents", latents)
|
| 387 |
+
|
| 388 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 389 |
+
progress_bar.update()
|
| 390 |
+
|
| 391 |
+
if not output_type == "latent":
|
| 392 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
| 393 |
+
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
| 394 |
+
else:
|
| 395 |
+
frames = latents
|
| 396 |
+
|
| 397 |
+
# cast back to fp16 if needed
|
| 398 |
+
if needs_upcasting:
|
| 399 |
+
self.vae.to(dtype=torch.float16)
|
| 400 |
+
|
| 401 |
+
self.maybe_free_model_hooks()
|
| 402 |
+
|
| 403 |
+
if not return_dict:
|
| 404 |
+
return frames
|
| 405 |
+
|
| 406 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
src/pipelines/pipeline_video_diffusion.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import StableVideoDiffusionPipeline as StableVideoDiffusionPipeline_original
|
| 2 |
+
import torch
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 5 |
+
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
| 6 |
+
from typing import Callable, Dict, List, Tuple, Optional, Union
|
| 7 |
+
import PIL.Image
|
| 8 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
| 9 |
+
tensor2vid,
|
| 10 |
+
StableVideoDiffusionPipelineOutput,
|
| 11 |
+
_append_dims,
|
| 12 |
+
EXAMPLE_DOC_STRING
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 16 |
+
|
| 17 |
+
class VideoDiffusionPipeline(StableVideoDiffusionPipeline_original):
|
| 18 |
+
|
| 19 |
+
def _encode_vae_condition(
|
| 20 |
+
self,
|
| 21 |
+
cond_image: torch.tensor,
|
| 22 |
+
device: Union[str, torch.device],
|
| 23 |
+
num_videos_per_prompt: int,
|
| 24 |
+
do_classifier_free_guidance: bool,
|
| 25 |
+
):
|
| 26 |
+
video_length = cond_image.shape[1]
|
| 27 |
+
cond_image = cond_image.to(device=device)
|
| 28 |
+
cond_image = cond_image.to(dtype=self.vae.dtype)
|
| 29 |
+
cond_image = rearrange(cond_image, "b f c h w -> (b f) c h w")
|
| 30 |
+
cond_em = self.vae.encode(cond_image).latent_dist.mode()
|
| 31 |
+
cond_em = rearrange(cond_em, "(b f) c h w -> b f c h w", f=video_length)
|
| 32 |
+
|
| 33 |
+
# duplicate cond_em for each generation per prompt, using mps friendly method
|
| 34 |
+
cond_em = cond_em.repeat(num_videos_per_prompt, 1, 1, 1, 1)
|
| 35 |
+
|
| 36 |
+
if do_classifier_free_guidance:
|
| 37 |
+
negative_cond_em = torch.zeros_like(cond_em)
|
| 38 |
+
|
| 39 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 40 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 41 |
+
# to avoid doing two forward passes
|
| 42 |
+
cond_em = torch.cat([negative_cond_em, cond_em])
|
| 43 |
+
|
| 44 |
+
return cond_em
|
| 45 |
+
|
| 46 |
+
def decode_latent_to_video(self, latents,
|
| 47 |
+
decode_chunk_size: Optional[int] = None,
|
| 48 |
+
num_frames: Optional[int] = None,
|
| 49 |
+
output_type: Optional[str] = "pil",):
|
| 50 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
| 51 |
+
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
| 52 |
+
|
| 53 |
+
@torch.no_grad()
|
| 54 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 55 |
+
def __call__(
|
| 56 |
+
self,
|
| 57 |
+
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
| 58 |
+
bbox_images: Optional[torch.FloatTensor] = None,
|
| 59 |
+
bbox_conditions: Optional[Dict[str, Union[torch.FloatTensor, List[Union[float, int]]]]] = None,
|
| 60 |
+
original_size: Optional[Tuple[int]] = (1242, 375),
|
| 61 |
+
height: int = 576,
|
| 62 |
+
width: int = 1024,
|
| 63 |
+
num_frames: Optional[int] = None,
|
| 64 |
+
num_inference_steps: int = 25,
|
| 65 |
+
min_guidance_scale: float = 1.0,
|
| 66 |
+
max_guidance_scale: float = 3.0,
|
| 67 |
+
fps: int = 7,
|
| 68 |
+
motion_bucket_id: int = 127,
|
| 69 |
+
noise_aug_strength: float = 0.02,
|
| 70 |
+
decode_chunk_size: Optional[int] = None,
|
| 71 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 72 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 73 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 74 |
+
output_type: Optional[str] = "pil",
|
| 75 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 76 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 77 |
+
return_dict: bool = True,
|
| 78 |
+
num_cond_bbox_frames: int=3,
|
| 79 |
+
):
|
| 80 |
+
r"""
|
| 81 |
+
The call function to the pipeline for generation.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
| 85 |
+
Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
|
| 86 |
+
1]`.
|
| 87 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 88 |
+
The height in pixels of the generated image.
|
| 89 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 90 |
+
The width in pixels of the generated image.
|
| 91 |
+
num_frames (`int`, *optional*):
|
| 92 |
+
The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
|
| 93 |
+
`stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
|
| 94 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
| 95 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
| 96 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 97 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
| 98 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
| 99 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
| 100 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
| 101 |
+
fps (`int`, *optional*, defaults to 7):
|
| 102 |
+
Frames per second. The rate at which the generated images shall be exported to a video after
|
| 103 |
+
generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
| 104 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
| 105 |
+
Used for conditioning the amount of motion for the generation. The higher the number the more motion
|
| 106 |
+
will be in the video.
|
| 107 |
+
noise_aug_strength (`float`, *optional*, defaults to 0.02):
|
| 108 |
+
The amount of noise added to the init image, the higher it is the less the video will look like the
|
| 109 |
+
init image. Increase it for more motion.
|
| 110 |
+
decode_chunk_size (`int`, *optional*):
|
| 111 |
+
The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
|
| 112 |
+
expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
|
| 113 |
+
For lower memory usage, reduce `decode_chunk_size`.
|
| 114 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 115 |
+
The number of videos to generate per prompt.
|
| 116 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 117 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 118 |
+
generation deterministic.
|
| 119 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 120 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
| 121 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 122 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 123 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 124 |
+
The output format of the generated image. Choose between `pil`, `np` or `pt`.
|
| 125 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 126 |
+
A function that is called at the end of each denoising step during inference. The function is called
|
| 127 |
+
with the following arguments:
|
| 128 |
+
`callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
|
| 129 |
+
`callback_kwargs` will include a list of all tensors as specified by
|
| 130 |
+
`callback_on_step_end_tensor_inputs`.
|
| 131 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 132 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 133 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 134 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 135 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 136 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 137 |
+
plain tuple.
|
| 138 |
+
|
| 139 |
+
Examples:
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
| 143 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
|
| 144 |
+
returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`)
|
| 145 |
+
is returned.
|
| 146 |
+
"""
|
| 147 |
+
# 0. Default height and width to unet
|
| 148 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 149 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 150 |
+
|
| 151 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
| 152 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
| 153 |
+
|
| 154 |
+
# 1. Check inputs. Raise error if not correct
|
| 155 |
+
self.check_inputs(image, height, width)
|
| 156 |
+
|
| 157 |
+
# 2. Define call parameters
|
| 158 |
+
if isinstance(image, PIL.Image.Image):
|
| 159 |
+
batch_size = 1
|
| 160 |
+
elif isinstance(image, list):
|
| 161 |
+
batch_size = len(image)
|
| 162 |
+
else:
|
| 163 |
+
batch_size = image.shape[0]
|
| 164 |
+
device = self._execution_device
|
| 165 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 166 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 167 |
+
# corresponds to doing no classifier free guidance.
|
| 168 |
+
self._guidance_scale = max_guidance_scale
|
| 169 |
+
|
| 170 |
+
# 3. Encode input image
|
| 171 |
+
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
| 172 |
+
|
| 173 |
+
# NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
|
| 174 |
+
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
|
| 175 |
+
fps = fps - 1
|
| 176 |
+
|
| 177 |
+
# 4. Encode input image using VAE
|
| 178 |
+
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
|
| 179 |
+
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
|
| 180 |
+
image = image + noise_aug_strength * noise
|
| 181 |
+
|
| 182 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 183 |
+
if needs_upcasting:
|
| 184 |
+
self.vae.to(dtype=torch.float32)
|
| 185 |
+
|
| 186 |
+
image_latents = self._encode_vae_image(
|
| 187 |
+
image,
|
| 188 |
+
device=device,
|
| 189 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 190 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 191 |
+
)
|
| 192 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 193 |
+
|
| 194 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
| 195 |
+
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
| 196 |
+
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
| 197 |
+
|
| 198 |
+
# 7b. Prepare control latent embeds
|
| 199 |
+
if not bbox_images is None:
|
| 200 |
+
cond_latents = self._encode_vae_condition(bbox_images,
|
| 201 |
+
device,
|
| 202 |
+
num_videos_per_prompt,
|
| 203 |
+
self.do_classifier_free_guidance)
|
| 204 |
+
image_latents[:,0:num_cond_bbox_frames,::] = cond_latents[:,0:num_cond_bbox_frames,::]
|
| 205 |
+
image_latents[:,-1,::]=cond_latents[:,-1,::]
|
| 206 |
+
|
| 207 |
+
# 5. Get Added Time IDs
|
| 208 |
+
added_time_ids = self._get_add_time_ids(
|
| 209 |
+
fps,
|
| 210 |
+
motion_bucket_id,
|
| 211 |
+
noise_aug_strength,
|
| 212 |
+
image_embeddings.dtype,
|
| 213 |
+
batch_size,
|
| 214 |
+
num_videos_per_prompt,
|
| 215 |
+
self.do_classifier_free_guidance,
|
| 216 |
+
)
|
| 217 |
+
added_time_ids = added_time_ids.to(device)
|
| 218 |
+
|
| 219 |
+
# 6. Prepare timesteps
|
| 220 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 221 |
+
timesteps = self.scheduler.timesteps
|
| 222 |
+
|
| 223 |
+
# 7. Prepare latent variables
|
| 224 |
+
num_channels_latents = self.unet.config.out_channels*2
|
| 225 |
+
latents = self.prepare_latents(
|
| 226 |
+
batch_size * num_videos_per_prompt,
|
| 227 |
+
num_frames,
|
| 228 |
+
num_channels_latents,
|
| 229 |
+
height,
|
| 230 |
+
width,
|
| 231 |
+
image_embeddings.dtype,
|
| 232 |
+
device,
|
| 233 |
+
generator,
|
| 234 |
+
latents,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# 8. Prepare guidance scale
|
| 238 |
+
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
| 239 |
+
guidance_scale = guidance_scale.to(device, latents.dtype)
|
| 240 |
+
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
| 241 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
| 242 |
+
|
| 243 |
+
self._guidance_scale = guidance_scale
|
| 244 |
+
|
| 245 |
+
# 9. Denoising loop
|
| 246 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 247 |
+
self._num_timesteps = len(timesteps)
|
| 248 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 249 |
+
for i, t in enumerate(timesteps):
|
| 250 |
+
# expand the latents if we are doing classifier free guidance
|
| 251 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 252 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 253 |
+
|
| 254 |
+
# Concatenate image_latents over channels dimension
|
| 255 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 256 |
+
|
| 257 |
+
# predict the noise residual
|
| 258 |
+
noise_pred = self.unet(
|
| 259 |
+
latent_model_input,
|
| 260 |
+
t,
|
| 261 |
+
encoder_hidden_states=image_embeddings,
|
| 262 |
+
added_time_ids=added_time_ids,
|
| 263 |
+
return_dict=False,
|
| 264 |
+
)[0]
|
| 265 |
+
|
| 266 |
+
# perform guidance
|
| 267 |
+
if self.do_classifier_free_guidance:
|
| 268 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 269 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 270 |
+
|
| 271 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 272 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 273 |
+
|
| 274 |
+
if callback_on_step_end is not None:
|
| 275 |
+
callback_kwargs = {}
|
| 276 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 277 |
+
callback_kwargs[k] = locals()[k]
|
| 278 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 279 |
+
|
| 280 |
+
latents = callback_outputs.pop("latents", latents)
|
| 281 |
+
|
| 282 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 283 |
+
progress_bar.update()
|
| 284 |
+
|
| 285 |
+
if not output_type == "latent":
|
| 286 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
| 287 |
+
frames = torch.clamp(frames, -1, 1)
|
| 288 |
+
# not sure why these codes were here
|
| 289 |
+
# for i in range(frames.shape[2]):
|
| 290 |
+
# frame = frames[:, :, i]
|
| 291 |
+
# if frame.min() > -0.9:
|
| 292 |
+
# frames[:,:,i] = torch.zeros_like(frame)
|
| 293 |
+
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
| 294 |
+
else:
|
| 295 |
+
frames = latents
|
| 296 |
+
|
| 297 |
+
# cast back to fp16 if needed
|
| 298 |
+
if needs_upcasting:
|
| 299 |
+
self.vae.to(dtype=torch.float16)
|
| 300 |
+
self.maybe_free_model_hooks()
|
| 301 |
+
|
| 302 |
+
if not return_dict:
|
| 303 |
+
return frames
|
| 304 |
+
|
| 305 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
src/preprocess/README.md
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Video Dataset Processing Tools
|
| 2 |
+
|
| 3 |
+
This directory contains tools for processing and filtering video datasets. There are two main types of tools:
|
| 4 |
+
|
| 5 |
+
1. Dataset Preprocessing Tools (`preprocess_*.py`)
|
| 6 |
+
2. Dataset Filtering Tool (`filter_dataset_tool.py`)
|
| 7 |
+
|
| 8 |
+
## Dataset Preprocessing Tools
|
| 9 |
+
|
| 10 |
+
These scripts process raw video datasets (DADA2000, CAP, and Russia Car Crash) by:
|
| 11 |
+
- Extracting frames at specified FPS
|
| 12 |
+
- Cropping frames to desired dimensions
|
| 13 |
+
- Generating object detection labels
|
| 14 |
+
- Creating train/val splits
|
| 15 |
+
|
| 16 |
+
### Usage
|
| 17 |
+
|
| 18 |
+
Basic usage with default settings:
|
| 19 |
+
```bash
|
| 20 |
+
# For DADA2000 dataset
|
| 21 |
+
python preprocess_dada_dataset.py
|
| 22 |
+
|
| 23 |
+
# For CAP dataset
|
| 24 |
+
python preprocess_cap_dataset.py
|
| 25 |
+
|
| 26 |
+
# For Russia Car Crash dataset
|
| 27 |
+
python preprocess_russia_dataset.py
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
Advanced usage with custom settings:
|
| 31 |
+
```bash
|
| 32 |
+
python preprocess_dada_dataset.py \
|
| 33 |
+
--dataset_root /path/to/datasets \
|
| 34 |
+
--dataset_dir /path/to/raw/dataset \
|
| 35 |
+
--out_directory /path/to/output \
|
| 36 |
+
--out_fps 15 \
|
| 37 |
+
--skip_extraction \
|
| 38 |
+
--skip_labels \
|
| 39 |
+
--skip_split
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
### Common Arguments
|
| 43 |
+
- `--dataset_root`: Root directory for datasets
|
| 44 |
+
- `--dataset_dir`: Directory containing the raw dataset
|
| 45 |
+
- `--out_directory`: Output directory (defaults to {dataset_root}/dataset_name)
|
| 46 |
+
- `--skip_extraction`: Skip frame extraction step
|
| 47 |
+
- `--skip_labels`: Skip label generation step
|
| 48 |
+
- `--skip_split`: Skip train/val split step
|
| 49 |
+
|
| 50 |
+
### Dataset-Specific Arguments
|
| 51 |
+
- DADA2000:
|
| 52 |
+
- `--out_fps`: Output frames per second (default: 12)
|
| 53 |
+
- CAP:
|
| 54 |
+
- `--reverse`: Process samples in reverse order
|
| 55 |
+
- Russia:
|
| 56 |
+
- `--process_train`: Process training set (default is validation set only)
|
| 57 |
+
|
| 58 |
+
## Dataset Filtering Tool
|
| 59 |
+
|
| 60 |
+
A tool for manually reviewing and filtering video datasets. It provides an interactive interface to review video frames and mark them as high quality or rejected. The tool can also automatically detect upscaled videos and scene changes to help with the filtering process.
|
| 61 |
+
|
| 62 |
+
### Features
|
| 63 |
+
- Interactive video frame review with keyboard controls
|
| 64 |
+
- Automatic detection of upscaled videos
|
| 65 |
+
- Scene change detection
|
| 66 |
+
- Caching support for faster processing
|
| 67 |
+
- Support for both single-category and multi-category datasets
|
| 68 |
+
|
| 69 |
+
### Usage
|
| 70 |
+
|
| 71 |
+
Basic usage with default settings:
|
| 72 |
+
```bash
|
| 73 |
+
python filter_dataset_tool.py --dataset_name my_dataset
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
Advanced usage with all features enabled:
|
| 77 |
+
```bash
|
| 78 |
+
python filter_dataset_tool.py \
|
| 79 |
+
--dataset_name my_dataset \
|
| 80 |
+
--start_idx 0 \
|
| 81 |
+
--data_dir ./custom/path/to/images \
|
| 82 |
+
--output_root ./custom/output/path \
|
| 83 |
+
--use_cache
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### Keyboard Controls
|
| 87 |
+
- `w`: Next frame
|
| 88 |
+
- `s`: Previous frame
|
| 89 |
+
- `d`: Next video
|
| 90 |
+
- `a`: Previous video
|
| 91 |
+
- `r`: Reject video
|
| 92 |
+
- `h`: Mark as high quality
|
| 93 |
+
- `p`: Increase playback speed
|
| 94 |
+
- `l`: Decrease playback speed
|
| 95 |
+
- `ESC`: Exit
|
| 96 |
+
|
| 97 |
+
### Command Line Arguments
|
| 98 |
+
- `--dataset_name`: Name of the dataset directory (required)
|
| 99 |
+
- `--start_idx`: Starting index for video review (default: 0)
|
| 100 |
+
- `--data_dir`: Custom data directory path (default: ./{dataset_name}/images)
|
| 101 |
+
- `--output_root`: Custom output root directory (default: ./{dataset_name})
|
| 102 |
+
- `--disable_sort_by_upsample`: Disable sorting by upsampling factor
|
| 103 |
+
- `--disable_check_scene_changes`: Disable scene change detection
|
| 104 |
+
- `--single_category`: Process videos from a single category directory
|
| 105 |
+
- `--use_cache`: Use cache to speed up processing
|
src/preprocess/filter_dataset_tool.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
from time import time
|
| 6 |
+
import glob
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import scenedetect as sd
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
# Load existing JSON data if available
|
| 12 |
+
def load_json(filename):
|
| 13 |
+
if os.path.exists(filename):
|
| 14 |
+
with open(filename, "r") as f:
|
| 15 |
+
return json.load(f)
|
| 16 |
+
return []
|
| 17 |
+
|
| 18 |
+
def save_json(filename, data):
|
| 19 |
+
with open(filename, "w") as f:
|
| 20 |
+
json.dump(data, f, indent=4)
|
| 21 |
+
|
| 22 |
+
def estimate_upsizing_factor(image_path):
|
| 23 |
+
"""Estimate how much an image was upsized before being resized to 720x1280"""
|
| 24 |
+
|
| 25 |
+
# Load image in grayscale
|
| 26 |
+
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
| 27 |
+
|
| 28 |
+
if img is None:
|
| 29 |
+
print(f"Error loading image: {image_path}")
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
# Compute the 2D Fourier Transform
|
| 33 |
+
f = np.fft.fft2(img)
|
| 34 |
+
fshift = np.fft.fftshift(f) # Center the low frequencies
|
| 35 |
+
magnitude_spectrum = np.abs(fshift)
|
| 36 |
+
|
| 37 |
+
# Compute high-frequency energy
|
| 38 |
+
h, w = img.shape
|
| 39 |
+
cx, cy = w // 2, h // 2 # Center of the image
|
| 40 |
+
radius = min(cx, cy) // 4 # Define a region for high frequencies
|
| 41 |
+
|
| 42 |
+
# Mask low frequencies (keep only high frequencies)
|
| 43 |
+
mask = np.zeros((h, w), np.uint8)
|
| 44 |
+
cv2.circle(mask, (cx, cy), radius, 1, thickness=-1)
|
| 45 |
+
high_freq_energy = np.sum(magnitude_spectrum * (1 - mask))
|
| 46 |
+
|
| 47 |
+
# Normalize energy by image size
|
| 48 |
+
energy_score = high_freq_energy / (h * w)
|
| 49 |
+
|
| 50 |
+
# Estimate how much the image was upscaled
|
| 51 |
+
upsize_factor = 1 / (1 + energy_score) # Inverse relation: lower energy → more upscaling
|
| 52 |
+
|
| 53 |
+
return upsize_factor
|
| 54 |
+
|
| 55 |
+
def check_upsample(video_paths, output_root, use_cache=True):
|
| 56 |
+
t = time()
|
| 57 |
+
|
| 58 |
+
if use_cache:
|
| 59 |
+
cache_file = f"{output_root}/upsample_scores.json"
|
| 60 |
+
cached_data = load_json(cache_file)
|
| 61 |
+
|
| 62 |
+
results = {}
|
| 63 |
+
num_frames = 5
|
| 64 |
+
for src_images in tqdm(video_paths, desc="Computing upscale"):
|
| 65 |
+
vid_name = src_images.split('/')[-1]
|
| 66 |
+
if use_cache and vid_name in cached_data:
|
| 67 |
+
results[src_images] = cached_data[vid_name]
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
all_images = sorted(glob.glob(f"{src_images}/*.jpg"))
|
| 71 |
+
|
| 72 |
+
if len(all_images) < 5:
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
frame_indices = np.linspace(0, len(all_images) - 1, num_frames).astype(int)
|
| 76 |
+
|
| 77 |
+
vid_scores = []
|
| 78 |
+
for frame_idx in frame_indices:
|
| 79 |
+
image_path = all_images[frame_idx]
|
| 80 |
+
|
| 81 |
+
upsize_factor = estimate_upsizing_factor(image_path)
|
| 82 |
+
# print(image_dir, upsize_factor)
|
| 83 |
+
vid_scores.append(upsize_factor)
|
| 84 |
+
|
| 85 |
+
results[src_images] = np.median(vid_scores).item()
|
| 86 |
+
|
| 87 |
+
sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
|
| 88 |
+
sorted_results = {k: v for k, v in sorted_results}
|
| 89 |
+
|
| 90 |
+
if use_cache:
|
| 91 |
+
sorted_vids_by_names = {k.split('/')[-1]: v for k, v in sorted_results.items()}
|
| 92 |
+
save_json(cache_file, sorted_vids_by_names)
|
| 93 |
+
|
| 94 |
+
# print(f"Done in {time()-t:.2f}s")
|
| 95 |
+
return sorted_results
|
| 96 |
+
|
| 97 |
+
def detect_scenes(image_folder, threshold=27.0):
|
| 98 |
+
"""Detects scene changes in a folder of images using PySceneDetect."""
|
| 99 |
+
image_files = [os.path.join(image_folder, f) for f in sorted(os.listdir(image_folder)) if f.lower().endswith(('.jpg', '.jpeg'))]
|
| 100 |
+
detector = sd.detectors.ContentDetector(threshold=threshold)
|
| 101 |
+
scene_list = []
|
| 102 |
+
prev_frame = None
|
| 103 |
+
frame_num = 0
|
| 104 |
+
|
| 105 |
+
for image_idx in range(0, len(image_files), 2): # Skip frames to go faster
|
| 106 |
+
image_file = image_files[image_idx]
|
| 107 |
+
frame = cv2.imread(image_file)
|
| 108 |
+
if frame is None:
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
frame_num += 1
|
| 112 |
+
if prev_frame is not None:
|
| 113 |
+
if detector.process_frame(frame_num, frame):
|
| 114 |
+
scene_list.append(frame_num)
|
| 115 |
+
|
| 116 |
+
prev_frame = frame
|
| 117 |
+
|
| 118 |
+
return scene_list
|
| 119 |
+
|
| 120 |
+
def scan_scene_changes(video_paths, output_root, use_cache=True):
|
| 121 |
+
|
| 122 |
+
if use_cache:
|
| 123 |
+
cache_file = f"{output_root}/scene_changes.json"
|
| 124 |
+
cached_data = load_json(cache_file)
|
| 125 |
+
|
| 126 |
+
all_scene_change_vids = []
|
| 127 |
+
scene_changes_by_vid_name = {}
|
| 128 |
+
for folder_path in tqdm(video_paths, desc="Detecting scene changes"):
|
| 129 |
+
|
| 130 |
+
vid_name = folder_path.split('/')[-1]
|
| 131 |
+
if use_cache and vid_name in cached_data:
|
| 132 |
+
scene_changes = cached_data[vid_name]
|
| 133 |
+
else:
|
| 134 |
+
scene_changes = detect_scenes(folder_path)
|
| 135 |
+
|
| 136 |
+
scene_changes_by_vid_name[vid_name] = scene_changes
|
| 137 |
+
|
| 138 |
+
if len(scene_changes) > 0:
|
| 139 |
+
# print(f"{folder_path.split('/')[-1]} scene changes:", scene_changes)
|
| 140 |
+
all_scene_change_vids.append(folder_path)
|
| 141 |
+
|
| 142 |
+
if use_cache:
|
| 143 |
+
save_json(cache_file, scene_changes_by_vid_name)
|
| 144 |
+
|
| 145 |
+
print("Scene change vids:", len(all_scene_change_vids))
|
| 146 |
+
return all_scene_change_vids
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def sort_tool(video_folders, rejected_file, highquality_file, start_idx=0):
|
| 150 |
+
|
| 151 |
+
rejected_videos = load_json(rejected_file)
|
| 152 |
+
highquality_videos = load_json(highquality_file)
|
| 153 |
+
|
| 154 |
+
rejected_videos_count = 0
|
| 155 |
+
for video_path in video_folders:
|
| 156 |
+
video_name = video_path.split("/")[-1]
|
| 157 |
+
if video_name in rejected_videos:
|
| 158 |
+
rejected_videos_count += 1
|
| 159 |
+
print(f"{rejected_videos_count}/{len(video_folders)} videos already rejected in this set")
|
| 160 |
+
|
| 161 |
+
video_idx = start_idx
|
| 162 |
+
frame_idx = 0
|
| 163 |
+
fps = 12
|
| 164 |
+
last_action_next = True
|
| 165 |
+
|
| 166 |
+
while True:
|
| 167 |
+
video_path = video_folders[video_idx]
|
| 168 |
+
video_name = video_path.split("/")[-1]
|
| 169 |
+
image_files = sorted([f for f in os.listdir(video_path) if f.endswith(".jpg")])
|
| 170 |
+
|
| 171 |
+
if not image_files:
|
| 172 |
+
print(f"No images found in {video_name}")
|
| 173 |
+
if last_action_next:
|
| 174 |
+
video_idx = (video_idx + 1) % len(video_folders)
|
| 175 |
+
else:
|
| 176 |
+
video_idx = (video_idx - 1) % len(video_folders)
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
if video_name in rejected_videos or video_name in highquality_videos:
|
| 180 |
+
print(f"{video_name} already filtered")
|
| 181 |
+
if last_action_next:
|
| 182 |
+
video_idx = (video_idx + 1) % len(video_folders)
|
| 183 |
+
else:
|
| 184 |
+
video_idx = (video_idx - 1) % len(video_folders)
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
frame_idx = 0
|
| 188 |
+
playing = True
|
| 189 |
+
paused = False
|
| 190 |
+
|
| 191 |
+
while playing:
|
| 192 |
+
frame_path = os.path.join(video_path, image_files[frame_idx])
|
| 193 |
+
frame = cv2.imread(frame_path)
|
| 194 |
+
|
| 195 |
+
if frame is None:
|
| 196 |
+
print(f"Failed to load {frame_path}")
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
display_text = f"Video: {video_name} ({video_idx}/{len(video_folders)}) | Frame: {frame_idx + 1}/{len(image_files)} | fps: {fps}"
|
| 200 |
+
cv2.putText(frame, display_text, (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
| 201 |
+
cv2.imshow("Video Reviewer", frame)
|
| 202 |
+
|
| 203 |
+
key = cv2.waitKey(int(1000 / fps)) # 12 FPS
|
| 204 |
+
|
| 205 |
+
if key == ord('w'): # Next frame
|
| 206 |
+
frame_idx = min(len(image_files)-1, frame_idx + 1)
|
| 207 |
+
paused = True
|
| 208 |
+
elif key == ord('s'): # Previous frame
|
| 209 |
+
frame_idx = max(0, frame_idx - 1)
|
| 210 |
+
paused = True
|
| 211 |
+
elif key == ord('d'): # Next video
|
| 212 |
+
video_idx = (video_idx + 1) % len(video_folders)
|
| 213 |
+
last_action_next = True
|
| 214 |
+
break
|
| 215 |
+
elif key == ord('a'): # Previous video
|
| 216 |
+
video_idx = (video_idx - 1) % len(video_folders)
|
| 217 |
+
last_action_next = False
|
| 218 |
+
break
|
| 219 |
+
elif key == ord('r'): # Reject video
|
| 220 |
+
if video_name not in rejected_videos:
|
| 221 |
+
rejected_videos.append(video_name)
|
| 222 |
+
save_json(rejected_file, rejected_videos)
|
| 223 |
+
print(f"Rejected: {video_name}")
|
| 224 |
+
video_idx = (video_idx + 1) % len(video_folders)
|
| 225 |
+
break
|
| 226 |
+
elif key == ord('h'): # Mark as high quality
|
| 227 |
+
if video_name not in highquality_videos:
|
| 228 |
+
highquality_videos.append(video_name)
|
| 229 |
+
save_json(highquality_file, highquality_videos)
|
| 230 |
+
print(f"High Quality: {video_name}")
|
| 231 |
+
video_idx = (video_idx + 1) % len(video_folders)
|
| 232 |
+
break
|
| 233 |
+
elif key == ord('p'): # Increase fps
|
| 234 |
+
fps += 1
|
| 235 |
+
elif key == ord('l'): # Lower fps
|
| 236 |
+
fps = max(1, fps - 1)
|
| 237 |
+
elif key == 27: # ESC to exit
|
| 238 |
+
playing = False
|
| 239 |
+
break
|
| 240 |
+
|
| 241 |
+
if not paused:
|
| 242 |
+
frame_idx = (frame_idx + 1) % len(image_files)
|
| 243 |
+
|
| 244 |
+
if key == 27:
|
| 245 |
+
print(f"Last video: {video_name} ({video_idx})")
|
| 246 |
+
break
|
| 247 |
+
|
| 248 |
+
cv2.destroyAllWindows()
|
| 249 |
+
|
| 250 |
+
def collect_all_videos(data_dir, single_category=False):
|
| 251 |
+
all_video_paths = []
|
| 252 |
+
if single_category:
|
| 253 |
+
all_video_paths = sorted(glob.glob(f"{data_dir}/*"))
|
| 254 |
+
else:
|
| 255 |
+
for category in sorted(os.listdir(data_dir)):
|
| 256 |
+
category_path = os.path.join(data_dir, category)
|
| 257 |
+
if os.path.isdir(category_path):
|
| 258 |
+
all_video_paths.extend(sorted(glob.glob(f"{category_path}/*")))
|
| 259 |
+
return all_video_paths
|
| 260 |
+
|
| 261 |
+
if __name__ == "__main__":
|
| 262 |
+
parser = argparse.ArgumentParser(description='Filter and sort video dataset')
|
| 263 |
+
parser.add_argument('--dataset_name', type=str, required=True,
|
| 264 |
+
help='Name of the dataset directory')
|
| 265 |
+
parser.add_argument('--start_idx', type=int, default=0,
|
| 266 |
+
help='Starting index for video review')
|
| 267 |
+
parser.add_argument('--data_dir', type=str, default=None,
|
| 268 |
+
help='Custom data directory path (defaults to ./{dataset_name}/images)')
|
| 269 |
+
parser.add_argument('--output_root', type=str, default=None,
|
| 270 |
+
help='Custom output root directory (defaults to ./{dataset_name})')
|
| 271 |
+
parser.add_argument('--disable_sort_by_upsample', action='store_true',
|
| 272 |
+
help='Disable sorting videos by upsampling factor')
|
| 273 |
+
parser.add_argument('--disable_check_scene_changes', action='store_true',
|
| 274 |
+
help='Disable checking for scene changes in videos')
|
| 275 |
+
parser.add_argument('--single_category', action='store_true',
|
| 276 |
+
help='Process videos from a single category directory')
|
| 277 |
+
parser.add_argument('--use_cache', action='store_true',
|
| 278 |
+
help='Use cache to speed up processing')
|
| 279 |
+
args = parser.parse_args()
|
| 280 |
+
|
| 281 |
+
# Set default paths if not specified
|
| 282 |
+
if args.data_dir is None:
|
| 283 |
+
args.data_dir = f"./{args.dataset_name}/images"
|
| 284 |
+
if args.output_root is None:
|
| 285 |
+
args.output_root = f"./{args.dataset_name}"
|
| 286 |
+
|
| 287 |
+
# Output JSON files
|
| 288 |
+
rejected_file = f"{args.output_root}/rejected.json"
|
| 289 |
+
auto_low_quality = f"{args.output_root}/auto_low_quality.json"
|
| 290 |
+
highquality_file = f"{args.output_root}/highquality.json"
|
| 291 |
+
|
| 292 |
+
all_video_paths = collect_all_videos(args.data_dir, args.single_category)
|
| 293 |
+
|
| 294 |
+
if not args.disable_sort_by_upsample:
|
| 295 |
+
sorted_vids = check_upsample(all_video_paths, args.output_root, use_cache=args.use_cache)
|
| 296 |
+
video_folders = list(sorted_vids.keys())
|
| 297 |
+
|
| 298 |
+
# Save the worst to file
|
| 299 |
+
# auto_reject_vids = [v.split('/')[-1] for v in video_folders[:2000]]
|
| 300 |
+
# save_json(auto_low_quality, auto_reject_vids)
|
| 301 |
+
else:
|
| 302 |
+
video_folders = all_video_paths
|
| 303 |
+
|
| 304 |
+
# Prepend scene change samples
|
| 305 |
+
if not args.disable_check_scene_changes:
|
| 306 |
+
new_video_folders = []
|
| 307 |
+
scene_change_vids = scan_scene_changes(all_video_paths, args.output_root, use_cache=args.use_cache)
|
| 308 |
+
new_video_folders.extend(scene_change_vids)
|
| 309 |
+
for vid_name in video_folders:
|
| 310 |
+
if vid_name not in new_video_folders:
|
| 311 |
+
new_video_folders.append(vid_name)
|
| 312 |
+
video_folders = new_video_folders
|
| 313 |
+
|
| 314 |
+
# Start tool
|
| 315 |
+
sort_tool(video_folders, rejected_file, highquality_file, start_idx=args.start_idx)
|
src/preprocess/preprocess_cap_dataset.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import json
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from glob import glob
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
from yolo_sam import YoloSamProcessor
|
| 9 |
+
|
| 10 |
+
def load_json(filename):
|
| 11 |
+
if os.path.exists(filename):
|
| 12 |
+
with open(filename, "r") as f:
|
| 13 |
+
return json.load(f)
|
| 14 |
+
print(filename, "not found")
|
| 15 |
+
return []
|
| 16 |
+
|
| 17 |
+
def create_video(sample, video_path):
|
| 18 |
+
video_filename = f"{video_path}.mp4"
|
| 19 |
+
FPS = 12
|
| 20 |
+
frame_size = (1056, 660)#(512, 320)
|
| 21 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 22 |
+
video_writer_out = cv2.VideoWriter(video_filename, fourcc, FPS, frame_size)
|
| 23 |
+
|
| 24 |
+
for img in sample:
|
| 25 |
+
video_writer_out.write(img)
|
| 26 |
+
|
| 27 |
+
video_writer_out.release()
|
| 28 |
+
print(f"Video saved: {video_filename}")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def crop_images(images_dir_path, output_dir_path, crop_extents=None):
|
| 32 |
+
"""
|
| 33 |
+
Crop frames
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
all_images = sorted(glob(f"{images_dir_path}/*.jpg"))
|
| 37 |
+
total_frames = len(all_images)
|
| 38 |
+
sample_image = cv2.imread(all_images[0])
|
| 39 |
+
sample_name = str(int(images_dir_path.split('/')[-2])).zfill(5)
|
| 40 |
+
sample_category = images_dir_path.split('/')[-3]
|
| 41 |
+
out_vid_name = f"{sample_category}_{sample_name}"
|
| 42 |
+
src_height, src_width = sample_image.shape[:2]
|
| 43 |
+
|
| 44 |
+
if crop_extents:
|
| 45 |
+
src_height, src_width = (src_height + crop_extents[1]) - crop_extents[0], (src_width + crop_extents[3]) - crop_extents[2]
|
| 46 |
+
|
| 47 |
+
# print(f"Source images '{out_vid_name}': {src_height}x{src_width}")
|
| 48 |
+
|
| 49 |
+
image_output_folder = os.path.join(output_dir_path, out_vid_name)
|
| 50 |
+
os.makedirs(image_output_folder, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
out_frame_count = 0
|
| 53 |
+
# sample_test = []
|
| 54 |
+
for frame_idx in range(total_frames):
|
| 55 |
+
|
| 56 |
+
frame_path = all_images[frame_idx]
|
| 57 |
+
frame = cv2.imread(frame_path)
|
| 58 |
+
|
| 59 |
+
# Crop frame
|
| 60 |
+
if crop_extents:
|
| 61 |
+
frame = frame[crop_extents[0]:crop_extents[1], crop_extents[2]:crop_extents[3]]
|
| 62 |
+
|
| 63 |
+
# Save frame
|
| 64 |
+
out_image_name = f"{out_vid_name}_{str(frame_idx).zfill(4)}.jpg"
|
| 65 |
+
out_image_path = os.path.join(image_output_folder, out_image_name)
|
| 66 |
+
cv2.imwrite(out_image_path, frame)
|
| 67 |
+
|
| 68 |
+
# sample_test.append(frame)
|
| 69 |
+
|
| 70 |
+
out_frame_count += 1
|
| 71 |
+
|
| 72 |
+
print(f"Done '{out_vid_name}': {src_height}x{src_width}, {out_frame_count} frames")
|
| 73 |
+
|
| 74 |
+
# create_video(sample_test, "path/to/sample_test_vid", fps=6)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def extract_frames(dataset_dir, out_directory, crop_extents=None, specific_videos=None):
|
| 78 |
+
|
| 79 |
+
# NOTE: We are excluding all crashes that involve visible humans (pedestrians, cyclists, motorbikes...)
|
| 80 |
+
video_types_to_exclude = [1, 2, 3, 4, 5, 6, 37, 38, 45, 46, 47, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60]
|
| 81 |
+
|
| 82 |
+
for category_dir in sorted(os.listdir(dataset_dir)):
|
| 83 |
+
category_dir_path = os.path.join(dataset_dir, category_dir)
|
| 84 |
+
if not os.path.isdir(category_dir_path):
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
# Let's filter the videos we want right away
|
| 88 |
+
vid_type = int(category_dir)
|
| 89 |
+
if int(vid_type) in video_types_to_exclude:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
for vid_name in tqdm(sorted(os.listdir(category_dir_path))):
|
| 93 |
+
if specific_videos is not None and vid_name not in specific_videos:
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
images_dir_path = os.path.join(category_dir_path, vid_name, "images")
|
| 97 |
+
out_path = os.path.join(out_directory, "images", category_dir)
|
| 98 |
+
crop_images(images_dir_path, out_path, crop_extents=crop_extents)
|
| 99 |
+
|
| 100 |
+
print("Extraction complete.")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def generate_labels(out_directory, vid_names=None, subdir="", in_directory=None, reverse_order=False):
|
| 104 |
+
label_output_folder = os.path.join(out_directory, "labels", subdir)
|
| 105 |
+
os.makedirs(label_output_folder, exist_ok=True)
|
| 106 |
+
|
| 107 |
+
# Checkpoint paths
|
| 108 |
+
yolo_ckpt = "yolov8x.pt" # Will auto download with utltralytics
|
| 109 |
+
sam2_ckpt = "/network/scratch/x/xuolga/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
|
| 110 |
+
sam2_cfg = "./configs/sam2.1/sam2.1_hiera_b+.yaml"
|
| 111 |
+
yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
|
| 112 |
+
|
| 113 |
+
samples_run = 0
|
| 114 |
+
|
| 115 |
+
src_directory = in_directory if in_directory is not None else out_directory
|
| 116 |
+
video_dir_root = os.path.join(src_directory, "images", subdir)
|
| 117 |
+
for category in sorted(os.listdir(video_dir_root), reverse=reverse_order):
|
| 118 |
+
category_root = os.path.join(video_dir_root, category)
|
| 119 |
+
for video_name in tqdm(sorted(os.listdir(category_root), reverse=reverse_order)):
|
| 120 |
+
if vid_names is not None and video_name not in vid_names:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
video_dir = os.path.join(category_root, video_name)
|
| 124 |
+
if len(os.listdir(video_dir)) == 0:
|
| 125 |
+
print("Empty video dir:", video_dir)
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
# Skip if label file already exists
|
| 129 |
+
out_label_path = os.path.join(label_output_folder, f"{video_name}.json")
|
| 130 |
+
if os.path.exists(out_label_path):
|
| 131 |
+
print(f"Skipping {video_name} - label file already exists")
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
if len(os.listdir(video_dir)) > 300:
|
| 135 |
+
print(f"SKIPPING LONG VIDEO {video_name}")
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
print(f"Computing bboxes for {video_name}...")
|
| 139 |
+
video_data = yolo_sam(video_dir, rel_bbox=True)
|
| 140 |
+
|
| 141 |
+
# Add metadata
|
| 142 |
+
vid_type = int(video_name.split('_')[0])
|
| 143 |
+
ego_involved = vid_type < 19 or vid_type == 61
|
| 144 |
+
final_out_data = {
|
| 145 |
+
"video_source": f"{video_name}.mp4",
|
| 146 |
+
"metadata": {
|
| 147 |
+
"ego_involved": ego_involved,
|
| 148 |
+
"accident_type": vid_type
|
| 149 |
+
},
|
| 150 |
+
"data": video_data
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
with open(out_label_path, 'w') as json_file:
|
| 154 |
+
json.dump(final_out_data, json_file, indent=1)
|
| 155 |
+
|
| 156 |
+
print("Saved label:", out_label_path)
|
| 157 |
+
|
| 158 |
+
samples_run += 1
|
| 159 |
+
if samples_run > 50:
|
| 160 |
+
print("Resetting Yolo_Sam in case of memory leak")
|
| 161 |
+
del yolo_sam
|
| 162 |
+
yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
|
| 163 |
+
samples_run = 0
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def make_train_val_split(out_directory):
|
| 167 |
+
image_folder = os.path.join(out_directory, "images")
|
| 168 |
+
label_folder = os.path.join(out_directory, "labels")
|
| 169 |
+
|
| 170 |
+
all_image_folders = os.listdir(image_folder)
|
| 171 |
+
split_idx = int(len(all_image_folders) * 0.9)
|
| 172 |
+
|
| 173 |
+
train_split = all_image_folders[:split_idx]
|
| 174 |
+
val_split = all_image_folders[split_idx:]
|
| 175 |
+
|
| 176 |
+
os.makedirs(os.path.join(image_folder, "train"), exist_ok=True)
|
| 177 |
+
os.makedirs(os.path.join(image_folder, "val"), exist_ok=True)
|
| 178 |
+
os.makedirs(os.path.join(label_folder, "train"), exist_ok=True)
|
| 179 |
+
os.makedirs(os.path.join(label_folder, "val"), exist_ok=True)
|
| 180 |
+
|
| 181 |
+
for filename in train_split:
|
| 182 |
+
os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "train", filename))
|
| 183 |
+
os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "train", f"{filename}.json"))
|
| 184 |
+
|
| 185 |
+
for filename in val_split:
|
| 186 |
+
os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "val", filename))
|
| 187 |
+
os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "val", f"{filename}.json"))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__ == "__main__":
|
| 191 |
+
parser = argparse.ArgumentParser(description='Process CAP dataset')
|
| 192 |
+
parser.add_argument('--dataset_root', type=str, required=True,
|
| 193 |
+
help='Root directory for datasets')
|
| 194 |
+
parser.add_argument('--dataset_dir', type=str, default="/network/scratch/l/luis.lara/dev/MM-AU/CAP-DATA",
|
| 195 |
+
help='Directory containing the CAP dataset')
|
| 196 |
+
parser.add_argument('--out_directory', type=str, default=None,
|
| 197 |
+
help='Output directory (defaults to {dataset_root}/cap_images_12fps)')
|
| 198 |
+
parser.add_argument('--reverse', action='store_true',
|
| 199 |
+
help='Process samples in reverse order')
|
| 200 |
+
parser.add_argument('--skip_extraction', action='store_true',
|
| 201 |
+
help='Skip frame extraction step')
|
| 202 |
+
parser.add_argument('--skip_labels', action='store_true',
|
| 203 |
+
help='Skip label generation step')
|
| 204 |
+
parser.add_argument('--skip_split', action='store_true',
|
| 205 |
+
help='Skip train/val split step')
|
| 206 |
+
args = parser.parse_args()
|
| 207 |
+
|
| 208 |
+
# Set default output directory if not specified
|
| 209 |
+
if args.out_directory is None:
|
| 210 |
+
args.out_directory = os.path.join(args.dataset_root, "cap_images_12fps")
|
| 211 |
+
|
| 212 |
+
# Extract frames from videos
|
| 213 |
+
if not args.skip_extraction:
|
| 214 |
+
cap_crop_extents = [40, -40, 128, -128] # Custom crop for CAP dataset (get ratio right)
|
| 215 |
+
extract_frames(args.dataset_dir, args.out_directory, crop_extents=cap_crop_extents, specific_videos=None)
|
| 216 |
+
|
| 217 |
+
# Create labels (run bbox detector)
|
| 218 |
+
if not args.skip_labels:
|
| 219 |
+
in_directory = os.path.join(args.dataset_root, "cap_images_12fps")
|
| 220 |
+
generate_labels(args.out_directory, vid_names=None, reverse_order=args.reverse)
|
| 221 |
+
|
| 222 |
+
# Split into train and val sets
|
| 223 |
+
if not args.skip_split:
|
| 224 |
+
make_train_val_split(args.out_directory)
|
src/preprocess/preprocess_dada_dataset.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import json
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
from yolo_sam import YoloSamProcessor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_video(sample, video_path):
|
| 11 |
+
video_filename = f"{video_path}.mp4"
|
| 12 |
+
FPS = 12
|
| 13 |
+
frame_size = (1056, 660)#(512, 320)
|
| 14 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 15 |
+
video_writer_out = cv2.VideoWriter(video_filename, fourcc, FPS, frame_size)
|
| 16 |
+
|
| 17 |
+
for img in sample:
|
| 18 |
+
video_writer_out.write(img)
|
| 19 |
+
|
| 20 |
+
video_writer_out.release()
|
| 21 |
+
print(f"Video saved: {video_filename}")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def downsample_and_crop_vid(video_path, output_dir, out_fps=12, crop_extents=None):
|
| 25 |
+
"""
|
| 26 |
+
Downsample fps and crop frames
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
# Load video
|
| 30 |
+
cap = cv2.VideoCapture(video_path)
|
| 31 |
+
org_fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 32 |
+
src_width, src_height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 33 |
+
original_sample_name = video_path.split('/')[-1].split('.')[0]
|
| 34 |
+
category = original_sample_name.split("_")[0]
|
| 35 |
+
vid_num = '90'+str(int(original_sample_name.split("_")[-1])).zfill(3) # NOTE: We prepend a '90' to differentiate between DADA and CAP samples
|
| 36 |
+
sample_name = f"{category}_{vid_num}"
|
| 37 |
+
if crop_extents:
|
| 38 |
+
src_width, src_height = (src_width + crop_extents[3]) - crop_extents[2], (src_height + crop_extents[1]) - crop_extents[0]
|
| 39 |
+
|
| 40 |
+
print(f"Source video '{sample_name}': {src_width}x{src_height}, fps={org_fps}")
|
| 41 |
+
|
| 42 |
+
total_frames = 0
|
| 43 |
+
target_period = 1/out_fps
|
| 44 |
+
last_frame_time = target_period
|
| 45 |
+
out_frame_count = 0
|
| 46 |
+
|
| 47 |
+
image_output_folder = os.path.join(output_dir, "images", category, sample_name)
|
| 48 |
+
os.makedirs(image_output_folder, exist_ok=True)
|
| 49 |
+
|
| 50 |
+
# sample_test = []
|
| 51 |
+
while cap.isOpened():
|
| 52 |
+
success, frame = cap.read()
|
| 53 |
+
|
| 54 |
+
if not success:
|
| 55 |
+
break
|
| 56 |
+
|
| 57 |
+
# Extract frames according to desired fps
|
| 58 |
+
if last_frame_time >= target_period:
|
| 59 |
+
out_frame_count += 1
|
| 60 |
+
last_frame_time = (last_frame_time - target_period)
|
| 61 |
+
|
| 62 |
+
# Crop frame
|
| 63 |
+
if crop_extents:
|
| 64 |
+
frame = frame[:, crop_extents[2] : crop_extents[3]]
|
| 65 |
+
|
| 66 |
+
# Save frame
|
| 67 |
+
out_image_name = f"{sample_name}_{str(total_frames).zfill(4)}.jpg"
|
| 68 |
+
out_image_path = os.path.join(image_output_folder, out_image_name)
|
| 69 |
+
cv2.imwrite(out_image_path, frame)
|
| 70 |
+
|
| 71 |
+
# sample_test.append(frame)
|
| 72 |
+
|
| 73 |
+
total_frames += 1
|
| 74 |
+
last_frame_time += 1/org_fps
|
| 75 |
+
|
| 76 |
+
print(f"Done '{sample_name}': {out_frame_count} frames, fps: {out_frame_count / (total_frames*1/org_fps)}")
|
| 77 |
+
cap.release()
|
| 78 |
+
# create_video(sample_test, "/path/to/sample_test_vid")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def extract_frames(dataset_dir, out_directory, crop_extents=None, out_fps=12):
|
| 82 |
+
dataset_video_dir = os.path.join(dataset_dir)
|
| 83 |
+
|
| 84 |
+
# NOTE: We are excluding all crashes that involve visible humans (pedestrians, cyclists, motorbikes...)
|
| 85 |
+
video_types_to_exclude = [1, 2, 3, 4, 5, 6, 37, 38, 45, 46, 47, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60]
|
| 86 |
+
|
| 87 |
+
for filename in tqdm(os.listdir(dataset_video_dir)):
|
| 88 |
+
if filename.split('.')[-1] != "mp4":
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
# Let's filter the videos we want right away
|
| 92 |
+
vid_type = filename.split('_')[0]
|
| 93 |
+
if int(vid_type) in video_types_to_exclude:
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
video_path = os.path.join(dataset_video_dir, filename)
|
| 97 |
+
downsample_and_crop_vid(video_path, out_directory, out_fps=out_fps, crop_extents=crop_extents)
|
| 98 |
+
print("Extraction complete.")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def generate_labels(out_directory, vid_names=None, subdir=""):
|
| 102 |
+
label_output_folder = os.path.join(out_directory, "labels", subdir)
|
| 103 |
+
os.makedirs(label_output_folder, exist_ok=True)
|
| 104 |
+
|
| 105 |
+
# Checkpoint paths
|
| 106 |
+
yolo_ckpt = "yolov8x.pt" # Will auto download with utltralytics
|
| 107 |
+
sam2_ckpt = "/network/scratch/x/xuolga/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
|
| 108 |
+
sam2_cfg = "./configs/sam2.1/sam2.1_hiera_b+.yaml"
|
| 109 |
+
yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
|
| 110 |
+
|
| 111 |
+
samples_run = 0
|
| 112 |
+
|
| 113 |
+
video_dir_root = os.path.join(out_directory, "images", subdir)
|
| 114 |
+
for category in sorted(os.listdir(video_dir_root), reverse=True):
|
| 115 |
+
category_root = os.path.join(video_dir_root, category)
|
| 116 |
+
for video_name in tqdm(os.listdir(category_root)):
|
| 117 |
+
if vid_names is not None and video_name not in vid_names:
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
video_dir = os.path.join(category_root, video_name)
|
| 121 |
+
if len(os.listdir(video_dir)) == 0:
|
| 122 |
+
print("Empty video dir:", video_dir)
|
| 123 |
+
continue
|
| 124 |
+
|
| 125 |
+
if len(os.listdir(video_dir)) > 300:
|
| 126 |
+
print(f"SKIPPING LONG VIDEO {video_name}")
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
# Skip if label file already exists
|
| 130 |
+
out_label_path = os.path.join(label_output_folder, f"{video_name}.json")
|
| 131 |
+
if os.path.exists(out_label_path):
|
| 132 |
+
print(f"Skipping {video_name} - label file already exists")
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
video_data = yolo_sam(video_dir, rel_bbox=True)
|
| 136 |
+
|
| 137 |
+
if video_data is None:
|
| 138 |
+
print("COMPUTED VIDEO DATA IS NULL for video:", video_dir)
|
| 139 |
+
|
| 140 |
+
# Add metadata
|
| 141 |
+
vid_type = int(video_name.split('_')[0])
|
| 142 |
+
ego_involved = vid_type < 19 or vid_type == 61
|
| 143 |
+
final_out_data = {
|
| 144 |
+
"video_source": f"{video_name}.mp4",
|
| 145 |
+
"metadata": {
|
| 146 |
+
"ego_involved": ego_involved,
|
| 147 |
+
"accident_type": vid_type
|
| 148 |
+
},
|
| 149 |
+
"data": video_data
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
with open(out_label_path, 'w') as json_file:
|
| 153 |
+
json.dump(final_out_data, json_file, indent=1)
|
| 154 |
+
|
| 155 |
+
print("Saved label:", out_label_path)
|
| 156 |
+
|
| 157 |
+
samples_run += 1
|
| 158 |
+
if samples_run > 50:
|
| 159 |
+
print("Resetting Yolo_Sam in case of memory leak")
|
| 160 |
+
del yolo_sam
|
| 161 |
+
yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
|
| 162 |
+
samples_run = 0
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def make_train_val_split(out_directory):
|
| 166 |
+
image_folder = os.path.join(out_directory, "images")
|
| 167 |
+
label_folder = os.path.join(out_directory, "labels")
|
| 168 |
+
|
| 169 |
+
all_image_folders = os.listdir(image_folder)
|
| 170 |
+
split_idx = int(len(all_image_folders) * 0.9)
|
| 171 |
+
|
| 172 |
+
train_split = all_image_folders[:split_idx]
|
| 173 |
+
val_split = all_image_folders[split_idx:]
|
| 174 |
+
|
| 175 |
+
os.makedirs(os.path.join(image_folder, "train"), exist_ok=True)
|
| 176 |
+
os.makedirs(os.path.join(image_folder, "val"), exist_ok=True)
|
| 177 |
+
os.makedirs(os.path.join(label_folder, "train"), exist_ok=True)
|
| 178 |
+
os.makedirs(os.path.join(label_folder, "val"), exist_ok=True)
|
| 179 |
+
|
| 180 |
+
for filename in train_split:
|
| 181 |
+
os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "train", filename))
|
| 182 |
+
os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "train", f"{filename}.json"))
|
| 183 |
+
|
| 184 |
+
for filename in val_split:
|
| 185 |
+
os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "val", filename))
|
| 186 |
+
os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "val", f"{filename}.json"))
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
parser = argparse.ArgumentParser(description='Process DADA2000 dataset')
|
| 191 |
+
parser.add_argument('--dataset_root', type=str, required=True,
|
| 192 |
+
help='Root directory for datasets')
|
| 193 |
+
parser.add_argument('--dataset_dir', type=str, required=True,
|
| 194 |
+
help='Directory containing the DADA2000 dataset')
|
| 195 |
+
parser.add_argument('--out_directory', type=str, default=None,
|
| 196 |
+
help='Output directory (defaults to {dataset_root}/dada2000_images_12fps)')
|
| 197 |
+
parser.add_argument('--skip_extraction', action='store_true',
|
| 198 |
+
help='Skip frame extraction step')
|
| 199 |
+
parser.add_argument('--skip_labels', action='store_true',
|
| 200 |
+
help='Skip label generation step')
|
| 201 |
+
parser.add_argument('--skip_split', action='store_true',
|
| 202 |
+
help='Skip train/val split step')
|
| 203 |
+
parser.add_argument('--out_fps', type=int, default=12,
|
| 204 |
+
help='Output frames per second (default: 12)')
|
| 205 |
+
args = parser.parse_args()
|
| 206 |
+
|
| 207 |
+
# Set default output directory if not specified
|
| 208 |
+
if args.out_directory is None:
|
| 209 |
+
args.out_directory = os.path.join(args.dataset_root, "dada2000_images_12fps")
|
| 210 |
+
|
| 211 |
+
# Extract frames from videos
|
| 212 |
+
if not args.skip_extraction:
|
| 213 |
+
dada_crop_extents = [0, -0, 264, -264] # Custom crop for DADA2000 dataset (get ratio right)
|
| 214 |
+
extract_frames(args.dataset_dir, args.out_directory, crop_extents=dada_crop_extents, out_fps=args.out_fps)
|
| 215 |
+
|
| 216 |
+
# Create labels (run bbox detector)
|
| 217 |
+
if not args.skip_labels:
|
| 218 |
+
generate_labels(args.out_directory, vid_names=None)
|
| 219 |
+
|
| 220 |
+
# Split into train and val sets
|
| 221 |
+
if not args.skip_split:
|
| 222 |
+
make_train_val_split(args.out_directory)
|
src/preprocess/preprocess_russia_dataset.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import json
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
from yolo_sam import YoloSamProcessor
|
| 8 |
+
|
| 9 |
+
def downsample_and_crop_vid(video_path, output_dir, out_fps=7, crop_extents=None):
|
| 10 |
+
"""
|
| 11 |
+
Downsample fps and crop frames
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
# Load video
|
| 15 |
+
cap = cv2.VideoCapture(video_path)
|
| 16 |
+
org_fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 17 |
+
src_width, src_height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 18 |
+
sample_name = video_path.split('/')[-1].split('.')[0]
|
| 19 |
+
if crop_extents:
|
| 20 |
+
src_width, src_height = (src_width + crop_extents[3]) - crop_extents[2], (src_height + crop_extents[1]) - crop_extents[0]
|
| 21 |
+
|
| 22 |
+
print(f"Source video '{sample_name}': {src_width}x{src_height}, fps={org_fps}")
|
| 23 |
+
|
| 24 |
+
total_frames = 0
|
| 25 |
+
target_period = 1/out_fps
|
| 26 |
+
last_frame_time = target_period
|
| 27 |
+
out_frame_count = 0
|
| 28 |
+
|
| 29 |
+
image_output_folder = os.path.join(output_dir, "images", sample_name)
|
| 30 |
+
os.makedirs(image_output_folder, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
while cap.isOpened():
|
| 33 |
+
success, frame = cap.read()
|
| 34 |
+
|
| 35 |
+
if not success:
|
| 36 |
+
break
|
| 37 |
+
|
| 38 |
+
# Extract frames according to desired fps
|
| 39 |
+
if last_frame_time >= target_period:
|
| 40 |
+
out_frame_count += 1
|
| 41 |
+
last_frame_time = (last_frame_time - target_period)
|
| 42 |
+
|
| 43 |
+
# Crop frame
|
| 44 |
+
if crop_extents:
|
| 45 |
+
frame = frame[crop_extents[0] : crop_extents[1], crop_extents[2] : crop_extents[3]]
|
| 46 |
+
|
| 47 |
+
# Save frame
|
| 48 |
+
out_image_name = f"{sample_name}_{str(total_frames).zfill(4)}.jpg"
|
| 49 |
+
out_image_path = os.path.join(image_output_folder, out_image_name)
|
| 50 |
+
cv2.imwrite(out_image_path, frame)
|
| 51 |
+
|
| 52 |
+
total_frames += 1
|
| 53 |
+
last_frame_time += 1/org_fps
|
| 54 |
+
|
| 55 |
+
print(f"Done '{sample_name}': {out_frame_count} frames, fps: {out_frame_count / (total_frames*1/org_fps)}")
|
| 56 |
+
cap.release()
|
| 57 |
+
|
| 58 |
+
def extract_frames(dataset_dir, out_directory, crop_extents=None):
|
| 59 |
+
dataset_video_dir = os.path.join(dataset_dir, "video")
|
| 60 |
+
for filename in tqdm(os.listdir(dataset_video_dir)):
|
| 61 |
+
video_path = os.path.join(dataset_video_dir, filename)
|
| 62 |
+
fps = 7
|
| 63 |
+
downsample_and_crop_vid(video_path, out_directory, out_fps=fps, crop_extents=crop_extents)
|
| 64 |
+
print("Extraction complete.")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def generate_labels(dataset_dir, out_directory, video_subdir=''):
|
| 68 |
+
label_output_folder = os.path.join(out_directory, "labels", video_subdir)
|
| 69 |
+
os.makedirs(label_output_folder, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
# Checkpoint paths
|
| 72 |
+
yolo_ckpt = "yolov8x.pt" # Will auto download with utltralytics
|
| 73 |
+
sam2_ckpt = "/network/scratch/x/xuolga/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
|
| 74 |
+
sam2_cfg = "./configs/sam2.1/sam2.1_hiera_b+.yaml"
|
| 75 |
+
yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
|
| 76 |
+
|
| 77 |
+
video_dir_root = os.path.join(out_directory, "images", video_subdir)
|
| 78 |
+
# for video_name in tqdm(os.listdir(video_dir_root)):
|
| 79 |
+
for video_name in tqdm(["w10_138", "w10_94", "w1_10", "w1_46", "w2_79", "w3_17", "w6_14", "w6_44", "w6_78", "w6_94", "w7_1", "w7_14"]):
|
| 80 |
+
video_dir = os.path.join(video_dir_root, video_name)
|
| 81 |
+
if len(os.listdir(video_dir)) == 0:
|
| 82 |
+
print("Empty video dir:", video_dir)
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
video_data = yolo_sam(video_dir, rel_bbox=True)
|
| 86 |
+
|
| 87 |
+
# Add metadata
|
| 88 |
+
org_dataset_labels = os.path.join(dataset_dir, "label", "json")
|
| 89 |
+
orig_label_path = os.path.join(org_dataset_labels, f"{video_name}.json")
|
| 90 |
+
with open(orig_label_path, 'r') as json_file:
|
| 91 |
+
metadata = json.load(json_file)[0]['meta_data']
|
| 92 |
+
|
| 93 |
+
final_out_data = {
|
| 94 |
+
"video_source": f"{video_name}.mp4",
|
| 95 |
+
"metadata": metadata,
|
| 96 |
+
"data": video_data
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
out_label_path = os.path.join(label_output_folder, f"{video_name}.json")
|
| 100 |
+
with open(out_label_path, 'w') as json_file:
|
| 101 |
+
json.dump(final_out_data, json_file, indent=1)
|
| 102 |
+
|
| 103 |
+
print("Saved label:", out_label_path)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def make_train_val_split(out_directory):
|
| 107 |
+
image_folder = os.path.join(out_directory, "images")
|
| 108 |
+
label_folder = os.path.join(out_directory, "labels")
|
| 109 |
+
|
| 110 |
+
all_image_folders = os.listdir(image_folder)
|
| 111 |
+
split_idx = int(len(all_image_folders) * 0.9)
|
| 112 |
+
|
| 113 |
+
train_split = all_image_folders[:split_idx]
|
| 114 |
+
val_split = all_image_folders[split_idx:]
|
| 115 |
+
|
| 116 |
+
os.makedirs(os.path.join(image_folder, "train"), exist_ok=True)
|
| 117 |
+
os.makedirs(os.path.join(image_folder, "val"), exist_ok=True)
|
| 118 |
+
os.makedirs(os.path.join(label_folder, "train"), exist_ok=True)
|
| 119 |
+
os.makedirs(os.path.join(label_folder, "val"), exist_ok=True)
|
| 120 |
+
|
| 121 |
+
for filename in train_split:
|
| 122 |
+
os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "train", filename))
|
| 123 |
+
os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "train", f"{filename}.json"))
|
| 124 |
+
|
| 125 |
+
for filename in val_split:
|
| 126 |
+
os.rename(os.path.join(image_folder, filename), os.path.join(image_folder, "val", filename))
|
| 127 |
+
os.rename(os.path.join(label_folder, f"{filename}.json"), os.path.join(label_folder, "val", f"{filename}.json"))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
parser = argparse.ArgumentParser(description='Process Russia Car Crash dataset')
|
| 132 |
+
parser.add_argument('--dataset_root', type=str, required=True,
|
| 133 |
+
help='Root directory for datasets')
|
| 134 |
+
parser.add_argument('--dataset_dir', type=str, required=True,
|
| 135 |
+
help='Directory containing the Russia Car Crash dataset')
|
| 136 |
+
parser.add_argument('--out_directory', type=str, default=None,
|
| 137 |
+
help='Output directory (defaults to {dataset_root}/preprocess_russia_crash)')
|
| 138 |
+
parser.add_argument('--skip_extraction', action='store_true',
|
| 139 |
+
help='Skip frame extraction step')
|
| 140 |
+
parser.add_argument('--skip_labels', action='store_true',
|
| 141 |
+
help='Skip label generation step')
|
| 142 |
+
parser.add_argument('--skip_split', action='store_true',
|
| 143 |
+
help='Skip train/val split step')
|
| 144 |
+
parser.add_argument('--process_train', action='store_true',
|
| 145 |
+
help='Process training set (default is validation set only)')
|
| 146 |
+
args = parser.parse_args()
|
| 147 |
+
|
| 148 |
+
# Set default output directory if not specified
|
| 149 |
+
if args.out_directory is None:
|
| 150 |
+
args.out_directory = os.path.join(args.dataset_root, "preprocess_russia_crash")
|
| 151 |
+
|
| 152 |
+
# Custom crop for Russia dataset (hide largest watermarks)
|
| 153 |
+
src_height, src_width = 986, 555
|
| 154 |
+
russia_crop_extents = [int(0.032*src_height), -int(0.198*src_height), int(0.115*src_width), -int(0.115*src_width)]
|
| 155 |
+
|
| 156 |
+
# Extract frames from videos
|
| 157 |
+
if not args.skip_extraction:
|
| 158 |
+
extract_frames(args.dataset_dir, args.out_directory, crop_extents=russia_crop_extents)
|
| 159 |
+
|
| 160 |
+
# Create labels (run bbox detector)
|
| 161 |
+
if not args.skip_labels:
|
| 162 |
+
generate_labels(args.dataset_dir, args.out_directory, video_subdir='val')
|
| 163 |
+
if args.process_train:
|
| 164 |
+
generate_labels(args.dataset_dir, args.out_directory, video_subdir='train')
|
| 165 |
+
|
| 166 |
+
# Split into train and val sets
|
| 167 |
+
if not args.skip_split:
|
| 168 |
+
make_train_val_split(args.out_directory)
|
src/preprocess/yolo_sam.py
ADDED
|
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
import random
|
| 7 |
+
from itertools import combinations
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import json
|
| 10 |
+
import cv2
|
| 11 |
+
|
| 12 |
+
from ultralytics import YOLO
|
| 13 |
+
from sam2.build_sam import build_sam2_video_predictor
|
| 14 |
+
|
| 15 |
+
NUM_LOOK_BACK_FRAMES = 3
|
| 16 |
+
FPS = 12
|
| 17 |
+
|
| 18 |
+
CLASSES_TO_KEEP = { # Using YOLO ids
|
| 19 |
+
0: 'person',
|
| 20 |
+
1: 'bicycle',
|
| 21 |
+
2: 'car',
|
| 22 |
+
3: 'motorcycle',
|
| 23 |
+
5: 'bus',
|
| 24 |
+
6: 'train',
|
| 25 |
+
7: 'truck',
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_video_from_images(images_dir, output_video, out_fps, start_frame=None, end_frame=None):
|
| 30 |
+
|
| 31 |
+
images = sorted(os.listdir(images_dir))
|
| 32 |
+
|
| 33 |
+
img0_path = os.path.join(images_dir, images[0])
|
| 34 |
+
img0 = cv2.imread(img0_path)
|
| 35 |
+
height, width, _ = img0.shape
|
| 36 |
+
|
| 37 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 38 |
+
out = cv2.VideoWriter(output_video, fourcc, out_fps, (width, height))
|
| 39 |
+
|
| 40 |
+
for idx, frame_name in enumerate(images):
|
| 41 |
+
|
| 42 |
+
if start_frame is not None and idx < start_frame:
|
| 43 |
+
continue
|
| 44 |
+
if end_frame is not None and idx >= end_frame:
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
img = cv2.imread(os.path.join(images_dir, frame_name))
|
| 48 |
+
out.write(img)
|
| 49 |
+
|
| 50 |
+
out.release()
|
| 51 |
+
print("Saved video:", output_video)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def show_mask(mask, ax, obj_id=None, random_color=False, label=None):
|
| 55 |
+
if random_color:
|
| 56 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 57 |
+
else:
|
| 58 |
+
cmap = plt.get_cmap("tab10")
|
| 59 |
+
cmap_idx = 0 if obj_id is None else obj_id
|
| 60 |
+
color = np.array([*cmap(cmap_idx)[:3], 0.6])
|
| 61 |
+
h, w = mask.shape[-2:]
|
| 62 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 63 |
+
if label is not None:
|
| 64 |
+
text_location = mask.nonzero()
|
| 65 |
+
if len(text_location[0]) > 0:
|
| 66 |
+
rand_point = random.randint(0, len(text_location[0]) - 1)
|
| 67 |
+
ax.text(text_location[2][rand_point], text_location[1][rand_point], label, color=(1, 1, 1))
|
| 68 |
+
ax.imshow(mask_image)
|
| 69 |
+
|
| 70 |
+
def show_box(box, ax, label=None, color=((1, 0.7, 0.7))):
|
| 71 |
+
x0, y0 = box[0], box[1]
|
| 72 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
| 73 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0, 0, 0, 0), lw=1))
|
| 74 |
+
|
| 75 |
+
if label is not None:
|
| 76 |
+
ax.text(x0 + w // 2, y0 + h // 2, label, color=color)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TrackedObject:
|
| 80 |
+
def __init__(self, track_id, class_id, bbox, initial_frame_idx):
|
| 81 |
+
self.track_id = track_id
|
| 82 |
+
self.bbox = bbox
|
| 83 |
+
self.class_pred_counts = {class_id: 1}
|
| 84 |
+
self.initial_frame_idx = initial_frame_idx
|
| 85 |
+
|
| 86 |
+
self._top_class = class_id
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def class_id(self):
|
| 90 |
+
"""
|
| 91 |
+
The class for the object is whichever class was predicted the most for it
|
| 92 |
+
"""
|
| 93 |
+
if self._top_class is not None:
|
| 94 |
+
return self._top_class
|
| 95 |
+
|
| 96 |
+
top_class = None
|
| 97 |
+
top_count = 0
|
| 98 |
+
for class_id, count in self.class_pred_counts.items():
|
| 99 |
+
if count >= top_count:
|
| 100 |
+
top_count = count
|
| 101 |
+
top_class = class_id
|
| 102 |
+
|
| 103 |
+
self._top_class = top_class
|
| 104 |
+
return top_class
|
| 105 |
+
|
| 106 |
+
def new_pred(self, class_id):
|
| 107 |
+
if class_id not in self.class_pred_counts:
|
| 108 |
+
self.class_pred_counts[class_id] = 0
|
| 109 |
+
self.class_pred_counts[class_id] += 1
|
| 110 |
+
self._top_class = None # Remove cached top_class
|
| 111 |
+
|
| 112 |
+
def __repr__(self):
|
| 113 |
+
return f"id:{self.track_id}, class:{self.class_id}, bbox:{self.bbox}, init_frame:{self.initial_frame_idx}"
|
| 114 |
+
|
| 115 |
+
def __str__(self):
|
| 116 |
+
return f"id:{self.track_id}, class:{self.class_id}, bbox:{self.bbox}, init_frame:{self.initial_frame_idx}"
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class YoloSamProcessor():
|
| 120 |
+
|
| 121 |
+
def __init__(self, yolo_ckpt, sam_ckpt, sam_cfg, gpu_id=0):
|
| 122 |
+
# Load models
|
| 123 |
+
|
| 124 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
| 125 |
+
self.device = f"cuda:{gpu_id}"
|
| 126 |
+
|
| 127 |
+
self.yolo_model = YOLO(yolo_ckpt)
|
| 128 |
+
self.sam_model = build_sam2_video_predictor(sam_cfg, sam_ckpt, device=self.device)
|
| 129 |
+
|
| 130 |
+
def __call__(self, video_dir, rel_bbox=True):
|
| 131 |
+
# Renaming the videos to be compatible with SAM2
|
| 132 |
+
prev_frame_paths = {}
|
| 133 |
+
src_width, src_height = None, None
|
| 134 |
+
self.num_frames = len(os.listdir(video_dir))
|
| 135 |
+
for vid in os.listdir(video_dir):
|
| 136 |
+
new_name = vid.split("_")[-1]
|
| 137 |
+
og_path = os.path.join(video_dir, vid)
|
| 138 |
+
new_path = os.path.join(video_dir, new_name)
|
| 139 |
+
os.rename(og_path, new_path)
|
| 140 |
+
prev_frame_paths[new_path] = og_path
|
| 141 |
+
|
| 142 |
+
if src_width is None:
|
| 143 |
+
img_sample = Image.open(new_path)
|
| 144 |
+
src_width, src_height = img_sample.size
|
| 145 |
+
|
| 146 |
+
self.out_data = None
|
| 147 |
+
try:
|
| 148 |
+
# YOLO model
|
| 149 |
+
self.run_yolo(video_dir)
|
| 150 |
+
self.filter_yolo_preds()
|
| 151 |
+
|
| 152 |
+
# SAM2 model
|
| 153 |
+
self.run_sam(video_dir)
|
| 154 |
+
self.filter_sam_preds()
|
| 155 |
+
|
| 156 |
+
self.extract_final_bboxes()
|
| 157 |
+
|
| 158 |
+
# Format the data
|
| 159 |
+
self.out_data = []
|
| 160 |
+
for frame_idx, frame_data in enumerate(self.final_bboxes):
|
| 161 |
+
frame_path = self.frame_path_by_idx[frame_idx]
|
| 162 |
+
frame_name = prev_frame_paths[frame_path].split('/')[-1]
|
| 163 |
+
self.out_data.append({
|
| 164 |
+
"image_source": frame_name,
|
| 165 |
+
"labels": []
|
| 166 |
+
})
|
| 167 |
+
for obj_id, bbox in frame_data.items():
|
| 168 |
+
|
| 169 |
+
bbox = bbox.tolist()
|
| 170 |
+
if rel_bbox:
|
| 171 |
+
# Save bbox coordinates as a ratio to image size
|
| 172 |
+
formatted_bbox = [bbox[0]/src_width, bbox[1]/src_height, bbox[2]/src_width, bbox[3]/src_height]
|
| 173 |
+
else:
|
| 174 |
+
# Keep in absolute coordinates
|
| 175 |
+
formatted_bbox = bbox
|
| 176 |
+
|
| 177 |
+
tracked_obj = self.initial_bboxes[obj_id]
|
| 178 |
+
|
| 179 |
+
out_label = {
|
| 180 |
+
"track_id": obj_id,
|
| 181 |
+
"name": CLASSES_TO_KEEP[tracked_obj.class_id],
|
| 182 |
+
"class": tracked_obj.class_id,
|
| 183 |
+
"box": formatted_bbox,
|
| 184 |
+
}
|
| 185 |
+
self.out_data[frame_idx]["labels"].append(out_label)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
print("Yolo_Sam processor failed:", e)
|
| 188 |
+
finally:
|
| 189 |
+
# Revert back the names of the files
|
| 190 |
+
print("Restoring names in video dir after exception")
|
| 191 |
+
for new_path, old_path in prev_frame_paths.items():
|
| 192 |
+
os.rename(new_path, old_path)
|
| 193 |
+
|
| 194 |
+
return self.out_data
|
| 195 |
+
|
| 196 |
+
def run_yolo(self, video_dir):
|
| 197 |
+
self.initial_bboxes = {} # Store the first bbox for each track id
|
| 198 |
+
self.id_reassigns = {}
|
| 199 |
+
all_preds_by_track_id = []
|
| 200 |
+
|
| 201 |
+
self.frame_path_by_idx = {}
|
| 202 |
+
|
| 203 |
+
# Reset yolo's tracker for new video
|
| 204 |
+
if self.yolo_model.predictor is not None:
|
| 205 |
+
self.yolo_model.predictor.trackers[0].reset()
|
| 206 |
+
|
| 207 |
+
new_id_counter = 1
|
| 208 |
+
sorted_frames = sorted(os.listdir(video_dir))
|
| 209 |
+
for frame_idx, frame_file in enumerate(sorted_frames):
|
| 210 |
+
frame_path = os.path.join(video_dir, frame_file)
|
| 211 |
+
self.frame_path_by_idx[frame_idx] = frame_path
|
| 212 |
+
img = Image.open(frame_path)
|
| 213 |
+
|
| 214 |
+
yolo_results = self.yolo_model.track(img, persist=True, conf=0.1, verbose=False, device=self.device)
|
| 215 |
+
yolo_boxes = yolo_results[0].boxes
|
| 216 |
+
all_preds_by_track_id.append({})
|
| 217 |
+
|
| 218 |
+
# If the detection is has a new track id, then we record
|
| 219 |
+
if yolo_boxes.is_track:
|
| 220 |
+
for idx in range(len(yolo_boxes)):
|
| 221 |
+
track_id = int(yolo_boxes.id[idx].item())
|
| 222 |
+
bbox = yolo_boxes.xyxy[idx].numpy()
|
| 223 |
+
class_id = int(yolo_boxes.cls[idx].item())
|
| 224 |
+
|
| 225 |
+
tracked_obj = self.initial_bboxes.get(track_id)
|
| 226 |
+
# Check if YOLO is trying to assign a id that was already predicted but was not present in the previous frame
|
| 227 |
+
# This means YOLO lost this obj and is attempting to reassign. We will assign a new id for this as we want to
|
| 228 |
+
# trust SAM with the tracking not YOLO.
|
| 229 |
+
prev_track_id = track_id
|
| 230 |
+
if tracked_obj is not None and frame_idx > 0 and track_id not in all_preds_by_track_id[frame_idx - 1]:
|
| 231 |
+
|
| 232 |
+
# Check if id has been reassigned
|
| 233 |
+
if self.id_reassigns.get(track_id) is not None:
|
| 234 |
+
track_id = self.id_reassigns.get(track_id)
|
| 235 |
+
|
| 236 |
+
if track_id not in all_preds_by_track_id[frame_idx - 1]: # Check again because id might have changed
|
| 237 |
+
# Assign a new track id
|
| 238 |
+
track_id = 100 + new_id_counter
|
| 239 |
+
new_id_counter += 1
|
| 240 |
+
self.id_reassigns[prev_track_id] = track_id
|
| 241 |
+
tracked_obj = None
|
| 242 |
+
# print(f"Frame: {frame_idx} re-assigned id {prev_track_id}->{track_id}")
|
| 243 |
+
|
| 244 |
+
if tracked_obj is None:
|
| 245 |
+
|
| 246 |
+
# Check overlap with existing bboxes to make sure this isn't a double detection
|
| 247 |
+
reject_detection = False
|
| 248 |
+
for idx2 in range(len(yolo_boxes)):
|
| 249 |
+
track_id2 = int(yolo_boxes.id[idx2].item())
|
| 250 |
+
if track_id2 in [track_id, prev_track_id] or self.initial_bboxes.get(track_id2) is None:
|
| 251 |
+
continue
|
| 252 |
+
bbox2 = yolo_boxes.xyxy[idx2].numpy()
|
| 253 |
+
iou = self._bbox_iou(bbox, bbox2)
|
| 254 |
+
if iou >= 0.8:
|
| 255 |
+
reject_detection = True
|
| 256 |
+
# print("Redetection! Frame:", frame_idx, "Iou:", iou, track_id, track_id2)
|
| 257 |
+
break
|
| 258 |
+
|
| 259 |
+
if not reject_detection:
|
| 260 |
+
tracked_obj = TrackedObject(track_id, class_id, bbox, frame_idx)
|
| 261 |
+
self.initial_bboxes[track_id] = tracked_obj
|
| 262 |
+
else:
|
| 263 |
+
tracked_obj.new_pred(class_id)
|
| 264 |
+
|
| 265 |
+
if tracked_obj is not None:
|
| 266 |
+
all_preds_by_track_id[frame_idx][track_id] = TrackedObject(track_id, class_id, bbox, tracked_obj.initial_frame_idx)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def filter_yolo_preds(self):
|
| 270 |
+
# Smooth classes detected + filter out unwanted classes
|
| 271 |
+
self.filtered_objects = []
|
| 272 |
+
self.filtered_objects_by_frame = {}
|
| 273 |
+
for _, tracked_obj in self.initial_bboxes.items():
|
| 274 |
+
if tracked_obj.class_id in CLASSES_TO_KEEP:
|
| 275 |
+
self.filtered_objects.append(tracked_obj)
|
| 276 |
+
|
| 277 |
+
if self.filtered_objects_by_frame.get(tracked_obj.initial_frame_idx) is None:
|
| 278 |
+
self.filtered_objects_by_frame[tracked_obj.initial_frame_idx] = []
|
| 279 |
+
self.filtered_objects_by_frame[tracked_obj.initial_frame_idx].append(tracked_obj)
|
| 280 |
+
|
| 281 |
+
self.initial_frame_idx_by_track_id = {obj.track_id: obj.initial_frame_idx for obj in self.filtered_objects}
|
| 282 |
+
|
| 283 |
+
# print("Filtered objects:", self.filtered_objects)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def run_sam(self, video_dir):
|
| 287 |
+
self.video_segments = {} # video_segments contains the per-frame segmentation results
|
| 288 |
+
self.track_ids_to_reject = {} # {track_id: reject_all_before_frame_idx}
|
| 289 |
+
|
| 290 |
+
if self.filtered_objects is None or len(self.filtered_objects) == 0:
|
| 291 |
+
# There are no objects to track
|
| 292 |
+
return
|
| 293 |
+
|
| 294 |
+
inference_state = self.sam_model.init_state(video_path=video_dir) # NOTE: Kind of annoying that the model requires frames to be named with numbers only...
|
| 295 |
+
|
| 296 |
+
self.sam_model.reset_state(inference_state)
|
| 297 |
+
for obj in self.filtered_objects:
|
| 298 |
+
_, out_obj_ids, out_mask_logits = self.sam_model.add_new_points_or_box(
|
| 299 |
+
inference_state=inference_state,
|
| 300 |
+
frame_idx=obj.initial_frame_idx,
|
| 301 |
+
obj_id=obj.track_id,
|
| 302 |
+
box=obj.bbox,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
def get_last_frame_occurrence(sam_track_ids_per_frame, track_id, current_idx):
|
| 306 |
+
for frame_idx in range(current_idx-1, -1, -1):
|
| 307 |
+
if track_id in sam_track_ids_per_frame.get(frame_idx, []):
|
| 308 |
+
return frame_idx
|
| 309 |
+
return -1
|
| 310 |
+
|
| 311 |
+
# run propagation throughout the video and collect the results in a dict
|
| 312 |
+
sam_track_ids_per_frame = {}
|
| 313 |
+
long_non_existence_track_ids = []
|
| 314 |
+
with torch.cuda.amp.autocast(): # Need this for some reason to fix some casting issues... (BFloat16 and Float16 mismatches)
|
| 315 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in self.sam_model.propagate_in_video(inference_state):
|
| 316 |
+
|
| 317 |
+
sam_tracked_ids = []
|
| 318 |
+
for pred_idx, mask_logits in enumerate(out_mask_logits):
|
| 319 |
+
mask = (mask_logits > 0.0).cpu().numpy()
|
| 320 |
+
track_id = out_obj_ids[pred_idx]
|
| 321 |
+
if mask.sum() > 0 and out_frame_idx > (self.initial_frame_idx_by_track_id[track_id] - NUM_LOOK_BACK_FRAMES):
|
| 322 |
+
sam_tracked_ids.append(track_id)
|
| 323 |
+
sam_track_ids_per_frame[out_frame_idx] = sam_tracked_ids
|
| 324 |
+
|
| 325 |
+
# Compare *new* YOLO preds and make sure they don't overlap with existing SAM preds
|
| 326 |
+
for new_yolo_obj in self.filtered_objects_by_frame.get(out_frame_idx, []):
|
| 327 |
+
yolo_track_id = new_yolo_obj.track_id
|
| 328 |
+
yolo_bbox = new_yolo_obj.bbox
|
| 329 |
+
|
| 330 |
+
for ind, sam_track_id in enumerate(out_obj_ids):
|
| 331 |
+
if (sam_track_id == yolo_track_id) or (out_frame_idx < (self.initial_frame_idx_by_track_id[sam_track_id] - NUM_LOOK_BACK_FRAMES)):
|
| 332 |
+
continue
|
| 333 |
+
|
| 334 |
+
sam_mask = (out_mask_logits[ind] > 0.0).cpu().numpy()
|
| 335 |
+
sam_bbox = self._get_bbox_from_mask(sam_mask)
|
| 336 |
+
if sam_bbox is None:
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
# Flag the SAM prediction only if this prediction has been lost for many frames and SAM is trying to recover it
|
| 340 |
+
# (in which case we should keep the YOLO pred)
|
| 341 |
+
|
| 342 |
+
last_occurrence_idx = get_last_frame_occurrence(sam_track_ids_per_frame, sam_track_id, out_frame_idx)
|
| 343 |
+
if (out_frame_idx - last_occurrence_idx) >= (FPS * 0.8) and last_occurrence_idx >= 0:
|
| 344 |
+
print(sam_track_id, "long non-existence:", out_frame_idx - last_occurrence_idx, "frames")
|
| 345 |
+
long_non_existence_track_ids.append(sam_track_id)
|
| 346 |
+
|
| 347 |
+
iou = self._bbox_iou(yolo_bbox, sam_bbox)
|
| 348 |
+
if iou > 0.8:
|
| 349 |
+
# Reject the SAM tracked object if it was lost for many frames
|
| 350 |
+
if sam_track_id in long_non_existence_track_ids:
|
| 351 |
+
rejected_track_id = sam_track_id
|
| 352 |
+
self.track_ids_to_reject[rejected_track_id] = self.initial_frame_idx_by_track_id[sam_track_id]
|
| 353 |
+
print(f"Frame {out_frame_idx}. {yolo_track_id} & {sam_track_id} iou: {iou:.2f}. Reject: {rejected_track_id} (all) for long non-existence")
|
| 354 |
+
else:
|
| 355 |
+
# Otherwise, choose the obj with the latest yolo initial frame detection to reject
|
| 356 |
+
yolo_initial_frame = self.initial_frame_idx_by_track_id[yolo_track_id] # This is just the current frame
|
| 357 |
+
sam_initial_frame = self.initial_frame_idx_by_track_id[sam_track_id]
|
| 358 |
+
yolo_error = yolo_initial_frame >= sam_initial_frame
|
| 359 |
+
rejected_track_id = yolo_track_id if yolo_error else sam_track_id
|
| 360 |
+
reject_all_before_frame_idx = self.num_frames if yolo_error else sam_initial_frame
|
| 361 |
+
self.track_ids_to_reject[rejected_track_id] = reject_all_before_frame_idx
|
| 362 |
+
print(f"Frame {out_frame_idx}. {yolo_track_id} & {sam_track_id} iou: {iou:.2f}. Reject: {rejected_track_id} ({'before frame #' + str(reject_all_before_frame_idx) if not yolo_error else 'all'})")
|
| 363 |
+
|
| 364 |
+
self.video_segments[out_frame_idx] = {
|
| 365 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
| 366 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def filter_sam_preds(self):
|
| 371 |
+
self.filtered_sam_preds = []
|
| 372 |
+
self.filtered_yolo_preds = []
|
| 373 |
+
self.num_sam_existence_frames_by_track_id = {}
|
| 374 |
+
|
| 375 |
+
def check_reject_pred(obj_id, frame_idx):
|
| 376 |
+
return obj_id in self.track_ids_to_reject and frame_idx < self.track_ids_to_reject[obj_id]
|
| 377 |
+
|
| 378 |
+
for frame_idx in range(self.num_frames):
|
| 379 |
+
self.filtered_sam_preds.append({})
|
| 380 |
+
self.filtered_yolo_preds.append({})
|
| 381 |
+
|
| 382 |
+
if frame_idx not in self.video_segments.keys():
|
| 383 |
+
continue
|
| 384 |
+
|
| 385 |
+
for obj_id, mask in self.video_segments[frame_idx].items():
|
| 386 |
+
# Only keep mask predictions that happen after the initial frame (with a small buffer)
|
| 387 |
+
# (SAM will try to predict them before the prompt frame, and will often get them wrong)
|
| 388 |
+
if (frame_idx >= self.initial_frame_idx_by_track_id[obj_id] - NUM_LOOK_BACK_FRAMES) and not check_reject_pred(obj_id, frame_idx):
|
| 389 |
+
self.filtered_sam_preds[frame_idx][obj_id] = mask
|
| 390 |
+
|
| 391 |
+
if mask.sum() > 0:
|
| 392 |
+
if self.num_sam_existence_frames_by_track_id.get(obj_id) is None:
|
| 393 |
+
self.num_sam_existence_frames_by_track_id[obj_id] = 0
|
| 394 |
+
self.num_sam_existence_frames_by_track_id[obj_id] += 1
|
| 395 |
+
|
| 396 |
+
for obj in self.filtered_objects:
|
| 397 |
+
if obj.initial_frame_idx == frame_idx and not check_reject_pred(obj.track_id, frame_idx):
|
| 398 |
+
self.filtered_yolo_preds[frame_idx][obj.track_id] = obj.bbox
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def extract_final_bboxes(self):
|
| 402 |
+
# Extract the bboxes from the predicted masks
|
| 403 |
+
# Also filter any overlapping. At this stage if there is overlapping masks this is likely a fault on SAM's side
|
| 404 |
+
# (id switching/collecting) and there is not much we can do for this.
|
| 405 |
+
self.final_bboxes = []
|
| 406 |
+
rejected_ids = []
|
| 407 |
+
for frame_idx in range(self.num_frames):
|
| 408 |
+
self.final_bboxes.append({})
|
| 409 |
+
for obj_id, mask in self.filtered_sam_preds[frame_idx].items():
|
| 410 |
+
mask_box = self._get_bbox_from_mask(mask)
|
| 411 |
+
if mask_box is not None:
|
| 412 |
+
self.final_bboxes[frame_idx][obj_id] = mask_box
|
| 413 |
+
|
| 414 |
+
# Compute IOU overlap and eliminate duplicates
|
| 415 |
+
items_to_compare = list(self.final_bboxes[frame_idx].items()) + list(self.final_bboxes[frame_idx].items())
|
| 416 |
+
for (id0, bbox0), (id1, bbox1) in combinations(items_to_compare, 2):
|
| 417 |
+
if id0 == id1:
|
| 418 |
+
continue
|
| 419 |
+
|
| 420 |
+
if id0 not in self.final_bboxes[frame_idx] or id1 not in self.final_bboxes[frame_idx]: # Could've been removed previously
|
| 421 |
+
continue
|
| 422 |
+
|
| 423 |
+
if id0 in rejected_ids:
|
| 424 |
+
del self.final_bboxes[frame_idx][id0]
|
| 425 |
+
continue
|
| 426 |
+
if id1 in rejected_ids:
|
| 427 |
+
del self.final_bboxes[frame_idx][id1]
|
| 428 |
+
continue
|
| 429 |
+
|
| 430 |
+
iou = self._bbox_iou(bbox0, bbox1)
|
| 431 |
+
if iou > 0.8:
|
| 432 |
+
# Rejecting the prediction that exists for the least amount of frames throughout the video
|
| 433 |
+
frame_count0 = self.num_sam_existence_frames_by_track_id[id0]
|
| 434 |
+
frame_count1 = self.num_sam_existence_frames_by_track_id[id1]
|
| 435 |
+
rejected_id = id0 if frame_count0 < frame_count1 else id1
|
| 436 |
+
|
| 437 |
+
del self.final_bboxes[frame_idx][rejected_id]
|
| 438 |
+
rejected_ids.append(rejected_id)
|
| 439 |
+
# print(f"Frame {frame_idx}. {id0} & {id1} iou: {iou}. Rejecting {rejected_id}")
|
| 440 |
+
|
| 441 |
+
def _bbox_iou(self, box1, box2):
|
| 442 |
+
"""
|
| 443 |
+
Compute the Intersection over Union (IoU) between two bounding boxes.
|
| 444 |
+
|
| 445 |
+
Parameters:
|
| 446 |
+
box1: (x1, y1, x2, y2) coordinates of the first box.
|
| 447 |
+
box2: (x1, y1, x2, y2) coordinates of the second box.
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
iou: Intersection over Union value (0 to 1).
|
| 451 |
+
"""
|
| 452 |
+
# Get the coordinates of the intersection rectangle
|
| 453 |
+
x1 = max(box1[0], box2[0])
|
| 454 |
+
y1 = max(box1[1], box2[1])
|
| 455 |
+
x2 = min(box1[2], box2[2])
|
| 456 |
+
y2 = min(box1[3], box2[3])
|
| 457 |
+
|
| 458 |
+
# Compute intersection area
|
| 459 |
+
inter_width = max(0, x2 - x1)
|
| 460 |
+
inter_height = max(0, y2 - y1)
|
| 461 |
+
inter_area = inter_width * inter_height
|
| 462 |
+
|
| 463 |
+
# Compute areas of both bounding boxes
|
| 464 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 465 |
+
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 466 |
+
|
| 467 |
+
# Compute union area
|
| 468 |
+
union_area = box1_area + box2_area - inter_area
|
| 469 |
+
|
| 470 |
+
# Compute IoU
|
| 471 |
+
iou = inter_area / union_area if union_area > 0 else 0
|
| 472 |
+
|
| 473 |
+
return iou
|
| 474 |
+
|
| 475 |
+
def _get_bbox_from_mask(self, mask):
|
| 476 |
+
mask_points = mask.nonzero()
|
| 477 |
+
if len(mask_points[0]) == 0:
|
| 478 |
+
return None
|
| 479 |
+
|
| 480 |
+
x0 = min(mask_points[2])
|
| 481 |
+
y0 = min(mask_points[1])
|
| 482 |
+
x1 = max(mask_points[2])
|
| 483 |
+
y1 = max(mask_points[1])
|
| 484 |
+
|
| 485 |
+
return np.array([x0, y0, x1, y1])
|
| 486 |
+
|
| 487 |
+
from collections import defaultdict
|
| 488 |
+
class CVCOLORS:
|
| 489 |
+
RED = (0,0,255)
|
| 490 |
+
GREEN = (0,255,0)
|
| 491 |
+
BLUE = (255,0,0)
|
| 492 |
+
PURPLE = (247,44,200)
|
| 493 |
+
ORANGE = (44,162,247)
|
| 494 |
+
MINT = (239,255,66)
|
| 495 |
+
YELLOW = (2,255,250)
|
| 496 |
+
BROWN = (42,42,165)
|
| 497 |
+
LIME=(51,255,153)
|
| 498 |
+
GRAY=(128, 128, 128)
|
| 499 |
+
LIGHTPINK = (222,209,255)
|
| 500 |
+
LIGHTGREEN = (204,255,204)
|
| 501 |
+
LIGHTBLUE = (255,235,207)
|
| 502 |
+
LIGHTPURPLE = (255,153,204)
|
| 503 |
+
LIGHTRED = (204,204,255)
|
| 504 |
+
WHITE = (255,255,255)
|
| 505 |
+
BLACK = (0,0,0)
|
| 506 |
+
|
| 507 |
+
TRACKID_LOOKUP = defaultdict(lambda: (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255)))
|
| 508 |
+
TYPE_LOOKUP = [BROWN, BLUE, PURPLE, RED, ORANGE, YELLOW, GREEN, LIGHTPURPLE, LIGHTPINK, LIGHTRED, GRAY]
|
| 509 |
+
REVERT_CHANNEL_F = lambda x: (x[2], x[1], x[0])
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
if __name__ == "__main__":
|
| 513 |
+
|
| 514 |
+
# Checkpoint paths
|
| 515 |
+
yolo_ckpt = "yolov8x.pt" # Will auto download with utltralytics
|
| 516 |
+
|
| 517 |
+
# NOTE: Need to download beforehand from https://github.com/facebookresearch/sam2
|
| 518 |
+
sam2_ckpt = "/network/scratch/x/xuolga/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
|
| 519 |
+
sam2_cfg = "./configs/sam2.1/sam2.1_hiera_b+.yaml"
|
| 520 |
+
|
| 521 |
+
yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg)
|
| 522 |
+
output_dir = f"/path/to/output_dir"
|
| 523 |
+
bboxes_out_dir = os.path.join(output_dir, "bboxes")
|
| 524 |
+
labels_out_dir = os.path.join(output_dir, "json")
|
| 525 |
+
os.makedirs(bboxes_out_dir, exist_ok=True)
|
| 526 |
+
os.makedirs(labels_out_dir, exist_ok=True)
|
| 527 |
+
|
| 528 |
+
video_dir_root = f"/path/to/dada2000_images_12fps/images"
|
| 529 |
+
videos = ['8_90038', '8_90002', '10_90019', '10_90045', '10_90027', '10_90029', '10_90082', '10_90021', '10_90064', '10_90083', '10_90141', '10_90139', '10_90034', '10_90134', '10_90056', '10_90169', '10_90040', '11_90109', '11_90162', '11_90202', '11_90142', '11_90180', '11_90161', '11_90091', '11_90189', '11_90002', '11_90192', '11_90221', '11_90181', '12_90007', '12_90042', '13_90002', '13_90008', '13_90007', '14_90012', '14_90018', '14_90014', '14_90027', '24_90017', '24_90005', '24_90006', '24_90011', '42_90021', '43_90013', '48_90078', '48_90031', '48_90001', '48_90075', '49_90030', '49_90021', '61_90016', '61_90004']
|
| 530 |
+
|
| 531 |
+
for video_name in tqdm(videos):
|
| 532 |
+
cat = video_name.split("_")[0]
|
| 533 |
+
video_dir = os.path.join(video_dir_root, cat, video_name)
|
| 534 |
+
|
| 535 |
+
if len(os.listdir(video_dir)) > 300:
|
| 536 |
+
print("Skipping:", video_name)
|
| 537 |
+
continue
|
| 538 |
+
out_data = yolo_sam(video_dir, rel_bbox=False)
|
| 539 |
+
|
| 540 |
+
# Output format:
|
| 541 |
+
"""
|
| 542 |
+
out_data = [ # List of Dicts, each inner dict represents one frame of the video
|
| 543 |
+
{
|
| 544 |
+
"image_source": img1.jpg,
|
| 545 |
+
"labels":
|
| 546 |
+
[ # List of Dicts, each dict represents one tracked object
|
| 547 |
+
{'track_id': 0, 'name': 'car', 'class':2, 'bbox': [x0, y0, x1, y1]}, # Obj 0
|
| 548 |
+
{...}, # Obj 1
|
| 549 |
+
]
|
| 550 |
+
}, # Frame 0
|
| 551 |
+
|
| 552 |
+
{...}, # Frame 1
|
| 553 |
+
]
|
| 554 |
+
"""
|
| 555 |
+
|
| 556 |
+
# Plot final bboxes and save to file
|
| 557 |
+
out_json_path = os.path.join(labels_out_dir, f"{video_name}.json")
|
| 558 |
+
os.makedirs(os.path.dirname(out_json_path), exist_ok=True)
|
| 559 |
+
with open(out_json_path, 'w') as json_file:
|
| 560 |
+
json.dump(out_data, json_file, indent=1)
|
| 561 |
+
|
| 562 |
+
og_frames = sorted(os.listdir(video_dir))
|
| 563 |
+
out_bbox_path = os.path.join(bboxes_out_dir, video_name)
|
| 564 |
+
os.makedirs(out_bbox_path, exist_ok=True)
|
| 565 |
+
for frame_idx, frame_data in enumerate(out_data):
|
| 566 |
+
plt.figure(figsize=(9, 6))
|
| 567 |
+
plt.axis("off")
|
| 568 |
+
img = Image.open(os.path.join(video_dir, og_frames[frame_idx]))
|
| 569 |
+
plt.imshow(img)
|
| 570 |
+
for obj in frame_data["labels"]:
|
| 571 |
+
color = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[obj["class"]])) / 255.0
|
| 572 |
+
show_box(obj["box"], plt.gca(), label=str(obj["track_id"]), color=color)
|
| 573 |
+
|
| 574 |
+
frame_id_name = og_frames[frame_idx].split("_")[-1].split(".")[0]
|
| 575 |
+
plt.savefig(os.path.join(out_bbox_path, f"bboxes_frame_{frame_id_name}"))
|
| 576 |
+
|
| 577 |
+
# Save videos of bboxes
|
| 578 |
+
videos_out_dir = os.path.join(output_dir, "videos")
|
| 579 |
+
out_video_path = os.path.join(videos_out_dir, f"{video_name}_with_bboxes.mp4")
|
| 580 |
+
os.makedirs(os.path.dirname(out_video_path), exist_ok=True)
|
| 581 |
+
create_video_from_images(out_bbox_path, out_video_path, out_fps=12)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .parser import parse_args
|
| 2 |
+
from .utils import encode_video_image, get_add_time_ids, get_samples, get_model_attr
|
src/utils/parser.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TODO: Clean up unused args
|
| 2 |
+
def parse_args():
|
| 3 |
+
import argparse
|
| 4 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 5 |
+
|
| 6 |
+
parser.add_argument(
|
| 7 |
+
"--project_name",
|
| 8 |
+
type=str,
|
| 9 |
+
default="car_crash",
|
| 10 |
+
help="Name of the project."
|
| 11 |
+
)
|
| 12 |
+
parser.add_argument(
|
| 13 |
+
"--pretrained_model_name_or_path",
|
| 14 |
+
type=str,
|
| 15 |
+
default=None,
|
| 16 |
+
required=True,
|
| 17 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 18 |
+
)
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--finetuned_svd_path",
|
| 21 |
+
type=str,
|
| 22 |
+
default=None,
|
| 23 |
+
required=False,
|
| 24 |
+
help="Path to pretrained unet model. Used to override 'pretrained_model_name_or_path', and will default to this if not provided",
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--revision",
|
| 28 |
+
type=str,
|
| 29 |
+
default=None,
|
| 30 |
+
required=False,
|
| 31 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--variant",
|
| 35 |
+
type=str,
|
| 36 |
+
default=None,
|
| 37 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--output_dir",
|
| 41 |
+
type=str,
|
| 42 |
+
default="out",
|
| 43 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--logging_dir",
|
| 47 |
+
type=str,
|
| 48 |
+
default="logs",
|
| 49 |
+
help=(
|
| 50 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 51 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 52 |
+
),
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--eval_dir",
|
| 56 |
+
type=str,
|
| 57 |
+
default="eval",
|
| 58 |
+
help=(
|
| 59 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 60 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 61 |
+
),
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--learning_rate",
|
| 65 |
+
type=float,
|
| 66 |
+
default=1e-4,
|
| 67 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--object_net_lr_factor",
|
| 71 |
+
type=float,
|
| 72 |
+
default=1.0,
|
| 73 |
+
help="Factor to scale the learning rate of the object network.",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--gradient_accumulation_steps",
|
| 77 |
+
type=int,
|
| 78 |
+
default=1,
|
| 79 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--scale_lr",
|
| 83 |
+
action="store_true",
|
| 84 |
+
default=False,
|
| 85 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--lr_scheduler",
|
| 89 |
+
type=str,
|
| 90 |
+
default="constant",
|
| 91 |
+
help=(
|
| 92 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 93 |
+
' "constant", "constant_with_warmup"]'
|
| 94 |
+
),
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--snr_gamma",
|
| 101 |
+
type=float,
|
| 102 |
+
default=None,
|
| 103 |
+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
| 104 |
+
"More details here: https://arxiv.org/abs/2303.09556.",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 107 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 108 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 109 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 110 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--max_train_steps",
|
| 117 |
+
type=int,
|
| 118 |
+
default=None,
|
| 119 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--dataloader_num_workers",
|
| 123 |
+
type=int,
|
| 124 |
+
default=0,
|
| 125 |
+
help=(
|
| 126 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 127 |
+
),
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--mixed_precision",
|
| 131 |
+
type=str,
|
| 132 |
+
default=None,
|
| 133 |
+
choices=["no", "fp16", "bf16"],
|
| 134 |
+
help=(
|
| 135 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 136 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 137 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 138 |
+
),
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--rank",
|
| 142 |
+
type=int,
|
| 143 |
+
default=4,
|
| 144 |
+
help=("The dimension of the LoRA update matrices."),
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--checkpointing_steps",
|
| 148 |
+
type=int,
|
| 149 |
+
default=500,
|
| 150 |
+
help=(
|
| 151 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
| 152 |
+
" training using `--resume_from_checkpoint`."
|
| 153 |
+
),
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--checkpointing_time",
|
| 157 |
+
type=int,
|
| 158 |
+
default=-1,
|
| 159 |
+
help=(
|
| 160 |
+
"Save a checkpoint of the training state every X seconds. Useful to save checkpoint when using jobs with short timeouts. If <= 0, will not save any checkpoints based on time."
|
| 161 |
+
" training using `--resume_from_checkpoint`."
|
| 162 |
+
),
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--checkpoints_total_limit",
|
| 166 |
+
type=int,
|
| 167 |
+
default=None,
|
| 168 |
+
help=("Max number of checkpoints to store."),
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--resume_from_checkpoint",
|
| 172 |
+
type=str,
|
| 173 |
+
default=None,
|
| 174 |
+
help=(
|
| 175 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 176 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 177 |
+
),
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--enable_gradient_checkpointing",
|
| 181 |
+
action="store_true",
|
| 182 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--data_root",
|
| 186 |
+
type=str,
|
| 187 |
+
default="./data",
|
| 188 |
+
help="The root directory of the dataset.",
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--validation_steps",
|
| 192 |
+
type=int,
|
| 193 |
+
default=50,
|
| 194 |
+
help=(
|
| 195 |
+
"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
|
| 196 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
| 197 |
+
),
|
| 198 |
+
)
|
| 199 |
+
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
| 200 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument("--disable_object_condition", action="store_true", help="Whether or not to disable object condition.")
|
| 205 |
+
parser.add_argument("--encoder_hid_dim_type", type=str, default=None, help="The type of unet's input hidden.")
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--report_to",
|
| 208 |
+
type=str,
|
| 209 |
+
default="wandb",
|
| 210 |
+
help=(
|
| 211 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 212 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 213 |
+
),
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--run_name",
|
| 217 |
+
type=str,
|
| 218 |
+
default=None,
|
| 219 |
+
help="The name of the run log.",
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--prediction_type",
|
| 223 |
+
type=str,
|
| 224 |
+
default=None,
|
| 225 |
+
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--guidance_scale",
|
| 229 |
+
type=float,
|
| 230 |
+
default=1.0,
|
| 231 |
+
help="(Image only). A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`."
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--guidance_rescale",
|
| 235 |
+
type=float,
|
| 236 |
+
default=0.0,
|
| 237 |
+
help="(Image only). Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR."
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--min_guidance_scale",
|
| 241 |
+
type=float,
|
| 242 |
+
default=1.0,
|
| 243 |
+
help="(Video generation only). The minimum guidance scale. Used for the classifier free guidance with first frame."
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--max_guidance_scale",
|
| 247 |
+
type=float,
|
| 248 |
+
default=3.0,
|
| 249 |
+
help="(Video generation only). The maximum guidance scale. Used for the classifier free guidance with last frame."
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--conditioning_dropout_prob",
|
| 253 |
+
type=float,
|
| 254 |
+
default=0.1,
|
| 255 |
+
help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--clip_length",
|
| 259 |
+
type=int,
|
| 260 |
+
default=25,
|
| 261 |
+
help="The number of frames in a clip.",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--non_ema_revision",
|
| 266 |
+
type=str,
|
| 267 |
+
default=None,
|
| 268 |
+
required=False,
|
| 269 |
+
help=(
|
| 270 |
+
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
|
| 271 |
+
" remote repository specified with --pretrained_model_name_or_path."
|
| 272 |
+
),
|
| 273 |
+
)
|
| 274 |
+
parser.add_argument(
|
| 275 |
+
"--backprop_temporal_blocks_start_iter",
|
| 276 |
+
type=int,
|
| 277 |
+
default=-1,
|
| 278 |
+
help="(Video generation only). The starting iteration of only backpropagating into temporal blocks (if -1, then always backpropagate the entire network).",
|
| 279 |
+
)
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
"--enable_lora",
|
| 282 |
+
action="store_true",
|
| 283 |
+
default=False,
|
| 284 |
+
help="Enable LoRA.",
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--add_bbox_frame_conditioning",
|
| 288 |
+
action="store_true",
|
| 289 |
+
default=False,
|
| 290 |
+
help="(Video generation only). Add bbox frame conditioning.",
|
| 291 |
+
)
|
| 292 |
+
parser.add_argument(
|
| 293 |
+
"--bbox_dropout_prob",
|
| 294 |
+
type=float,
|
| 295 |
+
default=0.0,
|
| 296 |
+
help="(Video generation only). Bbox dropout probability. Drops out the bbox conditionings.",
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument(
|
| 299 |
+
"--num_demo_samples",
|
| 300 |
+
type=int,
|
| 301 |
+
default=1,
|
| 302 |
+
help="Number of samples to generate during demo.",
|
| 303 |
+
)
|
| 304 |
+
parser.add_argument(
|
| 305 |
+
"--noise_aug_strength",
|
| 306 |
+
type=float,
|
| 307 |
+
default=0.02,
|
| 308 |
+
help="(Video generation only). Strength of noise augmentation.",
|
| 309 |
+
)
|
| 310 |
+
parser.add_argument(
|
| 311 |
+
"--num_inference_steps",
|
| 312 |
+
type=int,
|
| 313 |
+
default=25,
|
| 314 |
+
help="Number of inference denoising steps.",
|
| 315 |
+
)
|
| 316 |
+
parser.add_argument(
|
| 317 |
+
"--conditioning_scale",
|
| 318 |
+
type=float,
|
| 319 |
+
default=1.0,
|
| 320 |
+
help="(Controlnet only). The scale of conditioning."
|
| 321 |
+
)
|
| 322 |
+
parser.add_argument(
|
| 323 |
+
"--train_H",
|
| 324 |
+
type=int,
|
| 325 |
+
default=None,
|
| 326 |
+
help="For training, the height of the image to use. If None, the default height is 320 for video diffusion and 512 for image diffusion."
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--train_W",
|
| 330 |
+
type=int,
|
| 331 |
+
default=None,
|
| 332 |
+
help="For training, the width of the image to use. If None, the default width is 512."
|
| 333 |
+
)
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
"--eval_H",
|
| 336 |
+
type=int,
|
| 337 |
+
default=None,
|
| 338 |
+
help="For evaluation, the height of the image to use. If None, the default height is 320 for video diffusion and 512 for image diffusion."
|
| 339 |
+
)
|
| 340 |
+
parser.add_argument(
|
| 341 |
+
"--generate_bbox",
|
| 342 |
+
action="store_true",
|
| 343 |
+
default=False,
|
| 344 |
+
help="(Controlnet only). Whether to generate bbox."
|
| 345 |
+
)
|
| 346 |
+
parser.add_argument(
|
| 347 |
+
"--predict_bbox",
|
| 348 |
+
action="store_true",
|
| 349 |
+
default=False,
|
| 350 |
+
help="(Video diffusion only). Whether to predict bbox."
|
| 351 |
+
)
|
| 352 |
+
parser.add_argument(
|
| 353 |
+
"--evaluate_only",
|
| 354 |
+
action="store_true",
|
| 355 |
+
default=False,
|
| 356 |
+
help="Whether to only evaluate the model."
|
| 357 |
+
)
|
| 358 |
+
parser.add_argument(
|
| 359 |
+
"--demo_path",
|
| 360 |
+
default=None,
|
| 361 |
+
type=str,
|
| 362 |
+
help="Path where the demo samples are saved."
|
| 363 |
+
)
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
"--pretrained_bbox_model",
|
| 366 |
+
default=None,
|
| 367 |
+
type=str,
|
| 368 |
+
help="Path to the pretrained bbox model."
|
| 369 |
+
)
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
"--if_last_frame_trajectory",
|
| 372 |
+
action="store_true",
|
| 373 |
+
default=False,
|
| 374 |
+
help="Whether to use the last frame as the trajectory."
|
| 375 |
+
)
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--fps",
|
| 378 |
+
type=int,
|
| 379 |
+
default=None,
|
| 380 |
+
help="FPS of the video."
|
| 381 |
+
)
|
| 382 |
+
parser.add_argument(
|
| 383 |
+
"--num_cond_bbox_frames",
|
| 384 |
+
type=int,
|
| 385 |
+
default=3,
|
| 386 |
+
help="Number of conditioning bbox frames."
|
| 387 |
+
)
|
| 388 |
+
parser.add_argument(
|
| 389 |
+
"--wandb_entity",
|
| 390 |
+
type=str,
|
| 391 |
+
default="",
|
| 392 |
+
help="Wandb entity",
|
| 393 |
+
)
|
| 394 |
+
parser.add_argument(
|
| 395 |
+
"--non_overlapping_clips",
|
| 396 |
+
action="store_true",
|
| 397 |
+
default=False,
|
| 398 |
+
help="Load clips that do not overlap in dataset",
|
| 399 |
+
)
|
| 400 |
+
parser.add_argument(
|
| 401 |
+
"--disable_wandb",
|
| 402 |
+
action="store_true",
|
| 403 |
+
default=False,
|
| 404 |
+
help="Disable wandb logging.",
|
| 405 |
+
)
|
| 406 |
+
parser.add_argument(
|
| 407 |
+
"--empty_cuda_cache",
|
| 408 |
+
action="store_true",
|
| 409 |
+
default=False,
|
| 410 |
+
help="Periodically call `torch.cuda.empty_cache` every training loop. Helps to reduce chance of memory leakage and OOM error",
|
| 411 |
+
)
|
| 412 |
+
parser.add_argument(
|
| 413 |
+
"--bbox_masking_prob",
|
| 414 |
+
type=float,
|
| 415 |
+
default=0.0,
|
| 416 |
+
help="Bbox dropout probability. Drops out the bbox conditionings, will drop certain agents in the scene.",
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
parser.add_argument(
|
| 420 |
+
"--dataset_name",
|
| 421 |
+
nargs="+",
|
| 422 |
+
type=str,
|
| 423 |
+
default="russia_crash",
|
| 424 |
+
choices=["russia_crash", "nuscenes", "dada2000", "mmau", "bdd100k"],
|
| 425 |
+
help=(
|
| 426 |
+
"The name of the Dataset to train on. Can specify a list of dataset names to be merged together."
|
| 427 |
+
),
|
| 428 |
+
)
|
| 429 |
+
parser.add_argument(
|
| 430 |
+
"--use_action_conditioning",
|
| 431 |
+
action="store_true",
|
| 432 |
+
default=False,
|
| 433 |
+
help="Whether to use the action conditioning (ie accident type flags)"
|
| 434 |
+
)
|
| 435 |
+
parser.add_argument(
|
| 436 |
+
"--contiguous_bbox_masking_prob",
|
| 437 |
+
type=float,
|
| 438 |
+
default=0.0,
|
| 439 |
+
help="Prob to mask out bbox conditionings in a contiguous manner (ie, all bboxes after a certain frame). A random frame between [0, N] will be selected and all subsequent frames will be masked",
|
| 440 |
+
)
|
| 441 |
+
parser.add_argument(
|
| 442 |
+
"--contiguous_bbox_masking_start_ratio",
|
| 443 |
+
type=float,
|
| 444 |
+
default=0.0,
|
| 445 |
+
help="Ratio of contiguous bbox frames masking (as defined by --contiguous_bbox_masking_prob) to mask out from the start of the video instead of the end.",
|
| 446 |
+
)
|
| 447 |
+
parser.add_argument(
|
| 448 |
+
"--val_on_first_step",
|
| 449 |
+
action="store_true",
|
| 450 |
+
default=False,
|
| 451 |
+
help="Whether to run validation on the first step (useful for testing)"
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
args = parser.parse_args()
|
| 455 |
+
# default to using the same revision for the non-ema model if not specified
|
| 456 |
+
if args.non_ema_revision is None:
|
| 457 |
+
args.non_ema_revision = args.revision
|
| 458 |
+
|
| 459 |
+
if args.enable_lora:
|
| 460 |
+
args.backprop_temporal_blocks_start_iter = -1
|
| 461 |
+
|
| 462 |
+
if args.evaluate_only:
|
| 463 |
+
assert args.resume_from_checkpoint is not None, "Must provide a checkpoint to evaluate the model."
|
| 464 |
+
|
| 465 |
+
if args.fps is None:
|
| 466 |
+
if args.dataset_name == "bdd100k":
|
| 467 |
+
args.fps = 5
|
| 468 |
+
elif args.dataset_name == "nuscenes":
|
| 469 |
+
args.fps = 7 # NOTE: Must match the fps in nuscenes dataloader (=2 if no interpolation)
|
| 470 |
+
else:
|
| 471 |
+
args.fps = 7
|
| 472 |
+
return args
|