sameernotes commited on
Commit
1e8ac17
·
verified ·
1 Parent(s): 3472658

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +291 -83
app.py CHANGED
@@ -1,13 +1,17 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.responses import FileResponse, JSONResponse
3
- from pydantic import BaseModel
 
 
 
 
 
4
  import cv2
5
  import numpy as np
6
  import tensorflow as tf
7
  import pickle
8
  import matplotlib.pyplot as plt
9
  import matplotlib.font_manager as fm
10
- # import py_text_scan
11
  import os
12
  import io
13
  import sys
@@ -18,33 +22,199 @@ import uvicorn
18
  import shutil
19
  from pathlib import Path
20
  import py_text_scan
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  app = FastAPI(
23
  title="Hindi OCR API",
24
- description="API for Hindi OCR and word detection",
25
  version="1.0.0"
26
  )
27
 
28
- # URLs for the model and encoder hosted on Hugging Face
29
  MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras"
30
  ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl"
31
  FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf"
 
 
 
32
 
33
- # Paths for local storage
34
- MODEL_PATH = os.path.join(tempfile.gettempdir(), "hindi_ocr_model.keras")
35
- ENCODER_PATH = os.path.join(tempfile.gettempdir(), "label_encoder.pkl")
36
- FONT_PATH = os.path.join(tempfile.gettempdir(), "NotoSansDevanagari-Regular.ttf")
37
-
38
- # Use a temporary directory for outputs
39
- OUTPUT_DIR = tempfile.mkdtemp()
40
-
41
- # Download model and encoder
42
  def download_file(url, dest):
43
- response = requests.get(url)
44
- with open(dest, 'wb') as f:
45
- f.write(response.content)
 
 
 
 
 
46
 
47
- # Load the model and encoder
48
  def load_model():
49
  if not os.path.exists(MODEL_PATH):
50
  return None
@@ -52,54 +222,56 @@ def load_model():
52
 
53
  def load_label_encoder():
54
  if not os.path.exists(ENCODER_PATH):
55
- return None
56
  with open(ENCODER_PATH, 'rb') as f:
57
  return pickle.load(f)
58
-
59
- # Set up global variables
60
  model = None
61
  label_encoder = None
 
62
 
63
- # Download required files on startup
64
  @app.on_event("startup")
65
  async def startup_event():
66
- # Download models and font if not already present
67
- if not os.path.exists(MODEL_PATH):
68
- download_file(MODEL_URL, MODEL_PATH)
69
- if not os.path.exists(ENCODER_PATH):
70
- download_file(ENCODER_URL, ENCODER_PATH)
71
- if not os.path.exists(FONT_PATH):
72
- download_file(FONT_URL, FONT_PATH)
73
-
74
- # Load the custom font if available
75
  if os.path.exists(FONT_PATH):
76
  fm.fontManager.addfont(FONT_PATH)
77
  plt.rcParams['font.family'] = 'Noto Sans Devanagari'
78
-
79
- # Initialize global variables
80
- global model, label_encoder
81
  model = load_model()
82
  label_encoder = load_label_encoder()
83
 
84
- # Word detection function
 
 
 
 
 
 
 
 
 
 
 
 
85
  def detect_words(image):
86
  _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
87
  kernel = np.ones((3,3), np.uint8)
88
  dilated = cv2.dilate(binary, kernel, iterations=2)
89
  contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
90
-
91
  word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
92
  word_count = 0
93
-
94
  for contour in contours:
95
  x, y, w, h = cv2.boundingRect(contour)
96
  if w > 10 and h > 10:
97
  cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
98
  word_count += 1
99
-
100
  return word_img, word_count
101
 
102
- # Sakshi OCR output capture
103
  def run_py_text_scan(image_path):
104
  buffer = io.StringIO()
105
  old_stdout = sys.stdout
@@ -110,35 +282,23 @@ def run_py_text_scan(image_path):
110
  sys.stdout = old_stdout
111
  return buffer.getvalue()
112
 
113
- # File storage for session
114
- session_files = {}
115
-
116
- # Main OCR processing function
117
  def process_image(image_array):
118
- # Convert image array to grayscale
119
  img = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
120
-
121
- # Word detection
122
  word_detected_img, word_count = detect_words(img)
