import spaces import os, json, tempfile, uuid from typing import List, Tuple, Dict, Optional import numpy as np from PIL import Image, ImageDraw import torch import gradio as gr import subprocess import importlib, site, sys import time import shutil from pathlib import Path # Re-discover all .pth/.egg-link files for sitedir in site.getsitepackages(): site.addsitedir(sitedir) # Clear caches so importlib will pick up new modules importlib.invalidate_caches() def sh(cmd): subprocess.check_call(cmd, shell=True) sh(f"pip install git+https://github.com/facebookresearch/sam2.git") # tell Python to re-scan site-packages now that the egg-link exists import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches() from huggingface_hub import hf_hub_download, snapshot_download import cv2 from src.utils.yolo_sam import YoloSamProcessor sam_2_ckpt = hf_hub_download( repo_id="facebook/sam2.1-hiera-base-plus", repo_type="model", filename="sam2.1_hiera_base_plus.pt", ) from diffusers.utils import logging # Enable tqdm as the progress bar logging.set_verbosity_info() logging.enable_progress_bar() from src.eval.generate_samples import ( load_ctrlv_pipelines, generate_video_ctrlv, ) os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" CLIP_LENGTH = 25 W, H = 512, 320 # the code calls pipeline with width=512, height=320 # --------- action types ---------- # Display labels mapped to integer IDs ACTION_TYPES = { "No Crash": 0, "Ego-Only crash": 1, "Ego-and-Vehicle crash": 2, "Vehicle-Only crash": 3, "Vehicle-and-Vehicle crash": 4, } # --------- helpers ---------- def resize_pad_to(img: Image.Image, w: int, h: int) -> Image.Image: """Letterbox to (w,h) preserving aspect ratio.""" img = img.convert("RGB") r = min(w / img.width, h / img.height) nw, nh = int(img.width * r), int(img.height * r) img_resized = img.resize((nw, nh), Image.BICUBIC) canvas = Image.new("RGB", (w, h), (0, 0, 0)) canvas.paste(img_resized, ((w - nw) // 2, (h - nh) // 2)) return canvas def draw_boxes_frame( size: Tuple[int, int], boxes: List[Tuple[float,float,float,float]], # normalized colors: Optional[List[Tuple[int,int,int]]] = None, thickness: int = 3, bg: Tuple[int,int,int] = (0,0,0), # black background ) -> np.ndarray: """Return CHW uint8 frame with boxes drawn. Expects normalized boxes in [0,1].""" W_, H_ = size im = Image.new("RGB", (W_, H_), bg) d = ImageDraw.Draw(im) if not boxes: return np.asarray(im).transpose(2,0,1) if colors is None: colors = [(255,255,255)] * len(boxes) # white boxes def denorm(b): x1 = int(round(b[0] * W_)); y1 = int(round(b[1] * H_)) x2 = int(round(b[2] * W_)); y2 = int(round(b[3] * H_)) return x1, y1, x2, y2 for b, c in zip(boxes, colors): x1,y1,x2,y2 = denorm(b) for t in range(thickness): d.rectangle([x1-t, y1-t, x2+t, y2+t], outline=c) return np.asarray(im).transpose(2,0,1) def build_bbox_video( per_frame_boxes: Dict[int, List[Tuple[float, float, float, float]]], # normalized size: Tuple[int, int] = (W, H), T: int = CLIP_LENGTH ): """ per_frame_boxes: {frame_index: [(x1,y1,x2,y2), ...]} where frame_index is 0..T-1. No interpolation: frames without an entry get no boxes. """ frames_np = [] for t in range(T): boxes = per_frame_boxes.get(t, []) frames_np.append(draw_boxes_frame(size, boxes)) frames_np = np.stack(frames_np, axis=0) # [T, 3, H, W] uint8 frames_t = torch.from_numpy(frames_np.astype(np.float32) / 255.0) return frames_np, frames_t def make_sample( init_image_path: str, keyframes_boxes: Dict[int, List[Tuple[float,float,float,float]]], action_id: int, ): init_img = Image.open(init_image_path) init_img = resize_pad_to(init_img, W, H) bbox_np, bbox_t = build_bbox_video(keyframes_boxes, (W,H)) action_type = torch.tensor([action_id], dtype=torch.long) sample = { "image_init": init_img, # PIL.Image "bbox_images": bbox_t, # [T,3,H,W] float32 in [0,1] "bbox_images_np": bbox_np, # for optional viz "action_type": action_type.squeeze(0), # scalar tensor "image_paths": [[f"frame_{i:02d}.png"] for i in range(CLIP_LENGTH)], "gt_clip_np": np.stack( [np.full((3,H,W), 255, dtype=np.uint8) for _ in range(CLIP_LENGTH)], axis=0 ), "vid_name": "single_sample", } return sample g_use_factor_guidance = True pipe = load_ctrlv_pipelines( "AnthonyGosselin/Ctrl-Crash", use_null_model=True, use_factor_guidance=g_use_factor_guidance ) def run_yolo_sam_on_image( image_dir, yolo_ckpt="yolov8x.pt", sam2_ckpt="sam2.1_hiera_base_plus.pt", sam2_cfg="./configs/sam2.1/sam2.1_hiera_b+.yaml", rel_bbox=True ): """ Returns: dict[int, list[[x1,y1,x2,y2], ...]] for frames found in image_dir. Keys are integer frame ids parsed from '.jpeg' or similar filenames. """ yolo_sam = YoloSamProcessor(yolo_ckpt, sam2_ckpt, sam2_cfg) result = yolo_sam(image_dir, rel_bbox=rel_bbox) vehicles = {"car", "truck"} output_dict: Dict[int, List[List[float]]] = {} # YoloSamProcessor may return a dict (single image) or list (multiple) iter_images = [result] if isinstance(result, dict) else (result or []) for img_res in iter_images: filename = os.path.basename(img_res.get("image_source", "0")) # Accept names like '00000.jpeg', '00000.jpg', '0.png' stem = os.path.splitext(filename)[0] try: frame = int(stem) except ValueError: # fallback: try to strip non-digits digits = "".join(ch for ch in stem if ch.isdigit()) frame = int(digits) if digits else 0 # fresh list per frame frame_boxes: List[List[float]] = [] for lab in img_res.get("labels", []): if lab.get("name") in vehicles: box = lab.get("box") if isinstance(box, (list, tuple)) and len(box) == 4: frame_boxes.append([float(box[0]), float(box[1]), float(box[2]), float(box[3])]) if frame_boxes: output_dict[frame] = frame_boxes return output_dict # --------- inference ---------- def run_one( init_image_path: str, keyframes_boxes: Dict[int, List[Tuple[float,float,float,float]]], action_id: int, out_dir: str, guidance: Tuple[float,float] = (1.0, 3.0), ) -> str: os.makedirs(out_dir, exist_ok=True) sample = make_sample(init_image_path, keyframes_boxes, action_id) # mask all frames, then unmask only those with boxes bbox_mask_frames = [False] * CLIP_LENGTH for idx in keyframes_boxes.keys(): if 0 <= idx < CLIP_LENGTH: bbox_mask_frames[idx] = True video_prefix = os.path.join(out_dir, "genvid_minimal") generate_video_ctrlv( sample, pipeline=pipe, video_path=video_prefix, json_path=os.path.join(out_dir, "gt_frames_minimal.json"), bbox_mask_frames=bbox_mask_frames, action_type=sample["action_type"].unsqueeze(0), use_factor_guidance=g_use_factor_guidance, guidance=list(guidance), video_path2=None, ) return f"{video_prefix}.mp4" def parse_keyframes_json(text: str) -> Dict[int, List[Tuple[float,float,float,float]]]: if not text.strip(): raise ValueError("Provide a JSON mapping of frame_index → list of [x1,y1,x2,y2].") data = json.loads(text) out = {} for k, v in data.items(): ki = int(k) boxes = [] for b in v: if len(b) != 4: raise ValueError("Each box must be [x1,y1,x2,y2].") boxes.append(tuple(float(x) for x in b)) out[ki] = boxes return out def extract_frames(video_file, target_frames=CLIP_LENGTH, num_frames=3, session_id_dir = None): # Create temporary directory to store extracted frames video_frames_dir = os.path.join(session_id_dir, f"video_frames_{int(time.time())}") os.makedirs(video_frames_dir, exist_ok=True) cap = cv2.VideoCapture(video_file) if not cap.isOpened(): raise RuntimeError(f"Could not open video: {video_file}") orig_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) final_frames = int(target_frames) if final_frames >= orig_frame_count and orig_frame_count > 0: frame_indices = list(range(orig_frame_count)) # pad by repeating the last index if needed while len(frame_indices) < final_frames: frame_indices.append(frame_indices[-1]) frame_indices = frame_indices[:final_frames] frame_indices = frame_indices[:num_frames] else: # Uniform sampling to exactly final_frames interval = max(orig_frame_count / max(final_frames, 1), 1) frame_indices = [min(int(i * interval), max(orig_frame_count - 1, 0)) for i in range(final_frames)] frame_indices = frame_indices[:num_frames] first_frame_path = None for i, frame_idx in enumerate(frame_indices): cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() if not ret or frame is None: # fallback: reuse previous good frame if possible if first_frame_path is not None: # write a duplicate of the first good frame as a placeholder frame = cv2.imread(first_frame_path) if frame is None: continue else: continue fname = f"{i:05d}.jpeg" output_path = os.path.join(video_frames_dir, fname) if not cv2.imwrite(output_path, frame): raise RuntimeError(f"Failed to write {output_path}") if first_frame_path is None: first_frame_path = output_path cap.release() if first_frame_path is None: raise RuntimeError("No frames could be extracted from the video.") return video_frames_dir, first_frame_path @spaces.GPU(duration=130) def ctrl_generate_from_video( vid_path, action_label: str, num_frames: int = 3, guidance_min: float = 1.0, guidance_max: float = 3.0, session_id = None ): if session_id is None: session_id = uuid.uuid4().hex session_id_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) os.makedirs(session_id_dir, exist_ok=True) if vid_path is None: raise gr.Error("Please upload an video.") img_dir, first_frame = extract_frames(vid_path, num_frames= num_frames, session_id_dir=session_id_dir) keyframes_out = run_yolo_sam_on_image(img_dir, sam2_ckpt=sam_2_ckpt) keyframes_json = json.dumps(keyframes_out, indent=2) gif_path = os.path.join(session_id_dir, "boxes.gif") save_boxes_gif(img_dir, keyframes_out, gif_path, size=(W, H), fps=6) # Parse keyframes keyframes = parse_keyframes_json(keyframes_json) # Map dropdown label → action_id if action_label not in ACTION_TYPES: raise gr.Error("Invalid action type selection.") action_id = ACTION_TYPES[action_label] print(action_id) # Unique out dir to avoid collisions between runs/sessions out_dir = os.path.join(session_id_dir, "video_out") # Run model mp4_path = run_one( init_image_path=first_frame, keyframes_boxes=keyframes, action_id=action_id, out_dir=out_dir, guidance=(guidance_min, guidance_max), ) if not os.path.exists(mp4_path): raise gr.Error("Video file was not created. Check model path and logs.") return mp4_path, gif_path @spaces.GPU(duration=130) def ctrl_generate_from_image( img_path, action_label: str, num_frames: int = 3, guidance_min: float = 1.0, guidance_max: float = 3.0, session_id = None ): if session_id is None: session_id = uuid.uuid4().hex session_id_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) os.makedirs(session_id_dir, exist_ok=True) if img_path is None: raise gr.Error("Please upload an Image.") img_dir = os.path.join(session_id_dir, f"video_frames_{int(time.time())}") os.makedirs(img_dir, exist_ok=True) raw_image = Image.open(img_path) first_frame = os.path.join(img_dir,"00001.jpeg") second_frame = os.path.join(img_dir,"00001.jpeg") third_frame = os.path.join(img_dir,"00001.jpeg") raw_image.save(first_frame) raw_image.save(second_frame) raw_image.save(third_frame) keyframes_out = run_yolo_sam_on_image(img_dir, sam2_ckpt=sam_2_ckpt) keyframes_json = json.dumps(keyframes_out, indent=2) # Parse keyframes keyframes = parse_keyframes_json(keyframes_json) # Map dropdown label → action_id if action_label not in ACTION_TYPES: raise gr.Error("Invalid action type selection.") action_id = ACTION_TYPES[action_label] print(action_id) # Unique out dir to avoid collisions between runs/sessions out_dir = os.path.join(session_id_dir, "video_out") # Run model mp4_path = run_one( init_image_path=first_frame, keyframes_boxes=keyframes, action_id=action_id, out_dir=out_dir, guidance=(guidance_min, guidance_max), ) if not os.path.exists(mp4_path): raise gr.Error("Video file was not created. Check model path and logs.") return mp4_path css = """ #col-container { margin: 0 auto; max-width: 1024px; } /* editable vs locked, reusing theme variables that adapt to dark/light */ .stateful textarea:not(:disabled):not([readonly]) { color: var(--color-text) !important; /* accent in both modes */ } .stateful textarea:disabled, .stateful textarea[readonly]{ color: var(--body-text-color-subdued) !important; /* subdued in both modes */ } """ def cleanup(request: gr.Request): sid = request.session_hash if sid: d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid) shutil.rmtree(d1, ignore_errors=True) def start_session(request: gr.Request): return request.session_hash def save_boxes_gif( img_dir: str, per_frame_boxes: Dict[int, List[Tuple[float, float, float, float]]], out_path: str, size: Tuple[int, int] = (W, H), fps: int = 6, box_color: Tuple[int,int,int] = (255, 0, 0), thickness: int = 3, ): """ Draw normalized boxes on the extracted frames and save as an animated GIF. """ W_, H_ = size frames: List[Image.Image] = [] # sort frames by numeric stem (00000, 00001, ...) fpaths = sorted(Path(img_dir).glob("*.*"), key=lambda p: int(''.join(ch for ch in p.stem if ch.isdigit()) or 0)) for idx, p in enumerate(fpaths): try: im = Image.open(str(p)).convert("RGB") except Exception: continue im = resize_pad_to(im, W_, H_) draw = ImageDraw.Draw(im) # draw any boxes for this frame index for b in per_frame_boxes.get(idx, []): x1 = int(round(b[0] * W_)); y1 = int(round(b[1] * H_)) x2 = int(round(b[2] * W_)); y2 = int(round(b[3] * H_)) # thickness by expanding rectangle for t in range(thickness): draw.rectangle([x1 - t, y1 - t, x2 + t, y2 + t], outline=box_color) frames.append(im) if not frames: # fallback: single black frame frames = [Image.new("RGB", size, (0, 0, 0))] # duration is per-frame ms duration_ms = int(1000 / max(fps, 1)) frames[0].save( out_path, save_all=True, append_images=frames[1:], format="GIF", duration=duration_ms, loop=0, optimize=False, disposal=2, ) return out_path with gr.Blocks(css=css) as demo: session_state = gr.State() demo.load(start_session, outputs=[session_state]) with gr.Column(elem_id="col-container"): gr.HTML( """

