Shak33l-UiRev commited on
Commit
5a29686
·
verified ·
1 Parent(s): 956f2af
Files changed (1) hide show
  1. app.py +114 -28
app.py CHANGED
@@ -15,6 +15,15 @@ import io
15
  import base64
16
  import json
17
  from datetime import datetime
 
 
 
 
 
 
 
 
 
18
 
19
  @st.cache_resource
20
  def load_model(model_name):
@@ -27,9 +36,72 @@ def load_model(model_name):
27
  dict: Dictionary containing model components
28
  """
29
  try:
30
- if model_name == "Donut":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
32
  model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
 
33
  # Configure Donut specific parameters
34
  model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
35
  model.config.pad_token_id = processor.tokenizer.pad_token_id
@@ -42,34 +114,13 @@ def load_model(model_name):
42
  model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
43
 
44
  return {'model': model, 'processor': processor}
45
-
46
- elif model_name == "OmniParser":
47
- # Load YOLO model for icon detection
48
- yolo_model = YOLO("microsoft/OmniParser-icon-detection")
49
-
50
- # Load Florence-2 processor and model for captioning
51
- processor = AutoProcessor.from_pretrained(
52
- "microsoft/OmniParser-caption",
53
- trust_remote_code=True
54
- )
55
-
56
- # Load the captioning model
57
- caption_model = AutoModelForCausalLM.from_pretrained(
58
- "microsoft/OmniParser-caption",
59
- trust_remote_code=True
60
- )
61
-
62
- return {
63
- 'yolo': yolo_model,
64
- 'processor': processor,
65
- 'model': caption_model
66
- }
67
 
68
  else:
69
  raise ValueError(f"Unknown model name: {model_name}")
70
 
71
  except Exception as e:
72
  st.error(f"Error loading model {model_name}: {str(e)}")
 
73
  return None
74
 
75
  @spaces.GPU
@@ -357,16 +408,20 @@ if uploaded_file is not None and selected_model:
357
  st.info("Loading model...")
358
 
359
  add_debug(f"Loading {selected_model} model and processor...")
360
- model, processor = load_model(selected_model)
361
 
362
- if model is None or processor is None:
363
  with result_col:
364
  st.error("Failed to load model. Please try again.")
365
  add_debug("Model loading failed!", "error")
366
  else:
367
  add_debug("Model loaded successfully", "success")
368
- add_debug(f"Model device: {next(model.parameters()).device}")
369
- add_debug(f"Model memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f}MB") if torch.cuda.is_available() else None
 
 
 
 
370
 
371
  # Update progress
372
  with result_col:
@@ -379,7 +434,7 @@ if uploaded_file is not None and selected_model:
379
 
380
  # Analyze document
381
  add_debug("Starting document analysis...")
382
- results = analyze_document(image, selected_model, model, processor)
383
  add_debug("Analysis completed", "success")
384
 
385
  # Update progress
@@ -425,6 +480,37 @@ if uploaded_file is not None and selected_model:
425
  add_debug("Traceback available in logs", "warning")
426
 
427
  # Add improved information about usage and limitations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  st.markdown("""
429
  ---
430
  ### Usage Notes:
 
15
  import base64
16
  import json
17
  from datetime import datetime
18
+ import os
19
+ import logging
20
+
21
+ # Add this near the top of the file, after imports
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(levelname)s - %(message)s'
25
+ )
26
+ logger = logging.getLogger(__name__)
27
 
28
  @st.cache_resource
29
  def load_model(model_name):
 
36
  dict: Dictionary containing model components
