|
|
|
import os.path as osp |
|
|
|
import cv2 |
|
import torch |
|
from mmengine.config import Config |
|
from mmengine.dataset import Compose |
|
from mmdet.apis import init_detector |
|
from mmdet.utils import get_test_pipeline_cfg |
|
|
|
|
|
def inference(model, image, texts, test_pipeline, score_thr=0.3, max_dets=100): |
|
image = cv2.imread(image) |
|
image = image[:, :, [2, 1, 0]] |
|
data_info = dict(img=image, img_id=0, texts=texts) |
|
data_info = test_pipeline(data_info) |
|
data_batch = dict(inputs=data_info['inputs'].unsqueeze(0), |
|
data_samples=[data_info['data_samples']]) |
|
with torch.no_grad(): |
|
output = model.test_step(data_batch)[0] |
|
pred_instances = output.pred_instances |
|
|
|
pred_instances = pred_instances[pred_instances.scores.float() > score_thr] |
|
|
|
if len(pred_instances.scores) > max_dets: |
|
indices = pred_instances.scores.float().topk(max_dets)[1] |
|
pred_instances = pred_instances[indices] |
|
|
|
pred_instances = pred_instances.cpu().numpy() |
|
boxes = pred_instances['bboxes'] |
|
labels = pred_instances['labels'] |
|
scores = pred_instances['scores'] |
|
label_texts = [texts[x][0] for x in labels] |
|
return boxes, labels, label_texts, scores |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
config_file = "configs/pretrain/yolo_world_v2_x_vlpan_bn_2e-3_100e_4x8gpus_obj365v1_goldg_train_1280ft_lvis_minival.py" |
|
checkpoint = "weights/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36.pth" |
|
|
|
cfg = Config.fromfile(config_file) |
|
cfg.work_dir = osp.join('./work_dirs') |
|
|
|
cfg.load_from = checkpoint |
|
model = init_detector(cfg, checkpoint=checkpoint, device='cuda:0') |
|
test_pipeline_cfg = get_test_pipeline_cfg(cfg=cfg) |
|
test_pipeline_cfg[0].type = 'mmdet.LoadImageFromNDArray' |
|
test_pipeline = Compose(test_pipeline_cfg) |
|
|
|
texts = [['person'], ['bus'], [' ']] |
|
image = "demo/sample_images/bus.jpg" |
|
print(f"starting to detect: {image}") |
|
results = inference(model, image, texts, test_pipeline) |
|
format_str = [ |
|
f"obj-{idx}: {box}, label-{lbl}, class-{lbl_text}, score-{score}" |
|
for idx, (box, lbl, lbl_text, score) in enumerate(zip(*results)) |
|
] |
|
print("detecting results:") |
|
for q in format_str: |
|
print(q) |
|
|