|
import unittest |
|
from unittest.mock import patch, MagicMock |
|
import numpy as np |
|
from pdf2zh.doclayout import ( |
|
OnnxModel, |
|
YoloResult, |
|
YoloBox, |
|
) |
|
|
|
|
|
class TestOnnxModel(unittest.TestCase): |
|
@patch("onnx.load") |
|
@patch("onnxruntime.InferenceSession") |
|
def setUp(self, mock_inference_session, mock_onnx_load): |
|
|
|
mock_model = MagicMock() |
|
mock_model.metadata_props = [ |
|
MagicMock(key="stride", value="32"), |
|
MagicMock(key="names", value="['class1', 'class2']"), |
|
] |
|
mock_onnx_load.return_value = mock_model |
|
|
|
|
|
self.model_path = "fake_model_path.onnx" |
|
self.model = OnnxModel(self.model_path) |
|
|
|
def test_stride_property(self): |
|
|
|
self.assertEqual(self.model.stride, 32) |
|
|
|
def test_resize_and_pad_image(self): |
|
|
|
image = np.ones((100, 200, 3), dtype=np.uint8) |
|
resized_image = self.model.resize_and_pad_image(image, 1024) |
|
|
|
|
|
self.assertEqual(resized_image.shape[0], 512) |
|
self.assertEqual(resized_image.shape[1], 1024) |
|
|
|
|
|
padded_height = resized_image.shape[0] - image.shape[0] |
|
padded_width = resized_image.shape[1] - image.shape[1] |
|
self.assertGreater(padded_height, 0) |
|
self.assertGreater(padded_width, 0) |
|
|
|
def test_scale_boxes(self): |
|
img1_shape = (1024, 1024) |
|
img0_shape = (500, 300) |
|
boxes = np.array([[512, 512, 768, 768]]) |
|
|
|
scaled_boxes = self.model.scale_boxes(img1_shape, boxes, img0_shape) |
|
|
|
|
|
self.assertEqual(scaled_boxes.shape, boxes.shape) |
|
self.assertTrue(np.all(scaled_boxes <= max(img0_shape))) |
|
|
|
def test_predict(self): |
|
|
|
mock_output = np.random.random((1, 300, 6)) |
|
self.model.model.run.return_value = [mock_output] |
|
|
|
|
|
image = np.ones((500, 300, 3), dtype=np.uint8) |
|
|
|
results = self.model.predict(image) |
|
|
|
|
|
self.assertEqual(len(results), 1) |
|
self.assertIsInstance(results[0], YoloResult) |
|
self.assertGreater(len(results[0].boxes), 0) |
|
self.assertIsInstance(results[0].boxes[0], YoloBox) |
|
|
|
|
|
class TestYoloResult(unittest.TestCase): |
|
def test_yolo_result(self): |
|
|
|
boxes = [ |
|
[100, 200, 300, 400, 0.9, 0], |
|
[50, 100, 150, 200, 0.8, 1], |
|
] |
|
names = ["class1", "class2"] |
|
|
|
result = YoloResult(boxes, names) |
|
|
|
|
|
self.assertEqual(len(result.boxes), 2) |
|
self.assertGreater(result.boxes[0].conf, result.boxes[1].conf) |
|
self.assertEqual(result.names, names) |
|
|
|
|
|
class TestYoloBox(unittest.TestCase): |
|
def test_yolo_box(self): |
|
|
|
box_data = [100, 200, 300, 400, 0.9, 0] |
|
|
|
box = YoloBox(box_data) |
|
|
|
|
|
self.assertEqual(box.xyxy, box_data[:4]) |
|
self.assertEqual(box.conf, box_data[4]) |
|
self.assertEqual(box.cls, box_data[5]) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|