tiktokApp / model_processing.py
vktrbr's picture
test
8d5f133
import ffmpegio
import gc
import torch
from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation
from config import FPS_DIV, MAX_LENGTH, BATCH_SIZE, MODEL_PATH
class PreprocessModel(torch.nn.Module):
device = 'cpu'
def __init__(self):
super().__init__()
self.feature_extractor = MobileViTImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small")
self.mobile_vit = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
self.convs = torch.nn.Sequential(
torch.nn.MaxPool2d(2, 2)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mobile_vit(x).logits
x = self.convs(x)
return x
def read_video(self, path: str) -> torch.Tensor:
"""
Читает видео и возвращает тензор с фичами
"""
_, video = ffmpegio.video.read(path, t=1.0)
video = video[::FPS_DIV][:MAX_LENGTH]
out_seg_video = []
for i in range(0, video.shape[0], BATCH_SIZE):
frames = [video[j] for j in range(i, min(i + BATCH_SIZE, video.shape[0]))]
frames = self.feature_extractor(images=frames, return_tensors='pt')['pixel_values']
out = self.forward(frames.to(self.device)).detach().to('cpu')
out_seg_video.append(out)
del frames, out
gc.collect()
if self.device == 'cuda':
torch.cuda.empty_cache()
return torch.cat(out_seg_video)
class VideoModel(torch.nn.Module):
def __init__(self):
super().__init__()
p = 0.5
self.pic_cnn = torch.nn.Sequential(
torch.nn.Conv2d(21, 128, (2, 2), stride=2),
torch.nn.BatchNorm2d(128),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(128, 256, (2, 2), stride=2),
torch.nn.BatchNorm2d(256),
torch.nn.Dropout2d(p),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(256, 256, (4, 4), stride=2),
torch.nn.BatchNorm2d(256),
torch.nn.Dropout2d(p),
torch.nn.Flatten()
)
self.vid_cnn = torch.nn.Sequential(
torch.nn.Conv2d(21, 128, (2, 2), stride=2),
torch.nn.BatchNorm2d(128),
torch.nn.Tanh(),
torch.nn.Conv2d(128, 256, (2, 2), stride=2),
torch.nn.BatchNorm2d(256),
torch.nn.Dropout2d(p),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(256, 512, (2, 2), stride=2),
torch.nn.BatchNorm2d(512),
torch.nn.Dropout2d(p),
torch.nn.Flatten()
)
self.lstm = torch.nn.LSTM(2048, 256, 1, batch_first=True, bidirectional=True)
self.fc1 = torch.nn.Linear(256 * 2, 1024)
self.fc_norm = torch.nn.BatchNorm1d(256 * 2)
self.tanh = torch.nn.Tanh()
self.fc2 = torch.nn.Linear(1024, 2)
self.sigmoid = torch.nn.Sigmoid()
self.dropout = torch.nn.Dropout(p)
# xaiver init
for m in self.modules():
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv3d):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
elif isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
def forward(self, video: torch.Tensor) -> torch.Tensor:
"""
Использует превью как начальное скрытое состояние, а кадры видео как последовательность.
video[0] - превью, video[1] - видео
:param video: torch.Tensor, shape = (batch_size, frames + 1, 1344)
"""
frames = video.shape[0]
video = torch.nn.functional.pad(video, (0, 0, 0, 0, 0, 0, MAX_LENGTH + 1 - frames, 0))
video = video.unsqueeze(0)
_batch_size = video.shape[0]
_preview = video[:, 0, :, :]
_video = video[:, 1:, :, :]
h0 = self.pic_cnn(_preview).unsqueeze(0)
h0 = torch.nn.functional.pad(h0, (0, 0, 0, 0, 0, 1))
c0 = torch.zeros_like(h0)
_video = self.vid_cnn(_video.reshape(-1, 21, 16, 16))
_video = _video.reshape(_batch_size, 90, -1)
context, _ = self.lstm(_video, (h0, c0))
out = self.fc_norm(context[:, -1])
out = self.tanh(self.fc1(out))
out = self.dropout(out)
out = self.sigmoid(self.fc2(out))
return out
# @st.cache_resource
class TikTokAnalytics(torch.nn.Module):
def __init__(self):
super().__init__()
self.preprocessing_model = PreprocessModel()
self.predict_model = torch.load(MODEL_PATH, map_location=self.preprocessing_model.device)
self.preprocessing_model.eval()
self.predict_model.eval()
def forward(self, path: str) -> torch.Tensor:
"""
Вызываем препроцесс, потом предикт
:param path:
:return:
"""
tensor = self.preprocessing_model.read_video(path)
predict = self.predict_model(tensor)
return predict
# if __name__ == '__main__':
# model = TikTokAnalytics()
# model = model(
# '/Users/victorbarbarich/PycharmProjects/nueramic/vktrbr-video-tiktok/data/videos/video-6930454291186502917.mp4')
# print(model)