HaithemH commited on
Commit
feb730b
·
verified ·
1 Parent(s): b1e0874

Upload gradio_demo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gradio_demo.py +138 -0
gradio_demo.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import torch
5
+ import random
6
+ import gradio as gr
7
+ import numpy as np
8
+ from loguru import logger
9
+ from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor
10
+
11
+ # Config
12
+ CHECKPOINT = "HaithemH/vjepa2-vitl-fpc16-256-ssv2-66K-220cat"
13
+ TORCH_DTYPE = torch.float16
14
+ TORCH_DEVICE = "cuda:4" # Change if needed
15
+ UPDATE_EVERY_N_FRAMES = 16
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+
18
+ # Load model & processor
19
+ model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16)
20
+ model = model.to(TORCH_DEVICE)
21
+ video_processor = AutoVideoProcessor.from_pretrained(CHECKPOINT)
22
+ frames_per_clip = model.config.frames_per_clip
23
+
24
+
25
+ def add_text_on_image(image, text):
26
+ image[:70] = 0
27
+ line_spacing = 10
28
+ top_margin = 20
29
+ font = cv2.FONT_HERSHEY_SIMPLEX
30
+ font_scale = 0.5
31
+ thickness = 1
32
+ color = (255, 255, 255)
33
+ words = text.split()
34
+ lines = []
35
+ current_line = ""
36
+ img_width = image.shape[1]
37
+ for word in words:
38
+ test_line = current_line + (" " if current_line else "") + word
39
+ (test_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness)
40
+ if test_width > img_width - 20:
41
+ lines.append(current_line)
42
+ current_line = word
43
+ else:
44
+ current_line = test_line
45
+ if current_line:
46
+ lines.append(current_line)
47
+ y = top_margin
48
+ for line in lines:
49
+ (line_width, line_height), _ = cv2.getTextSize(line, font, font_scale, thickness)
50
+ x = (img_width - line_width) // 2
51
+ cv2.putText(image, line, (x, y + line_height), font, font_scale, color, thickness, cv2.LINE_AA)
52
+ y += line_height + line_spacing
53
+ return image
54
+
55
+
56
+ class RunningFramesCache:
57
+ def __init__(self, max_frames=16):
58
+ self.max_frames = max_frames
59
+ self._frames = []
60
+ self.counter = 0
61
+
62
+ def add_frame(self, frame):
63
+ self.counter += 1
64
+ self._frames.append(frame)
65
+ if len(self._frames) > self.max_frames:
66
+ self._frames.pop(0)
67
+
68
+ def get_last_n_frames(self, n):
69
+ return self._frames[-n:]
70
+
71
+ def __len__(self):
72
+ return len(self._frames)
73
+
74
+
75
+ class RunningResult:
76
+ def __init__(self, max_predictions=4):
77
+ self.predictions = []
78
+ self.max_predictions = max_predictions
79
+
80
+ def add_prediction(self, prediction):
81
+ current_time = time.strftime("%H:%M:%S", time.gmtime(time.time()))
82
+ self.predictions.append((current_time, prediction))
83
+ if len(self.predictions) > self.max_predictions:
84
+ self.predictions.pop(0)
85
+
86
+ def get_formatted(self):
87
+ if not self.predictions:
88
+ return "Starting..."
89
+ current, *past = self.predictions[::-1]
90
+ text = f">>> {current[1]}\n\n" + "\n".join(
91
+ f"[{time_str}] {pred}" for time_str, pred in past
92
+ )
93
+ return text
94
+
95
+ def get_last(self):
96
+ return self.predictions[-1][1] if self.predictions else "Starting..."
97
+
98
+
99
+ # Shared state
100
+ frames_cache = RunningFramesCache(max_frames=frames_per_clip)
101
+ results_cache = RunningResult()
102
+
103
+
104
+ def classify_frame(image):
105
+ image = cv2.flip(image, 1) # mirror webcam
106
+ frames_cache.add_frame(image)
107
+
108
+ if frames_cache.counter % UPDATE_EVERY_N_FRAMES == 0 and len(frames_cache) >= frames_per_clip:
109
+ frames = frames_cache.get_last_n_frames(frames_per_clip)
110
+ frames = np.array(frames)
111
+ inputs = video_processor(frames, device=TORCH_DEVICE, return_tensors="pt")
112
+ inputs = inputs.to(dtype=TORCH_DTYPE)
113
+ with torch.no_grad():
114
+ logits = model(**inputs).logits
115
+ top_idx = logits.argmax(dim=-1).item()
116
+ class_name = model.config.id2label[top_idx]
117
+ logger.info(f"Predicted: {class_name}")
118
+ results_cache.add_prediction(class_name)
119
+
120
+ annotated_image = add_text_on_image(image.copy(), results_cache.get_last())
121
+ return annotated_image, results_cache.get_formatted()
122
+
123
+
124
+ # Gradio UI
125
+ demo = gr.Interface(
126
+ fn=classify_frame,
127
+ inputs=gr.Image(sources=["webcam"], streaming=True),
128
+ outputs=[
129
+ gr.Image(label="Live Prediction", type="numpy"),
130
+ gr.TextArea(label="Recent Predictions", lines=10),
131
+ ],
132
+ live=True,
133
+ title="V-JEPA 2: Streaming Video Action Recognition - SSV2",
134
+ description="This demo showcases a specialized version of V-JEPA 2, fine-tuned for real-time video action recognition!",
135
+ )
136
+
137
+ if __name__ == "__main__":
138
+ demo.launch(share=True)