File size: 2,911 Bytes
2685b0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
from PIL import Image
import axengine as ort
import torch
import os
from transformers import (
    AutoImageProcessor,
    AutoTokenizer,
)


def determine_max_value(image):
    w,h = image.size
    max_val = (w//16)*(h//16)
    if max_val > 784:
        return 1024
    elif max_val > 576:
        return 784
    elif max_val > 256:
        return 576
    elif max_val > 128:
        return 256
    else:
        return 128

    
if __name__ == "__main__":
    image_path = "bedroom.jpg"
    model_root = "/root/wangjian/hf_cache/fg-clip2-base"
    
    image_encoder_path = "image_encoder.axmodel"
    text_encoder_path = "text_encoder.axmodel"
    
    onnx_image_encoder = ort.InferenceSession(image_encoder_path)
    onnx_text_encoder = ort.InferenceSession(text_encoder_path)
    
    image = Image.open(image_path).convert("RGB")
    
    image_processor = AutoImageProcessor.from_pretrained(model_root)
    tokenizer = AutoTokenizer.from_pretrained(model_root)
    
    image_input = image_processor(images=image, max_num_patches=determine_max_value(image), return_tensors="pt")
    captions = [
        "一个简约风格的卧室角落,黑色金属衣架上挂着多件米色和白色的衣物,下方架子放着两双浅色鞋子,旁边是一盆绿植,左侧可见一张铺有白色床单和灰色枕头的床。",
        "一个简约风格的卧室角落,黑色金属衣架上挂着多件红色和蓝色的衣物,下方架子放着两双黑色高跟鞋,旁边是一盆绿植,左侧可见一张铺有白色床单和灰色枕头的床。",
        "一个简约风格的卧室角落,黑色金属衣架上挂着多件米色和白色的衣物,下方架子放着两双运动鞋,旁边是一盆仙人掌,左侧可见一张铺有白色床单和灰色枕头的床。",
        "一个繁忙的街头市场,摊位上摆满水果,背景是高楼大厦,人们在喧闹中购物。"
    ]
    captions = [caption.lower() for caption in captions]

    caption_input = tokenizer(captions, padding="max_length", max_length=196, truncation=True, return_tensors="pt")

    image_feature = onnx_image_encoder.run(None, {
        "pixel_values": image_input["pixel_values"].numpy().astype(np.float32),
        "pixel_attention_mask": image_input["pixel_attention_mask"].numpy().astype(np.int32)
    })[0]

    text_feature = []
    for c in caption_input["input_ids"]:
        tmp_text_feature = onnx_text_encoder.run(None, {
            "input_ids": c[None].numpy().astype(np.int32),
        })[0]
        text_feature.append(tmp_text_feature)
    text_feature = np.concatenate(text_feature, axis=0)
    
    logits_per_image = image_feature @ text_feature.T
    logit_scale, logit_bias = 4.75, -16.75
    logits_per_image = logits_per_image * np.exp(logit_scale) + logit_bias

    print("Logits per image:", torch.from_numpy(logits_per_image).softmax(dim=-1))