shanty2301's picture
Update app.py
5998e32 verified
raw
history blame
5.99 kB
import streamlit as st
from PIL import Image, ImageFilter
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
import numpy as np
# pointless since no GPU access
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.set_page_config(
page_title="Blur App",
page_icon="📸"
)
st.title("Image blur using segmentation and depth estimation.")
st.subheader("Upload your image and choose a blur style", divider=True)
st.warning("**Note**: The lens blur option takes a long time to process (>5min) since this space isn't linked to a GPU.")
@st.cache_resource(show_spinner="Pushing pixels...")
def load_gblur_model():
birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
birefnet.to(device)
birefnet.eval()
return birefnet
@st.cache_resource(show_spinner="Running with scissors...")
def load_lblur_model():
image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)
return model, image_processor
gblur_model = load_gblur_model()
lblur_model, lblur_img_proc = load_lblur_model()
def gaussian_blur(image, blur_str):
# Image transform
image_size = (512, 512)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Prediction
input_image = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = gblur_model(input_image)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = np.array(pred_pil.resize(image.size))
# Blurring
blur = np.array(image.filter(ImageFilter.GaussianBlur(radius=blur_str)))
mask = np.expand_dims(mask, axis=2)
output_image = np.where(mask, np.array(image), blur)
return Image.fromarray(output_image)
def lens_blur(image, blur_str):
# Process image
inputs = lblur_img_proc(images=image, return_tensors="pt").to(device)
# Perform forward pass
with torch.no_grad():
outputs = lblur_model(**inputs)
post_processed_output = lblur_img_proc.post_process_depth_estimation(
outputs, target_sizes=[(image.height, image.width)],
)
# Get depth map
depth = post_processed_output[0]["predicted_depth"]
depth = (depth - depth.min()) / (depth.max() - depth.min())
depth = depth * 255.
depth = depth.detach().cpu().numpy()
# Normalize map
depth = depth / 255.0
# No of discrete blurs and max blur intensity
num_levels = 15
max_radius = blur_str
# Pre-compute all blur images
blurred_images = []
for i in range(num_levels):
radius = (i / (num_levels - 1)) * max_radius
if radius < 0.1:
blurred_images.append(np.array(image))
else:
blurred = np.array(image.filter(ImageFilter.GaussianBlur(radius)))
blurred_images.append(blurred)
blurred_stack = np.stack(blurred_images, axis=0)
# Blend together the images using
# Bilinear Interpolation of depth levels
h, w = depth.shape
y_coords, x_coords = np.indices((h, w))
depth_levels = depth * (num_levels - 1)
low_levels = np.floor(depth_levels).astype(int)
high_levels = np.clip(low_levels + 1, 0, num_levels - 1)
alpha = depth_levels - low_levels
pixel_low = blurred_stack[low_levels, y_coords, x_coords, :].astype(np.float32)
pixel_high = blurred_stack[high_levels, y_coords, x_coords, :].astype(np.float32)
output = (1 - alpha)[..., np.newaxis] * pixel_low + alpha[..., np.newaxis] * pixel_high
# Final blurred image
output_img = Image.fromarray(np.clip(output, 0, 255).astype(np.uint8))
return output_img
def apply_func(image, option):
c1, c2 = st.columns(2)
with c1:
og_img = image.copy()
st.image(og_img, caption="Original Image", use_container_width=True)
with c2:
if option == "Gaussian Blur":
with st.spinner(f"Spinning violently around the y-axis..."):
result = gaussian_blur(image)
else:
with st.spinner(f"One mississippi, two mississippi..."):
result = lens_blur(image)
st.image(result, caption=f"{option} Image", use_container_width=True)
st.header("Upload your image")
up_img = st.file_uploader("Upload image", type=["jpg", "jpeg", "png", "dng", "tiff"])
st.divider()
if up_img:
image = Image.open(up_img).convert("RGB")
st.header("Set blur settings.")
with st.form("Blur_form"):
options = ["Gaussian", "Lens"]
selection = st.radio(
"Choose a blur type:",
options, index=None,
captions=[
"Uses segmentation",
"Use depth estimation"],
horizontal=True
)
blur_str = st.slider("Blur Strength", 5, 50, 20)
submitted = st.form_submit_button("Apply Blur")
st.divider()
disp_left, disp_right = st.columns(2)
with disp_left:
og_img = image.copy()
st.image(og_img, caption="Original Image", use_container_width=True)
with disp_right:
if submitted and selection in options:
if selection == "Gaussian":
with st.spinner(f"Spinning violently around the y-axis..."):
result = gaussian_blur(image, blur_str)
elif selection == "Lens":
with st.spinner(f"One mississippi, two mississippi..."):
result = lens_blur(image, blur_str)
st.image(result, "Blurred Image", use_container_width=True)
else:
st.write("Waiting for you to select a blur type...")
# apply_func(image, options)