abhilash88 commited on
Commit
34b38de
·
verified ·
1 Parent(s): 88e0144

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -242
model.py DELETED
@@ -1,242 +0,0 @@
1
- """
2
- Complete Working Pipeline Implementation for Age-Gender Prediction
3
- This enables: pipeline("image-classification", model="abhilash88/age-gender-prediction", trust_remote_code=True)
4
- """
5
-
6
- import torch
7
- import torch.nn as nn
8
- from transformers import (
9
- ViTModel,
10
- ViTImageProcessor,
11
- PreTrainedModel,
12
- PretrainedConfig,
13
- ImageClassificationPipeline
14
- )
15
- from PIL import Image
16
- import numpy as np
17
- from typing import Union, Dict, Any, List
18
- import requests
19
- from io import BytesIO
20
-
21
-
22
- class AgeGenderConfig(PretrainedConfig):
23
- """Configuration for AgeGenderViTModel"""
24
- model_type = "age-gender-vit"
25
-
26
- def __init__(
27
- self,
28
- vit_model_name="google/vit-base-patch16-224",
29
- hidden_size=768,
30
- intermediate_size=256,
31
- final_size=64,
32
- dropout_rate=0.1,
33
- num_age_classes=100,
34
- **kwargs
35
- ):
36
- super().__init__(**kwargs)
37
- self.vit_model_name = vit_model_name
38
- self.hidden_size = hidden_size
39
- self.intermediate_size = intermediate_size
40
- self.final_size = final_size
41
- self.dropout_rate = dropout_rate
42
- self.num_age_classes = num_age_classes
43
-
44
-
45
- class AgeGenderViTModel(PreTrainedModel):
46
- """Age-Gender ViT Model with pipeline support"""
47
- config_class = AgeGenderConfig
48
-
49
- def __init__(self, config=None):
50
- if config is None:
51
- config = AgeGenderConfig()
52
- super().__init__(config)
53
-
54
- # ViT backbone
55
- self.vit = ViTModel.from_pretrained(config.vit_model_name)
56
-
57
- # Age head: 768 → 256 → 64 → 1
58
- self.age_head = nn.Sequential(
59
- nn.Linear(config.hidden_size, config.intermediate_size),
60
- nn.ReLU(),
61
- nn.Dropout(config.dropout_rate),
62
- nn.Linear(config.intermediate_size, config.final_size),
63
- nn.ReLU(),
64
- nn.Dropout(config.dropout_rate),
65
- nn.Linear(config.final_size, 1)
66
- )
67
-
68
- # Gender head: 768 → 256 → 64 → 1
69
- self.gender_head = nn.Sequential(
70
- nn.Linear(config.hidden_size, config.intermediate_size),
71
- nn.ReLU(),
72
- nn.Dropout(config.dropout_rate),
73
- nn.Linear(config.intermediate_size, config.final_size),
74
- nn.ReLU(),
75
- nn.Dropout(config.dropout_rate),
76
- nn.Linear(config.final_size, 1),
77
- nn.Sigmoid()
78
- )
79
-
80
- # For pipeline compatibility, we need a classifier head that outputs logits
81
- self.classifier = nn.Linear(2, 2) # Dummy classifier for pipeline
82
-
83
- def forward(self, pixel_values, **kwargs):
84
- """Forward pass - returns format expected by pipeline"""
85
- outputs = self.vit(pixel_values=pixel_values)
86
- pooled_output = outputs.pooler_output
87
-
88
- # Get age and gender predictions
89
- age_output = self.age_head(pooled_output)
90
- gender_output = self.gender_head(pooled_output)
91
-
92
- # For pipeline compatibility, create fake logits
93
- # We'll process these in the pipeline postprocessing
94
- fake_logits = torch.cat([age_output, gender_output], dim=1)
95
-
96
- return type('ModelOutput', (), {
97
- 'logits': fake_logits,
98
- 'age_logits': age_output,
99
- 'gender_logits': gender_output
100
- })()
101
-
102
-
103
- class AgeGenderImageClassificationPipeline(ImageClassificationPipeline):
104
- """Custom pipeline for age-gender classification"""
105
-
106
- def __init__(self, *args, **kwargs):
107
- super().__init__(*args, **kwargs)
108
-
109
- def postprocess(self, model_outputs, top_k=1, **kwargs):
110
- """Custom postprocessing for age-gender predictions"""
111
- outputs = model_outputs[0] # Single image output
112
-
113
- # Extract age and gender logits
114
- age_logits = outputs.age_logits
115
- gender_logits = outputs.gender_logits
116
-
117
- # Process predictions
118
- age = int(torch.clamp(age_logits, 0, 100).item())
119
- gender_prob = gender_logits.item()
120
- gender = "Female" if gender_prob > 0.5 else "Male"
121
- confidence = gender_prob if gender_prob > 0.5 else 1 - gender_prob
122
-
123
- # Return in pipeline format
124
- return [{
125
- "label": f"{age} years, {gender}",
126
- "score": confidence,
127
- "age": age,
128
- "gender": gender,
129
- "gender_confidence": round(confidence, 3),
130
- "gender_probability": round(gender_prob, 3)
131
- }]
132
-
133
-
134
- # Simple wrapper function for easy usage
135
- def predict_age_gender_pipeline(image_path_or_url: str) -> Dict[str, Any]:
136
- """
137
- Simple function using the pipeline
138
-
139
- Args:
140
- image_path_or_url: Path to image file or URL
141
-
142
- Returns:
143
- Dictionary with age and gender predictions
144
- """
145
- from transformers import pipeline
146
-
147
- # Create pipeline
148
- classifier = pipeline(
149
- "image-classification",
150
- model="abhilash88/age-gender-prediction",
151
- trust_remote_code=True
152
- )
153
-
154
- # Make prediction
155
- result = classifier(image_path_or_url)[0] # Get first result
156
-
157
- return {
158
- "age": result["age"],
159
- "gender": result["gender"],
160
- "confidence": result["gender_confidence"],
161
- "summary": result["label"]
162
- }
163
-
164
-
165
- # Manual implementation (guaranteed to work)
166
- def predict_age_gender_manual(image_input: Union[str, Image.Image, np.ndarray]) -> Dict[str, Any]:
167
- """
168
- Manual prediction function (always works)
169
-
170
- Args:
171
- image_input: Image path, URL, PIL Image, or numpy array
172
-
173
- Returns:
174
- Dictionary with predictions
175
- """
176
- # Create model
177
- model = AgeGenderViTModel()
178
-
179
- # Load weights
180
- model_url = "https://huggingface.co/abhilash88/age-gender-prediction/resolve/main/pytorch_model.bin"
181
- weights = torch.hub.load_state_dict_from_url(model_url, map_location='cpu')
182
- model.load_state_dict(weights)
183
- model.eval()
184
-
185
- # Create processor
186
- processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
187
-
188
- # Handle different input types
189
- if isinstance(image_input, str):
190
- if image_input.startswith(('http://', 'https://')):
191
- response = requests.get(image_input)
192
- image = Image.open(BytesIO(response.content)).convert('RGB')
193
- else:
194
- image = Image.open(image_input).convert('RGB')
195
- elif isinstance(image_input, np.ndarray):
196
- image = Image.fromarray(image_input).convert('RGB')
197
- else:
198
- image = image_input.convert('RGB')
199
-
200
- # Process and predict
201
- inputs = processor(images=image, return_tensors="pt")
202
-
203
- with torch.no_grad():
204
- outputs = model(inputs["pixel_values"])
205
- age_pred = outputs.age_logits
206
- gender_pred = outputs.gender_logits
207
-
208
- # Process results
209
- age = int(torch.clamp(age_pred, 0, 100).item())
210
- gender_prob = gender_pred.item()
211
- gender = "Female" if gender_prob > 0.5 else "Male"
212
- confidence = gender_prob if gender_prob > 0.5 else 1 - gender_prob
213
-
214
- return {
215
- "age": age,
216
- "gender": gender,
217
- "confidence": round(confidence, 3),
218
- "gender_probability": round(gender_prob, 3)
219
- }
220
-
221
-
222
- # Test both approaches
223
- if __name__ == "__main__":
224
- print("🧪 Testing Age-Gender Prediction...")
225
-
226
- # Test image URL
227
- test_url = "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?w=300&h=300&fit=crop&crop=face"
228
-
229
- try:
230
- print("🔧 Testing manual approach...")
231
- result_manual = predict_age_gender_manual(test_url)
232
- print(f"✅ Manual result: {result_manual}")
233
- except Exception as e:
234
- print(f"❌ Manual failed: {e}")
235
-
236
- try:
237
- print("🚀 Testing pipeline approach...")
238
- result_pipeline = predict_age_gender_pipeline(test_url)
239
- print(f"✅ Pipeline result: {result_pipeline}")
240
- except Exception as e:
241
- print(f"❌ Pipeline failed: {e}")
242
- print("Note: Pipeline requires the files to be uploaded to your HF repo")