jordan0811 commited on
Commit
2685b0c
·
verified ·
1 Parent(s): 5b44310

Upload run_axmodel.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_axmodel.py +71 -0
run_axmodel.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import axengine as ort
4
+ import torch
5
+ import os
6
+ from transformers import (
7
+ AutoImageProcessor,
8
+ AutoTokenizer,
9
+ )
10
+
11
+
12
+ def determine_max_value(image):
13
+ w,h = image.size
14
+ max_val = (w//16)*(h//16)
15
+ if max_val > 784:
16
+ return 1024
17
+ elif max_val > 576:
18
+ return 784
19
+ elif max_val > 256:
20
+ return 576
21
+ elif max_val > 128:
22
+ return 256
23
+ else:
24
+ return 128
25
+
26
+
27
+ if __name__ == "__main__":
28
+ image_path = "bedroom.jpg"
29
+ model_root = "/root/wangjian/hf_cache/fg-clip2-base"
30
+
31
+ image_encoder_path = "image_encoder.axmodel"
32
+ text_encoder_path = "text_encoder.axmodel"
33
+
34
+ onnx_image_encoder = ort.InferenceSession(image_encoder_path)
35
+ onnx_text_encoder = ort.InferenceSession(text_encoder_path)
36
+
37
+ image = Image.open(image_path).convert("RGB")
38
+
39
+ image_processor = AutoImageProcessor.from_pretrained(model_root)
40
+ tokenizer = AutoTokenizer.from_pretrained(model_root)
41
+
42
+ image_input = image_processor(images=image, max_num_patches=determine_max_value(image), return_tensors="pt")
43
+ captions = [
44
+ "一个简约风格的卧室角落,黑色金属衣架上挂着多件米色和白色的衣物,下方架子放着两双浅色鞋子,旁边是一盆绿植,左侧可见一张铺有白色床单和灰色枕头的床。",
45
+ "一个简约风格的卧室角落,黑色金属衣架上挂着多件红色和蓝色的衣物,下方架子放着两双黑色高跟鞋,旁边是一盆绿植,左侧可见一张铺有白色床单和灰色枕头的床。",
46
+ "一个简约风格的卧室角落,黑色金属衣架上挂着多件米色和白色的衣物,下方架子放着两双运动鞋,旁边是一盆仙人掌,左侧可见一张铺有白色床单和灰色枕头的床。",
47
+ "一个繁忙的街头市场,摊位上摆满水果,背景是高楼大厦,人们在喧闹中购物。"
48
+ ]
49
+ captions = [caption.lower() for caption in captions]
50
+
51
+ caption_input = tokenizer(captions, padding="max_length", max_length=196, truncation=True, return_tensors="pt")
52
+
53
+ image_feature = onnx_image_encoder.run(None, {
54
+ "pixel_values": image_input["pixel_values"].numpy().astype(np.float32),
55
+ "pixel_attention_mask": image_input["pixel_attention_mask"].numpy().astype(np.int32)
56
+ })[0]
57
+
58
+ text_feature = []
59
+ for c in caption_input["input_ids"]:
60
+ tmp_text_feature = onnx_text_encoder.run(None, {
61
+ "input_ids": c[None].numpy().astype(np.int32),
62
+ })[0]
63
+ text_feature.append(tmp_text_feature)
64
+ text_feature = np.concatenate(text_feature, axis=0)
65
+
66
+ logits_per_image = image_feature @ text_feature.T
67
+ logit_scale, logit_bias = 4.75, -16.75
68
+ logits_per_image = logits_per_image * np.exp(logit_scale) + logit_bias
69
+
70
+ print("Logits per image:", torch.from_numpy(logits_per_image).softmax(dim=-1))
71
+