37
  """
38
  try:
39
+ if model_name == "OmniParser":
40
+ try:
41
+ # First try loading from HuggingFace Hub with correct repository structure
42
+ yolo_model = YOLO("microsoft/OmniParser/icon_detect") # Updated path
43
+
44
+ processor = AutoProcessor.from_pretrained(
45
+ "microsoft/OmniParser/icon_caption_florence", # Updated path
46
+ trust_remote_code=True
47
+ )
48
+
49
+ caption_model = AutoModelForCausalLM.from_pretrained(
50
+ "microsoft/OmniParser/icon_caption_florence", # Updated path
51
+ trust_remote_code=True,
52
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
53
+ )
54
+
55
+ if torch.cuda.is_available():
56
+ caption_model = caption_model.to("cuda")
57
+
58
+ st.success("Successfully loaded OmniParser models")
59
+ return {
60
+ 'yolo': yolo_model,
61
+ 'processor': processor,
62
+ 'model': caption_model
63
+ }
64
+
65
+ except Exception as e:
66
+ st.error(f"Failed to load OmniParser from HuggingFace Hub: {str(e)}")
67
+
68
+ # Try loading from local weights if available
69
+ weights_path = "weights"
70
+ if os.path.exists(os.path.join(weights_path, "icon_detect/model.safetensors")):
71
+ st.info("Attempting to load from local weights...")
72
+
73
+ yolo_model = YOLO(os.path.join(weights_path, "icon_detect/model.safetensors"))
74
+
75
+ processor = AutoProcessor.from_pretrained(
76
+ os.path.join(weights_path, "icon_caption_florence"),
77
+ trust_remote_code=True,
78
+ local_files_only=True
79
+ )
80
+
81
+ caption_model = AutoModelForCausalLM.from_pretrained(
82
+ os.path.join(weights_path, "icon_caption_florence"),
83
+ trust_remote_code=True,
84
+ local_files_only=True,
85
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
86
+ )
87
+
88
+ if torch.cuda.is_available():
89
+ caption_model = caption_model.to("cuda")
90
+
91
+ st.success("Successfully loaded OmniParser from local weights")
92
+ return {
93
+ 'yolo': yolo_model,
94
+ 'processor': processor,
95
+ 'model': caption_model
96
+ }
97
+ else:
98
+ st.error("Could not find local weights and HuggingFace Hub loading failed")
99
+ raise ValueError("No valid model weights found for OmniParser")
100
+
101
+ elif model_name == "Donut":
102
  processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
103
  model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
104
+
105
  # Configure Donut specific parameters
106
  model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
107
  model.config.pad_token_id = processor.tokenizer.pad_token_id
 
114
  model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
115
 
116
  return {'model': model, 'processor': processor}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  else:
119
  raise ValueError(f"Unknown model name: {model_name}")
120
 
121
  except Exception as e:
122
  st.error(f"Error loading model {model_name}: {str(e)}")
123
+ logger.error(f"Error details: {str(e)}", exc_info=True)
124
  return None
125
 
126
  @spaces.GPU
 
408
  st.info("Loading model...")
409
 
410
  add_debug(f"Loading {selected_model} model and processor...")
411
+ models_dict = load_model(selected_model)
412
 
413
+ if models_dict is None:
414
  with result_col:
415
  st.error("Failed to load model. Please try again.")
416
  add_debug("Model loading failed!", "error")
417
  else:
418
  add_debug("Model loaded successfully", "success")
419
+ # For device info, we need to check which model we're using
420
+ if selected_model == "OmniParser":
421
+ model_device = next(models_dict['model'].parameters()).device
422
+ else:
423
+ model_device = next(models_dict['model'].parameters()).device
424
+ add_debug(f"Model device: {model_device}")
425
 
426
  # Update progress
427
  with result_col:
 
434
 
435
  # Analyze document
436
  add_debug("Starting document analysis...")
437
+ results = analyze_document(image, selected_model, models_dict)
438
  add_debug("Analysis completed", "success")
439
 
440
  # Update progress
 
480
  add_debug("Traceback available in logs", "warning")
481
 
482
  # Add improved information about usage and limitations
483
+ def verify_weights_directory():
484
+ """Verify the weights directory structure and files"""
485
+ weights_path = "weights"
486
+ required_files = {
487
+ os.path.join(weights_path, "icon_detect", "model.safetensors"): "YOLO model weights",
488
+ os.path.join(weights_path, "icon_detect", "model.yaml"): "YOLO model config",
489
+ os.path.join(weights_path, "icon_caption_florence", "model.safetensors"): "Florence model weights",
490
+ os.path.join(weights_path, "icon_caption_florence", "config.json"): "Florence model config",
491
+ os.path.join(weights_path, "icon_caption_florence", "generation_config.json"): "Florence generation config"
492
+ }
493
+
494
+ missing_files = []
495
+ for file_path, description in required_files.items():
496
+ if not os.path.exists(file_path):
497
+ missing_files.append(f"{description} at {file_path}")
498
+
499
+ if missing_files:
500
+ st.warning("Missing required model files:")
501
+ for missing in missing_files:
502
+ st.write(f"- {missing}")
503
+ return False
504
+
505
+ return True
506
+
507
+ # Add this in your app's initialization
508
+ if st.checkbox("Check Model Files"):
509
+ if verify_weights_directory():
510
+ st.success("All required model files are present")
511
+ else:
512
+ st.error("Some model files are missing. Please ensure all required files are in the weights directory")
513
+
514
  st.markdown("""
515
  ---
516
  ### Usage Notes: