|
from __future__ import print_function, unicode_literals, absolute_import, division |
|
|
|
import numpy as np |
|
import warnings |
|
import math |
|
from tqdm import tqdm |
|
|
|
from csbdeep.models import BaseConfig |
|
from csbdeep.internals.blocks import unet_block |
|
from csbdeep.utils import _raise, backend_channels_last, axes_check_and_normalize, axes_dict |
|
from csbdeep.utils.tf import keras_import, IS_TF_1, CARETensorBoard, CARETensorBoardImage |
|
from skimage.segmentation import clear_border |
|
from skimage.measure import regionprops |
|
from scipy.ndimage import zoom |
|
from distutils.version import LooseVersion |
|
|
|
keras = keras_import() |
|
K = keras_import('backend') |
|
Input, Conv2D, MaxPooling2D = keras_import('layers', 'Input', 'Conv2D', 'MaxPooling2D') |
|
Model = keras_import('models', 'Model') |
|
|
|
from .base import StarDistBase, StarDistDataBase, _tf_version_at_least |
|
from ..sample_patches import sample_patches |
|
from ..utils import edt_prob, _normalize_grid, mask_to_categorical |
|
from ..geometry import star_dist, dist_to_coord, polygons_to_label |
|
from ..nms import non_maximum_suppression, non_maximum_suppression_sparse |
|
|
|
|
|
class StarDistData2D(StarDistDataBase): |
|
|
|
def __init__(self, X, Y, batch_size, n_rays, length, |
|
n_classes=None, classes=None, |
|
patch_size=(256,256), b=32, grid=(1,1), shape_completion=False, augmenter=None, foreground_prob=0, **kwargs): |
|
|
|
super().__init__(X=X, Y=Y, n_rays=n_rays, grid=grid, |
|
n_classes=n_classes, classes=classes, |
|
batch_size=batch_size, patch_size=patch_size, length=length, |
|
augmenter=augmenter, foreground_prob=foreground_prob, **kwargs) |
|
|
|
self.shape_completion = bool(shape_completion) |
|
if self.shape_completion and b > 0: |
|
self.b = slice(b,-b),slice(b,-b) |
|
else: |
|
self.b = slice(None),slice(None) |
|
|
|
self.sd_mode = 'opencl' if self.use_gpu else 'cpp' |
|
|
|
|
|
def __getitem__(self, i): |
|
idx = self.batch(i) |
|
arrays = [sample_patches((self.Y[k],) + self.channels_as_tuple(self.X[k]), |
|
patch_size=self.patch_size, n_samples=1, |
|
valid_inds=self.get_valid_inds(k)) for k in idx] |
|
|
|
if self.n_channel is None: |
|
X, Y = list(zip(*[(x[0][self.b],y[0]) for y,x in arrays])) |
|
else: |
|
X, Y = list(zip(*[(np.stack([_x[0] for _x in x],axis=-1)[self.b], y[0]) for y,*x in arrays])) |
|
|
|
X, Y = tuple(zip(*tuple(self.augmenter(_x, _y) for _x, _y in zip(X,Y)))) |
|
|
|
|
|
prob = np.stack([edt_prob(lbl[self.b][self.ss_grid[1:3]]) for lbl in Y]) |
|
|
|
|
|
|
|
if self.shape_completion: |
|
Y_cleared = [clear_border(lbl) for lbl in Y] |
|
_dist = np.stack([star_dist(lbl,self.n_rays,mode=self.sd_mode)[self.b+(slice(None),)] for lbl in Y_cleared]) |
|
dist = _dist[self.ss_grid] |
|
dist_mask = np.stack([edt_prob(lbl[self.b][self.ss_grid[1:3]]) for lbl in Y_cleared]) |
|
else: |
|
|
|
dist = np.stack([star_dist(lbl,self.n_rays,mode=self.sd_mode, grid=self.grid) for lbl in Y]) |
|
dist_mask = prob |
|
|
|
X = np.stack(X) |
|
if X.ndim == 3: |
|
X = np.expand_dims(X,-1) |
|
prob = np.expand_dims(prob,-1) |
|
dist_mask = np.expand_dims(dist_mask,-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dist_and_mask = np.empty(dist.shape[:-1]+(self.n_rays+1,), np.float32) |
|
dist_and_mask[...,:-1] = dist |
|
dist_and_mask[...,-1:] = dist_mask |
|
|
|
|
|
if self.n_classes is None: |
|
return [X], [prob,dist_and_mask] |
|
else: |
|
prob_class = np.stack(tuple((mask_to_categorical(y, self.n_classes, self.classes[k]) for y,k in zip(Y, idx)))) |
|
|
|
|
|
|
|
|
|
prob_class = zoom(prob_class, (1,)+tuple(1/g for g in self.grid)+(1,), order=0) |
|
|
|
return [X], [prob,dist_and_mask, prob_class] |
|
|
|
|
|
|
|
class Config2D(BaseConfig): |
|
"""Configuration for a :class:`StarDist2D` model. |
|
|
|
Parameters |
|
---------- |
|
axes : str or None |
|
Axes of the input images. |
|
n_rays : int |
|
Number of radial directions for the star-convex polygon. |
|
Recommended to use a power of 2 (default: 32). |
|
n_channel_in : int |
|
Number of channels of given input image (default: 1). |
|
grid : (int,int) |
|
Subsampling factors (must be powers of 2) for each of the axes. |
|
Model will predict on a subsampled grid for increased efficiency and larger field of view. |
|
n_classes : None or int |
|
Number of object classes to use for multi-class predection (use None to disable) |
|
backbone : str |
|
Name of the neural network architecture to be used as backbone. |
|
kwargs : dict |
|
Overwrite (or add) configuration attributes (see below). |
|
|
|
|
|
Attributes |
|
---------- |
|
unet_n_depth : int |
|
Number of U-Net resolution levels (down/up-sampling layers). |
|
unet_kernel_size : (int,int) |
|
Convolution kernel size for all (U-Net) convolution layers. |
|
unet_n_filter_base : int |
|
Number of convolution kernels (feature channels) for first U-Net layer. |
|
Doubled after each down-sampling layer. |
|
unet_pool : (int,int) |
|
Maxpooling size for all (U-Net) convolution layers. |
|
net_conv_after_unet : int |
|
Number of filters of the extra convolution layer after U-Net (0 to disable). |
|
unet_* : * |
|
Additional parameters for U-net backbone. |
|
train_shape_completion : bool |
|
Train model to predict complete shapes for partially visible objects at image boundary. |
|
train_completion_crop : int |
|
If 'train_shape_completion' is set to True, specify number of pixels to crop at boundary of training patches. |
|
Should be chosen based on (largest) object sizes. |
|
train_patch_size : (int,int) |
|
Size of patches to be cropped from provided training images. |
|
train_background_reg : float |
|
Regularizer to encourage distance predictions on background regions to be 0. |
|
train_foreground_only : float |
|
Fraction (0..1) of patches that will only be sampled from regions that contain foreground pixels. |
|
train_sample_cache : bool |
|
Activate caching of valid patch regions for all training images (disable to save memory for large datasets) |
|
train_dist_loss : str |
|
Training loss for star-convex polygon distances ('mse' or 'mae'). |
|
train_loss_weights : tuple of float |
|
Weights for losses relating to (probability, distance) |
|
train_epochs : int |
|
Number of training epochs. |
|
train_steps_per_epoch : int |
|
Number of parameter update steps per epoch. |
|
train_learning_rate : float |
|
Learning rate for training. |
|
train_batch_size : int |
|
Batch size for training. |
|
train_n_val_patches : int |
|
Number of patches to be extracted from validation images (``None`` = one patch per image). |
|
train_tensorboard : bool |
|
Enable TensorBoard for monitoring training progress. |
|
train_reduce_lr : dict |
|
Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable. |
|
use_gpu : bool |
|
Indicate that the data generator should use OpenCL to do computations on the GPU. |
|
|
|
.. _ReduceLROnPlateau: https://keras.io/api/callbacks/reduce_lr_on_plateau/ |
|
""" |
|
|
|
def __init__(self, axes='YX', n_rays=32, n_channel_in=1, grid=(1,1), n_classes=None, backbone='unet', **kwargs): |
|
"""See class docstring.""" |
|
|
|
super().__init__(axes=axes, n_channel_in=n_channel_in, n_channel_out=1+n_rays) |
|
|
|
|
|
self.n_rays = int(n_rays) |
|
self.grid = _normalize_grid(grid,2) |
|
self.backbone = str(backbone).lower() |
|
self.n_classes = None if n_classes is None else int(n_classes) |
|
|
|
|
|
if self.backbone == 'unet': |
|
self.unet_n_depth = 3 |
|
self.unet_kernel_size = 3,3 |
|
self.unet_n_filter_base = 32 |
|
self.unet_n_conv_per_depth = 2 |
|
self.unet_pool = 2,2 |
|
self.unet_activation = 'relu' |
|
self.unet_last_activation = 'relu' |
|
self.unet_batch_norm = False |
|
self.unet_dropout = 0.0 |
|
self.unet_prefix = '' |
|
self.net_conv_after_unet = 128 |
|
else: |
|
|
|
raise ValueError("backbone '%s' not supported." % self.backbone) |
|
|
|
|
|
if backend_channels_last(): |
|
self.net_input_shape = None,None,self.n_channel_in |
|
self.net_mask_shape = None,None,1 |
|
else: |
|
self.net_input_shape = self.n_channel_in,None,None |
|
self.net_mask_shape = 1,None,None |
|
|
|
self.train_shape_completion = False |
|
self.train_completion_crop = 32 |
|
self.train_patch_size = 256,256 |
|
self.train_background_reg = 1e-4 |
|
self.train_foreground_only = 0.9 |
|
self.train_sample_cache = True |
|
|
|
self.train_dist_loss = 'mae' |
|
self.train_loss_weights = (1,0.2) if self.n_classes is None else (1,0.2,1) |
|
self.train_class_weights = (1,1) if self.n_classes is None else (1,)*(self.n_classes+1) |
|
self.train_epochs = 400 |
|
self.train_steps_per_epoch = 100 |
|
self.train_learning_rate = 0.0003 |
|
self.train_batch_size = 4 |
|
self.train_n_val_patches = None |
|
self.train_tensorboard = True |
|
|
|
min_delta_key = 'epsilon' if LooseVersion(keras.__version__)<=LooseVersion('2.1.5') else 'min_delta' |
|
self.train_reduce_lr = {'factor': 0.5, 'patience': 40, min_delta_key: 0} |
|
|
|
self.use_gpu = False |
|
|
|
|
|
for k in ('n_dim', 'n_channel_out'): |
|
try: del kwargs[k] |
|
except KeyError: pass |
|
|
|
self.update_parameters(False, **kwargs) |
|
|
|
|
|
if not len(self.train_loss_weights) == (2 if self.n_classes is None else 3): |
|
raise ValueError(f"train_loss_weights {self.train_loss_weights} not compatible with n_classes ({self.n_classes}): must be 3 weights if n_classes is not None, otherwise 2") |
|
|
|
if not len(self.train_class_weights) == (2 if self.n_classes is None else self.n_classes+1): |
|
raise ValueError(f"train_class_weights {self.train_class_weights} not compatible with n_classes ({self.n_classes}): must be 'n_classes + 1' weights if n_classes is not None, otherwise 2") |
|
|
|
|
|
|
|
class StarDist2D(StarDistBase): |
|
"""StarDist2D model. |
|
|
|
Parameters |
|
---------- |
|
config : :class:`Config` or None |
|
Will be saved to disk as JSON (``config.json``). |
|
If set to ``None``, will be loaded from disk (must exist). |
|
name : str or None |
|
Model name. Uses a timestamp if set to ``None`` (default). |
|
basedir : str |
|
Directory that contains (or will contain) a folder with the given model name. |
|
|
|
Raises |
|
------ |
|
FileNotFoundError |
|
If ``config=None`` and config cannot be loaded from disk. |
|
ValueError |
|
Illegal arguments, including invalid configuration. |
|
|
|
Attributes |
|
---------- |
|
config : :class:`Config` |
|
Configuration, as provided during instantiation. |
|
keras_model : `Keras model <https://keras.io/getting-started/functional-api-guide/>`_ |
|
Keras neural network model. |
|
name : str |
|
Model name. |
|
logdir : :class:`pathlib.Path` |
|
Path to model folder (which stores configuration, weights, etc.) |
|
""" |
|
|
|
def __init__(self, config=Config2D(), name=None, basedir='.'): |
|
"""See class docstring.""" |
|
super().__init__(config, name=name, basedir=basedir) |
|
|
|
|
|
def _build(self): |
|
self.config.backbone == 'unet' or _raise(NotImplementedError()) |
|
unet_kwargs = {k[len('unet_'):]:v for (k,v) in vars(self.config).items() if k.startswith('unet_')} |
|
|
|
input_img = Input(self.config.net_input_shape, name='input') |
|
|
|
|
|
pooled = np.array([1,1]) |
|
pooled_img = input_img |
|
while tuple(pooled) != tuple(self.config.grid): |
|
pool = 1 + (np.asarray(self.config.grid) > pooled) |
|
pooled *= pool |
|
for _ in range(self.config.unet_n_conv_per_depth): |
|
pooled_img = Conv2D(self.config.unet_n_filter_base, self.config.unet_kernel_size, |
|
padding='same', activation=self.config.unet_activation)(pooled_img) |
|
pooled_img = MaxPooling2D(pool)(pooled_img) |
|
|
|
unet_base = unet_block(**unet_kwargs)(pooled_img) |
|
|
|
if self.config.net_conv_after_unet > 0: |
|
unet = Conv2D(self.config.net_conv_after_unet, self.config.unet_kernel_size, |
|
name='features', padding='same', activation=self.config.unet_activation)(unet_base) |
|
else: |
|
unet = unet_base |
|
|
|
output_prob = Conv2D( 1, (1,1), name='prob', padding='same', activation='sigmoid')(unet) |
|
output_dist = Conv2D(self.config.n_rays, (1,1), name='dist', padding='same', activation='linear')(unet) |
|
|
|
|
|
if self._is_multiclass(): |
|
if self.config.net_conv_after_unet > 0: |
|
unet_class = Conv2D(self.config.net_conv_after_unet, self.config.unet_kernel_size, |
|
name='features_class', padding='same', activation=self.config.unet_activation)(unet_base) |
|
else: |
|
unet_class = unet_base |
|
|
|
output_prob_class = Conv2D(self.config.n_classes+1, (1,1), name='prob_class', padding='same', activation='softmax')(unet_class) |
|
return Model([input_img], [output_prob,output_dist,output_prob_class]) |
|
else: |
|
return Model([input_img], [output_prob,output_dist]) |
|
|
|
|
|
def train(self, X, Y, validation_data, classes='auto', augmenter=None, seed=None, epochs=None, steps_per_epoch=None, workers=1): |
|
"""Train the neural network with the given data. |
|
|
|
Parameters |
|
---------- |
|
X : tuple, list, `numpy.ndarray`, `keras.utils.Sequence` |
|
Input images |
|
Y : tuple, list, `numpy.ndarray`, `keras.utils.Sequence` |
|
Label masks |
|
classes (optional): 'auto' or iterable of same length as X |
|
label id -> class id mapping for each label mask of Y if multiclass prediction is activated (n_classes > 0) |
|
list of dicts with label id -> class id (1,...,n_classes) |
|
'auto' -> all objects will be assigned to the first non-background class, |
|
or will be ignored if config.n_classes is None |
|
validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) or triple (if multiclass) |
|
Tuple (triple if multiclass) of X,Y,[classes] validation data. |
|
augmenter : None or callable |
|
Function with expected signature ``xt, yt = augmenter(x, y)`` |
|
that takes in a single pair of input/label image (x,y) and returns |
|
the transformed images (xt, yt) for the purpose of data augmentation |
|
during training. Not applied to validation images. |
|
Example: |
|
def simple_augmenter(x,y): |
|
x = x + 0.05*np.random.normal(0,1,x.shape) |
|
return x,y |
|
seed : int |
|
Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.) |
|
epochs : int |
|
Optional argument to use instead of the value from ``config``. |
|
steps_per_epoch : int |
|
Optional argument to use instead of the value from ``config``. |
|
|
|
Returns |
|
------- |
|
``History`` object |
|
See `Keras training history <https://keras.io/models/model/#fit>`_. |
|
|
|
""" |
|
if seed is not None: |
|
|
|
np.random.seed(seed) |
|
if epochs is None: |
|
epochs = self.config.train_epochs |
|
if steps_per_epoch is None: |
|
steps_per_epoch = self.config.train_steps_per_epoch |
|
|
|
classes = self._parse_classes_arg(classes, len(X)) |
|
|
|
if not self._is_multiclass() and classes is not None: |
|
warnings.warn("Ignoring given classes as n_classes is set to None") |
|
|
|
isinstance(validation_data,(list,tuple)) or _raise(ValueError()) |
|
if self._is_multiclass() and len(validation_data) == 2: |
|
validation_data = tuple(validation_data) + ('auto',) |
|
((len(validation_data) == (3 if self._is_multiclass() else 2)) |
|
or _raise(ValueError(f'len(validation_data) = {len(validation_data)}, but should be {3 if self._is_multiclass() else 2}'))) |
|
|
|
patch_size = self.config.train_patch_size |
|
axes = self.config.axes.replace('C','') |
|
b = self.config.train_completion_crop if self.config.train_shape_completion else 0 |
|
div_by = self._axes_div_by(axes) |
|
[(p-2*b) % d == 0 or _raise(ValueError( |
|
"'train_patch_size' - 2*'train_completion_crop' must be divisible by {d} along axis '{a}'".format(a=a,d=d) if self.config.train_shape_completion else |
|
"'train_patch_size' must be divisible by {d} along axis '{a}'".format(a=a,d=d) |
|
)) for p,d,a in zip(patch_size,div_by,axes)] |
|
|
|
if not self._model_prepared: |
|
self.prepare_for_training() |
|
|
|
data_kwargs = dict ( |
|
n_rays = self.config.n_rays, |
|
patch_size = self.config.train_patch_size, |
|
grid = self.config.grid, |
|
shape_completion = self.config.train_shape_completion, |
|
b = self.config.train_completion_crop, |
|
use_gpu = self.config.use_gpu, |
|
foreground_prob = self.config.train_foreground_only, |
|
n_classes = self.config.n_classes, |
|
sample_ind_cache = self.config.train_sample_cache, |
|
) |
|
|
|
|
|
n_data_val = len(validation_data[0]) |
|
classes_val = self._parse_classes_arg(validation_data[2], n_data_val) if self._is_multiclass() else None |
|
n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val |
|
_data_val = StarDistData2D(validation_data[0],validation_data[1], classes=classes_val, batch_size=n_take, length=1, **data_kwargs) |
|
data_val = _data_val[0] |
|
|
|
|
|
self.data_train = StarDistData2D(X, Y, classes=classes, batch_size=self.config.train_batch_size, |
|
augmenter=augmenter, length=epochs*steps_per_epoch, **data_kwargs) |
|
|
|
if self.config.train_tensorboard: |
|
|
|
_n = min(3, self.config.n_rays) |
|
channel = axes_dict(self.config.axes)['C'] |
|
output_slices = [[slice(None)]*4,[slice(None)]*4] |
|
output_slices[1][1+channel] = slice(0,(self.config.n_rays//_n)*_n, self.config.n_rays//_n) |
|
if self._is_multiclass(): |
|
_n = min(3, self.config.n_classes) |
|
output_slices += [[slice(None)]*4] |
|
output_slices[2][1+channel] = slice(1,1+(self.config.n_classes//_n)*_n, self.config.n_classes//_n) |
|
|
|
if IS_TF_1: |
|
for cb in self.callbacks: |
|
if isinstance(cb,CARETensorBoard): |
|
cb.output_slices = output_slices |
|
|
|
cb.output_target_shapes = [None,[None]*4,None] |
|
cb.output_target_shapes[1][1+channel] = data_val[1][1].shape[1+channel] |
|
elif self.basedir is not None and not any(isinstance(cb,CARETensorBoardImage) for cb in self.callbacks): |
|
self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=data_val, log_dir=str(self.logdir/'logs'/'images'), |
|
n_images=3, prob_out=False, output_slices=output_slices)) |
|
|
|
fit = self.keras_model.fit_generator if IS_TF_1 else self.keras_model.fit |
|
history = fit(iter(self.data_train), validation_data=data_val, |
|
epochs=epochs, steps_per_epoch=steps_per_epoch, |
|
workers=workers, use_multiprocessing=workers>1, |
|
callbacks=self.callbacks, verbose=1, |
|
|
|
**(dict(validation_batch_size = self.config.train_batch_size) if _tf_version_at_least("2.2.0") else {})) |
|
self._training_finished() |
|
|
|
return history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _instances_from_prediction(self, img_shape, prob, dist, points=None, prob_class=None, prob_thresh=None, nms_thresh=None, overlap_label=None, return_labels=True, scale=None, **nms_kwargs): |
|
""" |
|
if points is None -> dense prediction |
|
if points is not None -> sparse prediction |
|
|
|
if prob_class is None -> single class prediction |
|
if prob_class is not None -> multi class prediction |
|
""" |
|
if prob_thresh is None: prob_thresh = self.thresholds.prob |
|
if nms_thresh is None: nms_thresh = self.thresholds.nms |
|
if overlap_label is not None: raise NotImplementedError("overlap_label not supported for 2D yet!") |
|
|
|
|
|
if points is not None: |
|
points, probi, disti, indsi = non_maximum_suppression_sparse(dist, prob, points, nms_thresh=nms_thresh, **nms_kwargs) |
|
if prob_class is not None: |
|
prob_class = prob_class[indsi] |
|
|
|
|
|
else: |
|
points, probi, disti = non_maximum_suppression(dist, prob, grid=self.config.grid, |
|
prob_thresh=prob_thresh, nms_thresh=nms_thresh, **nms_kwargs) |
|
if prob_class is not None: |
|
inds = tuple(p//g for p,g in zip(points.T, self.config.grid)) |
|
prob_class = prob_class[inds] |
|
|
|
if scale is not None: |
|
|
|
|
|
|
|
if not (isinstance(scale,dict) and 'X' in scale and 'Y' in scale): |
|
raise ValueError("scale must be a dictionary with entries for 'X' and 'Y'") |
|
rescale = (1/scale['Y'],1/scale['X']) |
|
points = points * np.array(rescale).reshape(1,2) |
|
else: |
|
rescale = (1,1) |
|
|
|
if return_labels: |
|
labels = polygons_to_label(disti, points, prob=probi, shape=img_shape, scale_dist=rescale) |
|
else: |
|
labels = None |
|
|
|
coord = dist_to_coord(disti, points, scale_dist=rescale) |
|
res_dict = dict(coord=coord, points=points, prob=probi) |
|
|
|
|
|
if prob_class is not None: |
|
prob_class = np.asarray(prob_class) |
|
class_id = np.argmax(prob_class, axis=-1) |
|
res_dict.update(dict(class_prob=prob_class, class_id=class_id)) |
|
|
|
return labels, res_dict |
|
|
|
|
|
def _axes_div_by(self, query_axes): |
|
self.config.backbone == 'unet' or _raise(NotImplementedError()) |
|
query_axes = axes_check_and_normalize(query_axes) |
|
assert len(self.config.unet_pool) == len(self.config.grid) |
|
div_by = dict(zip( |
|
self.config.axes.replace('C',''), |
|
tuple(p**self.config.unet_n_depth * g for p,g in zip(self.config.unet_pool,self.config.grid)) |
|
)) |
|
return tuple(div_by.get(a,1) for a in query_axes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def _config_class(self): |
|
return Config2D |
|
|