import numpy as np import triton_python_backend_utils as pb_utils from omnicloudmask import predict_from_array import rasterio from rasterio.io import MemoryFile from rasterio.enums import Resampling class TritonPythonModel: def initialize(self, args): """ Initialize the model. This function is called once when the model is loaded. """ # You can load models or initialize resources here if needed. # Ensure rasterio is installed in the Python backend environment. print('Initialized Cloud Detection model with JP2 input') def execute(self, requests): """ Process inference requests. """ responses = [] # Every request must contain three JP2 byte strings (Red, Green, NIR). for request in requests: # Get the input tensor containing the byte arrays input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes") # as_numpy() for TYPE_STRING gives an ndarray of Python bytes objects jp2_bytes_list = input_tensor.as_numpy() if len(jp2_bytes_list) != 3: # Send an error response if the input shape is incorrect error = pb_utils.TritonError(f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}") response = pb_utils.InferenceResponse(output_tensors=[], error=error) responses.append(response) continue # Skip to the next request # Assume order: Red, Green, NIR based on client logic red_bytes = jp2_bytes_list[0] green_bytes = jp2_bytes_list[1] nir_bytes = jp2_bytes_list[2] try: # Process JP2 bytes using rasterio in memory with MemoryFile(red_bytes) as memfile_red: with memfile_red.open() as src_red: red_data = src_red.read(1).astype(np.float32) target_height = src_red.height target_width = src_red.width with MemoryFile(green_bytes) as memfile_green: with memfile_green.open() as src_green: # Ensure green band matches red band dimensions (should if B03) if src_green.height != target_height or src_green.width != target_width: # Optional: Resample green if necessary, though B03 usually matches B04 green_data = src_green.read( 1, out_shape=(1, target_height, target_width), resampling=Resampling.bilinear ).astype(np.float32) else: green_data = src_green.read(1).astype(np.float32) with MemoryFile(nir_bytes) as memfile_nir: with memfile_nir.open() as src_nir: # Resample NIR (B8A) to match Red/Green (B04/B03) resolution nir_data = src_nir.read( 1, # Read the first band out_shape=(1, target_height, target_width), resampling=Resampling.bilinear ).astype(np.float32) # Stack bands in CHW format (Red, Green, NIR) for the model # Match the channel order expected by predict_from_array input_array = np.stack([red_data, green_data, nir_data], axis=0) # Perform inference using the original function pred_mask = predict_from_array(input_array) # Create output tensor output_tensor = pb_utils.Tensor( "output_mask", pred_mask.astype(np.uint8) ) response = pb_utils.InferenceResponse([output_tensor]) except Exception as e: # Handle errors during processing (e.g., invalid JP2 data) error = pb_utils.TritonError(f"Error processing JP2 data: {str(e)}") response = pb_utils.InferenceResponse(output_tensors=[], error=error) responses.append(response) # Return a list of responses return responses def finalize(self): """ Called when the model is unloaded. Perform any necessary cleanup. """ print('Finalizing Cloud Detection model')