123
  word_detection_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
124
  cv2.imwrite(word_detection_path, word_detected_img)
125
-
126
- # Store the file path in our session dict
127
  session_files['word_detection'] = word_detection_path
128
-
129
- # First OCR model prediction
130
  pred_path = None
131
  try:
132
  img_resized = cv2.resize(img, (128, 32))
133
  img_norm = img_resized / 255.0
134
- img_input = img_norm[np.newaxis, ..., np.newaxis] # Shape: (1, 32, 128, 1)
135
-
136
  if model is not None and label_encoder is not None:
137
  pred = model.predict(img_input)
138
  pred_label_idx = np.argmax(pred)
139
  pred_label = label_encoder.inverse_transform([pred_label_idx])[0]
140
-
141
- # Create plot with prediction
142
  fig, ax = plt.subplots()
143
  ax.imshow(img, cmap='gray')
144
  ax.set_title(f"Predicted: {pred_label}", fontsize=12)
@@ -146,20 +306,16 @@ def process_image(image_array):
146
  pred_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
147
  plt.savefig(pred_path)
148
  plt.close()
149
-
150
- # Store the file path in our session dict
151
  session_files['prediction'] = pred_path
152
  else:
153
  pred_label = "Model or encoder not loaded"
154
  except Exception as e:
155
  pred_label = f"Error: {str(e)}"
156
-
157
- # Sakshi OCR processing
158
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
159
  cv2.imwrite(tmp_file.name, img)
160
  sakshi_output = run_py_text_scan(tmp_file.name)
161
  os.unlink(tmp_file.name)
162
-
163
  return {
164
  "sakshi_output": sakshi_output,
165
  "word_detection_path": word_detection_path if 'word_detection' in session_files else None,
@@ -168,18 +324,38 @@ def process_image(image_array):
168
  "prediction_label": pred_label
169
  }
170
 
171
- class OCRResponse(BaseModel):
172
- sakshi_output: str
173
- word_count: int
174
- prediction_label: str
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  @app.post("/process/", response_model=OCRResponse)
177
- async def process(file: UploadFile = File(...)):
178
- # Check if the file is an image
179
  if not file.content_type.startswith("image/"):
180
  raise HTTPException(status_code=400, detail="File must be an image")
181
-
182
- # Clean up previous session files
183
  for key, filepath in session_files.items():
184
  if os.path.exists(filepath):
185
  try:
@@ -187,19 +363,14 @@ async def process(file: UploadFile = File(...)):
187
  except:
188
  pass
189
  session_files.clear()
190
-
191
- # Create a temporary file to save the uploaded image
192
  temp_file = tempfile.NamedTemporaryFile(delete=False)
193
  try:
194
- # Save the uploaded file
195
  with temp_file as f:
196
  shutil.copyfileobj(file.file, f)
197
-
198
- # Open and process the image
199
  image = Image.open(temp_file.name)
200
  image_array = np.array(image)
201
  result = process_image(image_array)
