sameernotes commited on
Commit
12fbfda
·
verified ·
1 Parent(s): 5f09c62

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +33 -0
  2. app.py +161 -0
  3. models/indian_name_gender_model.pt +3 -0
  4. requirements.txt +18 -0
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ # Install system dependencies for OpenCV
4
+ RUN apt-get update && apt-get install -y \
5
+ libgl1-mesa-glx \
6
+ libglib2.0-0 \
7
+ libsm6 \
8
+ libxrender1 \
9
+ libxext6 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ WORKDIR /code
13
+
14
+ # Create a non-root user to run the application
15
+ RUN useradd -m appuser
16
+
17
+ # Create directories with appropriate permissions
18
+ RUN mkdir -p /code/output && \
19
+ chown -R appuser:appuser /code
20
+
21
+ COPY --chown=appuser:appuser ./requirements.txt /code/requirements.txt
22
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
23
+
24
+ # For headless matplotlib
25
+ ENV MPLBACKEND=Agg
26
+
27
+ COPY --chown=appuser:appuser . /code/
28
+
29
+ # Switch to the non-root user
30
+ USER appuser
31
+
32
+ # Make sure the app.py file is correctly named
33
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Query
2
+ from pydantic import BaseModel, Field, conlist
3
+ import torch
4
+ import torch.nn as nn
5
+ import os # Import the 'os' module
6
+ from typing import List
7
+
8
+ # --- Model Definition (same as before) ---
9
+ class NameGenderClassifierCNN(nn.Module):
10
+ def __init__(self, vocab_size, embedding_dim, num_filters=64, filter_sizes=[2, 3, 4], dropout=0.5):
11
+ super(NameGenderClassifierCNN, self).__init__()
12
+
13
+ self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
14
+
15
+ self.convs = nn.ModuleList([
16
+ nn.Conv1d(in_channels=embedding_dim, out_channels=num_filters, kernel_size=fs)
17
+ for fs in filter_sizes
18
+ ])
19
+
20
+ self.fc1 = nn.Linear(len(filter_sizes) * num_filters, 100)
21
+ self.fc2 = nn.Linear(100, 1)
22
+ self.dropout = nn.Dropout(dropout)
23
+ self.sigmoid = nn.Sigmoid()
24
+
25
+ def forward(self, x):
26
+ x = self.embedding(x)
27
+ x = x.transpose(1, 2)
28
+ conv_outputs = []
29
+ for conv in self.convs:
30
+ conv_out = torch.relu(conv(x))
31
+ pool_out = torch.max_pool1d(conv_out, conv_out.shape[2])
32
+ conv_outputs.append(pool_out.squeeze(2))
33
+ x = torch.cat(conv_outputs, dim=1)
34
+ x = self.dropout(x)
35
+ x = torch.relu(self.fc1(x))
36
+ x = self.dropout(x)
37
+ x = self.fc2(x)
38
+ return self.sigmoid(x).squeeze()
39
+
40
+
41
+
42
+ # --- Utility Function (same as before, but adapted) ---
43
+
44
+ def tokenize_name(name, char_to_idx, max_length):
45
+ """Tokenizes and pads a name."""
46
+ name = str(name).lower()
47
+ tokens = [char_to_idx.get(char, char_to_idx.get(' ', 1)) for char in name]
48
+
49
+ # Pad or truncate
50
+ if len(tokens) < max_length:
51
+ tokens = tokens + [char_to_idx['<PAD>']] * (max_length - len(tokens))
52
+ else:
53
+ tokens = tokens[:max_length]
54
+
55
+ return tokens
56
+
57
+
58
+ # --- FastAPI Setup ---
59
+
60
+ app = FastAPI(title="Indian Name Gender Prediction API",
61
+ description="Predicts the gender of Indian names using a CNN model.",
62
+ version="1.0")
63
+
64
+ # --- Model Loading (on startup) ---
65
+
66
+ MODEL_PATH = "models/indian_name_gender_model.pt" # Correct path within the space
67
+
68
+
69
+ def load_model():
70
+ """Loads the model, char_to_idx, and max_name_length."""
71
+ try:
72
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
73
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
74
+ char_to_idx = checkpoint['char_to_idx']
75
+ max_name_length = checkpoint['max_name_length']
76
+ config = checkpoint['model_config']
77
+
78
+ model = NameGenderClassifierCNN(
79
+ vocab_size=config['vocab_size'],
80
+ embedding_dim=config['embedding_dim'],
81
+ num_filters=config['num_filters'],
82
+ filter_sizes=config['filter_sizes']
83
+ )
84
+ model.load_state_dict(checkpoint['model_state_dict'])
85
+ model.to(device)
86
+ model.eval() # Set to evaluation mode
87
+ return model, char_to_idx, max_name_length, device
88
+ except Exception as e:
89
+ raise Exception(f"Error loading model: {e}")
90
+
91
+ # Load model at startup
92
+ try:
93
+ model, char_to_idx, max_name_length, device = load_model()
94
+ except Exception as e:
95
+ print(f"Failed to load model: {e}")
96
+ raise # Re-raise the exception to halt startup
97
+
98
+ # --- Pydantic Models (for request/response validation) ---
99
+
100
+ class PredictionRequest(BaseModel):
101
+ names: conlist(str, min_length=1) = Field(..., example=["Aarav", "Anika"])
102
+ threshold: float = Field(0.5, ge=0.0, le=1.0, description="Probability threshold for classifying as male.")
103
+
104
+ class PredictionResponse(BaseModel):
105
+ predictions: List[dict] = Field(..., example=[
106
+ {"name": "Aarav", "predicted_gender": "Male", "male_probability": 0.95, "confidence": 0.95},
107
+ {"name": "Anika", "predicted_gender": "Female", "male_probability": 0.05, "confidence": 0.95}
108
+ ])
109
+
110
+
111
+ # --- Prediction Function ---
112
+
113
+ def predict_gender(name: str, model, char_to_idx, max_length, device, threshold: float = 0.5) -> tuple[str, float, float]:
114
+ """Predicts gender for a single name. Includes threshold."""
115
+ tokenized_name = tokenize_name(name, char_to_idx, max_length)
116
+ input_tensor = torch.tensor([tokenized_name], dtype=torch.long).to(device)
117
+
118
+ with torch.no_grad():
119
+ output = model(input_tensor)
120
+ probability = output.item()
121
+ predicted_gender = 'Male' if probability >= threshold else 'Female'
122
+ confidence = probability if probability >= threshold else 1 - probability
123
+ return predicted_gender, probability, confidence
124
+
125
+ # --- API Endpoints ---
126
+
127
+ @app.get("/", response_model=str)
128
+ async def read_root():
129
+ return "Welcome to the Indian Name Gender Prediction API. Use the /predict endpoint."
130
+
131
+ @app.post("/predict", response_model=PredictionResponse)
132
+ async def predict(request: PredictionRequest):
133
+ """Predicts the gender of one or more Indian names."""
134
+ try:
135
+ predictions = []
136
+ for name in request.names:
137
+ gender, prob, conf = predict_gender(name, model, char_to_idx, max_name_length, device, request.threshold)
138
+ predictions.append({
139
+ "name": name,
140
+ "predicted_gender": gender,
141
+ "male_probability": prob,
142
+ "confidence": conf
143
+ })
144
+ return {"predictions": predictions}
145
+ except Exception as e:
146
+ raise HTTPException(status_code=500, detail=str(e))
147
+
148
+ @app.get("/predict_single")
149
+ async def predict_single(name: str = Query(..., description="The name to predict."),
150
+ threshold: float = Query(0.5, ge=0.0, le=1.0, description="Probability threshold for classifying as male.")):
151
+ """Predicts gender for a *single* name, provided as a query parameter."""
152
+ try:
153
+ gender, prob, conf = predict_gender(name, model, char_to_idx, max_name_length, device, threshold)
154
+ return {
155
+ "name": name,
156
+ "predicted_gender": gender,
157
+ "male_probability": prob,
158
+ "confidence": conf
159
+ }
160
+ except Exception as e:
161
+ raise HTTPException(status_code=500, detail=str(e))
models/indian_name_gender_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8cdfdbc357c5567e1f45fd752459608a8b097cc6bc820a4347bbaeb543c1075
3
+ size 508560
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ tensorflow
4
+ numpy
5
+ pandas
6
+ opencv-python
7
+ matplotlib
8
+ scikit-learn
9
+ python-multipart
10
+ sakshi-ocr
11
+ pydantic
12
+ requests
13
+ google-genai
14
+ py-text-scan
15
+ SQLAlchemy
16
+ passlib
17
+ python-multipart
18
+ pydantic[email]