Shashank486 commited on
Commit
5fb95d7
·
verified ·
1 Parent(s): 782954b

create app.py

Browse files

gradio
open_clip-torch
torch # <--- Ensure this line is present
datasets
torchvision
Pillow
numpy
transformers # (if you are using any transformer models directly besides open_clip)

Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import open_clip
3
+ import gradio as gr
4
+ from datasets import load_dataset
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import numpy as np
8
+
9
+ # Load the dataset (fashion product images dataset)
10
+ dataset = load_dataset("ceyda/fashion-products-small", split="train")
11
+
12
+ # Load CLIP model with correct unpacking and QuickGELU
13
+ model = open_clip.create_model("ViT-B-32-quickgelu", pretrained="openai")
14
+
15
+ # Corrected image transform function
16
+ preprocess = open_clip.image_transform(model.visual.image_size, is_train=False)
17
+
18
+ # Load tokenizer
19
+ tokenizer = open_clip.get_tokenizer("ViT-B-32")
20
+
21
+ # Move model to GPU if available
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ model.to(device)
24
+
25
+ # Function to compute image embeddings
26
+ def get_image_embedding(image):
27
+ image = preprocess(image).unsqueeze(0).to(device)
28
+ with torch.no_grad():
29
+ image_features = model.encode_image(image)
30
+ return image_features / image_features.norm(dim=-1, keepdim=True)
31
+
32
+ # Function to compute text embeddings
33
+ def get_text_embedding(text):
34
+ text_inputs = tokenizer([text]).to(device)
35
+ with torch.no_grad():
36
+ text_features = model.encode_text(text_inputs)
37
+ return text_features / text_features.norm(dim=-1, keepdim=True)
38
+
39
+ # Precompute embeddings for all images in the dataset
40
+ image_embeddings = []
41
+ image_paths = []
42
+ for item in dataset.select(range(1000)): # Limit to 100 images for speed
43
+ image = item["image"]
44
+ image_embeddings.append(get_image_embedding(image))
45
+ image_paths.append(image)
46
+
47
+ # Stack all embeddings into a tensor
48
+ image_embeddings = torch.cat(image_embeddings, dim=0)
49
+
50
+ # Function to search for similar images based on text
51
+ def search_similar_image(query_text):
52
+ text_embedding = get_text_embedding(query_text)
53
+ similarities = (image_embeddings @ text_embedding.T).squeeze(1).cpu().numpy()
54
+
55
+ # Get top 20 matches
56
+ best_match_idxs = np.argsort(similarities)[-20:][::-1]
57
+
58
+ return [image_paths[i] for i in best_match_idxs]
59
+
60
+ # Function to search for similar images based on an uploaded image
61
+ def search_similar_by_image(uploaded_image):
62
+ query_embedding = get_image_embedding(uploaded_image)
63
+ similarities = (image_embeddings @ query_embedding.T).squeeze(1).cpu().numpy()
64
+
65
+ # Get top 20 matches
66
+ best_match_idxs = np.argsort(similarities)[-20:][::-1]
67
+
68
+ return [image_paths[i] for i in best_match_idxs]
69
+
70
+ # Gradio UI
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("## 🛍️ Visual Search for Fashion Products")
73
+ gr.Markdown("Search using **text** or **upload an image** to find similar items.")
74
+
75
+ with gr.Row():
76
+ query_input = gr.Textbox(label="Search by Text", placeholder="e.g., red sneakers")
77
+ search_button = gr.Button("Search by Text")
78
+
79
+ with gr.Row():
80
+ image_input = gr.Image(type="pil", label="Upload an Image")
81
+ image_search_button = gr.Button("Search by Image")
82
+
83
+ output_gallery = gr.Gallery(label="Similar Items", columns=4, height=500)
84
+
85
+ search_button.click(search_similar_image, inputs=[query_input], outputs=[output_gallery])
86
+ image_search_button.click(search_similar_by_image, inputs=[image_input], outputs=[output_gallery])
87
+
88
+ demo.launch(share=True)