202
-
203
  return OCRResponse(
204
  sakshi_output=result["sakshi_output"],
205
  word_count=result["word_count"],
@@ -208,27 +379,64 @@ async def process(file: UploadFile = File(...)):
208
  except Exception as e:
209
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
210
  finally:
211
- # Clean up the temporary file
212
  os.unlink(temp_file.name)
213
 
214
  @app.get("/word-detection/")
215
- async def get_word_detection():
216
- """Return the word detection image."""
217
  if 'word_detection' not in session_files or not os.path.exists(session_files['word_detection']):
218
- raise HTTPException(status_code=404, detail="Word detection image not found. Process an image first.")
219
  return FileResponse(session_files['word_detection'])
220
 
221
  @app.get("/prediction/")
222
- async def get_prediction():
223
- """Return the prediction image."""
224
  if 'prediction' not in session_files or not os.path.exists(session_files['prediction']):
225
- raise HTTPException(status_code=404, detail="Prediction image not found. Process an image first.")
226
  return FileResponse(session_files['prediction'])
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  @app.get("/")
229
  async def root():
230
- return {"message": "Hindi OCR API is running. Use POST /process/ to analyze images."}
 
231
 
232
- # For local testing
233
  if __name__ == "__main__":
234
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ # app.py (Complete, for Hugging Face Spaces)
2
+
3
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, status, Request
4
+ from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
5
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
6
+ from fastapi.templating import Jinja2Templates
7
+ from pydantic import BaseModel, EmailStr, Field
8
+ from typing import List, Optional
9
  import cv2
10
  import numpy as np
11
  import tensorflow as tf
12
  import pickle
13
  import matplotlib.pyplot as plt
14
  import matplotlib.font_manager as fm
 
15
  import os
16
  import io
17
  import sys
 
22
  import shutil
23
  from pathlib import Path
24
  import py_text_scan
25
+ from sqlalchemy import create_engine, Column, Integer, String, Boolean, Text, DateTime
26
+ from sqlalchemy.ext.declarative import declarative_base
27
+ from sqlalchemy.orm import sessionmaker, Session
28
+ from passlib.context import CryptContext
29
+ import datetime
30
+
31
+ # --- Database Setup (SQLite) ---
32
+ DATABASE_URL = "sqlite:///./test.db"
33
+ engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
34
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
35
+ Base = declarative_base()
36
+
37
+ # --- Database Models ---
38
+ class User(Base):
39
+ __tablename__ = "users"
40
+ id = Column(Integer, primary_key=True, index=True)
41
+ username = Column(String, unique=True, index=True)
42
+ email = Column(String, unique=True, index=True)
43
+ hashed_password = Column(String)
44
+ is_active = Column(Boolean, default=True)
45
+ is_admin = Column(Boolean, default=False)
46
+
47
+ class Feedback(Base):
48
+ __tablename__ = "feedback"
49
+ id = Column(Integer, primary_key=True, index=True)
50
+ username = Column(String)
51
+ comment = Column(Text)
52
+ created_at = Column(DateTime, default=datetime.datetime.utcnow)
53
+
54
+ Base.metadata.create_all(bind=engine) # Create tables
55
+
56
+ # --- Pydantic Schemas ---
57
+ class UserBase(BaseModel):
58
+ username: str = Field(..., min_length=3, max_length=50)
59
+ email: EmailStr
60
+ password: str = Field(..., min_length=6)
61
+
62
+ class UserCreate(UserBase):
63
+ pass
64
+
65
+ class User(UserBase):
66
+ id: int
67
+ is_active: bool
68
+ is_admin: bool
69
+ class Config:
70
+ from_attributes = True
71
+
72
+ class UserUpdate(BaseModel):
73
+ username: Optional[str] = None
74
+ email: Optional[EmailStr] = None
75
+ is_active: Optional[bool] = None
76
+ is_admin: Optional[bool] = None
77
+
78
+ class FeedbackBase(BaseModel):
79
+ username: str
80
+ comment: str
81
+
82
+ class FeedbackCreate(FeedbackBase):
83
+ pass
84
+
85
+ class Feedback(FeedbackBase):
86
+ id: int
87
+ created_at: datetime.datetime
88
+ class Config:
89
+ from_attributes = True
90
+
91
+ class Token(BaseModel):
92
+ access_token: str
93
+ token_type: str
94
+
95
+ class TokenData(BaseModel):
96
+ username: str | None = None
97
+
98
+ class OCRResponse(BaseModel):
99
+ sakshi_output: str
100
+ word_count: int
101
+ prediction_label: str
102
+
103
+
104
+ # --- Password Hashing ---
105
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
106
+
107
+ # --- Authentication ---
108
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
109
+
110
+ def get_db():
111
+ db = SessionLocal()
112
+ try:
113
+ yield db
114
+ finally:
115
+ db.close()
116
+
117
+ async def get_current_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
118
+ user = get_user_by_username(db, username=token)
119
+ if not user:
120
+ raise HTTPException(
121
+ status_code=status.HTTP_401_UNAUTHORIZED,
122
+ detail="Invalid authentication credentials",
123
+ headers={"WWW-Authenticate": "Bearer"},
124
+ )
125
+ return user
126
+
127
+ async def get_current_active_user(current_user: User = Depends(get_current_user)):
128
+ if not current_user.is_active:
129
+ raise HTTPException(status_code=400, detail="Inactive user")
130
+ return current_user
131
+
132
+ async def get_current_admin_user(current_user: User = Depends(get_current_active_user)):
133
+ if not current_user.is_admin:
134
+ raise HTTPException(status_code=403, detail="Not an administrator")
135
+ return current_user
136
+
137
+
138
+ # --- CRUD Operations ---
139
+ def get_user(db: Session, user_id: int):
140
+ return db.query(User).filter(User.id == user_id).first()
141
+
142
+ def get_user_by_username(db: Session, username: str):
143
+ return db.query(User).filter(User.username == username).first()
144
+
145
+ def get_user_by_email(db: Session, email: str):
146
+ return db.query(User).filter(User.email == email).first()
147
+
148
+ def get_users(db: Session, skip: int = 0, limit: int = 100):
149
+ return db.query(User).offset(skip).limit(limit).all()
150
+
151
+ def create_user(db: Session, user: UserCreate):
152
+ hashed_password = pwd_context.hash(user.password)
153
+ db_user = User(username=user.username, email=user.email, hashed_password=hashed_password)
154
+ db.add(db_user)
155
+ db.commit()
156
+ db.refresh(db_user)
157
+ return db_user
158
+
159
+ def update_user(db: Session, user_id: int, user: UserUpdate):
160
+ db_user = get_user(db, user_id)
161
+ if db_user:
162
+ for key, value in user.dict(exclude_unset=True).items():
163
+ setattr(db_user, key, value)
164
+ db.commit()
165
+ db.refresh(db_user)
166
+ return db_user
167
+
168
+ def delete_user(db: Session, user_id: int):
169
+ db_user = get_user(db, user_id)
170
+ if db_user:
171
+ db.delete(db_user)
172
+ db.commit()
173
+ return True
174
+ return False
175
 
176
+ def verify_password(plain_password, hashed_password):
177
+ return pwd_context.verify(plain_password, hashed_password)
178
+
179
+ def create_feedback(db: Session, feedback: FeedbackCreate):
180
+ db_feedback = Feedback(**feedback.dict())
181
+ db.add(db_feedback)
182
+ db.commit()
183
+ db.refresh(db_feedback)
184
+ return db_feedback
185
+
186
+ def get_feedback(db: Session, skip: int = 0, limit: int = 100):
187
+ return db.query(Feedback).order_by(Feedback.created_at.desc()).offset(skip).limit(limit).all()
188
+
189
+
190
+
191
+ # --- FastAPI App Setup ---
192
  app = FastAPI(
193
  title="Hindi OCR API",
194
+ description="API for Hindi OCR, word detection, authentication, and feedback",
195
  version="1.0.0"
196
  )
197
 
198
+ # --- Hugging Face Model and Resource URLs ---
199
  MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras"
200
  ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl"
201
  FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf"
202
+ MODEL_PATH = "hindi_ocr_model.keras" # Local paths after download
203
+ ENCODER_PATH = "label_encoder.pkl"
204
+ FONT_PATH = "NotoSansDevanagari-Regular.ttf"
205
 
206
+ # --- Download Helper ---
 
 
 
 
 
 
 
 
207
  def download_file(url, dest):
208
+ if not os.path.exists(dest):
209
+ print(f"Downloading {dest}...")
210
+ response = requests.get(url, stream=True)
211
+ response.raise_for_status()
212
+ with open(dest, 'wb') as f:
213
+ for chunk in response.iter_content(chunk_size=8192):
214
+ f.write(chunk)
215
+ print(f"Downloaded {dest}")
216
 
217
+ # --- Model Loading ---
218
  def load_model():
219
  if not os.path.exists(MODEL_PATH):
220
  return None
 
222
 
223
  def load_label_encoder():
224
  if not os.path.exists(ENCODER_PATH):
225
+ return None
226
  with open(ENCODER_PATH, 'rb') as f:
227
  return pickle.load(f)
228
+ # --- Global Variables ---
 
229
  model = None
230
  label_encoder = None
231
+ session_files = {} # For storing temporary file paths
232
 
233
+ # --- Startup Event ---
234
  @app.on_event("startup")
235
  async def startup_event():
236
+ global model, label_encoder
237
+ download_file(MODEL_URL, MODEL_PATH)
238
+ download_file(ENCODER_URL, ENCODER_PATH)
239
+ download_file(FONT_URL, FONT_PATH)
240
+
 
 
 
 
241
  if os.path.exists(FONT_PATH):
242
  fm.fontManager.addfont(FONT_PATH)
243
  plt.rcParams['font.family'] = 'Noto Sans Devanagari'
 
 
 
244
  model = load_model()
245
  label_encoder = load_label_encoder()
246
 
247
+ # Create an admin user if one doesn't exist
248
+ db = SessionLocal()
249
+ if not get_user_by_username(db, "admin"):
250
+ admin_user = UserCreate(username="admin", email="[email protected]", password="adminpassword") #Change the password here
251
+ create_user(db, admin_user)
252
+ admin = get_user_by_username(db, "admin")
253
+ admin.is_admin = True # Make this user an admin
254
+ db.commit()
255
+ db.close()
256
+
257
+
258
+
259
+ # --- Word Detection ---
260
  def detect_words(image):
261
  _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
262
  kernel = np.ones((3,3), np.uint8)
263
  dilated = cv2.dilate(binary, kernel, iterations=2)
264
  contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
265
  word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
266
  word_count = 0
 
267
  for contour in contours:
268
  x, y, w, h = cv2.boundingRect(contour)
269
  if w > 10 and h > 10:
270
  cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
271
  word_count += 1
 
272
  return word_img, word_count
273
 
274
+ # --- Sakshi OCR ---
275
  def run_py_text_scan(image_path):
276
  buffer = io.StringIO()
277
  old_stdout = sys.stdout
 
282
  sys.stdout = old_stdout
283
  return buffer.getvalue()
284
 
285
+ # --- Image Processing ---
 
 
 
286
  def process_image(image_array):
 
287
  img = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
 
 
288
  word_detected_img, word_count = detect_words(img)
289
  word_detection_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
290
  cv2.imwrite(word_detection_path, word_detected_img)
 
 
291
  session_files['word_detection'] = word_detection_path
292
+
 
293
  pred_path = None
294
  try:
295
  img_resized = cv2.resize(img, (128, 32))
296
  img_norm = img_resized / 255.0
297
+ img_input = img_norm[np.newaxis, ..., np.newaxis]
 
298
  if model is not None and label_encoder is not None:
299
  pred = model.predict(img_input)
300
  pred_label_idx = np.argmax(pred)
301
  pred_label = label_encoder.inverse_transform([pred_label_idx])[0]
 
 
302
  fig, ax = plt.subplots()
303
  ax.imshow(img, cmap='gray')
304
  ax.set_title(f"Predicted: {pred_label}", fontsize=12)
 
306
  pred_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
307
  plt.savefig(pred_path)
308
  plt.close()
 
 
309
  session_files['prediction'] = pred_path
310
  else:
311
  pred_label = "Model or encoder not loaded"
312
  except Exception as e:
313
  pred_label = f"Error: {str(e)}"
314
+
 
315
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
316
  cv2.imwrite(tmp_file.name, img)
317
  sakshi_output = run_py_text_scan(tmp_file.name)
318
  os.unlink(tmp_file.name)
 
319
  return {
320
  "sakshi_output": sakshi_output,
321
  "word_detection_path": word_detection_path if 'word_detection' in session_files else None,
 
324
  "prediction_label": pred_label
325
  }
326
 
327
+ # --- API Endpoints ---
 
 
 
328
 
329
+ # Authentication Endpoints
330
+ @app.post("/token", response_model=Token)
331
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
332
+ user = get_user_by_username(db, form_data.username)
333
+ if not user or not verify_password(form_data.password, user.hashed_password):
334
+ raise HTTPException(
335
+ status_code=status.HTTP_401_UNAUTHORIZED,
336
+ detail="Incorrect username or password",
337
+ headers={"WWW-Authenticate": "Bearer"},
338
+ )
339
+ # Use username as the access token (for simplicity in this example)
340
+ access_token = user.username
341
+ return {"access_token": access_token, "token_type": "bearer"}
342
+
343
+ @app.post("/signup", response_model=User)
344
+ async def signup(user: UserCreate = Depends(), db: Session = Depends(get_db)):
345
+ db_user = get_user_by_username(db, username=user.username)
346
+ if db_user:
347
+ raise HTTPException(status_code=400, detail="Username already registered")
348
+ db_user = get_user_by_email(db, email=user.email)
349
+ if db_user:
350
+ raise HTTPException(status_code=400, detail="Email already registered")
351
+ return create_user(db=db, user=user)
352
+
353
+ # OCR Endpoint
354
  @app.post("/process/", response_model=OCRResponse)
355
+ async def process(file: UploadFile = File(...), current_user: User = Depends(get_current_active_user)):
 
356
  if not file.content_type.startswith("image/"):
357
  raise HTTPException(status_code=400, detail="File must be an image")
358
+
 
359
  for key, filepath in session_files.items():
360
  if os.path.exists(filepath):
361
  try:
 
363
  except:
364
  pass
365
  session_files.clear()
366
+
 
367
  temp_file = tempfile.NamedTemporaryFile(delete=False)
368
  try:
 
369
  with temp_file as f:
370
  shutil.copyfileobj(file.file, f)
 
 
371
  image = Image.open(temp_file.name)
372
  image_array = np.array(image)
373
  result = process_image(image_array)
 
374
  return OCRResponse(
375
  sakshi_output=result["sakshi_output"],
376
  word_count=result["word_count"],
 
379
  except Exception as e:
380
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
381
  finally:
 
382
  os.unlink(temp_file.name)
383
 
384
  @app.get("/word-detection/")
385
+ async def get_word_detection(current_user: User = Depends(get_current_active_user)):
 
386
  if 'word_detection' not in session_files or not os.path.exists(session_files['word_detection']):
387
+ raise HTTPException(status_code=404, detail="Word detection image not found")
388
  return FileResponse(session_files['word_detection'])
389
 
390
  @app.get("/prediction/")
391
+ async def get_prediction(current_user: User = Depends(get_current_active_user)):
 
392
  if 'prediction' not in session_files or not os.path.exists(session_files['prediction']):
393
+ raise HTTPException(status_code=404, detail="Prediction image not found")
394
  return FileResponse(session_files['prediction'])
395
 
396
+ # Feedback Endpoint
397
+ @app.post("/feedback/", response_model=Feedback)
398
+ async def create_feedback_route(feedback: FeedbackCreate, current_user: User = Depends(get_current_active_user),db: Session = Depends(get_db)):
399
+ return create_feedback(db=db, feedback=feedback)
400
+
401
+ # Admin Endpoints
402
+ @app.get("/admin/users/", response_model=List[User])
403
+ async def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_user)):
404
+ users = get_users(db, skip=skip, limit=limit)
405
+ return users
406
+
407
+ @app.get("/admin/users/{user_id}", response_model=User)
408
+ async def read_user(user_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_user)):
409
+ db_user = get_user(db, user_id=user_id)
410
+ if db_user is None:
411
+ raise HTTPException(status_code=404, detail="User not found")
412
+ return db_user
413
+
414
+ @app.put("/admin/users/{user_id}", response_model=User)
415
+ async def update_user_route(user_id: int, user: UserUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_user)):
416
+ updated_user = update_user(db=db, user_id=user_id, user=user)
417
+ if updated_user is None:
418
+ raise HTTPException(status_code=404, detail="User not found")
419
+ return updated_user
420
+
421
+ @app.delete("/admin/users/{user_id}", response_model=dict)
422
+ async def delete_user_route(user_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_user)):
423
+ if delete_user(db=db, user_id=user_id):
424
+ return {"message": "User deleted successfully"}
425
+ else:
426
+ raise HTTPException(status_code=404, detail="User not found")
427
+
428
+
429
+ @app.get("/admin/feedback/", response_model=List[Feedback])
430
+ async def read_feedback(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_user)):
431
+ feedback = get_feedback(db, skip=skip, limit=limit)
432
+ return feedback
433
+
434
+ # Basic Root Endpoint
435
  @app.get("/")
436
  async def root():
437
+ return {"message": "Hindi OCR API with Authentication and Admin. See /docs for API details."}
438
+
439
 
440
+ # --- Run with uvicorn (for local testing) ---
441
  if __name__ == "__main__":
442
  uvicorn.run(app, host="0.0.0.0", port=8000)