""" HuggingFace Inference Endpoints Handler for DSE-Qwen2-2B-MRL-V1 Encoder """ import torch from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, Qwen2VLProcessor from qwen_vl_utils import process_vision_info from PIL import Image import json import base64 import io import os from typing import Dict, List, Any, Optional # Compatibility patch for torch.compiler.is_compiling # This function was introduced in torch 2.2.0 if not hasattr(torch.compiler, 'is_compiling'): torch.compiler.is_compiling = lambda: False class EndpointHandler: def __init__(self, path=""): """ Initialize the DSE-Qwen2-2B-MRL-V1 encoder model and processor Args: path (str): Path to the model weights directory """ print(f"Initializing DSE-Qwen2-2B-MRL-V1 Encoder Handler") print(f"Model path: {path}") # Model configuration self.model_path = path if path else "." self.min_pixels = 1 * 28 * 28 self.max_pixels = 2560 * 28 * 28 self.device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {self.device}") # Initialize model and processor self.model = None self.processor = None # Load model and processor self._load_processor() self._load_model() # Configure padding if self.processor and hasattr(self.processor, 'tokenizer'): self.processor.tokenizer.padding_side = "left" if self.model and hasattr(self.model, 'padding_side'): self.model.padding_side = "left" print("Handler initialization completed successfully") def _load_processor(self): """Load the processor with error handling""" print("Loading processor...") try: # Try using Qwen2VLProcessor directly first self.processor = Qwen2VLProcessor.from_pretrained( self.model_path, min_pixels=self.min_pixels, max_pixels=self.max_pixels, trust_remote_code=True, local_files_only=True ) print("Processor loaded successfully with Qwen2VLProcessor") except Exception as e: print(f"Could not load with Qwen2VLProcessor: {str(e)[:200]}") try: # Fallback to AutoProcessor self.processor = AutoProcessor.from_pretrained( self.model_path, min_pixels=self.min_pixels, max_pixels=self.max_pixels, trust_remote_code=True, local_files_only=True ) print("Processor loaded successfully with AutoProcessor") except Exception as e2: print(f"Failed to load processor: {e2}") raise def _load_model(self): """Load the model with attention fallback strategy""" print("Loading model...") # Base model configuration model_kwargs = { 'torch_dtype': torch.bfloat16, 'device_map': "auto", 'low_cpu_mem_usage': True, 'trust_remote_code': True, 'local_files_only': True } # Try different attention implementations in order of preference attention_implementations = [ ("flash_attention_2", "Flash Attention 2"), ("sdpa", "SDPA (Scaled Dot Product Attention)"), ("eager", "Eager (standard) attention") ] model_loaded = False for attn_impl, attn_name in attention_implementations: try: model_kwargs['attn_implementation'] = attn_impl self.model = Qwen2VLForConditionalGeneration.from_pretrained( self.model_path, **model_kwargs ).eval() print(f"Model loaded successfully with {attn_name}") model_loaded = True break except Exception as e: print(f"Could not load with {attn_name}: {str(e)[:200]}") continue if not model_loaded: # Final fallback: try without specifying attention implementation try: model_kwargs.pop('attn_implementation', None) self.model = Qwen2VLForConditionalGeneration.from_pretrained( self.model_path, **model_kwargs ).eval() print("Model loaded successfully with default attention") except Exception as e: print(f"Failed to load model: {e}") raise def get_embedding(self, last_hidden_state: torch.Tensor, dimension: int) -> torch.Tensor: """Extract normalized embeddings from the last hidden state""" reps = last_hidden_state[:, -1] reps = torch.nn.functional.normalize(reps[:, :dimension], p=2, dim=-1) return reps def decode_base64_image(self, base64_string: str) -> Optional[Image.Image]: """Decode base64 encoded image string to PIL Image""" try: # Remove data URL prefix if present if base64_string.startswith('data:image'): base64_string = base64_string.split(',', 1)[1] # Decode base64 to bytes image_bytes = base64.b64decode(base64_string) # Convert to PIL Image image = Image.open(io.BytesIO(image_bytes)).convert('RGB') return image except Exception as e: print(f"Error decoding base64 image: {e}") return None def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Handle inference requests Expected input format: { "inputs": { "query": "text to encode", "image": "base64_encoded_image_optional", "embedding_dimension": 1536, "return_format": "json" } } Or simplified format: { "query": "text to encode", "image": "base64_encoded_image_optional", "embedding_dimension": 1536, "return_format": "json" } Returns: List[Dict[str, Any]]: JSON serializable response """ try: print("Processing inference request...") # Handle nested inputs format or direct format if "inputs" in data and isinstance(data["inputs"], dict): request_data = data["inputs"] else: request_data = data # Extract parameters from request query = request_data.get("query", "") if not query: return [{ "error": "Input is missing the 'query' key. Please include a query." }] image_input = request_data.get("image", None) embedding_dimension = request_data.get("embedding_dimension", 1536) return_format = request_data.get("return_format", "json") # Validate embedding dimension if not isinstance(embedding_dimension, int) or embedding_dimension < 128 or embedding_dimension > 4096: embedding_dimension = 1536 print(f"Processing query: {query[:50]}...") print(f"Embedding dimension: {embedding_dimension}") # Prepare message content content = [] # Handle image input print("Preparing image input...") if image_input is not None: # Handle base64 encoded image if isinstance(image_input, str): input_image = self.decode_base64_image(image_input) if input_image is not None: content.append({'type': 'image', 'image': input_image}) else: print("Failed to decode image, using dummy image") dummy_image = Image.new('RGB', (28, 28), color='white') content.append({ 'type': 'image', 'image': dummy_image, 'resized_height': 1, 'resized_width': 1 }) # Handle PIL Image object elif isinstance(image_input, Image.Image): content.append({'type': 'image', 'image': image_input.convert('RGB')}) else: print(f"Unsupported image input type: {type(image_input)}") dummy_image = Image.new('RGB', (28, 28), color='white') content.append({ 'type': 'image', 'image': dummy_image, 'resized_height': 1, 'resized_width': 1 }) else: # Use small dummy image as model expects visual input dummy_image = Image.new('RGB', (28, 28), color='white') content.append({ 'type': 'image', 'image': dummy_image, 'resized_height': 1, 'resized_width': 1 }) print("Image input prepared") # Add text query content.append({'type': 'text', 'text': f'Query: {query}'}) # Format message print("Formatting messages...") messages = [{ 'role': 'user', 'content': content }] # Apply chat template print("Applying chat template...") query_text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) + "<|endoftext|>" # Process vision inputs print("Processing vision inputs...") query_image_inputs, query_video_inputs = process_vision_info(messages) # Tokenize and prepare inputs print("Tokenizing and preparing inputs...") query_inputs = self.processor( text=[query_text], images=query_image_inputs, videos=query_video_inputs, padding='longest', return_tensors='pt' ) # Move to appropriate device print("Moving inputs to device...") device = next(self.model.parameters()).device query_inputs = {k: v.to(device) for k, v in query_inputs.items()} # Prepare for generation print("Preparing for generation...") cache_position = torch.arange(0, query_inputs['input_ids'].shape[1], device=device) query_inputs = self.model.prepare_inputs_for_generation( **query_inputs, cache_position=cache_position, use_cache=False ) # Generate embeddings print("Generating embeddings...") with torch.no_grad(): output = self.model(**query_inputs, return_dict=True, output_hidden_states=True) # Extract embedding print("Extracting embeddings...") query_embedding = self.get_embedding(output.hidden_states[-1], embedding_dimension) # Convert to list format embedding_list = query_embedding.cpu().float().numpy().tolist()[0] # Return in requested format print("Preparing output...") if return_format == "json": result = [{ "embedding": embedding_list, "dimension": embedding_dimension, "query": query }] else: result = [{"embedding": embedding_list}] print("Inference completed successfully") return result except Exception as e: print(f"Error in inference: {str(e)}") import traceback traceback.print_exc() return [{"error": str(e)}]