sameernotes commited on
Commit
d4898e7
·
verified ·
1 Parent(s): c2d88be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -77
app.py CHANGED
@@ -1,11 +1,8 @@
1
- # app.py (Complete, for Hugging Face Spaces)
2
-
3
  from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, status, Request
4
  from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
5
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
6
- from fastapi.templating import Jinja2Templates
7
  from pydantic import BaseModel, EmailStr, Field
8
- from typing import List, Optional # Import Optional from typing
9
  import cv2
10
  import numpy as np
11
  import tensorflow as tf
@@ -51,7 +48,7 @@ class FeedbackModel(Base):
51
  comment = Column(Text)
52
  created_at = Column(DateTime, default=datetime.datetime.utcnow)
53
 
54
- Base.metadata.create_all(bind=engine) # Create tables
55
 
56
  # --- Pydantic Schemas ---
57
  class UserBase(BaseModel):
@@ -92,14 +89,13 @@ class Token(BaseModel):
92
  token_type: str
93
 
94
  class TokenData(BaseModel):
95
- username: Optional[str] = None # Use Optional[str] instead of str | None
96
 
97
  class OCRResponse(BaseModel):
98
  sakshi_output: str
99
  word_count: int
100
  prediction_label: str
101
 
102
-
103
  # --- Password Hashing ---
104
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
105
 
@@ -114,7 +110,6 @@ def get_db():
114
  db.close()
115
 
116
  async def get_current_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
117
- # Correctly retrieve user by matching token with username.
118
  user = get_user_by_username(db, username=token)
119
  if not user:
