from __future__ import print_function, unicode_literals, absolute_import, division |
import numpy as np |
import sys |
import warnings |
import math |
from tqdm import tqdm |
from collections import namedtuple |
from pathlib import Path |
import threading |
import functools |
import scipy.ndimage as ndi |
import numbers |
from csbdeep.models.base_model import BaseModel |
from csbdeep.utils.tf import export_SavedModel, keras_import, IS_TF_1, CARETensorBoard |
import tensorflow as tf |
K = keras_import('backend') |
Sequence = keras_import('utils', 'Sequence') |
Adam = keras_import('optimizers', 'Adam') |
ReduceLROnPlateau, TensorBoard = keras_import('callbacks', 'ReduceLROnPlateau', 'TensorBoard') |
from csbdeep.utils import _raise, backend_channels_last, axes_check_and_normalize, axes_dict, load_json, save_json |
from csbdeep.internals.predict import tile_iterator, total_n_tiles |
from csbdeep.internals.train import RollingSequence |
from csbdeep.data import Resizer |
from ..sample_patches import get_valid_inds |
from ..nms import _ind_prob_thresh |
from ..utils import _is_power_of_2, _is_floatarray, optimize_threshold |
def generic_masked_loss(mask, loss, weights=1, norm_by_mask=True, reg_weight=0, reg_penalty=K.abs): |
def _loss(y_true, y_pred): |
actual_loss = K.mean(mask * weights * loss(y_true, y_pred), axis=-1) |
norm_mask = (K.mean(mask) + K.epsilon()) if norm_by_mask else 1 |
if reg_weight > 0: |
reg_loss = K.mean((1-mask) * reg_penalty(y_pred), axis=-1) |
return actual_loss / norm_mask + reg_weight * reg_loss |
else: |
return actual_loss / norm_mask |
return _loss |
def masked_loss(mask, penalty, reg_weight, norm_by_mask): |
loss = lambda y_true, y_pred: penalty(y_true - y_pred) |
return generic_masked_loss(mask, loss, reg_weight=reg_weight, norm_by_mask=norm_by_mask) |
def masked_loss_mae(mask, reg_weight=0, norm_by_mask=True): |
return masked_loss(mask, K.abs, reg_weight=reg_weight, norm_by_mask=norm_by_mask) |
def masked_loss_mse(mask, reg_weight=0, norm_by_mask=True): |
return masked_loss(mask, K.square, reg_weight=reg_weight, norm_by_mask=norm_by_mask) |
def masked_metric_mae(mask): |
def relevant_mae(y_true, y_pred): |
return masked_loss(mask, K.abs, reg_weight=0, norm_by_mask=True)(y_true, y_pred) |
return relevant_mae |
def masked_metric_mse(mask): |
def relevant_mse(y_true, y_pred): |
return masked_loss(mask, K.square, reg_weight=0, norm_by_mask=True)(y_true, y_pred) |
return relevant_mse |
def kld(y_true, y_pred): |
y_true = K.clip(y_true, K.epsilon(), 1) |
y_pred = K.clip(y_pred, K.epsilon(), 1) |
return K.mean(K.binary_crossentropy(y_true, y_pred) - K.binary_crossentropy(y_true, y_true), axis=-1) |
def masked_loss_iou(mask, reg_weight=0, norm_by_mask=True): |
def iou_loss(y_true, y_pred): |
axis = -1 if backend_channels_last() else 1 |
inter = K.mean(K.sign(y_pred)*K.square(K.minimum(y_true,y_pred)), axis=axis) |
union = K.mean(K.square(K.maximum(y_true,y_pred)), axis=axis) |
iou = inter/(union+K.epsilon()) |
iou = K.expand_dims(iou,axis) |
loss = 1. - iou |
return loss |
return generic_masked_loss(mask, iou_loss, reg_weight=reg_weight, norm_by_mask=norm_by_mask) |
def masked_metric_iou(mask, reg_weight=0, norm_by_mask=True): |
def iou_metric(y_true, y_pred): |
axis = -1 if backend_channels_last() else 1 |
y_pred = K.maximum(0., y_pred) |
inter = K.mean(K.square(K.minimum(y_true,y_pred)), axis=axis) |
union = K.mean(K.square(K.maximum(y_true,y_pred)), axis=axis) |
iou = inter/(union+K.epsilon()) |
loss = K.expand_dims(iou,axis) |
return loss |
return generic_masked_loss(mask, iou_metric, reg_weight=reg_weight, norm_by_mask=norm_by_mask) |
def weighted_categorical_crossentropy(weights, ndim): |
""" ndim = (2,3) """ |
axis = -1 if backend_channels_last() else 1 |
shape = [1]*(ndim+2) |
shape[axis] = len(weights) |
weights = np.broadcast_to(weights, shape) |
weights = K.constant(weights) |
def weighted_cce(y_true, y_pred): |
mask = K.cast(y_true>=0, K.floatx()) |
y_pred /= K.sum(y_pred+K.epsilon(), axis=axis, keepdims=True) |
y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon()) |
loss = - K.sum(weights*mask*y_true*K.log(y_pred), axis = axis) |
return loss |
return weighted_cce |
class StarDistDataBase(RollingSequence): |
def __init__(self, X, Y, n_rays, grid, batch_size, patch_size, length, |
n_classes=None, classes=None, |
use_gpu=False, sample_ind_cache=True, maxfilter_patch_size=None, augmenter=None, foreground_prob=0): |
super().__init__(data_size=len(X), batch_size=batch_size, length=length, shuffle=True) |
if isinstance(X, (np.ndarray, tuple, list)): |
X = [x.astype(np.float32, copy=False) for x in X] |
len(X)==len(Y) and len(X)>0 or _raise(ValueError("X and Y can't be empty and must have same length")) |
if classes is None: |
classes = (None,)*len(X) |
else: |
n_classes is not None or warnings.warn("Ignoring classes since n_classes is None") |
len(classes)==len(X) or _raise(ValueError("X and classes must have same length")) |
self.n_classes, self.classes = n_classes, classes |
nD = len(patch_size) |
assert nD in (2,3) |
x_ndim = X[0].ndim |
assert x_ndim in (nD,nD+1) |
if isinstance(X, (np.ndarray, tuple, list)) and \ |
isinstance(Y, (np.ndarray, tuple, list)): |
all(y.ndim==nD and x.ndim==x_ndim and x.shape[:nD]==y.shape for x,y in zip(X,Y)) or _raise(ValueError("images and masks should have corresponding shapes/dimensions")) |
all(x.shape[:nD]>=tuple(patch_size) for x in X) or _raise(ValueError("Some images are too small for given patch_size {patch_size}".format(patch_size=patch_size))) |
if x_ndim == nD: |
self.n_channel = None |
else: |
self.n_channel = X[0].shape[-1] |
if isinstance(X, (np.ndarray, tuple, list)): |
assert all(x.shape[-1]==self.n_channel for x in X) |
assert 0 <= foreground_prob <= 1 |
self.X, self.Y = X, Y |
self.n_rays = n_rays |
self.patch_size = patch_size |
self.ss_grid = (slice(None),) + tuple(slice(0, None, g) for g in grid) |
self.grid = tuple(grid) |
self.use_gpu = bool(use_gpu) |
if augmenter is None: |
augmenter = lambda *args: args |
callable(augmenter) or _raise(ValueError("augmenter must be None or callable")) |
self.augmenter = augmenter |
self.foreground_prob = foreground_prob |
if self.use_gpu: |
from gputools import max_filter |
self.max_filter = lambda y, patch_size: max_filter(y.astype(np.float32), patch_size) |
else: |
from scipy.ndimage.filters import maximum_filter |
self.max_filter = lambda y, patch_size: maximum_filter(y, patch_size, mode='constant') |
self.maxfilter_patch_size = maxfilter_patch_size if maxfilter_patch_size is not None else self.patch_size |
self.sample_ind_cache = sample_ind_cache |
self._ind_cache_fg = {} |
self._ind_cache_all = {} |
self.lock = threading.Lock() |
def get_valid_inds(self, k, foreground_prob=None): |
if foreground_prob is None: |
foreground_prob = self.foreground_prob |
foreground_only = np.random.uniform() < foreground_prob |
_ind_cache = self._ind_cache_fg if foreground_only else self._ind_cache_all |
if k in _ind_cache: |
inds = _ind_cache[k] |
else: |
patch_filter = (lambda y,p: self.max_filter(y, self.maxfilter_patch_size) > 0) if foreground_only else None |
inds = get_valid_inds(self.Y[k], self.patch_size, patch_filter=patch_filter) |
if self.sample_ind_cache: |
with self.lock: |
_ind_cache[k] = inds |
if foreground_only and len(inds[0])==0: |
return self.get_valid_inds(k, foreground_prob=0) |
return inds |
def channels_as_tuple(self, x): |
if self.n_channel is None: |
return (x,) |
else: |
return tuple(x[...,i] for i in range(self.n_channel)) |
class StarDistBase(BaseModel): |
def __init__(self, config, name=None, basedir='.'): |
super().__init__(config=config, name=name, basedir=basedir) |
threshs = dict(prob=None, nms=None) |
if basedir is not None: |
try: |
threshs = load_json(str(self.logdir / 'thresholds.json')) |
print("Loading thresholds from 'thresholds.json'.") |
if threshs.get('prob') is None or not (0 < threshs.get('prob') < 1): |
print("- Invalid 'prob' threshold (%s), using default value." % str(threshs.get('prob'))) |
threshs['prob'] = None |
if threshs.get('nms') is None or not (0 < threshs.get('nms') < 1): |
print("- Invalid 'nms' threshold (%s), using default value." % str(threshs.get('nms'))) |
threshs['nms'] = None |
except FileNotFoundError: |
if config is None and len(tuple(self.logdir.glob('*.h5'))) > 0: |
print("Couldn't load thresholds from 'thresholds.json', using default values. " |
"(Call 'optimize_thresholds' to change that.)") |
self.thresholds = dict ( |
prob = 0.5 if threshs['prob'] is None else threshs['prob'], |
nms = 0.4 if threshs['nms'] is None else threshs['nms'], |
) |
print("Using default values: prob_thresh={prob:g}, nms_thresh={nms:g}.".format(prob=self.thresholds.prob, nms=self.thresholds.nms)) |
@property |
def thresholds(self): |
return self._thresholds |
def _is_multiclass(self): |
return (self.config.n_classes is not None) |
def _parse_classes_arg(self, classes, length): |
""" creates a proper classes tuple from different possible "classes" arguments in model.train() |
classes can be |
"auto" -> all objects will be assigned to the first foreground class (unless n_classes is None) |
single integer -> all objects will be assigned that class |
tuple, list, ndarray -> do nothing (needs to be of given length) |
returns a tuple of given length |
""" |
if isinstance(classes, str): |
classes == "auto" or _raise(ValueError(f"classes = '{classes}': only 'auto' supported as string argument for classes")) |
if self.config.n_classes is None: |
classes = None |
elif self.config.n_classes == 1: |
classes = (1,)*length |
else: |
raise ValueError("using classes = 'auto' for n_classes > 1 not supported") |
elif isinstance(classes, (tuple, list, np.ndarray)): |
len(classes) == length or _raise(ValueError(f"len(classes) should be {length}!")) |
else: |
raise ValueError("classes should either be 'auto' or a list of scalars/label dicts") |
return classes |
@thresholds.setter |
def thresholds(self, d): |
self._thresholds = namedtuple('Thresholds',d.keys())(*d.values()) |
def prepare_for_training(self, optimizer=None): |
"""Prepare for neural network training. |
Compiles the model and creates |
`Keras Callbacks <https://keras.io/callbacks/>`_ to be used for training. |
Note that this method will be implicitly called once by :func:`train` |
(with default arguments) if not done so explicitly beforehand. |
Parameters |
---------- |
optimizer : obj or None |
Instance of a `Keras Optimizer <https://keras.io/optimizers/>`_ to be used for training. |
If ``None`` (default), uses ``Adam`` with the learning rate specified in ``config``. |
""" |
if optimizer is None: |
optimizer = Adam(self.config.train_learning_rate) |
masked_dist_loss = {'mse': masked_loss_mse, |
'mae': masked_loss_mae, |
'iou': masked_loss_iou, |
}[self.config.train_dist_loss] |
prob_loss = 'binary_crossentropy' |
def split_dist_true_mask(dist_true_mask): |
return tf.split(dist_true_mask, num_or_size_splits=[self.config.n_rays,-1], axis=-1) |
def dist_loss(dist_true_mask, dist_pred): |
dist_true, dist_mask = split_dist_true_mask(dist_true_mask) |
return masked_dist_loss(dist_mask, reg_weight=self.config.train_background_reg)(dist_true, dist_pred) |
def dist_iou_metric(dist_true_mask, dist_pred): |
dist_true, dist_mask = split_dist_true_mask(dist_true_mask) |
return masked_metric_iou(dist_mask, reg_weight=0)(dist_true, dist_pred) |
def relevant_mae(dist_true_mask, dist_pred): |
dist_true, dist_mask = split_dist_true_mask(dist_true_mask) |
return masked_metric_mae(dist_mask)(dist_true, dist_pred) |
def relevant_mse(dist_true_mask, dist_pred): |
dist_true, dist_mask = split_dist_true_mask(dist_true_mask) |
return masked_metric_mse(dist_mask)(dist_true, dist_pred) |
if self._is_multiclass(): |
prob_class_loss = weighted_categorical_crossentropy(self.config.train_class_weights, ndim=self.config.n_dim) |
loss = [prob_loss, dist_loss, prob_class_loss] |
else: |
loss = [prob_loss, dist_loss] |
self.keras_model.compile(optimizer, loss = loss, |
loss_weights = list(self.config.train_loss_weights), |
metrics = {'prob': kld, |
'dist': [relevant_mae, relevant_mse, dist_iou_metric]}) |
self.callbacks = [] |
if self.basedir is not None: |
self.callbacks += self._checkpoint_callbacks() |
if self.config.train_tensorboard: |
if IS_TF_1: |
self.callbacks.append(CARETensorBoard(log_dir=str(self.logdir), prefix_with_timestamp=False, n_images=3, write_images=True, prob_out=False)) |
else: |
self.callbacks.append(TensorBoard(log_dir=str(self.logdir/'logs'), write_graph=False, profile_batch=0)) |
if self.config.train_reduce_lr is not None: |
rlrop_params = self.config.train_reduce_lr |
if 'verbose' not in rlrop_params: |
rlrop_params['verbose'] = True |
self.callbacks.insert(0,ReduceLROnPlateau(**rlrop_params)) |
self._model_prepared = True |
def _predict_setup(self, img, axes, normalizer, n_tiles, show_tile_progress, predict_kwargs): |
""" Shared setup code between `predict` and `predict_sparse` """ |
if n_tiles is None: |
n_tiles = [1]*img.ndim |
try: |
n_tiles = tuple(n_tiles) |
img.ndim == len(n_tiles) or _raise(TypeError()) |
except TypeError: |
raise ValueError("n_tiles must be an iterable of length %d" % img.ndim) |
all(np.isscalar(t) and 1<=t and int(t)==t for t in n_tiles) or _raise( |
ValueError("all values of n_tiles must be integer values >= 1")) |
n_tiles = tuple(map(int,n_tiles)) |
axes = self._normalize_axes(img, axes) |
axes_net = self.config.axes |
_permute_axes = self._make_permute_axes(axes, axes_net) |
x = _permute_axes(img) |
channel = axes_dict(axes_net)['C'] |
self.config.n_channel_in == x.shape[channel] or _raise(ValueError()) |
axes_net_div_by = self._axes_div_by(axes_net) |
grid = tuple(self.config.grid) |
len(grid) == len(axes_net)-1 or _raise(ValueError()) |
grid_dict = dict(zip(axes_net.replace('C',''),grid)) |
normalizer = self._check_normalizer_resizer(normalizer, None)[0] |
resizer = StarDistPadAndCropResizer(grid=grid_dict) |
x = normalizer.before(x, axes_net) |
x = resizer.before(x, axes_net, axes_net_div_by) |
if not _is_floatarray(x): |
warnings.warn("Predicting on non-float input... ( forgot to normalize? )") |
def predict_direct(x): |
ys = self.keras_model.predict(x[np.newaxis], **predict_kwargs) |
return tuple(y[0] for y in ys) |
def tiling_setup(): |
assert np.prod(n_tiles) > 1 |
tiling_axes = axes_net.replace('C','') |
x_tiling_axis = tuple(axes_dict(axes_net)[a] for a in tiling_axes) |
axes_net_tile_overlaps = self._axes_tile_overlap(axes_net) |
_n_tiles = _permute_axes(np.empty(n_tiles,bool)).shape |
(all(_n_tiles[i] == 1 for i in range(x.ndim) if i not in x_tiling_axis) or |
_raise(ValueError("entry of n_tiles > 1 only allowed for axes '%s'" % tiling_axes))) |
sh = [s//grid_dict.get(a,1) for a,s in zip(axes_net,x.shape)] |
sh[channel] = None |
def create_empty_output(n_channel, dtype=np.float32): |
sh[channel] = n_channel |
return np.empty(sh,dtype) |
if callable(show_tile_progress): |
progress, _show_tile_progress = show_tile_progress, True |
else: |
progress, _show_tile_progress = tqdm, show_tile_progress |
n_block_overlaps = [int(np.ceil(overlap/blocksize)) for overlap, blocksize |
in zip(axes_net_tile_overlaps, axes_net_div_by)] |
num_tiles_used = total_n_tiles(x, _n_tiles, block_sizes=axes_net_div_by, n_block_overlaps=n_block_overlaps) |
tile_generator = progress(tile_iterator(x, _n_tiles, block_sizes=axes_net_div_by, n_block_overlaps=n_block_overlaps), |
disable=(not _show_tile_progress), total=num_tiles_used) |
return tile_generator, tuple(sh), create_empty_output |
return x, axes, axes_net, axes_net_div_by, _permute_axes, resizer, n_tiles, grid, grid_dict, channel, predict_direct, tiling_setup |
def _predict_generator(self, img, axes=None, normalizer=None, n_tiles=None, show_tile_progress=True, **predict_kwargs): |
"""Predict. |
Parameters |
---------- |
img : :class:`numpy.ndarray` |
Input image |
axes : str or None |
Axes of the input ``img``. |
``None`` denotes that axes of img are the same as denoted in the config. |
normalizer : :class:`csbdeep.data.Normalizer` or None |
(Optional) normalization of input image before prediction. |
Note that the default (``None``) assumes ``img`` to be already normalized. |
n_tiles : iterable or None |
Out of memory (OOM) errors can occur if the input image is too large. |
To avoid this problem, the input image is broken up into (overlapping) tiles |
that are processed independently and re-assembled. |
This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``). |
``None`` denotes that no tiling should be used. |
show_tile_progress: bool or callable |
If boolean, indicates whether to show progress (via tqdm) during tiled prediction. |
If callable, must be a drop-in replacement for tqdm. |
show_tile_progress: bool |
Whether to show progress during tiled prediction. |
predict_kwargs: dict |
Keyword arguments for ``predict`` function of Keras model. |
Returns |
------- |
(:class:`numpy.ndarray`, :class:`numpy.ndarray`, [:class:`numpy.ndarray`]) |
Returns the tuple (`prob`, `dist`, [`prob_class`]) of per-pixel object probabilities and star-convex polygon/polyhedra distances. |
In multiclass prediction mode, `prob_class` is the probability map for each of the 1+'n_classes' classes (first class is background) |
""" |
x, axes, axes_net, axes_net_div_by, _permute_axes, resizer, n_tiles, grid, grid_dict, channel, predict_direct, tiling_setup = \ |
self._predict_setup(img, axes, normalizer, n_tiles, show_tile_progress, predict_kwargs) |
if np.prod(n_tiles) > 1: |
tile_generator, output_shape, create_empty_output = tiling_setup() |
prob = create_empty_output(1) |
dist = create_empty_output(self.config.n_rays) |
if self._is_multiclass(): |
prob_class = create_empty_output(self.config.n_classes+1) |
result = (prob, dist, prob_class) |
else: |
result = (prob, dist) |
for tile, s_src, s_dst in tile_generator: |
result_tile = predict_direct(tile) |
s_src = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_src,axes_net)] |
s_dst = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_dst,axes_net)] |
s_src[channel] = slice(None) |
s_dst[channel] = slice(None) |
s_src, s_dst = tuple(s_src), tuple(s_dst) |
for part, part_tile in zip(result, result_tile): |
part[s_dst] = part_tile[s_src] |
yield |
else: |
result = predict_direct(x) |
result = [resizer.after(part, axes_net) for part in result] |
result[0] = np.take(result[0],0,axis=channel) |
result[1] = np.maximum(1e-3, result[1]) |
result[1] = np.moveaxis(result[1],channel,-1) |
if self._is_multiclass(): |
result[2] = np.moveaxis(result[2],channel,-1) |
yield tuple(result) |
@functools.wraps(_predict_generator) |
def predict(self, *args, **kwargs): |
r = None |
for r in self._predict_generator(*args, **kwargs): |
pass |
return r |
def _predict_sparse_generator(self, img, prob_thresh=None, axes=None, normalizer=None, n_tiles=None, show_tile_progress=True, b=2, **predict_kwargs): |
""" Sparse version of model.predict() |
Returns |
------- |
(prob, dist, [prob_class], points) flat list of probs, dists, (optional prob_class) and points |
""" |
if prob_thresh is None: prob_thresh = self.thresholds.prob |
x, axes, axes_net, axes_net_div_by, _permute_axes, resizer, n_tiles, grid, grid_dict, channel, predict_direct, tiling_setup = \ |
self._predict_setup(img, axes, normalizer, n_tiles, show_tile_progress, predict_kwargs) |
def _prep(prob, dist): |
prob = np.take(prob,0,axis=channel) |
dist = np.moveaxis(dist,channel,-1) |
dist = np.maximum(1e-3, dist) |
return prob, dist |
proba, dista, pointsa, prob_class = [],[],[], [] |
if np.prod(n_tiles) > 1: |
tile_generator, output_shape, create_empty_output = tiling_setup() |
sh = list(output_shape) |
sh[channel] = 1; |
proba, dista, pointsa, prob_classa = [], [], [], [] |
for tile, s_src, s_dst in tile_generator: |
results_tile = predict_direct(tile) |
s_src = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_src,axes_net)] |
s_dst = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_dst,axes_net)] |
s_src[channel] = slice(None) |
s_dst[channel] = slice(None) |
s_src, s_dst = tuple(s_src), tuple(s_dst) |
prob_tile, dist_tile = results_tile[:2] |
prob_tile, dist_tile = _prep(prob_tile[s_src], dist_tile[s_src]) |
bs = list((b if s.start==0 else -1, b if s.stop==_sh else -1) for s,_sh in zip(s_dst, sh)) |
bs.pop(channel) |
inds = _ind_prob_thresh(prob_tile, prob_thresh, b=bs) |
proba.extend(prob_tile[inds].copy()) |
dista.extend(dist_tile[inds].copy()) |
_points = np.stack(np.where(inds), axis=1) |
offset = list(s.start for i,s in enumerate(s_dst)) |
offset.pop(channel) |
_points = _points + np.array(offset).reshape((1,len(offset))) |
_points = _points * np.array(self.config.grid).reshape((1,len(self.config.grid))) |
pointsa.extend(_points) |
if self._is_multiclass(): |
p = results_tile[2][s_src].copy() |
p = np.moveaxis(p,channel,-1) |
prob_classa.extend(p[inds]) |
yield |
else: |
results = predict_direct(x) |
prob, dist = results[:2] |
prob, dist = _prep(prob, dist) |
inds = _ind_prob_thresh(prob, prob_thresh, b=b) |
proba = prob[inds].copy() |
dista = dist[inds].copy() |
_points = np.stack(np.where(inds), axis=1) |
pointsa = (_points * np.array(self.config.grid).reshape((1,len(self.config.grid)))) |
if self._is_multiclass(): |
p = np.moveaxis(results[2],channel,-1) |
prob_classa = p[inds].copy() |
proba = np.asarray(proba) |
dista = np.asarray(dista).reshape((-1,self.config.n_rays)) |
pointsa = np.asarray(pointsa).reshape((-1,self.config.n_dim)) |
idx = resizer.filter_points(x.ndim, pointsa, axes_net) |
proba = proba[idx] |
dista = dista[idx] |
pointsa = pointsa[idx] |
if self._is_multiclass(): |
prob_classa = np.asarray(prob_classa).reshape((-1,self.config.n_classes+1)) |
prob_classa = prob_classa[idx] |
yield proba, dista, prob_classa, pointsa |
else: |
prob_classa = None |
yield proba, dista, pointsa |
@functools.wraps(_predict_sparse_generator) |
def predict_sparse(self, *args, **kwargs): |
r = None |
for r in self._predict_sparse_generator(*args, **kwargs): |
pass |
return r |
def _predict_instances_generator(self, img, axes=None, normalizer=None, |
sparse=True, |
prob_thresh=None, nms_thresh=None, |
scale=None, |
n_tiles=None, show_tile_progress=True, |
verbose=False, |
return_labels=True, |
predict_kwargs=None, nms_kwargs=None, |
overlap_label=None, return_predict=False): |
"""Predict instance segmentation from input image. |
Parameters |
---------- |
img : :class:`numpy.ndarray` |
Input image |
axes : str or None |
Axes of the input ``img``. |
``None`` denotes that axes of img are the same as denoted in the config. |
normalizer : :class:`csbdeep.data.Normalizer` or None |
(Optional) normalization of input image before prediction. |
Note that the default (``None``) assumes ``img`` to be already normalized. |
sparse: bool |
If true, aggregate probabilities/distances sparsely during tiled |
prediction to save memory (recommended). |
prob_thresh : float or None |
Consider only object candidates from pixels with predicted object probability |
above this threshold (also see `optimize_thresholds`). |
nms_thresh : float or None |
Perform non-maximum suppression that considers two objects to be the same |
when their area/surface overlap exceeds this threshold (also see `optimize_thresholds`). |
scale: None or float or iterable |
Scale the input image internally by this factor and rescale the output accordingly. |
All spatial axes (X,Y,Z) will be scaled if a scalar value is provided. |
Alternatively, multiple scale values (compatible with input `axes`) can be used |
for more fine-grained control (scale values for non-spatial axes must be 1). |
n_tiles : iterable or None |
Out of memory (OOM) errors can occur if the input image is too large. |
To avoid this problem, the input image is broken up into (overlapping) tiles |
that are processed independently and re-assembled. |
This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``). |
``None`` denotes that no tiling should be used. |
show_tile_progress: bool |
Whether to show progress during tiled prediction. |
verbose: bool |
Whether to print some info messages. |
return_labels: bool |
Whether to create a label image, otherwise return None in its place. |
predict_kwargs: dict |
Keyword arguments for ``predict`` function of Keras model. |
nms_kwargs: dict |
Keyword arguments for non-maximum suppression. |
overlap_label: scalar or None |
if not None, label the regions where polygons overlap with that value |
return_predict: bool |
Also return the outputs of :func:`predict` (in a separate tuple) |
If True, implies sparse = False |
Returns |
------- |
(:class:`numpy.ndarray`, dict), (optional: return tuple of :func:`predict`) |
Returns a tuple of the label instances image and also |
a dictionary with the details (coordinates, etc.) of all remaining polygons/polyhedra. |
""" |
if predict_kwargs is None: |
predict_kwargs = {} |
if nms_kwargs is None: |
nms_kwargs = {} |
if return_predict and sparse: |
sparse = False |
warnings.warn("Setting sparse to False because return_predict is True") |
nms_kwargs.setdefault("verbose", verbose) |
_axes = self._normalize_axes(img, axes) |
_axes_net = self.config.axes |
_permute_axes = self._make_permute_axes(_axes, _axes_net) |
_shape_inst = tuple(s for s,a in zip(_permute_axes(img).shape, _axes_net) if a != 'C') |
if scale is not None: |
if isinstance(scale, numbers.Number): |
scale = tuple(scale if a in 'XYZ' else 1 for a in _axes) |
scale = tuple(scale) |
len(scale) == len(_axes) or _raise(ValueError(f"scale {scale} must be of length {len(_axes)}, i.e. one value for each of the axes {_axes}")) |
for s,a in zip(scale,_axes): |
s > 0 or _raise(ValueError("scale values must be greater than 0")) |
(s in (1,None) or a in 'XYZ') or warnings.warn(f"replacing scale value {s} for non-spatial axis {a} with 1") |
scale = tuple(s if a in 'XYZ' else 1 for s,a in zip(scale,_axes)) |
verbose and print(f"scaling image by factors {scale} for axes {_axes}") |
img = ndi.zoom(img, scale, order=1) |
yield 'predict' |
res = None |
if sparse: |
for res in self._predict_sparse_generator(img, axes=axes, normalizer=normalizer, n_tiles=n_tiles, |
prob_thresh=prob_thresh, show_tile_progress=show_tile_progress, **predict_kwargs): |
if res is None: |
yield 'tile' |
else: |
for res in self._predict_generator(img, axes=axes, normalizer=normalizer, n_tiles=n_tiles, |
show_tile_progress=show_tile_progress, **predict_kwargs): |
if res is None: |
yield 'tile' |
res = tuple(res) + (None,) |
if self._is_multiclass(): |
prob, dist, prob_class, points = res |
else: |
prob, dist, points = res |
prob_class = None |
yield 'nms' |
res_instances = self._instances_from_prediction(_shape_inst, prob, dist, |
points=points, |
prob_class=prob_class, |
prob_thresh=prob_thresh, |
nms_thresh=nms_thresh, |
scale=(None if scale is None else dict(zip(_axes,scale))), |
return_labels=return_labels, |
overlap_label=overlap_label, |
**nms_kwargs) |
if return_predict: |
yield res_instances, tuple(res[:-1]) |
else: |
yield res_instances |
@functools.wraps(_predict_instances_generator) |
def predict_instances(self, *args, **kwargs): |
r = None |
for r in self._predict_instances_generator(*args, **kwargs): |
pass |
return r |
def predict_instances_big(self, img, axes, block_size, min_overlap, context=None, |
labels_out=None, labels_out_dtype=np.int32, show_progress=True, **kwargs): |
"""Predict instance segmentation from very large input images. |
Intended to be used when `predict_instances` cannot be used due to memory limitations. |
This function will break the input image into blocks and process them individually |
via `predict_instances` and assemble all the partial results. If used as intended, the result |
should be the same as if `predict_instances` was used directly on the whole image. |
**Important**: The crucial assumption is that all predicted object instances are smaller than |
the provided `min_overlap`. Also, it must hold that: min_overlap + 2*context < block_size. |
Example |
------- |
>>> img.shape |
(20000, 20000) |
>>> labels, polys = model.predict_instances_big(img, axes='YX', block_size=4096, |
min_overlap=128, context=128, n_tiles=(4,4)) |
Parameters |
---------- |
img: :class:`numpy.ndarray` or similar |
Input image |
axes: str |
Axes of the input ``img`` (such as 'YX', 'ZYX', 'YXC', etc.) |
block_size: int or iterable of int |
Process input image in blocks of the provided shape. |
(If a scalar value is given, it is used for all spatial image dimensions.) |
min_overlap: int or iterable of int |
Amount of guaranteed overlap between blocks. |
(If a scalar value is given, it is used for all spatial image dimensions.) |
context: int or iterable of int, or None |
Amount of image context on all sides of a block, which is discarded. |
If None, uses an automatic estimate that should work in many cases. |
(If a scalar value is given, it is used for all spatial image dimensions.) |
labels_out: :class:`numpy.ndarray` or similar, or None, or False |
numpy array or similar (must be of correct shape) to which the label image is written. |
If None, will allocate a numpy array of the correct shape and data type ``labels_out_dtype``. |
If False, will not write the label image (useful if only the dictionary is needed). |
labels_out_dtype: str or dtype |
Data type of returned label image if ``labels_out=None`` (has no effect otherwise). |
show_progress: bool |
Show progress bar for block processing. |
kwargs: dict |
Keyword arguments for ``predict_instances``. |
Returns |
------- |
(:class:`numpy.ndarray` or False, dict) |
Returns the label image and a dictionary with the details (coordinates, etc.) of the polygons/polyhedra. |
""" |
from ..big import _grid_divisible, BlockND, OBJECT_KEYS |
from ..matching import relabel_sequential |
n = img.ndim |
axes = axes_check_and_normalize(axes, length=n) |
grid = self._axes_div_by(axes) |
axes_out = self._axes_out.replace('C','') |
shape_dict = dict(zip(axes,img.shape)) |
shape_out = tuple(shape_dict[a] for a in axes_out) |
if context is None: |
context = self._axes_tile_overlap(axes) |
if np.isscalar(block_size): block_size = n*[block_size] |
if np.isscalar(min_overlap): min_overlap = n*[min_overlap] |
if np.isscalar(context): context = n*[context] |
block_size, min_overlap, context = list(block_size), list(min_overlap), list(context) |
assert n == len(block_size) == len(min_overlap) == len(context) |
if 'C' in axes: |
i = axes_dict(axes)['C'] |
block_size[i] = img.shape[i] |
min_overlap[i] = context[i] = 0 |
block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes)) |
min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes)) |
context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes)) |
print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True) |
for a,c,o in zip(axes,context,self._axes_tile_overlap(axes)): |
if c < o: |
print(f"{a}: context of {c} is small, recommended to use at least {o}", flush=True) |
blocks = BlockND.cover(img.shape, axes, block_size, min_overlap, context, grid) |
if np.isscalar(labels_out) and bool(labels_out) is False: |
labels_out = None |
else: |
if labels_out is None: |
labels_out = np.zeros(shape_out, dtype=labels_out_dtype) |
else: |
labels_out.shape == shape_out or _raise(ValueError(f"'labels_out' must have shape {shape_out} (axes {axes_out}).")) |
polys_all = {} |
label_offset = 1 |
kwargs_override = dict(axes=axes, overlap_label=None, return_labels=True, return_predict=False) |
if show_progress: |
kwargs_override['show_tile_progress'] = False |
for k,v in kwargs_override.items(): |
if k in kwargs: print(f"changing '{k}' from {kwargs[k]} to {v}", flush=True) |
kwargs[k] = v |
blocks = tqdm(blocks, disable=(not show_progress)) |
for block in blocks: |
labels, polys = self.predict_instances(block.read(img, axes=axes), **kwargs) |
labels = block.crop_context(labels, axes=axes_out) |
labels, polys = block.filter_objects(labels, polys, axes=axes_out) |
labels = relabel_sequential(labels, label_offset)[0] |
if labels_out is not None: |
block.write(labels_out, labels, axes=axes_out) |
for k,v in polys.items(): |
polys_all.setdefault(k,[]).append(v) |
label_offset += len(polys['prob']) |
del labels |
polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()} |
return labels_out, polys_all |
def optimize_thresholds(self, X_val, Y_val, nms_threshs=[0.3,0.4,0.5], iou_threshs=[0.3,0.5,0.7], predict_kwargs=None, optimize_kwargs=None, save_to_json=True): |
"""Optimize two thresholds (probability, NMS overlap) necessary for predicting object instances. |
Note that the default thresholds yield good results in many cases, but optimizing |
the thresholds for a particular dataset can further improve performance. |
The optimized thresholds are automatically used for all further predictions |
and also written to the model directory. |
See ``utils.optimize_threshold`` for details and possible choices for ``optimize_kwargs``. |
Parameters |
---------- |
X_val : list of ndarray |
(Validation) input images (must be normalized) to use for threshold tuning. |
Y_val : list of ndarray |
(Validation) label images to use for threshold tuning. |
nms_threshs : list of float |
List of overlap thresholds to be considered for NMS. |
For each value in this list, optimization is run to find a corresponding prob_thresh value. |
iou_threshs : list of float |
List of intersection over union (IOU) thresholds for which |
the (average) matching performance is considered to tune the thresholds. |
predict_kwargs: dict |
Keyword arguments for ``predict`` function of this class. |
(If not provided, will guess value for `n_tiles` to prevent out of memory errors.) |
optimize_kwargs: dict |
Keyword arguments for ``utils.optimize_threshold`` function. |
""" |
if predict_kwargs is None: |
predict_kwargs = {} |
if optimize_kwargs is None: |
optimize_kwargs = {} |
def _predict_kwargs(x): |
if 'n_tiles' in predict_kwargs: |
return predict_kwargs |
else: |
return {**predict_kwargs, 'n_tiles': self._guess_n_tiles(x), 'show_tile_progress': False} |
Yhat_val = [self.predict(x, **_predict_kwargs(x))[:2] for x in X_val] |
opt_prob_thresh, opt_measure, opt_nms_thresh = None, -np.inf, None |
for _opt_nms_thresh in nms_threshs: |
_opt_prob_thresh, _opt_measure = optimize_threshold(Y_val, Yhat_val, model=self, nms_thresh=_opt_nms_thresh, iou_threshs=iou_threshs, **optimize_kwargs) |
if _opt_measure > opt_measure: |
opt_prob_thresh, opt_measure, opt_nms_thresh = _opt_prob_thresh, _opt_measure, _opt_nms_thresh |
opt_threshs = dict(prob=opt_prob_thresh, nms=opt_nms_thresh) |
self.thresholds = opt_threshs |
print(end='', file=sys.stderr, flush=True) |
print("Using optimized values: prob_thresh={prob:g}, nms_thresh={nms:g}.".format(prob=self.thresholds.prob, nms=self.thresholds.nms)) |
if save_to_json and self.basedir is not None: |
print("Saving to 'thresholds.json'.") |
save_json(opt_threshs, str(self.logdir / 'thresholds.json')) |
return opt_threshs |
def _guess_n_tiles(self, img): |
axes = self._normalize_axes(img, axes=None) |
shape = list(img.shape) |
if 'C' in axes: |
del shape[axes_dict(axes)['C']] |
b = self.config.train_batch_size**(1.0/self.config.n_dim) |
n_tiles = [int(np.ceil(s/(p*b))) for s,p in zip(shape,self.config.train_patch_size)] |
if 'C' in axes: |
n_tiles.insert(axes_dict(axes)['C'],1) |
return tuple(n_tiles) |
def _normalize_axes(self, img, axes): |
if axes is None: |
axes = self.config.axes |
assert 'C' in axes |
if img.ndim == len(axes)-1 and self.config.n_channel_in == 1: |
axes = axes.replace('C','') |
return axes_check_and_normalize(axes, img.ndim) |
def _compute_receptive_field(self, img_size=None): |
from scipy.ndimage import zoom |
if img_size is None: |
img_size = tuple(g*(128 if self.config.n_dim==2 else 64) for g in self.config.grid) |
if np.isscalar(img_size): |
img_size = (img_size,) * self.config.n_dim |
img_size = tuple(img_size) |
assert all(_is_power_of_2(s) for s in img_size) |
mid = tuple(s//2 for s in img_size) |
x = np.zeros((1,)+img_size+(self.config.n_channel_in,), dtype=np.float32) |
z = np.zeros_like(x) |
x[(0,)+mid+(slice(None),)] = 1 |
y = self.keras_model.predict(x)[0][0,...,0] |
y0 = self.keras_model.predict(z)[0][0,...,0] |
grid = tuple((np.array(x.shape[1:-1])/np.array(y.shape)).astype(int)) |
assert grid == self.config.grid |
y = zoom(y, grid,order=0) |
y0 = zoom(y0,grid,order=0) |
ind = np.where(np.abs(y-y0)>0) |
return [(m-np.min(i), np.max(i)-m) for (m,i) in zip(mid,ind)] |
def _axes_tile_overlap(self, query_axes): |
query_axes = axes_check_and_normalize(query_axes) |
try: |
self._tile_overlap |
except AttributeError: |
self._tile_overlap = self._compute_receptive_field() |
overlap = dict(zip( |
self.config.axes.replace('C',''), |
tuple(max(rf) for rf in self._tile_overlap) |
)) |
return tuple(overlap.get(a,0) for a in query_axes) |
def export_TF(self, fname=None, single_output=True, upsample_grid=True): |
"""Export model to TensorFlow's SavedModel format that can be used e.g. in the Fiji plugin |
Parameters |
---------- |
fname : str |
Path of the zip file to store the model |
If None, the default path "<modeldir>/TF_SavedModel.zip" is used |
single_output: bool |
If set, concatenates the two model outputs into a single output (note: this is currently mandatory for further use in Fiji) |
upsample_grid: bool |
If set, upsamples the output to the input shape (note: this is currently mandatory for further use in Fiji) |
""" |
Concatenate, UpSampling2D, UpSampling3D, Conv2DTranspose, Conv3DTranspose = keras_import('layers', 'Concatenate', 'UpSampling2D', 'UpSampling3D', 'Conv2DTranspose', 'Conv3DTranspose') |
Model = keras_import('models', 'Model') |
if self.basedir is None and fname is None: |
raise ValueError("Need explicit 'fname', since model directory not available (basedir=None).") |
if self._is_multiclass(): |
warnings.warn("multi-class mode not supported yet, removing classification output from exported model") |
grid = self.config.grid |
prob = self.keras_model.outputs[0] |
dist = self.keras_model.outputs[1] |
assert self.config.n_dim in (2,3) |
if upsample_grid and any(g>1 for g in grid): |
conv_transpose = Conv2DTranspose if self.config.n_dim==2 else Conv3DTranspose |
upsampling = UpSampling2D if self.config.n_dim==2 else UpSampling3D |
prob = conv_transpose(1, (1,)*self.config.n_dim, |
strides=grid, padding='same', |
kernel_initializer='ones', use_bias=False)(prob) |
dist = upsampling(grid)(dist) |
inputs = self.keras_model.inputs[0] |
outputs = Concatenate()([prob,dist]) if single_output else [prob,dist] |
csbdeep_model = Model(inputs, outputs) |
fname = (self.logdir / 'TF_SavedModel.zip') if fname is None else Path(fname) |
export_SavedModel(csbdeep_model, str(fname)) |
return csbdeep_model |
class StarDistPadAndCropResizer(Resizer): |
def __init__(self, grid, mode='reflect', **kwargs): |
assert isinstance(grid, dict) |
self.mode = mode |
self.grid = grid |
self.kwargs = kwargs |
def before(self, x, axes, axes_div_by): |
assert all(a%g==0 for g,a in zip((self.grid.get(a,1) for a in axes), axes_div_by)) |
axes = axes_check_and_normalize(axes,x.ndim) |
def _split(v): |
return 0, v |
self.pad = { |
a : _split((div_n-s%div_n)%div_n) |
for a, div_n, s in zip(axes, axes_div_by, x.shape) |
} |
x_pad = np.pad(x, tuple(self.pad[a] for a in axes), mode=self.mode, **self.kwargs) |
self.padded_shape = dict(zip(axes,x_pad.shape)) |
if 'C' in self.padded_shape: del self.padded_shape['C'] |
return x_pad |
def after(self, x, axes): |
axes = axes_check_and_normalize(axes,x.ndim) |
assert all(s_pad == s * g for s,s_pad,g in zip(x.shape, |
(self.padded_shape.get(a,_s) for a,_s in zip(axes,x.shape)), |
(self.grid.get(a,1) for a in axes))) |
crop = tuple ( |
slice(0, -(math.floor(p[1]/g)) if p[1]>=g else None) |
for p,g in zip((self.pad.get(a,(0,0)) for a in axes),(self.grid.get(a,1) for a in axes)) |
) |
return x[crop] |
def filter_points(self, ndim, points, axes): |
""" returns indices of points inside crop region """ |
assert points.ndim==2 |
axes = axes_check_and_normalize(axes,ndim) |
bounds = np.array(tuple(self.padded_shape[a]-self.pad[a][1] for a in axes if a.lower() in ('z','y','x'))) |
idx = np.where(np.all(points< bounds, 1)) |
return idx |
def _tf_version_at_least(version_string="1.0.0"): |
from packaging import version |
return version.parse(tf.__version__) >= version.parse(version_string) |