Ctrl-Crash – Controllable Diffusion for Realistic Car Crashes

[Porject]

Improving traffic safety requires realistic and controllable accident simulations.

To tackle the problem, Ctrl-Crash, a controllable car crash video generation model that conditions on signals such as bounding boxes, crash types, and an initial image frame

HF Space by: GitHub Repo
""" ) with gr.Row(): with gr.Column(scale=1): with gr.Tab("Video"): video_in = gr.Video(label="Dashcam Footage") bbx_frames = gr.Slider(1, 25, value=3, step=1, label="Guidance Frames") gr.Text(value="⌚ Zero GPU Required: ~130s (2.1 mins)", show_label=False) run_video_btn = gr.Button("🚙💥🚗 Bish Bash Bosh", variant="primary") with gr.Tab("Image"): image_in = gr.Image(label="Dashcam Image", type="filepath") gr.Text(value="⌚ Zero GPU Required: ~130s (2.1 mins)", show_label=False) run_image_btn = gr.Button("🚙💥🚗 Bish Bash Bosh", variant="primary") action_dropdown = gr.Dropdown( label="Crash Type", choices=list(ACTION_TYPES.keys()), value="Ego-and-Vehicle crash", ) boxes_gif = gr.Image(label="Detected Boxes (GIF)", type="filepath", interactive=False) with gr.Column(scale=1): video_out = gr.Video(label="Output video") cached_examples = gr.Examples( examples=[ [ "examples/gt_video_4.mp4", "Ego-and-Vehicle crash", 25, ], [ "examples/gt_video_1.mp4", "Ego-and-Vehicle crash", 22, ], [ "examples/gt_video_3.mp4", "Ego-and-Vehicle crash", 25, ], [ "examples/gt_video_1.mp4", "No Crash", 9, ], [ "examples/gt_video_4.mp4", "Ego-Only crash", 3, ], ], label="Cached Examples", inputs=[video_in, action_dropdown, bbx_frames], outputs=[video_out, boxes_gif], fn=ctrl_generate_from_video, cache_examples=True ) run_video_btn.click( fn=ctrl_generate_from_video, inputs=[ video_in, action_dropdown, bbx_frames ], outputs=[video_out, boxes_gif], api_name="generate" ) run_image_btn.click( fn=ctrl_generate_from_image, inputs=[ image_in, action_dropdown, bbx_frames ], outputs=[video_out], api_name="generate" ) if __name__ == "__main__": demo.unload(cleanup) demo.queue() demo.launch(share=True, ssr_mode=False)