120
  raise HTTPException(
@@ -124,19 +119,16 @@ async def get_current_user(db: Session = Depends(get_db), token: str = Depends(o
124
  )
125
  return user
126
 
127
-
128
- async def get_current_active_user(current_user: UserModel = Depends(get_current_user)): #Use UserModel
129
  if not current_user.is_active:
130
  raise HTTPException(status_code=400, detail="Inactive user")
131
  return current_user
132
 
133
- async def get_current_admin_user(current_user: UserModel = Depends(get_current_active_user)): #Use UserModel
134
  if not current_user.is_admin:
135
  raise HTTPException(status_code=403, detail="Not an administrator")
136
  return current_user
137
 
138
-
139
-
140
  # --- CRUD Operations ---
141
  def get_user(db: Session, user_id: int):
142
  return db.query(UserModel).filter(UserModel.id == user_id).first()
@@ -188,8 +180,6 @@ def create_feedback(db: Session, feedback: FeedbackCreate):
188
  def get_feedback(db: Session, skip: int = 0, limit: int = 100):
189
  return db.query(FeedbackModel).order_by(FeedbackModel.created_at.desc()).offset(skip).limit(limit).all()
190
 
191
-
192
-
193
  # --- FastAPI App Setup ---
194
  app = FastAPI(
195
  title="Hindi OCR API",
@@ -201,11 +191,10 @@ app = FastAPI(
201
  MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras"
202
  ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl"
203
  FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf"
204
- MODEL_PATH = "hindi_ocr_model.keras" # Local paths after download
205
  ENCODER_PATH = "label_encoder.pkl"
206
  FONT_PATH = "NotoSansDevanagari-Regular.ttf"
207
 
208
- # --- Download Helper ---
209
  def download_file(url, dest):
210
  if not os.path.exists(dest):
211
  print(f"Downloading {dest}...")
@@ -216,7 +205,6 @@ def download_file(url, dest):
216
  f.write(chunk)
217
  print(f"Downloaded {dest}")
218
 
219
- # --- Model Loading ---
220
  def load_model():
221
  if not os.path.exists(MODEL_PATH):
222
  return None
@@ -227,12 +215,11 @@ def load_label_encoder():
227
  return None
228
  with open(ENCODER_PATH, 'rb') as f:
229
  return pickle.load(f)
230
- # --- Global Variables ---
231
  model = None
232
  label_encoder = None
233
- session_files = {} # For storing temporary file paths
234
 
235
- # --- Startup Event ---
236
  @app.on_event("startup")
237
  async def startup_event():
238
  global model, label_encoder
@@ -246,19 +233,15 @@ async def startup_event():
246
  model = load_model()
247
  label_encoder = load_label_encoder()
248
 
249
- # Create an admin user if one doesn't exist
250
  db = SessionLocal()
251
  if not get_user_by_username(db, "admin"):
252
- admin_user = UserCreate(username="admin", email="[email protected]", password="adminpassword") #Change the password here
253
  create_user(db, admin_user)
254
  admin = get_user_by_username(db, "admin")
255
- admin.is_admin = True # Make this user an admin
256
  db.commit()
257
  db.close()
258
 
259
-
260
-
261
- # --- Word Detection ---
262
  def detect_words(image):
263
  _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
264
  kernel = np.ones((3,3), np.uint8)
@@ -273,7 +256,6 @@ def detect_words(image):
273
  word_count += 1
274
  return word_img, word_count
275
 
276
- # --- Sakshi OCR ---
277
  def run_py_text_scan(image_path):
278
  buffer = io.StringIO()
279
  old_stdout = sys.stdout
@@ -284,7 +266,6 @@ def run_py_text_scan(image_path):
284
  sys.stdout = old_stdout
285
  return buffer.getvalue()
286
 
287
- # --- Image Processing ---
288
  def process_image(image_array):
289
  img = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
290
  word_detected_img, word_count = detect_words(img)
@@ -326,9 +307,6 @@ def process_image(image_array):
326
  "prediction_label": pred_label
327
  }
328
 
329
- # --- API Endpoints ---
330
-
331
- # Authentication Endpoints
332
  @app.post("/token", response_model=Token)
333
  async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
334
  user = get_user_by_username(db, form_data.username)
@@ -338,7 +316,6 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(
338
  detail="Incorrect username or password",
339
  headers={"WWW-Authenticate": "Bearer"},
340
  )
341
- # Use username as the access token (for simplicity in this example)
342
  access_token = user.username
343
  return {"access_token": access_token, "token_type": "bearer"}
344
 
@@ -350,13 +327,11 @@ async def signup(user: UserCreate, db: Session = Depends(get_db)):
350
  db_user_email = get_user_by_email(db, email=user.email)
351
  if db_user_email:
352
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered")
353
- created = create_user(db=db, user=user)
354
  return created
355
 
356
-
357
- # OCR Endpoint
358
  @app.post("/process/", response_model=OCRResponse)
359
- async def process(file: UploadFile = File(...), current_user: UserModel = Depends(get_current_active_user)):
360
  if not file.content_type.startswith("image/"):
361
  raise HTTPException(status_code=400, detail="File must be an image")
362
 
@@ -397,50 +372,26 @@ async def get_prediction(current_user: UserModel = Depends(get_current_active_us
397
  raise HTTPException(status_code=404, detail="Prediction image not found")
398
  return FileResponse(session_files['prediction'])
399
 
400
- # Feedback Endpoint
 
401
  @app.post("/feedback/", response_model=FeedbackResponse)
402
- async def create_feedback_route(feedback: FeedbackCreate, current_user: UserModel = Depends(get_current_active_user),db: Session = Depends(get_db)):
403
  return create_feedback(db=db, feedback=feedback)
404
 
405
- # Admin Endpoints
406
- @app.get("/admin/users/", response_model=List[UserResponse])
407
- async def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_admin_user)):
408
- users = get_users(db, skip=skip, limit=limit)
409
- return users
410
-
411
- @app.get("/admin/users/{user_id}", response_model=UserResponse)
412
- async def read_user(user_id: int, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_admin_user)):
413
- db_user = get_user(db, user_id=user_id)
414
- if db_user is None:
415
- raise HTTPException(status_code=404, detail="User not found")
416
- return db_user
417
-
418
- @app.put("/admin/users/{user_id}", response_model=UserResponse)
419
- async def update_user_route(user_id: int, user: UserUpdate, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_admin_user)):
420
- updated_user = update_user(db=db, user_id=user_id, user=user)
421
- if updated_user is None:
422
- raise HTTPException(status_code=404, detail="User not found")
423
- return updated_user
424
-
425
- @app.delete("/admin/users/{user_id}", response_model=dict)
426
- async def delete_user_route(user_id: int, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_admin_user)):
427
- if delete_user(db=db, user_id=user_id):
428
- return {"message": "User deleted successfully"}
429
- else:
430
- raise HTTPException(status_code=404, detail="User not found")
431
-
432
-
433
- @app.get("/admin/feedback/", response_model=List[FeedbackResponse])
434
- async def read_feedback(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_admin_user)):
435
- feedback = get_feedback(db, skip=skip, limit=limit)
436
- return feedback
437
 
