program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "3405.2.1"}, {"coremlc-version", "3404.23.1"}, {"coremltools-component-torch", "2.5.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.3.0"}})] { func main(tensor logits) { tensor var_3 = const()[name = tensor("op_3"), val = tensor([-1])]; tensor flattened_cast_fp16 = reshape(shape = var_3, x = logits)[name = tensor("flattened_cast_fp16")]; tensor token_logits_begin_0 = const()[name = tensor("token_logits_begin_0"), val = tensor([0])]; tensor token_logits_end_0 = const()[name = tensor("token_logits_end_0"), val = tensor([1025])]; tensor token_logits_end_mask_0 = const()[name = tensor("token_logits_end_mask_0"), val = tensor([false])]; tensor token_logits_cast_fp16 = slice_by_index(begin = token_logits_begin_0, end = token_logits_end_0, end_mask = token_logits_end_mask_0, x = flattened_cast_fp16)[name = tensor("token_logits_cast_fp16")]; tensor duration_logits_begin_0 = const()[name = tensor("duration_logits_begin_0"), val = tensor([1025])]; tensor duration_logits_end_0 = const()[name = tensor("duration_logits_end_0"), val = tensor([1])]; tensor duration_logits_end_mask_0 = const()[name = tensor("duration_logits_end_mask_0"), val = tensor([true])]; tensor duration_logits_cast_fp16 = slice_by_index(begin = duration_logits_begin_0, end = duration_logits_end_0, end_mask = duration_logits_end_mask_0, x = flattened_cast_fp16)[name = tensor("duration_logits_cast_fp16")]; tensor var_17_axis_0 = const()[name = tensor("op_17_axis_0"), val = tensor(0)]; tensor var_17_keep_dims_0 = const()[name = tensor("op_17_keep_dims_0"), val = tensor(true)]; tensor var_17 = reduce_argmax(axis = var_17_axis_0, keep_dims = var_17_keep_dims_0, x = token_logits_cast_fp16)[name = tensor("op_17_cast_fp16")]; tensor reduce_max_0_axes_0 = const()[name = tensor("reduce_max_0_axes_0"), val = tensor([0])]; tensor reduce_max_0_keep_dims_0 = const()[name = tensor("reduce_max_0_keep_dims_0"), val = tensor(true)]; tensor reduce_max_0 = reduce_max(axes = reduce_max_0_axes_0, keep_dims = reduce_max_0_keep_dims_0, x = token_logits_cast_fp16)[name = tensor("reduce_max_0_cast_fp16")]; tensor var_24_axis_0 = const()[name = tensor("op_24_axis_0"), val = tensor(0)]; tensor var_24_keep_dims_0 = const()[name = tensor("op_24_keep_dims_0"), val = tensor(true)]; tensor var_24 = reduce_argmax(axis = var_24_axis_0, keep_dims = var_24_keep_dims_0, x = duration_logits_cast_fp16)[name = tensor("op_24_cast_fp16")]; } -> (var_17, reduce_max_0, var_24); }