import os import cv2 import time import torch import random import gradio as gr import numpy as np from loguru import logger from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor # Config CHECKPOINT = "HaithemH/vjepa2-vitl-fpc16-256-ssv2-66K-220cat" TORCH_DTYPE = torch.float16 TORCH_DEVICE = "cuda:4" # Change if needed UPDATE_EVERY_N_FRAMES = 16 HF_TOKEN = os.getenv("HF_TOKEN") # Load model & processor model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16) model = model.to(TORCH_DEVICE) video_processor = AutoVideoProcessor.from_pretrained(CHECKPOINT) frames_per_clip = model.config.frames_per_clip def add_text_on_image(image, text): image[:70] = 0 line_spacing = 10 top_margin = 20 font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 thickness = 1 color = (255, 255, 255) words = text.split() lines = [] current_line = "" img_width = image.shape[1] for word in words: test_line = current_line + (" " if current_line else "") + word (test_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness) if test_width > img_width - 20: lines.append(current_line) current_line = word else: current_line = test_line if current_line: lines.append(current_line) y = top_margin for line in lines: (line_width, line_height), _ = cv2.getTextSize(line, font, font_scale, thickness) x = (img_width - line_width) // 2 cv2.putText(image, line, (x, y + line_height), font, font_scale, color, thickness, cv2.LINE_AA) y += line_height + line_spacing return image class RunningFramesCache: def __init__(self, max_frames=16): self.max_frames = max_frames self._frames = [] self.counter = 0 def add_frame(self, frame): self.counter += 1 self._frames.append(frame) if len(self._frames) > self.max_frames: self._frames.pop(0) def get_last_n_frames(self, n): return self._frames[-n:] def __len__(self): return len(self._frames) class RunningResult: def __init__(self, max_predictions=4): self.predictions = [] self.max_predictions = max_predictions def add_prediction(self, prediction): current_time = time.strftime("%H:%M:%S", time.gmtime(time.time())) self.predictions.append((current_time, prediction)) if len(self.predictions) > self.max_predictions: self.predictions.pop(0) def get_formatted(self): if not self.predictions: return "Starting..." current, *past = self.predictions[::-1] text = f">>> {current[1]}\n\n" + "\n".join( f"[{time_str}] {pred}" for time_str, pred in past ) return text def get_last(self): return self.predictions[-1][1] if self.predictions else "Starting..." # Shared state frames_cache = RunningFramesCache(max_frames=frames_per_clip) results_cache = RunningResult() def classify_frame(image): image = cv2.flip(image, 1) # mirror webcam frames_cache.add_frame(image) if frames_cache.counter % UPDATE_EVERY_N_FRAMES == 0 and len(frames_cache) >= frames_per_clip: frames = frames_cache.get_last_n_frames(frames_per_clip) frames = np.array(frames) inputs = video_processor(frames, device=TORCH_DEVICE, return_tensors="pt") inputs = inputs.to(dtype=TORCH_DTYPE) with torch.no_grad(): logits = model(**inputs).logits top_idx = logits.argmax(dim=-1).item() class_name = model.config.id2label[top_idx] logger.info(f"Predicted: {class_name}") results_cache.add_prediction(class_name) annotated_image = add_text_on_image(image.copy(), results_cache.get_last()) return annotated_image, results_cache.get_formatted() # Gradio UI demo = gr.Interface( fn=classify_frame, inputs=gr.Image(sources=["webcam"], streaming=True), outputs=[ gr.Image(label="Live Prediction", type="numpy"), gr.TextArea(label="Recent Predictions", lines=10), ], live=True, title="V-JEPA 2: Streaming Video Action Recognition - SSV2", description="This demo showcases a specialized version of V-JEPA 2, fine-tuned for real-time video action recognition!", ) if __name__ == "__main__": demo.launch(share=True)