Spaces:
Sleeping
Sleeping
Annas Dev
commited on
Commit
·
868f784
1
Parent(s):
92c1964
add bit model
Browse files
src/similarity/model_implements/bit.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow_hub as hub
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class BigTransfer:
|
| 5 |
+
|
| 6 |
+
def __init__(self):
|
| 7 |
+
self.module = hub.KerasLayer("https://tfhub.dev/google/bit/m-r50x1/1")
|
| 8 |
+
|
| 9 |
+
def extract_feature(self, imgs):
|
| 10 |
+
features = []
|
| 11 |
+
for img in imgs:
|
| 12 |
+
features.append(np.squeeze(self.module(img)))
|
| 13 |
+
return features
|
src/similarity/similarity.py
CHANGED
|
@@ -3,17 +3,18 @@ from src.util import image as image_util
|
|
| 3 |
from src.util import matrix
|
| 4 |
from .model_implements.mobilenet_v3 import ModelnetV3
|
| 5 |
from .model_implements.vit_base import VitBase
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class Similarity:
|
| 9 |
def get_models(self):
|
| 10 |
return [
|
| 11 |
model.SimilarityModel(name= 'Mobilenet V3', image_size= 224, model_cls = ModelnetV3()),
|
|
|
|
| 12 |
model.SimilarityModel(name= 'Vision Transformer', image_size= 224, model_cls = VitBase(), image_input_type='pil'),
|
| 13 |
]
|
| 14 |
|
| 15 |
def check_similarity(self, img_urls, model):
|
| 16 |
-
# model = self.get_models()[model_idx]
|
| 17 |
imgs = []
|
| 18 |
for url in img_urls:
|
| 19 |
if url == "": continue
|
|
|
|
| 3 |
from src.util import matrix
|
| 4 |
from .model_implements.mobilenet_v3 import ModelnetV3
|
| 5 |
from .model_implements.vit_base import VitBase
|
| 6 |
+
from .model_implements.bit import BigTransfer
|
| 7 |
|
| 8 |
|
| 9 |
class Similarity:
|
| 10 |
def get_models(self):
|
| 11 |
return [
|
| 12 |
model.SimilarityModel(name= 'Mobilenet V3', image_size= 224, model_cls = ModelnetV3()),
|
| 13 |
+
model.SimilarityModel(name= 'Big Transfer (BiT)', image_size= 224, model_cls = BigTransfer()),
|
| 14 |
model.SimilarityModel(name= 'Vision Transformer', image_size= 224, model_cls = VitBase(), image_input_type='pil'),
|
| 15 |
]
|
| 16 |
|
| 17 |
def check_similarity(self, img_urls, model):
|
|
|
|
| 18 |
imgs = []
|
| 19 |
for url in img_urls:
|
| 20 |
if url == "": continue
|