Spaces:
Runtime error
Runtime error
import ffmpegio | |
import gc | |
import torch | |
from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation | |
from config import FPS_DIV, MAX_LENGTH, BATCH_SIZE, MODEL_PATH | |
class PreprocessModel(torch.nn.Module): | |
device = 'cpu' | |
def __init__(self): | |
super().__init__() | |
self.feature_extractor = MobileViTImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small") | |
self.mobile_vit = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small") | |
self.convs = torch.nn.Sequential( | |
torch.nn.MaxPool2d(2, 2) | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.mobile_vit(x).logits | |
x = self.convs(x) | |
return x | |
def read_video(self, path: str) -> torch.Tensor: | |
""" | |
Читает видео и возвращает тензор с фичами | |
""" | |
_, video = ffmpegio.video.read(path, t=1.0) | |
video = video[::FPS_DIV][:MAX_LENGTH] | |
out_seg_video = [] | |
for i in range(0, video.shape[0], BATCH_SIZE): | |
frames = [video[j] for j in range(i, min(i + BATCH_SIZE, video.shape[0]))] | |
frames = self.feature_extractor(images=frames, return_tensors='pt')['pixel_values'] | |
out = self.forward(frames.to(self.device)).detach().to('cpu') | |
out_seg_video.append(out) | |
del frames, out | |
gc.collect() | |
if self.device == 'cuda': | |
torch.cuda.empty_cache() | |
return torch.cat(out_seg_video) | |
class VideoModel(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
p = 0.5 | |
self.pic_cnn = torch.nn.Sequential( | |
torch.nn.Conv2d(21, 128, (2, 2), stride=2), | |
torch.nn.BatchNorm2d(128), | |
torch.nn.LeakyReLU(), | |
torch.nn.Conv2d(128, 256, (2, 2), stride=2), | |
torch.nn.BatchNorm2d(256), | |
torch.nn.Dropout2d(p), | |
torch.nn.LeakyReLU(), | |
torch.nn.Conv2d(256, 256, (4, 4), stride=2), | |
torch.nn.BatchNorm2d(256), | |
torch.nn.Dropout2d(p), | |
torch.nn.Flatten() | |
) | |
self.vid_cnn = torch.nn.Sequential( | |
torch.nn.Conv2d(21, 128, (2, 2), stride=2), | |
torch.nn.BatchNorm2d(128), | |
torch.nn.Tanh(), | |
torch.nn.Conv2d(128, 256, (2, 2), stride=2), | |
torch.nn.BatchNorm2d(256), | |
torch.nn.Dropout2d(p), | |
torch.nn.LeakyReLU(), | |
torch.nn.Conv2d(256, 512, (2, 2), stride=2), | |
torch.nn.BatchNorm2d(512), | |
torch.nn.Dropout2d(p), | |
torch.nn.Flatten() | |
) | |
self.lstm = torch.nn.LSTM(2048, 256, 1, batch_first=True, bidirectional=True) | |
self.fc1 = torch.nn.Linear(256 * 2, 1024) | |
self.fc_norm = torch.nn.BatchNorm1d(256 * 2) | |
self.tanh = torch.nn.Tanh() | |
self.fc2 = torch.nn.Linear(1024, 2) | |
self.sigmoid = torch.nn.Sigmoid() | |
self.dropout = torch.nn.Dropout(p) | |
# xaiver init | |
for m in self.modules(): | |
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv3d): | |
torch.nn.init.xavier_uniform_(m.weight) | |
if m.bias is not None: | |
torch.nn.init.zeros_(m.bias) | |
elif isinstance(m, torch.nn.Linear): | |
torch.nn.init.xavier_uniform_(m.weight) | |
if m.bias is not None: | |
torch.nn.init.zeros_(m.bias) | |
def forward(self, video: torch.Tensor) -> torch.Tensor: | |
""" | |
Использует превью как начальное скрытое состояние, а кадры видео как последовательность. | |
video[0] - превью, video[1] - видео | |
:param video: torch.Tensor, shape = (batch_size, frames + 1, 1344) | |
""" | |
frames = video.shape[0] | |
video = torch.nn.functional.pad(video, (0, 0, 0, 0, 0, 0, MAX_LENGTH + 1 - frames, 0)) | |
video = video.unsqueeze(0) | |
_batch_size = video.shape[0] | |
_preview = video[:, 0, :, :] | |
_video = video[:, 1:, :, :] | |
h0 = self.pic_cnn(_preview).unsqueeze(0) | |
h0 = torch.nn.functional.pad(h0, (0, 0, 0, 0, 0, 1)) | |
c0 = torch.zeros_like(h0) | |
_video = self.vid_cnn(_video.reshape(-1, 21, 16, 16)) | |
_video = _video.reshape(_batch_size, 90, -1) | |
context, _ = self.lstm(_video, (h0, c0)) | |
out = self.fc_norm(context[:, -1]) | |
out = self.tanh(self.fc1(out)) | |
out = self.dropout(out) | |
out = self.sigmoid(self.fc2(out)) | |
return out | |
# @st.cache_resource | |
class TikTokAnalytics(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.preprocessing_model = PreprocessModel() | |
self.predict_model = torch.load(MODEL_PATH, map_location=self.preprocessing_model.device) | |
self.preprocessing_model.eval() | |
self.predict_model.eval() | |
def forward(self, path: str) -> torch.Tensor: | |
""" | |
Вызываем препроцесс, потом предикт | |
:param path: | |
:return: | |
""" | |
tensor = self.preprocessing_model.read_video(path) | |
predict = self.predict_model(tensor) | |
return predict | |
# if __name__ == '__main__': | |
# model = TikTokAnalytics() | |
# model = model( | |
# '/Users/victorbarbarich/PycharmProjects/nueramic/vktrbr-video-tiktok/data/videos/video-6930454291186502917.mp4') | |
# print(model) | |