Delete model.py
Browse files
model.py
DELETED
@@ -1,242 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Complete Working Pipeline Implementation for Age-Gender Prediction
|
3 |
-
This enables: pipeline("image-classification", model="abhilash88/age-gender-prediction", trust_remote_code=True)
|
4 |
-
"""
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
from transformers import (
|
9 |
-
ViTModel,
|
10 |
-
ViTImageProcessor,
|
11 |
-
PreTrainedModel,
|
12 |
-
PretrainedConfig,
|
13 |
-
ImageClassificationPipeline
|
14 |
-
)
|
15 |
-
from PIL import Image
|
16 |
-
import numpy as np
|
17 |
-
from typing import Union, Dict, Any, List
|
18 |
-
import requests
|
19 |
-
from io import BytesIO
|
20 |
-
|
21 |
-
|
22 |
-
class AgeGenderConfig(PretrainedConfig):
|
23 |
-
"""Configuration for AgeGenderViTModel"""
|
24 |
-
model_type = "age-gender-vit"
|
25 |
-
|
26 |
-
def __init__(
|
27 |
-
self,
|
28 |
-
vit_model_name="google/vit-base-patch16-224",
|
29 |
-
hidden_size=768,
|
30 |
-
intermediate_size=256,
|
31 |
-
final_size=64,
|
32 |
-
dropout_rate=0.1,
|
33 |
-
num_age_classes=100,
|
34 |
-
**kwargs
|
35 |
-
):
|
36 |
-
super().__init__(**kwargs)
|
37 |
-
self.vit_model_name = vit_model_name
|
38 |
-
self.hidden_size = hidden_size
|
39 |
-
self.intermediate_size = intermediate_size
|
40 |
-
self.final_size = final_size
|
41 |
-
self.dropout_rate = dropout_rate
|
42 |
-
self.num_age_classes = num_age_classes
|
43 |
-
|
44 |
-
|
45 |
-
class AgeGenderViTModel(PreTrainedModel):
|
46 |
-
"""Age-Gender ViT Model with pipeline support"""
|
47 |
-
config_class = AgeGenderConfig
|
48 |
-
|
49 |
-
def __init__(self, config=None):
|
50 |
-
if config is None:
|
51 |
-
config = AgeGenderConfig()
|
52 |
-
super().__init__(config)
|
53 |
-
|
54 |
-
# ViT backbone
|
55 |
-
self.vit = ViTModel.from_pretrained(config.vit_model_name)
|
56 |
-
|
57 |
-
# Age head: 768 → 256 → 64 → 1
|
58 |
-
self.age_head = nn.Sequential(
|
59 |
-
nn.Linear(config.hidden_size, config.intermediate_size),
|
60 |
-
nn.ReLU(),
|
61 |
-
nn.Dropout(config.dropout_rate),
|
62 |
-
nn.Linear(config.intermediate_size, config.final_size),
|
63 |
-
nn.ReLU(),
|
64 |
-
nn.Dropout(config.dropout_rate),
|
65 |
-
nn.Linear(config.final_size, 1)
|
66 |
-
)
|
67 |
-
|
68 |
-
# Gender head: 768 → 256 → 64 → 1
|
69 |
-
self.gender_head = nn.Sequential(
|
70 |
-
nn.Linear(config.hidden_size, config.intermediate_size),
|
71 |
-
nn.ReLU(),
|
72 |
-
nn.Dropout(config.dropout_rate),
|
73 |
-
nn.Linear(config.intermediate_size, config.final_size),
|
74 |
-
nn.ReLU(),
|
75 |
-
nn.Dropout(config.dropout_rate),
|
76 |
-
nn.Linear(config.final_size, 1),
|
77 |
-
nn.Sigmoid()
|
78 |
-
)
|
79 |
-
|
80 |
-
# For pipeline compatibility, we need a classifier head that outputs logits
|
81 |
-
self.classifier = nn.Linear(2, 2) # Dummy classifier for pipeline
|
82 |
-
|
83 |
-
def forward(self, pixel_values, **kwargs):
|
84 |
-
"""Forward pass - returns format expected by pipeline"""
|
85 |
-
outputs = self.vit(pixel_values=pixel_values)
|
86 |
-
pooled_output = outputs.pooler_output
|
87 |
-
|
88 |
-
# Get age and gender predictions
|
89 |
-
age_output = self.age_head(pooled_output)
|
90 |
-
gender_output = self.gender_head(pooled_output)
|
91 |
-
|
92 |
-
# For pipeline compatibility, create fake logits
|
93 |
-
# We'll process these in the pipeline postprocessing
|
94 |
-
fake_logits = torch.cat([age_output, gender_output], dim=1)
|
95 |
-
|
96 |
-
return type('ModelOutput', (), {
|
97 |
-
'logits': fake_logits,
|
98 |
-
'age_logits': age_output,
|
99 |
-
'gender_logits': gender_output
|
100 |
-
})()
|
101 |
-
|
102 |
-
|
103 |
-
class AgeGenderImageClassificationPipeline(ImageClassificationPipeline):
|
104 |
-
"""Custom pipeline for age-gender classification"""
|
105 |
-
|
106 |
-
def __init__(self, *args, **kwargs):
|
107 |
-
super().__init__(*args, **kwargs)
|
108 |
-
|
109 |
-
def postprocess(self, model_outputs, top_k=1, **kwargs):
|
110 |
-
"""Custom postprocessing for age-gender predictions"""
|
111 |
-
outputs = model_outputs[0] # Single image output
|
112 |
-
|
113 |
-
# Extract age and gender logits
|
114 |
-
age_logits = outputs.age_logits
|
115 |
-
gender_logits = outputs.gender_logits
|
116 |
-
|
117 |
-
# Process predictions
|
118 |
-
age = int(torch.clamp(age_logits, 0, 100).item())
|
119 |
-
gender_prob = gender_logits.item()
|
120 |
-
gender = "Female" if gender_prob > 0.5 else "Male"
|
121 |
-
confidence = gender_prob if gender_prob > 0.5 else 1 - gender_prob
|
122 |
-
|
123 |
-
# Return in pipeline format
|
124 |
-
return [{
|
125 |
-
"label": f"{age} years, {gender}",
|
126 |
-
"score": confidence,
|
127 |
-
"age": age,
|
128 |
-
"gender": gender,
|
129 |
-
"gender_confidence": round(confidence, 3),
|
130 |
-
"gender_probability": round(gender_prob, 3)
|
131 |
-
}]
|
132 |
-
|
133 |
-
|
134 |
-
# Simple wrapper function for easy usage
|
135 |
-
def predict_age_gender_pipeline(image_path_or_url: str) -> Dict[str, Any]:
|
136 |
-
"""
|
137 |
-
Simple function using the pipeline
|
138 |
-
|
139 |
-
Args:
|
140 |
-
image_path_or_url: Path to image file or URL
|
141 |
-
|
142 |
-
Returns:
|
143 |
-
Dictionary with age and gender predictions
|
144 |
-
"""
|
145 |
-
from transformers import pipeline
|
146 |
-
|
147 |
-
# Create pipeline
|
148 |
-
classifier = pipeline(
|
149 |
-
"image-classification",
|
150 |
-
model="abhilash88/age-gender-prediction",
|
151 |
-
trust_remote_code=True
|
152 |
-
)
|
153 |
-
|
154 |
-
# Make prediction
|
155 |
-
result = classifier(image_path_or_url)[0] # Get first result
|
156 |
-
|
157 |
-
return {
|
158 |
-
"age": result["age"],
|
159 |
-
"gender": result["gender"],
|
160 |
-
"confidence": result["gender_confidence"],
|
161 |
-
"summary": result["label"]
|
162 |
-
}
|
163 |
-
|
164 |
-
|
165 |
-
# Manual implementation (guaranteed to work)
|
166 |
-
def predict_age_gender_manual(image_input: Union[str, Image.Image, np.ndarray]) -> Dict[str, Any]:
|
167 |
-
"""
|
168 |
-
Manual prediction function (always works)
|
169 |
-
|
170 |
-
Args:
|
171 |
-
image_input: Image path, URL, PIL Image, or numpy array
|
172 |
-
|
173 |
-
Returns:
|
174 |
-
Dictionary with predictions
|
175 |
-
"""
|
176 |
-
# Create model
|
177 |
-
model = AgeGenderViTModel()
|
178 |
-
|
179 |
-
# Load weights
|
180 |
-
model_url = "https://huggingface.co/abhilash88/age-gender-prediction/resolve/main/pytorch_model.bin"
|
181 |
-
weights = torch.hub.load_state_dict_from_url(model_url, map_location='cpu')
|
182 |
-
model.load_state_dict(weights)
|
183 |
-
model.eval()
|
184 |
-
|
185 |
-
# Create processor
|
186 |
-
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
187 |
-
|
188 |
-
# Handle different input types
|
189 |
-
if isinstance(image_input, str):
|
190 |
-
if image_input.startswith(('http://', 'https://')):
|
191 |
-
response = requests.get(image_input)
|
192 |
-
image = Image.open(BytesIO(response.content)).convert('RGB')
|
193 |
-
else:
|
194 |
-
image = Image.open(image_input).convert('RGB')
|
195 |
-
elif isinstance(image_input, np.ndarray):
|
196 |
-
image = Image.fromarray(image_input).convert('RGB')
|
197 |
-
else:
|
198 |
-
image = image_input.convert('RGB')
|
199 |
-
|
200 |
-
# Process and predict
|
201 |
-
inputs = processor(images=image, return_tensors="pt")
|
202 |
-
|
203 |
-
with torch.no_grad():
|
204 |
-
outputs = model(inputs["pixel_values"])
|
205 |
-
age_pred = outputs.age_logits
|
206 |
-
gender_pred = outputs.gender_logits
|
207 |
-
|
208 |
-
# Process results
|
209 |
-
age = int(torch.clamp(age_pred, 0, 100).item())
|
210 |
-
gender_prob = gender_pred.item()
|
211 |
-
gender = "Female" if gender_prob > 0.5 else "Male"
|
212 |
-
confidence = gender_prob if gender_prob > 0.5 else 1 - gender_prob
|
213 |
-
|
214 |
-
return {
|
215 |
-
"age": age,
|
216 |
-
"gender": gender,
|
217 |
-
"confidence": round(confidence, 3),
|
218 |
-
"gender_probability": round(gender_prob, 3)
|
219 |
-
}
|
220 |
-
|
221 |
-
|
222 |
-
# Test both approaches
|
223 |
-
if __name__ == "__main__":
|
224 |
-
print("🧪 Testing Age-Gender Prediction...")
|
225 |
-
|
226 |
-
# Test image URL
|
227 |
-
test_url = "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?w=300&h=300&fit=crop&crop=face"
|
228 |
-
|
229 |
-
try:
|
230 |
-
print("🔧 Testing manual approach...")
|
231 |
-
result_manual = predict_age_gender_manual(test_url)
|
232 |
-
print(f"✅ Manual result: {result_manual}")
|
233 |
-
except Exception as e:
|
234 |
-
print(f"❌ Manual failed: {e}")
|
235 |
-
|
236 |
-
try:
|
237 |
-
print("🚀 Testing pipeline approach...")
|
238 |
-
result_pipeline = predict_age_gender_pipeline(test_url)
|
239 |
-
print(f"✅ Pipeline result: {result_pipeline}")
|
240 |
-
except Exception as e:
|
241 |
-
print(f"❌ Pipeline failed: {e}")
|
242 |
-
print("Note: Pipeline requires the files to be uploaded to your HF repo")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|