438
- # Basic Root Endpoint
439
- @app.get("/")
440
- async def root():
441
- return {"message": "Hindi OCR API with Authentication and Admin. See /docs for API details."}
 
442
 
 
 
 
443
 
444
- # --- Run with uvicorn (for local testing) ---
445
  if __name__ == "__main__":
446
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, status, Request
2
  from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
3
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
 
4
  from pydantic import BaseModel, EmailStr, Field
5
+ from typing import List, Optional
6
  import cv2
7
  import numpy as np
8
  import tensorflow as tf
 
48
  comment = Column(Text)
49
  created_at = Column(DateTime, default=datetime.datetime.utcnow)
50
 
51
+ Base.metadata.create_all(bind=engine)
52
 
53
  # --- Pydantic Schemas ---
54
  class UserBase(BaseModel):
 
89
  token_type: str
90
 
91
  class TokenData(BaseModel):
92
+ username: Optional[str] = None
93
 
94
  class OCRResponse(BaseModel):
95
  sakshi_output: str
96
  word_count: int
97
  prediction_label: str
98
 
 
99
  # --- Password Hashing ---
100
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
101
 
 
110
  db.close()
111
 
112
  async def get_current_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
 
113
  user = get_user_by_username(db, username=token)
114
  if not user:
115
  raise HTTPException(
 
119
  )
120
  return user
121
 
122
+ async def get_current_active_user(current_user: UserModel = Depends(get_current_user)):
 
123
  if not current_user.is_active:
124
  raise HTTPException(status_code=400, detail="Inactive user")
125
  return current_user
126
 
127
+ async def get_current_admin_user(current_user: UserModel = Depends(get_current_active_user)):
128
  if not current_user.is_admin:
129
  raise HTTPException(status_code=403, detail="Not an administrator")
130
  return current_user
131
 
 
 
132
  # --- CRUD Operations ---
133
  def get_user(db: Session, user_id: int):
134
  return db.query(UserModel).filter(UserModel.id == user_id).first()
 
180
  def get_feedback(db: Session, skip: int = 0, limit: int = 100):
181
  return db.query(FeedbackModel).order_by(FeedbackModel.created_at.desc()).offset(skip).limit(limit).all()
182
 
 
 
183
  # --- FastAPI App Setup ---
