|
|
from typing import Tuple, Optional, List, Dict |
|
|
|
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
from functools import lru_cache |
|
|
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation |
|
|
|
|
|
import mediapipe as mp |
|
|
HAS_MEDIAPIPE = True |
|
|
|
|
|
|
|
|
def _ensure_rgb_uint8(image: np.ndarray) -> np.ndarray: |
|
|
"""Convert an input image array to RGB uint8 format. |
|
|
|
|
|
Gradio provides images as numpy arrays in RGB order with dtype uint8 by default, |
|
|
but we defensively normalize here in case inputs vary. |
|
|
""" |
|
|
if image is None: |
|
|
raise ValueError("No image provided") |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
image = np.array(image.convert("RGB")) |
|
|
elif image.dtype != np.uint8: |
|
|
image = image.astype(np.uint8) |
|
|
|
|
|
if image.ndim == 2: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
|
|
elif image.shape[2] == 4: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) |
|
|
return image |
|
|
|
|
|
|
|
|
def _central_crop_bbox(width: int, height: int, frac: float = 0.6) -> Tuple[int, int, int, int]: |
|
|
"""Return a central crop bounding box (x1, y1, x2, y2) covering `frac` of width/height.""" |
|
|
frac = float(np.clip(frac, 0.2, 1.0)) |
|
|
crop_w = int(width * frac) |
|
|
crop_h = int(height * frac) |
|
|
x1 = (width - crop_w) // 2 |
|
|
y1 = (height - crop_h) // 2 |
|
|
x2 = x1 + crop_w |
|
|
y2 = y1 + crop_h |
|
|
return x1, y1, x2, y2 |
|
|
|
|
|
|
|
|
def _detect_face_bbox_mediapipe(image_rgb: np.ndarray) -> Optional[Tuple[int, int, int, int]]: |
|
|
"""Detect a face bounding box using MediaPipe Face Detection and return (x1, y1, x2, y2). |
|
|
|
|
|
Returns None if detection fails or mediapipe is unavailable. |
|
|
""" |
|
|
if not HAS_MEDIAPIPE: |
|
|
return None |
|
|
height, width = image_rgb.shape[:2] |
|
|
try: |
|
|
with mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) as detector: |
|
|
results = detector.process(image_rgb) |
|
|
detections = results.detections or [] |
|
|
if not detections: |
|
|
return None |
|
|
|
|
|
def bbox_area(det): |
|
|
bbox = det.location_data.relative_bounding_box |
|
|
return max(0.0, bbox.width) * max(0.0, bbox.height) |
|
|
|
|
|
best = max(detections, key=bbox_area) |
|
|
rb = best.location_data.relative_bounding_box |
|
|
x1 = int(np.clip(rb.xmin * width, 0, width - 1)) |
|
|
y1 = int(np.clip(rb.ymin * height, 0, height - 1)) |
|
|
x2 = int(np.clip((rb.xmin + rb.width) * width, 0, width)) |
|
|
y2 = int(np.clip((rb.ymin + rb.height) * height, 0, height)) |
|
|
|
|
|
|
|
|
pad_x = int(0.08 * width) |
|
|
pad_y = int(0.12 * height) |
|
|
x1 = int(np.clip(x1 - pad_x, 0, width - 1)) |
|
|
y1 = int(np.clip(y1 - pad_y, 0, height - 1)) |
|
|
x2 = int(np.clip(x2 + pad_x, 0, width)) |
|
|
y2 = int(np.clip(y2 + pad_y, 0, height)) |
|
|
|
|
|
if x2 - x1 < 10 or y2 - y1 < 10: |
|
|
return None |
|
|
return x1, y1, x2, y2 |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
def _binary_open_close(mask: np.ndarray, kernel_size: int = 5, iterations: int = 1) -> np.ndarray: |
|
|
"""Apply morphological open then close to clean the binary mask.""" |
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) |
|
|
opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=iterations) |
|
|
closed = cv2.morphologyEx(opened, cv2.MORPH_CLOSE, kernel, iterations=iterations) |
|
|
return closed |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def _load_face_parsing_model(): |
|
|
"""Load face-parsing model and processor from the Hugging Face Hub (cached).""" |
|
|
model_id = "jonathandinu/face-parsing" |
|
|
processor = AutoImageProcessor.from_pretrained(model_id) |
|
|
model = AutoModelForSemanticSegmentation.from_pretrained(model_id) |
|
|
model.eval() |
|
|
id2label: Dict[int, str] = model.config.id2label |
|
|
label2id: Dict[str, int] = model.config.label2id |
|
|
return processor, model, id2label, label2id |
|
|
|
|
|
|
|
|
def _segment_face_labels(image_rgb: np.ndarray) -> Tuple[np.ndarray, Dict[int, str]]: |
|
|
"""Run face-parsing segmentation on an RGB crop. Returns (labels HxW int, id2label).""" |
|
|
processor, model, id2label, _ = _load_face_parsing_model() |
|
|
pil_img = Image.fromarray(image_rgb) |
|
|
inputs = processor(images=pil_img, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
|
upsampled = torch.nn.functional.interpolate( |
|
|
logits, |
|
|
size=pil_img.size[::-1], |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
labels = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.int32) |
|
|
return labels, id2label |
|
|
|
|
|
|
|
|
def _skin_indices_from_id2label(id2label: Dict[int, str]) -> List[int]: |
|
|
skin_indices: List[int] = [] |
|
|
for idx, name in id2label.items(): |
|
|
name_l = name.lower() |
|
|
if "skin" in name_l: |
|
|
skin_indices.append(int(idx)) |
|
|
|
|
|
if not skin_indices: |
|
|
for idx, name in id2label.items(): |
|
|
if "face" in name.lower(): |
|
|
skin_indices.append(int(idx)) |
|
|
return skin_indices |
|
|
|
|
|
|
|
|
def _compute_skin_color_hex(image_rgb: np.ndarray, mask: np.ndarray) -> Tuple[str, np.ndarray]: |
|
|
"""Compute a robust representative skin color as a hex string and return also the RGB color. |
|
|
|
|
|
Uses median across masked pixels to reduce influence of highlights/shadows. |
|
|
""" |
|
|
if mask is None or mask.size == 0: |
|
|
raise ValueError("Invalid mask for skin color computation") |
|
|
|
|
|
|
|
|
mask_bool = mask.astype(bool) |
|
|
if not np.any(mask_bool): |
|
|
raise ValueError("No skin pixels detected") |
|
|
|
|
|
skin_pixels = image_rgb[mask_bool] |
|
|
|
|
|
|
|
|
median_color = np.median(skin_pixels, axis=0) |
|
|
median_color = np.clip(median_color, 0, 255).astype(np.uint8) |
|
|
|
|
|
r, g, b = int(median_color[0]), int(median_color[1]), int(median_color[2]) |
|
|
hex_code = f"#{r:02X}{g:02X}{b:02X}" |
|
|
return hex_code, median_color |
|
|
|
|
|
|
|
|
def _solid_color_image(color_rgb: np.ndarray, size: Tuple[int, int] = (160, 160)) -> np.ndarray: |
|
|
swatch = np.zeros((size[1], size[0], 3), dtype=np.uint8) |
|
|
swatch[:, :] = color_rgb |
|
|
return swatch |
|
|
|
|
|
|
|
|
def detect_skin_tone(image: np.ndarray) -> Tuple[str, np.ndarray, np.ndarray]: |
|
|
"""Main pipeline: returns (hex_code, color_swatch_image, debug_mask_overlay). |
|
|
|
|
|
- image: input image as numpy array (H, W, 3) RGB uint8 |
|
|
- center_focus: if True, prioritizes central crop region to avoid background/hands |
|
|
""" |
|
|
rgb = _ensure_rgb_uint8(image) |
|
|
height, width = rgb.shape[:2] |
|
|
|
|
|
|
|
|
face_bbox = _detect_face_bbox_mediapipe(rgb) |
|
|
if face_bbox is None: |
|
|
raise ValueError("No face detected. Please upload an image with a clear frontal face.") |
|
|
x1, y1, x2, y2 = face_bbox |
|
|
central_rgb = rgb[y1:y2, x1:x2] |
|
|
|
|
|
|
|
|
labels, id2label = _segment_face_labels(central_rgb) |
|
|
skin_indices = _skin_indices_from_id2label(id2label) |
|
|
if not skin_indices: |
|
|
raise ValueError("Face parsing model did not expose a skin class.") |
|
|
|
|
|
skin_mask = np.isin(labels, np.array(skin_indices, dtype=np.int32)).astype(np.uint8) * 255 |
|
|
|
|
|
|
|
|
hex_code, color_rgb = _compute_skin_color_hex(central_rgb, skin_mask) |
|
|
|
|
|
|
|
|
swatch = _solid_color_image(color_rgb) |
|
|
|
|
|
|
|
|
full_mask = np.zeros((height, width), dtype=np.uint8) |
|
|
full_mask[y1:y2, x1:x2] = skin_mask |
|
|
color_mask = cv2.cvtColor(full_mask, cv2.COLOR_GRAY2RGB) |
|
|
overlay = cv2.addWeighted(rgb, 0.8, color_mask, 0.2, 0) |
|
|
|
|
|
return hex_code, swatch, overlay |
|
|
|
|
|
|
|
|
def _hex_html(hex_code: str) -> str: |
|
|
style = ( |
|
|
"display:flex;align-items:center;gap:12px;padding:8px 0;" |
|
|
) |
|
|
swatch_style = ( |
|
|
f"width:20px;height:20px;border-radius:4px;background:{hex_code};" |
|
|
"border:1px solid #ccc;" |
|
|
) |
|
|
return ( |
|
|
f"<div style='{style}'>" |
|
|
f"<div style='{swatch_style}'></div>" |
|
|
f"<span style='font-family:monospace;font-size:16px'>{hex_code}</span>" |
|
|
"</div>" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Skin Tone Detector") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
### Skin Tone Hex Detector |
|
|
Upload a face image. The app estimates a representative skin tone and returns a HEX color. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image( |
|
|
label="Upload face image", |
|
|
type="numpy", |
|
|
image_mode="RGB", |
|
|
height=360, |
|
|
) |
|
|
run_btn = gr.Button("Detect Skin Tone", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
hex_output = gr.HTML(label="HEX Color") |
|
|
swatch_output = gr.Image(label="Color Swatch", type="numpy") |
|
|
debug_output = gr.Image(label="Mask Overlay", type="numpy") |
|
|
gr.Markdown("MediaPipe face detection and a face-parsing model are used to isolate skin pixels.") |
|
|
|
|
|
def _run(image: Optional[np.ndarray]): |
|
|
if image is None: |
|
|
return _hex_html("#000000"), np.zeros((160, 160, 3), dtype=np.uint8), None |
|
|
hex_code, swatch, debug = detect_skin_tone(image) |
|
|
return _hex_html(hex_code), swatch, debug |
|
|
|
|
|
run_btn.click(_run, inputs=[input_image], outputs=[hex_output, swatch_output, debug_output]) |
|
|
input_image.change(_run, inputs=[input_image], outputs=[hex_output, swatch_output, debug_output]) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|