Source code for medusa.detect.yunet
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
from ..defaults import DEVICE
from ..io import load_inputs
from .base import BaseDetector
[docs]class YunetDetector(BaseDetector):
"""This detector is based on Yunet, a face detector based on YOLOv3 :cite:p:`facedetect-yu`."""
def __init__(self, det_threshold=0.5, nms_threshold=0.3, device=DEVICE, **kwargs):
super().__init__()
self.det_threshold = det_threshold
self.nms_threshold = nms_threshold
self.device = device
self._model = self._init_model(**kwargs)
self.to(device).eval()
def __str__(self):
return "yunet"
def _init_model(self, **kwargs):
try:
import cv2
except ImportError:
raise ImportError("cv2 is required for YunetDetector")
f_in = Path(__file__).parents[1] / "data/models/yunet.onnx"
model = cv2.FaceDetectorYN.create(
str(f_in), "", (0, 0), self.det_threshold, self.nms_threshold, **kwargs
)
return model
[docs] def forward(self, imgs):
imgs = load_inputs(imgs, load_as="numpy", channels_first=False)
# cv2 needs BGR (not RGB)
imgs = imgs[:, :, :, [2, 0, 1]]
b, h, w, c = imgs.shape
self._model.setInputSize((w, h))
outputs = defaultdict(list)
# Note to self: cv2 does not support batch prediction
for i in range(b):
_, det = self._model.detect(imgs[i, ...])
if det is not None:
outputs["img_idx"].extend([i] * det.shape[0])
outputs["conf"].append(det[:, [-1]].flatten())
bbox_ = det[:, :4]
# Convert offset to true vertex positions to keep consistent
# with scrfd/torchvision bbox definition
bbox_[:, 2:] = bbox_[:, :2] + bbox_[:, 2:]
outputs["bbox"].append(bbox_)
outputs["lms"].append(det[:, 4:-1].reshape((det.shape[0], 5, 2)))
if outputs.get("conf", None) is not None:
outputs["img_idx"] = np.array(outputs["img_idx"])
outputs["conf"] = np.concatenate(outputs["conf"])
outputs["bbox"] = np.vstack(outputs["bbox"])
outputs["lms"] = np.vstack(outputs["lms"])
for attr, data in outputs.items():
outputs[attr] = torch.as_tensor(data, device=self.device)
outputs["n_img"] = b
return outputs