File size: 4,523 Bytes
36d88d2 024f5b3 36d88d2 024f5b3 36d88d2 024f5b3 36d88d2 024f5b3 36d88d2 024f5b3 36d88d2 024f5b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import numpy as np
import triton_python_backend_utils as pb_utils
from omnicloudmask import predict_from_array
import rasterio
from rasterio.io import MemoryFile
from rasterio.enums import Resampling
class TritonPythonModel:
def initialize(self, args):
"""
Initialize the model. This function is called once when the model is loaded.
"""
# You can load models or initialize resources here if needed.
# Ensure rasterio is installed in the Python backend environment.
print('Initialized Cloud Detection model with JP2 input')
def execute(self, requests):
"""
Process inference requests.
"""
responses = []
# Every request must contain three JP2 byte strings (Red, Green, NIR).
for request in requests:
# Get the input tensor containing the byte arrays
input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes")
# as_numpy() for TYPE_STRING gives an ndarray of Python bytes objects
jp2_bytes_list = input_tensor.as_numpy()
if len(jp2_bytes_list) != 3:
# Send an error response if the input shape is incorrect
error = pb_utils.TritonError(f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}")
response = pb_utils.InferenceResponse(output_tensors=[], error=error)
responses.append(response)
continue # Skip to the next request
# Assume order: Red, Green, NIR based on client logic
red_bytes = jp2_bytes_list[0]
green_bytes = jp2_bytes_list[1]
nir_bytes = jp2_bytes_list[2]
try:
# Process JP2 bytes using rasterio in memory
with MemoryFile(red_bytes) as memfile_red:
with memfile_red.open() as src_red:
red_data = src_red.read(1).astype(np.float32)
target_height = src_red.height
target_width = src_red.width
with MemoryFile(green_bytes) as memfile_green:
with memfile_green.open() as src_green:
# Ensure green band matches red band dimensions (should if B03)
if src_green.height != target_height or src_green.width != target_width:
# Optional: Resample green if necessary, though B03 usually matches B04
green_data = src_green.read(
1,
out_shape=(1, target_height, target_width),
resampling=Resampling.bilinear
).astype(np.float32)
else:
green_data = src_green.read(1).astype(np.float32)
with MemoryFile(nir_bytes) as memfile_nir:
with memfile_nir.open() as src_nir:
# Resample NIR (B8A) to match Red/Green (B04/B03) resolution
nir_data = src_nir.read(
1, # Read the first band
out_shape=(1, target_height, target_width),
resampling=Resampling.bilinear
).astype(np.float32)
# Stack bands in CHW format (Red, Green, NIR) for the model
# Match the channel order expected by predict_from_array
input_array = np.stack([red_data, green_data, nir_data], axis=0)
# Perform inference using the original function
pred_mask = predict_from_array(input_array)
# Create output tensor
output_tensor = pb_utils.Tensor(
"output_mask",
pred_mask.astype(np.uint8)
)
response = pb_utils.InferenceResponse([output_tensor])
except Exception as e:
# Handle errors during processing (e.g., invalid JP2 data)
error = pb_utils.TritonError(f"Error processing JP2 data: {str(e)}")
response = pb_utils.InferenceResponse(output_tensors=[], error=error)
responses.append(response)
# Return a list of responses
return responses
def finalize(self):
"""
Called when the model is unloaded. Perform any necessary cleanup.
"""
print('Finalizing Cloud Detection model')
|