184
  app = FastAPI(
185
  title="Hindi OCR API",
 
191
  MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras"
192
  ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl"
193
  FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf"
194
+ MODEL_PATH = "hindi_ocr_model.keras"
195
  ENCODER_PATH = "label_encoder.pkl"
196
  FONT_PATH = "NotoSansDevanagari-Regular.ttf"
197
 
 
198
  def download_file(url, dest):
199
  if not os.path.exists(dest):
200
  print(f"Downloading {dest}...")
 
205
  f.write(chunk)
206
  print(f"Downloaded {dest}")
207
 
 
208
  def load_model():
209
  if not os.path.exists(MODEL_PATH):
210
  return None
 
215
  return None
216
  with open(ENCODER_PATH, 'rb') as f:
217
  return pickle.load(f)
218
+
219
  model = None
220
  label_encoder = None
221
+ session_files = {}
222
 
 
223
  @app.on_event("startup")
224
  async def startup_event():
225
  global model, label_encoder
 
233
  model = load_model()
234
  label_encoder = load_label_encoder()
235
 
 
236
  db = SessionLocal()
237
  if not get_user_by_username(db, "admin"):
238
+ admin_user = UserCreate(username="admin", email="[email protected]", password="adminpassword")
239
  create_user(db, admin_user)
240
  admin = get_user_by_username(db, "admin")
241
+ admin.is_admin = True
242
  db.commit()
243
  db.close()
244
 
 
 
 
245
  def detect_words(image):
246
  _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
247
  kernel = np.ones((3,3), np.uint8)
 
256
  word_count += 1
257
  return word_img, word_count
258
 
 
259
  def run_py_text_scan(image_path):
260
  buffer = io.StringIO()
261
  old_stdout = sys.stdout
 
266
  sys.stdout = old_stdout
267
  return buffer.getvalue()
268
 
 
269
  def process_image(image_array):
270
  img = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
271
  word_detected_img, word_count = detect_words(img)
 
307
  "prediction_label": pred_label
308
  }
309
 
 
 
 
310
  @app.post("/token", response_model=Token)
311
  async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
312
  user = get_user_by_username(db, form_data.username)
 
316
  detail="Incorrect username or password",
317
  headers={"WWW-Authenticate": "Bearer"},
318
  )
 
319
  access_token = user.username
320
  return {"access_token": access_token, "token_type": "bearer"}
321
 
 
327
  db_user_email = get_user_by_email(db, email=user.email)
328
  if db_user_email:
329
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered")
330
+ created = create_user(db=db, user=user)
331
  return created
332
 
 
 
333
  @app.post("/process/", response_model=OCRResponse)
334
+ async def process(file: UploadFile = File(...), current_user: UserModel = Depends(get_current_active_user)):
335
  if not file.content_type.startswith("image/"):
336
  raise HTTPException(status_code=400, detail="File must be an image")
337
 
 
372
  raise HTTPException(status_code=404, detail="Prediction image not found")
373
  return FileResponse(session_files['prediction'])
374
 
375
+ # --- Modified Feedback Endpoint ---
376
+ # No authentication dependency is used here so that anyone can submit feedback.
377
  @app.post("/feedback/", response_model=FeedbackResponse)
378
+ async def create_feedback_route(feedback: FeedbackCreate, db: Session = Depends(get_db)):
379
  return create_feedback(db=db, feedback=feedback)
380
 
381
+ # --- Admin Endpoints ---
382
+ @app.get("/admin/users/")
383
+ async def admin_get_users(skip: int = 0, limit: int = 100, current_user: UserModel = Depends(get_current_admin_user), db: Session = Depends(get_db)):
384
+ return get_users(db, skip=skip, limit=limit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
+ @app.delete("/admin/users/{user_id}")
387
+ async def admin_delete_user(user_id: int, current_user: UserModel = Depends(get_current_admin_user), db: Session = Depends(get_db)):
388
+ if delete_user(db, user_id):
389
+ return {"detail": "User deleted successfully"}
390
+ raise HTTPException(status_code=404, detail="User not found")
391
 
392
+ @app.get("/admin/feedback/")
393
+ async def admin_get_feedback(skip: int = 0, limit: int = 100, current_user: UserModel = Depends(get_current_admin_user), db: Session = Depends(get_db)):
394
+ return get_feedback(db, skip=skip, limit=limit)
395
 
 
396
  if __name__ == "__main__":
397
+ uvicorn.run(app, host="0.0.0.0", port=8000)