remiai3 commited on
Commit
93baf52
·
verified ·
1 Parent(s): 1e410d2

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +22 -0
  2. inference.py +50 -0
  3. main.py +34 -0
  4. online_image.jpg +0 -0
  5. remiai.png +0 -0
  6. requirements.txt +10 -0
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Classification (CPU/GPU)
2
+
3
+ - **Model:** torchvision `mobilenet_v2` pretrained on ImageNet (Apache-2.0)
4
+ - **Task:** Predict top-5 ImageNet classes for a given image.
5
+ - **Note:** Here we just provide the resources for to run this models in the laptops we didn't develop this entire models we just use the open source models for the experiment this model is developed by TorchVision / PyTorch
6
+
7
+ ## Quick start (any project)
8
+
9
+ ```bash
10
+ # 1) Create env
11
+ python -m venv venv && source .venv/bin/activate # Windows: ./venv/Scripts/activate
12
+
13
+ # 2) Install deps
14
+ pip install -r requirements.txt
15
+
16
+ # 3) Run
17
+ python main.py --help
18
+ ```
19
+
20
+ > Tip: If you have a GPU + CUDA, PyTorch will auto-use it. If not, everything runs on CPU (slower but works).
21
+
22
+ ---
inference.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import models
3
+ from PIL import Image
4
+ import urllib.request
5
+ import os
6
+
7
+ # URL for ImageNet class labels
8
+ IMAGENET_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
9
+
10
+ def load_labels():
11
+ with urllib.request.urlopen(IMAGENET_URL) as f:
12
+ labels = [s.strip() for s in f.read().decode("utf-8").splitlines()]
13
+ return labels
14
+
15
+ # Device selection
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # Load model
19
+ model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT).to(device).eval()
20
+ preprocess = models.MobileNet_V2_Weights.DEFAULT.transforms()
21
+
22
+ # Online image
23
+ online_image_url = "https://upload.wikimedia.org/wikipedia/commons/9/9a/Pug_600.jpg"
24
+ online_image_path = "online_image.jpg"
25
+ urllib.request.urlretrieve(online_image_url, online_image_path)
26
+
27
+ # Offline image from same directory
28
+ offline_image_path = "remiai.png" # Replace with your actual image filename
29
+
30
+ # Function to run inference
31
+ def classify_image(image_path):
32
+ img = Image.open(image_path).convert("RGB")
33
+ x = preprocess(img).unsqueeze(0).to(device)
34
+ with torch.no_grad():
35
+ logits = model(x)
36
+ probs = torch.softmax(logits, dim=-1)[0]
37
+ top5 = torch.topk(probs, 5)
38
+
39
+ labels = load_labels()
40
+ print(f"Results for: {image_path}")
41
+ for p, idx in zip(top5.values, top5.indices):
42
+ print(f"{labels[idx]}: {float(p):.4f}")
43
+ print()
44
+
45
+ # Run inference on both images
46
+ classify_image(online_image_path)
47
+ if os.path.exists(offline_image_path):
48
+ classify_image(offline_image_path)
49
+ else:
50
+ print(f"Offline image '{offline_image_path}' not found.")
main.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, json, torch
2
+ from torchvision import models, transforms
3
+ from PIL import Image
4
+ import urllib.request
5
+
6
+ IMAGENET_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
7
+
8
+ def load_labels():
9
+ with urllib.request.urlopen(IMAGENET_URL) as f:
10
+ labels = [s.strip() for s in f.read().decode("utf-8").splitlines()]
11
+ return labels
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument("--image", type=str, default=None, help="Path to an image")
16
+ args = parser.parse_args()
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT).to(device).eval()
20
+
21
+ preprocess = models.MobileNet_V2_Weights.DEFAULT.transforms()
22
+ img = Image.open(args.image).convert("RGB") if args.image else Image.new("RGB", (224,224), "white")
23
+ x = preprocess(img).unsqueeze(0).to(device)
24
+ with torch.no_grad():
25
+ logits = model(x)
26
+ probs = torch.softmax(logits, dim=-1)[0]
27
+ top5 = torch.topk(probs, 5)
28
+
29
+ labels = load_labels()
30
+ for p, idx in zip(top5.values, top5.indices):
31
+ print(f"{labels[idx]}: {float(p):.4f}")
32
+
33
+ if __name__ == "__main__":
34
+ main()
online_image.jpg ADDED
remiai.png ADDED
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
3
+ torchaudio==2.1.0
4
+ transformers==4.38.2
5
+ datasets==2.18.0
6
+ Pillow==10.2.0
7
+ numpy==1.26.4
8
+ tqdm==4.66.2
9
+ sentencepiece==0.1.99
10
+ sentence-transformers==2.6.1