import torch import torch.nn.functional as F import numpy as np import cv2 from .net_s3fd import s3fd import os import logging import json import hashlib logger = logging.getLogger(__name__) def decode(loc, priors, variances): boxes = torch.cat(( priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) boxes[:, :2] -= boxes[:, 2:] / 2 boxes[:, 2:] += boxes[:, :2] return boxes def nms(dets, thresh): if 0 == len(dets): return [] x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) ovr = w * h / (areas[i] + areas[order[1:]] - w * h) inds = np.where(ovr <= thresh)[0] order = order[inds + 1] return keep class SFDDetector: def __init__(self, model_path=None, device='cuda'): self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.net = s3fd() state_dict = torch.load(model_path, map_location=self.device) self.net.load_state_dict(state_dict) self.net.to(self.device) self.net.eval() def detect_from_batch(self, images): if len(images.shape) == 3: images = images[np.newaxis, ...] if images.shape[-1] == 4: images = images[...,:3] elif len(images.shape) == 3 and images.shape[-1] == 1: images = np.repeat(images, 3, axis=-1) elif len(images.shape) == 2: images = np.repeat(images[:,:,np.newaxis], 3, axis=2) images = images.astype(np.float32) images = images - np.array([104, 117, 123]) images = images.transpose(0, 3, 1, 2) images = torch.from_numpy(images).float().to(self.device) if images.shape[2] <= 256 and images.shape[3] <= 256: height, width = images.shape[2:4] margin = min(width, height) // 8 return [np.array([[margin, margin, width-margin, height-margin, 0.99]], dtype=np.int32)] with torch.no_grad(): olist = self.net(images) bboxlists = [] for i in range(len(olist) // 2): olist[i * 2] = F.softmax(olist[i * 2], dim=1) olist = [oelem.data.cpu() for oelem in olist] for batch_idx in range(images.shape[0]): bboxlist = [] for i in range(len(olist) // 2): ocls, oreg = olist[i * 2], olist[i * 2 + 1] stride = 2 ** (i + 2) scores = ocls[batch_idx, 1, :, :] pos_inds = np.where(scores.numpy() > 0.05) if len(pos_inds[0]) > 0: for hindex, windex in zip(*pos_inds): axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride score = scores[hindex, windex] loc = oreg[batch_idx, :, hindex, windex].contiguous().view(1, 4) priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) variances = [0.1, 0.2] box = decode(loc, priors, variances) box = box[0].numpy() * 1.0 bboxlist.append([box[0], box[1], box[2], box[3], score.item()]) bboxlist = np.array(bboxlist) if len(bboxlist) == 0: height, width = images.shape[2:4] margin = min(width, height) // 8 bboxlist = np.array([[margin, margin, width-margin, height-margin, 0.99]]) keep = nms(bboxlist, 0.3) bboxlist = bboxlist[keep] bboxlist[:, :4] = np.round(bboxlist[:, :4]).astype(np.int32) bboxlists.append(bboxlist) return bboxlists def _get_image_hash(self, image_path): try: with open(image_path, 'rb') as f: return hashlib.md5(f.read()).hexdigest() except Exception as e: logger.error(f"Error al calcular hash de imagen: {str(e)}") return None def _get_cache_path(self, image_path): image_hash = self._get_image_hash(image_path) if image_hash: cache_dir = os.path.join(os.path.dirname(__file__), 'cache') os.makedirs(cache_dir, exist_ok=True) return os.path.join(cache_dir, f"{image_hash}.json") return None def _save_to_cache(self, image_path, bboxes): try: cache_path = self._get_cache_path(image_path) if cache_path: with open(cache_path, 'w') as f: json.dump({ 'bboxes': bboxes.tolist(), 'image_path': image_path, 'timestamp': os.path.getmtime(image_path) }, f) logger.info(f"Resultados guardados en caché: {cache_path}") except Exception as e: logger.error(f"Error al guardar en caché: {str(e)}") def _load_from_cache(self, image_path): try: cache_path = self._get_cache_path(image_path) if cache_path and os.path.exists(cache_path): with open(cache_path, 'r') as f: data = json.load(f) if data['image_path'] == image_path and \ data['timestamp'] == os.path.getmtime(image_path): logger.info(f"Usando resultados de caché: {cache_path}") return np.array(data['bboxes'], dtype=np.int32) except Exception as e: logger.error(f"Error al cargar de caché: {str(e)}") return None def detect_from_image(self, image_path): # Verificar si es una imagen predeterminada if image_path.endswith(('male.png', 'female.png')): cached_result = self._load_from_cache(image_path) if cached_result is not None: return cached_result # Si no hay caché, proceder con la detección normal try: image = cv2.imread(image_path) if image is None: raise ValueError(f"No se pudo cargar la imagen: {image_path}") result = self.detect_from_batch(image)[0] # Asegurar que los resultados sean enteros result = result.astype(np.int32) # Guardar en caché si es una imagen predeterminada if image_path.endswith(('male.png', 'female.png')): self._save_to_cache(image_path, result) return result except Exception as e: logger.error(f"Error en detect_from_image: {str(e)}") if image is not None: height, width = image.shape[:2] margin_x = int(width * 0.1) margin_y = int(height * 0.1) return np.array([[margin_x, margin_y, width-margin_x, height-margin_y, 0.99]], dtype=np.int32) return np.array([[0, 0, 100, 100, 0.99]], dtype=np.int32)