Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -30,13 +30,12 @@ ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_e
|
|
| 30 |
FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf"
|
| 31 |
|
| 32 |
# Paths for local storage
|
| 33 |
-
MODEL_PATH = "hindi_ocr_model.keras"
|
| 34 |
-
ENCODER_PATH = "label_encoder.pkl"
|
| 35 |
-
FONT_PATH = "NotoSansDevanagari-Regular.ttf"
|
| 36 |
-
OUTPUT_DIR = "output"
|
| 37 |
|
| 38 |
-
#
|
| 39 |
-
|
| 40 |
|
| 41 |
# Download model and encoder
|
| 42 |
def download_file(url, dest):
|
|
@@ -56,6 +55,10 @@ def load_label_encoder():
|
|
| 56 |
with open(ENCODER_PATH, 'rb') as f:
|
| 57 |
return pickle.load(f)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# Download required files on startup
|
| 60 |
@app.on_event("startup")
|
| 61 |
async def startup_event():
|
|
@@ -106,6 +109,9 @@ def run_sakshi_ocr(image_path):
|
|
| 106 |
sys.stdout = old_stdout
|
| 107 |
return buffer.getvalue()
|
| 108 |
|
|
|
|
|
|
|
|
|
|
| 109 |
# Main OCR processing function
|
| 110 |
def process_image(image_array):
|
| 111 |
# Convert image array to grayscale
|
|
@@ -113,10 +119,14 @@ def process_image(image_array):
|
|
| 113 |
|
| 114 |
# Word detection
|
| 115 |
word_detected_img, word_count = detect_words(img)
|
| 116 |
-
word_detection_path =
|
| 117 |
cv2.imwrite(word_detection_path, word_detected_img)
|
| 118 |
|
|
|
|
|
|
|
|
|
|
| 119 |
# First OCR model prediction
|
|
|
|
| 120 |
try:
|
| 121 |
img_resized = cv2.resize(img, (128, 32))
|
| 122 |
img_norm = img_resized / 255.0
|
|
@@ -132,27 +142,28 @@ def process_image(image_array):
|
|
| 132 |
ax.imshow(img, cmap='gray')
|
| 133 |
ax.set_title(f"Predicted: {pred_label}", fontsize=12)
|
| 134 |
ax.axis('off')
|
| 135 |
-
pred_path =
|
| 136 |
plt.savefig(pred_path)
|
| 137 |
plt.close()
|
|
|
|
|
|
|
|
|
|
| 138 |
else:
|
| 139 |
-
pred_path = None
|
| 140 |
pred_label = "Model or encoder not loaded"
|
| 141 |
except Exception as e:
|
| 142 |
-
pred_path = None
|
| 143 |
pred_label = f"Error: {str(e)}"
|
| 144 |
|
| 145 |
# Sakshi OCR processing
|
| 146 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
|
| 147 |
cv2.imwrite(tmp_file.name, img)
|
| 148 |
sakshi_output = run_sakshi_ocr(tmp_file.name)
|
| 149 |
-
os.
|
| 150 |
|
| 151 |
return {
|
| 152 |
"sakshi_output": sakshi_output,
|
| 153 |
-
"word_detection_path": word_detection_path,
|
| 154 |
"word_count": word_count,
|
| 155 |
-
"prediction_path": pred_path,
|
| 156 |
"prediction_label": pred_label
|
| 157 |
}
|
| 158 |
|
|
@@ -167,6 +178,15 @@ async def process(file: UploadFile = File(...)):
|
|
| 167 |
if not file.content_type.startswith("image/"):
|
| 168 |
raise HTTPException(status_code=400, detail="File must be an image")
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
# Create a temporary file to save the uploaded image
|
| 171 |
temp_file = tempfile.NamedTemporaryFile(delete=False)
|
| 172 |
try:
|
|
@@ -193,18 +213,16 @@ async def process(file: UploadFile = File(...)):
|
|
| 193 |
@app.get("/word-detection/")
|
| 194 |
async def get_word_detection():
|
| 195 |
"""Return the word detection image."""
|
| 196 |
-
|
| 197 |
-
if not word_detection_path.exists():
|
| 198 |
raise HTTPException(status_code=404, detail="Word detection image not found. Process an image first.")
|
| 199 |
-
return FileResponse(
|
| 200 |
|
| 201 |
@app.get("/prediction/")
|
| 202 |
async def get_prediction():
|
| 203 |
"""Return the prediction image."""
|
| 204 |
-
|
| 205 |
-
if not prediction_path.exists():
|
| 206 |
raise HTTPException(status_code=404, detail="Prediction image not found. Process an image first.")
|
| 207 |
-
return FileResponse(
|
| 208 |
|
| 209 |
@app.get("/")
|
| 210 |
async def root():
|
|
|
|
| 30 |
FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf"
|
| 31 |
|
| 32 |
# Paths for local storage
|
| 33 |
+
MODEL_PATH = os.path.join(tempfile.gettempdir(), "hindi_ocr_model.keras")
|
| 34 |
+
ENCODER_PATH = os.path.join(tempfile.gettempdir(), "label_encoder.pkl")
|
| 35 |
+
FONT_PATH = os.path.join(tempfile.gettempdir(), "NotoSansDevanagari-Regular.ttf")
|
|
|
|
| 36 |
|
| 37 |
+
# Use a temporary directory for outputs
|
| 38 |
+
OUTPUT_DIR = tempfile.mkdtemp()
|
| 39 |
|
| 40 |
# Download model and encoder
|
| 41 |
def download_file(url, dest):
|
|
|
|
| 55 |
with open(ENCODER_PATH, 'rb') as f:
|
| 56 |
return pickle.load(f)
|
| 57 |
|
| 58 |
+
# Set up global variables
|
| 59 |
+
model = None
|
| 60 |
+
label_encoder = None
|
| 61 |
+
|
| 62 |
# Download required files on startup
|
| 63 |
@app.on_event("startup")
|
| 64 |
async def startup_event():
|
|
|
|
| 109 |
sys.stdout = old_stdout
|
| 110 |
return buffer.getvalue()
|
| 111 |
|
| 112 |
+
# File storage for session
|
| 113 |
+
session_files = {}
|
| 114 |
+
|
| 115 |
# Main OCR processing function
|
| 116 |
def process_image(image_array):
|
| 117 |
# Convert image array to grayscale
|
|
|
|
| 119 |
|
| 120 |
# Word detection
|
| 121 |
word_detected_img, word_count = detect_words(img)
|
| 122 |
+
word_detection_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
|
| 123 |
cv2.imwrite(word_detection_path, word_detected_img)
|
| 124 |
|
| 125 |
+
# Store the file path in our session dict
|
| 126 |
+
session_files['word_detection'] = word_detection_path
|
| 127 |
+
|
| 128 |
# First OCR model prediction
|
| 129 |
+
pred_path = None
|
| 130 |
try:
|
| 131 |
img_resized = cv2.resize(img, (128, 32))
|
| 132 |
img_norm = img_resized / 255.0
|
|
|
|
| 142 |
ax.imshow(img, cmap='gray')
|
| 143 |
ax.set_title(f"Predicted: {pred_label}", fontsize=12)
|
| 144 |
ax.axis('off')
|
| 145 |
+
pred_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
|
| 146 |
plt.savefig(pred_path)
|
| 147 |
plt.close()
|
| 148 |
+
|
| 149 |
+
# Store the file path in our session dict
|
| 150 |
+
session_files['prediction'] = pred_path
|
| 151 |
else:
|
|
|
|
| 152 |
pred_label = "Model or encoder not loaded"
|
| 153 |
except Exception as e:
|
|
|
|
| 154 |
pred_label = f"Error: {str(e)}"
|
| 155 |
|
| 156 |
# Sakshi OCR processing
|
| 157 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
|
| 158 |
cv2.imwrite(tmp_file.name, img)
|
| 159 |
sakshi_output = run_sakshi_ocr(tmp_file.name)
|
| 160 |
+
os.unlink(tmp_file.name)
|
| 161 |
|
| 162 |
return {
|
| 163 |
"sakshi_output": sakshi_output,
|
| 164 |
+
"word_detection_path": word_detection_path if 'word_detection' in session_files else None,
|
| 165 |
"word_count": word_count,
|
| 166 |
+
"prediction_path": pred_path if 'prediction' in session_files else None,
|
| 167 |
"prediction_label": pred_label
|
| 168 |
}
|
| 169 |
|
|
|
|
| 178 |
if not file.content_type.startswith("image/"):
|
| 179 |
raise HTTPException(status_code=400, detail="File must be an image")
|
| 180 |
|
| 181 |
+
# Clean up previous session files
|
| 182 |
+
for key, filepath in session_files.items():
|
| 183 |
+
if os.path.exists(filepath):
|
| 184 |
+
try:
|
| 185 |
+
os.unlink(filepath)
|
| 186 |
+
except:
|
| 187 |
+
pass
|
| 188 |
+
session_files.clear()
|
| 189 |
+
|
| 190 |
# Create a temporary file to save the uploaded image
|
| 191 |
temp_file = tempfile.NamedTemporaryFile(delete=False)
|
| 192 |
try:
|
|
|
|
| 213 |
@app.get("/word-detection/")
|
| 214 |
async def get_word_detection():
|
| 215 |
"""Return the word detection image."""
|
| 216 |
+
if 'word_detection' not in session_files or not os.path.exists(session_files['word_detection']):
|
|
|
|
| 217 |
raise HTTPException(status_code=404, detail="Word detection image not found. Process an image first.")
|
| 218 |
+
return FileResponse(session_files['word_detection'])
|
| 219 |
|
| 220 |
@app.get("/prediction/")
|
| 221 |
async def get_prediction():
|
| 222 |
"""Return the prediction image."""
|
| 223 |
+
if 'prediction' not in session_files or not os.path.exists(session_files['prediction']):
|
|
|
|
| 224 |
raise HTTPException(status_code=404, detail="Prediction image not found. Process an image first.")
|
| 225 |
+
return FileResponse(session_files['prediction'])
|
| 226 |
|
| 227 |
@app.get("/")
|
| 228 |
async def root():
|