MatthewTL commited on
Commit
f9f18ae
·
1 Parent(s): c433a92

Add ViTPoser intference API

Browse files
Files changed (2) hide show
  1. inference.py +45 -0
  2. requirements.txt +3 -0
inference.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ VitPoseForPoseEstimation,
3
+ AutoProcessor,
4
+ RTDetrForObjectDetection,
5
+ )
6
+ from PIL import Image
7
+ import torch
8
+
9
+ # load models
10
+ det_proc = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
11
+ det_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365").eval()
12
+
13
+ pose_proc = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
14
+ pose_model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple").eval()
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ det_model.to(device)
18
+ pose_model.to(device)
19
+
20
+ # Hugging Face will call this function
21
+ def predict(inputs: dict) -> dict:
22
+ """
23
+ inputs: {"image": PIL.Image}
24
+ returns: {"poses": ...}
25
+ """
26
+ image = inputs["image"]
27
+
28
+ # detect people
29
+ det_inputs = det_proc(images=image, return_tensors="pt").to(device)
30
+ det_outputs = det_model(**det_inputs)
31
+ results = det_proc.post_process_object_detection(
32
+ det_outputs,
33
+ threshold=0.5,
34
+ target_sizes=[(image.height, image.width)]
35
+ )
36
+ # keep only "person" class (label 0)
37
+ person_boxes = results[0]["boxes"][results[0]["labels"] == 0]
38
+
39
+ # run pose estimation
40
+ pose_inputs = pose_proc(image, boxes=[person_boxes], return_tensors="pt").to(device)
41
+ with torch.no_grad():
42
+ pose_outputs = pose_model(**pose_inputs)
43
+ poses = pose_proc.post_process_pose_estimation(pose_outputs, boxes=[person_boxes])
44
+
45
+ return {"poses": poses[0]}
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers>=4.43.0
3
+ Pillow