Baskar2005 commited on
Commit
f64042e
·
verified ·
1 Parent(s): d7ec4de

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import logging
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from zipfile import ZipFile
7
+ from typing import Any, Dict,List
8
+ from transformers import pipeline
9
+
10
+ class Image_classification:
11
+ def __init__(self):
12
+ self.model=""
13
+
14
+
15
+ def unzip_image_data(self) -> str:
16
+ """
17
+ Unzips an image dataset into a specified directory.
18
+
19
+ Returns:
20
+ str: The path to the directory containing the extracted image files.
21
+ """
22
+ try:
23
+ with ZipFile("image_dataset.zip","r") as extract:
24
+
25
+ directory_path=str("dataset")
26
+ os.mkdir(directory_path)
27
+ extract.extractall(f"{directory_path}")
28
+ return f"{directory_path}"
29
+
30
+ except Exception as e:
31
+ logging.error(f"An error occurred during extraction: {e}")
32
+ return ""
33
+
34
+ def example_images(self) -> List[str]:
35
+ """
36
+ Unzips the image dataset and generates a list of paths to the individual image files and use image for showing example
37
+
38
+ Returns:
39
+ List[str]: A list of file paths to each image in the dataset.
40
+ """
41
+ try:
42
+ image_dataset_folder = self.unzip_image_data()
43
+ image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp']
44
+ image_count = len([name for name in os.listdir(image_dataset_folder) if os.path.isfile(os.path.join(image_dataset_folder, name)) and os.path.splitext(name)[1].lower() in image_extensions])
45
+ example=[]
46
+ for i in range(image_count):
47
+ for name in os.listdir(image_dataset_folder):
48
+ path=(os.path.join(os.path.dirname(image_dataset_folder),os.path.join(image_dataset_folder,name)))
49
+ example.append(path)
50
+ return example
51
+
52
+ except Exception as e:
53
+ logging.error(f"An error occurred in example images: {e}")
54
+ return ""
55
+
56
+ def classify(self, image: Image.Image, model: Any) -> Dict[str, float]:
57
+ """
58
+ Classifies an image using a specified model.
59
+
60
+ Args:
61
+ image (Image.Image): The image to classify.
62
+ model (Any): The model used for classification.
63
+
64
+ Returns:
65
+ Dict[str, float]: A dictionary of classification labels and their corresponding scores.
66
+ """
67
+ try:
68
+ self.model=model
69
+ classifier = pipeline("image-classification", model=self.model)
70
+ result= classifier(image)
71
+ return result
72
+ except Exception as e:
73
+ logging.error(f"An error occurred during image classification: {e}")
74
+ raise
75
+
76
+ def format_the_result(self, image: Image.Image, model: Any) -> Dict[str, float]:
77
+ """
78
+ Formats the classification result by retaining the highest score for each label.
79
+
80
+ Args:
81
+ image (Image.Image): The image to classify.
82
+ model (Any): The model used for classification.
83
+
84
+ Returns:
85
+ Dict[str, float]: A dictionary with unique labels and the highest score for each label.
86
+ """
87
+ try:
88
+ data=self.classify(image,model)
89
+ new_dict = {}
90
+ for item in data:
91
+ label = item['label']
92
+ score = item['score']
93
+
94
+ if label in new_dict:
95
+ if new_dict[label] < score:
96
+ new_dict[label] = score
97
+ else:
98
+ new_dict[label] = score
99
+ return new_dict
100
+ except Exception as e:
101
+ logging.error(f"An error occurred while formatting the results: {e}")
102
+ raise
103
+
104
+ def interface(self):
105
+
106
+ with gr.Blocks(css=""".gradio-container {background: #314755;
107
+ background: -webkit-linear-gradient(to right, #26a0da, #314755);
108
+ background: linear-gradient(to right, #26a0da, #314755);}
109
+ .block svelte-90oupt padded{background:314755;}""") as demo:
110
+
111
+ gr.HTML("""
112
+ <center><h1 style="color:#fff">Image Classification</h1></center>""")
113
+
114
+ exam_img=self.example_images()
115
+ with gr.Row():
116
+ model = gr.Dropdown(["facebook/regnet-x-040","google/vit-large-patch16-384","microsoft/resnet-50",""],label="Choose a model")
117
+ with gr.Row():
118
+ image = gr.Image(type="filepath",sources="upload")
119
+ with gr.Column():
120
+ output=gr.Label()
121
+ with gr.Row():
122
+ button=gr.Button()
123
+ button.click(self.format_the_result,[image,model],output)
124
+ gr.Examples(
125
+ examples=exam_img,
126
+ inputs=[image],
127
+ outputs=output,
128
+ fn=self.format_the_result,
129
+ cache_examples=False,
130
+ )
131
+ demo.launch(debug=True)
132
+
133
+ if __name__=="__main__":
134
+
135
+ image_classification=Image_classification()
136
+ result=image_classification.interface()