|
import os |
|
import cv2 |
|
import time |
|
import torch |
|
import random |
|
import gradio as gr |
|
import numpy as np |
|
from loguru import logger |
|
from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor |
|
|
|
|
|
CHECKPOINT = "HaithemH/vjepa2-vitl-fpc16-256-ssv2-66K-220cat" |
|
TORCH_DTYPE = torch.float16 |
|
TORCH_DEVICE = "cuda:4" |
|
UPDATE_EVERY_N_FRAMES = 16 |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16) |
|
model = model.to(TORCH_DEVICE) |
|
video_processor = AutoVideoProcessor.from_pretrained(CHECKPOINT) |
|
frames_per_clip = model.config.frames_per_clip |
|
|
|
|
|
def add_text_on_image(image, text): |
|
image[:70] = 0 |
|
line_spacing = 10 |
|
top_margin = 20 |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
font_scale = 0.5 |
|
thickness = 1 |
|
color = (255, 255, 255) |
|
words = text.split() |
|
lines = [] |
|
current_line = "" |
|
img_width = image.shape[1] |
|
for word in words: |
|
test_line = current_line + (" " if current_line else "") + word |
|
(test_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness) |
|
if test_width > img_width - 20: |
|
lines.append(current_line) |
|
current_line = word |
|
else: |
|
current_line = test_line |
|
if current_line: |
|
lines.append(current_line) |
|
y = top_margin |
|
for line in lines: |
|
(line_width, line_height), _ = cv2.getTextSize(line, font, font_scale, thickness) |
|
x = (img_width - line_width) // 2 |
|
cv2.putText(image, line, (x, y + line_height), font, font_scale, color, thickness, cv2.LINE_AA) |
|
y += line_height + line_spacing |
|
return image |
|
|
|
|
|
class RunningFramesCache: |
|
def __init__(self, max_frames=16): |
|
self.max_frames = max_frames |
|
self._frames = [] |
|
self.counter = 0 |
|
|
|
def add_frame(self, frame): |
|
self.counter += 1 |
|
self._frames.append(frame) |
|
if len(self._frames) > self.max_frames: |
|
self._frames.pop(0) |
|
|
|
def get_last_n_frames(self, n): |
|
return self._frames[-n:] |
|
|
|
def __len__(self): |
|
return len(self._frames) |
|
|
|
|
|
class RunningResult: |
|
def __init__(self, max_predictions=4): |
|
self.predictions = [] |
|
self.max_predictions = max_predictions |
|
|
|
def add_prediction(self, prediction): |
|
current_time = time.strftime("%H:%M:%S", time.gmtime(time.time())) |
|
self.predictions.append((current_time, prediction)) |
|
if len(self.predictions) > self.max_predictions: |
|
self.predictions.pop(0) |
|
|
|
def get_formatted(self): |
|
if not self.predictions: |
|
return "Starting..." |
|
current, *past = self.predictions[::-1] |
|
text = f">>> {current[1]}\n\n" + "\n".join( |
|
f"[{time_str}] {pred}" for time_str, pred in past |
|
) |
|
return text |
|
|
|
def get_last(self): |
|
return self.predictions[-1][1] if self.predictions else "Starting..." |
|
|
|
|
|
|
|
frames_cache = RunningFramesCache(max_frames=frames_per_clip) |
|
results_cache = RunningResult() |
|
|
|
|
|
def classify_frame(image): |
|
image = cv2.flip(image, 1) |
|
frames_cache.add_frame(image) |
|
|
|
if frames_cache.counter % UPDATE_EVERY_N_FRAMES == 0 and len(frames_cache) >= frames_per_clip: |
|
frames = frames_cache.get_last_n_frames(frames_per_clip) |
|
frames = np.array(frames) |
|
inputs = video_processor(frames, device=TORCH_DEVICE, return_tensors="pt") |
|
inputs = inputs.to(dtype=TORCH_DTYPE) |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
top_idx = logits.argmax(dim=-1).item() |
|
class_name = model.config.id2label[top_idx] |
|
logger.info(f"Predicted: {class_name}") |
|
results_cache.add_prediction(class_name) |
|
|
|
annotated_image = add_text_on_image(image.copy(), results_cache.get_last()) |
|
return annotated_image, results_cache.get_formatted() |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_frame, |
|
inputs=gr.Image(sources=["webcam"], streaming=True), |
|
outputs=[ |
|
gr.Image(label="Live Prediction", type="numpy"), |
|
gr.TextArea(label="Recent Predictions", lines=10), |
|
], |
|
live=True, |
|
title="V-JEPA 2: Streaming Video Action Recognition - SSV2", |
|
description="This demo showcases a specialized version of V-JEPA 2, fine-tuned for real-time video action recognition!", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |