# import streamlit as st # from streamlit_webrtc import webrtc_streamer # import torch # torch.classes.__path__ = [] import sys import os from glob import glob import gradio as gr from fastrtc import WebRTC from fastrtc import VideoStreamHandler from PIL import Image import landmark_detection import numpy as np from time import time import cv2 from mtcnn_facedetection import detect_faces from selfie_filter import apply_sunglasses, process_video radius = 2 filter_img = None def do_facial_landmark_recognition( image: np.ndarray, face_boxes: list[landmark_detection.BoundingBox] ): faces = landmark_detection.get_faces(image, face_boxes) landmarks_batch = landmark_detection.get_landmarks(faces) for i, landmarks in enumerate(landmarks_batch): for landmark in landmarks: image = cv2.circle(image, landmark, radius, (255, 0, 0), -1) return image, landmarks_batch def do_facial_landmark_recognition_with_mtcnn(image: np.ndarray): face_boxes = detect_faces(image) return do_facial_landmark_recognition(image, face_boxes) def video_frame_callback_gradio(frame: np.array): flipped = cv2.flip(frame, 1) flipped, landmarks_batch = do_facial_landmark_recognition_with_mtcnn(flipped) # Apply sunglasses filter image = apply_sunglasses(flipped, landmarks_batch, filter_img) return image # , AdditionalOutputs(flipped, flipped) css = """.my-group {max-width: 600px !important;} .my-column {display: flex !important; justify-content: center !important; align-items: center !important;}""" image_extensions = [ "*.jpg", "*.jpeg", "*.png", "*.gif", "*.bmp", "*.tiff", "*.webp", ] all_image_files = [] for ext in image_extensions: pattern = os.path.join("images", "**", ext) # '**' for recursive search image_files = glob(pattern, recursive=True) all_image_files.extend(image_files) all_image_files.sort() with gr.Blocks(css=css) as demo: with gr.Column(elem_classes=["my-column"]): gr.HTML( """

Live Filter with FaceXFormer

""" ) with gr.Group(elem_classes=["my-group"]): selected_filter = gr.Dropdown( choices=all_image_files, label="Choose filter", value="images/sunglasses_1.png", ) def change_filter(filter_path): global filter_img try: filter_img = cv2.imread(filter_path, cv2.IMREAD_UNCHANGED) except: gr.Error("Error open" + filter_path) change_filter(selected_filter.value) selected_filter.change( change_filter, inputs=[selected_filter], show_progress="full" ) with gr.Group(elem_classes=["my-group"]): stream = WebRTC(label="Stream", rtc_configuration=None) stream.stream( fn=VideoStreamHandler( video_frame_callback_gradio, fps=12, skip_frames=True ), inputs=[stream], outputs=[stream], time_limit=None, ) with gr.Group(elem_classes=["my-group"]): with gr.Column(elem_classes=["my-column"]): gr.HTML( """

Or just apply the filter to a video

""" ) input_video = gr.Video(sources="upload", include_audio=False) output_video = gr.Video(interactive=False, include_audio=False) submit = gr.Button(variant="primary") with gr.Column(elem_classes=["my-column"]): submit.click( lambda input_path: process_video(input_path, filter_img), inputs=[input_video], outputs=[output_video], show_progress="full", ) def test(times=10): image = np.array(Image.open("tmp.jpg").resize((512, 512))) # faces = ai.get_faces(image) start = time() frame_times = [None] * times for i in range(times): before = time() do_facial_landmark_recognition_with_mtcnn(image) after = time() frame_times[i] = after - before end = time() print(f"Num Images: {times}") print(f"Total time: {end - start}") print( f"Max frametime: {max(frame_times)}, FPS: {1 / max(frame_times)}", ) print( f"Min frametime: {min(frame_times)}, FPS: {1 / min(frame_times)}", ) print( f"Avg frametime: {sum(frame_times) / len(frame_times)}, FPS: {1 / (sum(frame_times) / len(frame_times))}", ) if __name__ == "__main__": no_params = 0 for name, i in landmark_detection.model.named_parameters(recurse=True): no_params += i.numel() print(name, i.numel()) print(no_params) if "--test" in sys.argv: test() exit(0) else: demo.launch()