Lewislou's picture
Upload 27 files
0ce1ebe
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import warnings
import math
from tqdm import tqdm
from csbdeep.models import BaseConfig
from csbdeep.internals.blocks import unet_block
from csbdeep.utils import _raise, backend_channels_last, axes_check_and_normalize, axes_dict
from csbdeep.utils.tf import keras_import, IS_TF_1, CARETensorBoard, CARETensorBoardImage
from skimage.segmentation import clear_border
from skimage.measure import regionprops
from scipy.ndimage import zoom
from distutils.version import LooseVersion
keras = keras_import()
K = keras_import('backend')
Input, Conv2D, MaxPooling2D = keras_import('layers', 'Input', 'Conv2D', 'MaxPooling2D')
Model = keras_import('models', 'Model')
from .base import StarDistBase, StarDistDataBase, _tf_version_at_least
from ..sample_patches import sample_patches
from ..utils import edt_prob, _normalize_grid, mask_to_categorical
from ..geometry import star_dist, dist_to_coord, polygons_to_label
from ..nms import non_maximum_suppression, non_maximum_suppression_sparse
class StarDistData2D(StarDistDataBase):
def __init__(self, X, Y, batch_size, n_rays, length,
n_classes=None, classes=None,
patch_size=(256,256), b=32, grid=(1,1), shape_completion=False, augmenter=None, foreground_prob=0, **kwargs):
super().__init__(X=X, Y=Y, n_rays=n_rays, grid=grid,
n_classes=n_classes, classes=classes,
batch_size=batch_size, patch_size=patch_size, length=length,
augmenter=augmenter, foreground_prob=foreground_prob, **kwargs)
self.shape_completion = bool(shape_completion)
if self.shape_completion and b > 0:
self.b = slice(b,-b),slice(b,-b)
else:
self.b = slice(None),slice(None)
self.sd_mode = 'opencl' if self.use_gpu else 'cpp'
def __getitem__(self, i):
idx = self.batch(i)
arrays = [sample_patches((self.Y[k],) + self.channels_as_tuple(self.X[k]),
patch_size=self.patch_size, n_samples=1,
valid_inds=self.get_valid_inds(k)) for k in idx]
if self.n_channel is None:
X, Y = list(zip(*[(x[0][self.b],y[0]) for y,x in arrays]))
else:
X, Y = list(zip(*[(np.stack([_x[0] for _x in x],axis=-1)[self.b], y[0]) for y,*x in arrays]))
X, Y = tuple(zip(*tuple(self.augmenter(_x, _y) for _x, _y in zip(X,Y))))
prob = np.stack([edt_prob(lbl[self.b][self.ss_grid[1:3]]) for lbl in Y])
# prob = np.stack([edt_prob(lbl[self.b]) for lbl in Y])
# prob = prob[self.ss_grid]
if self.shape_completion:
Y_cleared = [clear_border(lbl) for lbl in Y]
_dist = np.stack([star_dist(lbl,self.n_rays,mode=self.sd_mode)[self.b+(slice(None),)] for lbl in Y_cleared])
dist = _dist[self.ss_grid]
dist_mask = np.stack([edt_prob(lbl[self.b][self.ss_grid[1:3]]) for lbl in Y_cleared])
else:
# directly subsample with grid
dist = np.stack([star_dist(lbl,self.n_rays,mode=self.sd_mode, grid=self.grid) for lbl in Y])
dist_mask = prob
X = np.stack(X)
if X.ndim == 3: # input image has no channel axis
X = np.expand_dims(X,-1)
prob = np.expand_dims(prob,-1)
dist_mask = np.expand_dims(dist_mask,-1)
# subsample wth given grid
# dist_mask = dist_mask[self.ss_grid]
# prob = prob[self.ss_grid]
# append dist_mask to dist as additional channel
# dist_and_mask = np.concatenate([dist,dist_mask],axis=-1)
# faster than concatenate
dist_and_mask = np.empty(dist.shape[:-1]+(self.n_rays+1,), np.float32)
dist_and_mask[...,:-1] = dist
dist_and_mask[...,-1:] = dist_mask
if self.n_classes is None:
return [X], [prob,dist_and_mask]
else:
prob_class = np.stack(tuple((mask_to_categorical(y, self.n_classes, self.classes[k]) for y,k in zip(Y, idx))))
# TODO: investigate downsampling via simple indexing vs. using 'zoom'
# prob_class = prob_class[self.ss_grid]
# 'zoom' might lead to better registered maps (especially if upscaled later)
prob_class = zoom(prob_class, (1,)+tuple(1/g for g in self.grid)+(1,), order=0)
return [X], [prob,dist_and_mask, prob_class]
class Config2D(BaseConfig):
"""Configuration for a :class:`StarDist2D` model.
Parameters
----------
axes : str or None
Axes of the input images.
n_rays : int
Number of radial directions for the star-convex polygon.
Recommended to use a power of 2 (default: 32).
n_channel_in : int
Number of channels of given input image (default: 1).
grid : (int,int)
Subsampling factors (must be powers of 2) for each of the axes.
Model will predict on a subsampled grid for increased efficiency and larger field of view.
n_classes : None or int
Number of object classes to use for multi-class predection (use None to disable)
backbone : str
Name of the neural network architecture to be used as backbone.
kwargs : dict
Overwrite (or add) configuration attributes (see below).
Attributes
----------
unet_n_depth : int
Number of U-Net resolution levels (down/up-sampling layers).
unet_kernel_size : (int,int)
Convolution kernel size for all (U-Net) convolution layers.
unet_n_filter_base : int
Number of convolution kernels (feature channels) for first U-Net layer.
Doubled after each down-sampling layer.
unet_pool : (int,int)
Maxpooling size for all (U-Net) convolution layers.
net_conv_after_unet : int
Number of filters of the extra convolution layer after U-Net (0 to disable).
unet_* : *
Additional parameters for U-net backbone.
train_shape_completion : bool
Train model to predict complete shapes for partially visible objects at image boundary.
train_completion_crop : int
If 'train_shape_completion' is set to True, specify number of pixels to crop at boundary of training patches.
Should be chosen based on (largest) object sizes.
train_patch_size : (int,int)
Size of patches to be cropped from provided training images.
train_background_reg : float
Regularizer to encourage distance predictions on background regions to be 0.
train_foreground_only : float
Fraction (0..1) of patches that will only be sampled from regions that contain foreground pixels.
train_sample_cache : bool
Activate caching of valid patch regions for all training images (disable to save memory for large datasets)
train_dist_loss : str
Training loss for star-convex polygon distances ('mse' or 'mae').
train_loss_weights : tuple of float
Weights for losses relating to (probability, distance)
train_epochs : int
Number of training epochs.
train_steps_per_epoch : int
Number of parameter update steps per epoch.
train_learning_rate : float
Learning rate for training.
train_batch_size : int
Batch size for training.
train_n_val_patches : int
Number of patches to be extracted from validation images (``None`` = one patch per image).
train_tensorboard : bool
Enable TensorBoard for monitoring training progress.
train_reduce_lr : dict
Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable.
use_gpu : bool
Indicate that the data generator should use OpenCL to do computations on the GPU.
.. _ReduceLROnPlateau: https://keras.io/api/callbacks/reduce_lr_on_plateau/
"""
def __init__(self, axes='YX', n_rays=32, n_channel_in=1, grid=(1,1), n_classes=None, backbone='unet', **kwargs):
"""See class docstring."""
super().__init__(axes=axes, n_channel_in=n_channel_in, n_channel_out=1+n_rays)
# directly set by parameters
self.n_rays = int(n_rays)
self.grid = _normalize_grid(grid,2)
self.backbone = str(backbone).lower()
self.n_classes = None if n_classes is None else int(n_classes)
# default config (can be overwritten by kwargs below)
if self.backbone == 'unet':
self.unet_n_depth = 3
self.unet_kernel_size = 3,3
self.unet_n_filter_base = 32
self.unet_n_conv_per_depth = 2
self.unet_pool = 2,2
self.unet_activation = 'relu'
self.unet_last_activation = 'relu'
self.unet_batch_norm = False
self.unet_dropout = 0.0
self.unet_prefix = ''
self.net_conv_after_unet = 128
else:
# TODO: resnet backbone for 2D model?
raise ValueError("backbone '%s' not supported." % self.backbone)
# net_mask_shape not needed but kept for legacy reasons
if backend_channels_last():
self.net_input_shape = None,None,self.n_channel_in
self.net_mask_shape = None,None,1
else:
self.net_input_shape = self.n_channel_in,None,None
self.net_mask_shape = 1,None,None
self.train_shape_completion = False
self.train_completion_crop = 32
self.train_patch_size = 256,256
self.train_background_reg = 1e-4
self.train_foreground_only = 0.9
self.train_sample_cache = True
self.train_dist_loss = 'mae'
self.train_loss_weights = (1,0.2) if self.n_classes is None else (1,0.2,1)
self.train_class_weights = (1,1) if self.n_classes is None else (1,)*(self.n_classes+1)
self.train_epochs = 400
self.train_steps_per_epoch = 100
self.train_learning_rate = 0.0003
self.train_batch_size = 4
self.train_n_val_patches = None
self.train_tensorboard = True
# the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5
min_delta_key = 'epsilon' if LooseVersion(keras.__version__)<=LooseVersion('2.1.5') else 'min_delta'
self.train_reduce_lr = {'factor': 0.5, 'patience': 40, min_delta_key: 0}
self.use_gpu = False
# remove derived attributes that shouldn't be overwritten
for k in ('n_dim', 'n_channel_out'):
try: del kwargs[k]
except KeyError: pass
self.update_parameters(False, **kwargs)
# FIXME: put into is_valid()
if not len(self.train_loss_weights) == (2 if self.n_classes is None else 3):
raise ValueError(f"train_loss_weights {self.train_loss_weights} not compatible with n_classes ({self.n_classes}): must be 3 weights if n_classes is not None, otherwise 2")
if not len(self.train_class_weights) == (2 if self.n_classes is None else self.n_classes+1):
raise ValueError(f"train_class_weights {self.train_class_weights} not compatible with n_classes ({self.n_classes}): must be 'n_classes + 1' weights if n_classes is not None, otherwise 2")
class StarDist2D(StarDistBase):
"""StarDist2D model.
Parameters
----------
config : :class:`Config` or None
Will be saved to disk as JSON (``config.json``).
If set to ``None``, will be loaded from disk (must exist).
name : str or None
Model name. Uses a timestamp if set to ``None`` (default).
basedir : str
Directory that contains (or will contain) a folder with the given model name.
Raises
------
FileNotFoundError
If ``config=None`` and config cannot be loaded from disk.
ValueError
Illegal arguments, including invalid configuration.
Attributes
----------
config : :class:`Config`
Configuration, as provided during instantiation.
keras_model : `Keras model <https://keras.io/getting-started/functional-api-guide/>`_
Keras neural network model.
name : str
Model name.
logdir : :class:`pathlib.Path`
Path to model folder (which stores configuration, weights, etc.)
"""
def __init__(self, config=Config2D(), name=None, basedir='.'):
"""See class docstring."""
super().__init__(config, name=name, basedir=basedir)
def _build(self):
self.config.backbone == 'unet' or _raise(NotImplementedError())
unet_kwargs = {k[len('unet_'):]:v for (k,v) in vars(self.config).items() if k.startswith('unet_')}
input_img = Input(self.config.net_input_shape, name='input')
# maxpool input image to grid size
pooled = np.array([1,1])
pooled_img = input_img
while tuple(pooled) != tuple(self.config.grid):
pool = 1 + (np.asarray(self.config.grid) > pooled)
pooled *= pool
for _ in range(self.config.unet_n_conv_per_depth):
pooled_img = Conv2D(self.config.unet_n_filter_base, self.config.unet_kernel_size,
padding='same', activation=self.config.unet_activation)(pooled_img)
pooled_img = MaxPooling2D(pool)(pooled_img)
unet_base = unet_block(**unet_kwargs)(pooled_img)
if self.config.net_conv_after_unet > 0:
unet = Conv2D(self.config.net_conv_after_unet, self.config.unet_kernel_size,
name='features', padding='same', activation=self.config.unet_activation)(unet_base)
else:
unet = unet_base
output_prob = Conv2D( 1, (1,1), name='prob', padding='same', activation='sigmoid')(unet)
output_dist = Conv2D(self.config.n_rays, (1,1), name='dist', padding='same', activation='linear')(unet)
# attach extra classification head when self.n_classes is given
if self._is_multiclass():
if self.config.net_conv_after_unet > 0:
unet_class = Conv2D(self.config.net_conv_after_unet, self.config.unet_kernel_size,
name='features_class', padding='same', activation=self.config.unet_activation)(unet_base)
else:
unet_class = unet_base
output_prob_class = Conv2D(self.config.n_classes+1, (1,1), name='prob_class', padding='same', activation='softmax')(unet_class)
return Model([input_img], [output_prob,output_dist,output_prob_class])
else:
return Model([input_img], [output_prob,output_dist])
def train(self, X, Y, validation_data, classes='auto', augmenter=None, seed=None, epochs=None, steps_per_epoch=None, workers=1):
"""Train the neural network with the given data.
Parameters
----------
X : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
Input images
Y : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
Label masks
classes (optional): 'auto' or iterable of same length as X
label id -> class id mapping for each label mask of Y if multiclass prediction is activated (n_classes > 0)
list of dicts with label id -> class id (1,...,n_classes)
'auto' -> all objects will be assigned to the first non-background class,
or will be ignored if config.n_classes is None
validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) or triple (if multiclass)
Tuple (triple if multiclass) of X,Y,[classes] validation data.
augmenter : None or callable
Function with expected signature ``xt, yt = augmenter(x, y)``
that takes in a single pair of input/label image (x,y) and returns
the transformed images (xt, yt) for the purpose of data augmentation
during training. Not applied to validation images.
Example:
def simple_augmenter(x,y):
x = x + 0.05*np.random.normal(0,1,x.shape)
return x,y
seed : int
Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.)
epochs : int
Optional argument to use instead of the value from ``config``.
steps_per_epoch : int
Optional argument to use instead of the value from ``config``.
Returns
-------
``History`` object
See `Keras training history <https://keras.io/models/model/#fit>`_.
"""
if seed is not None:
# https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
np.random.seed(seed)
if epochs is None:
epochs = self.config.train_epochs
if steps_per_epoch is None:
steps_per_epoch = self.config.train_steps_per_epoch
classes = self._parse_classes_arg(classes, len(X))
if not self._is_multiclass() and classes is not None:
warnings.warn("Ignoring given classes as n_classes is set to None")
isinstance(validation_data,(list,tuple)) or _raise(ValueError())
if self._is_multiclass() and len(validation_data) == 2:
validation_data = tuple(validation_data) + ('auto',)
((len(validation_data) == (3 if self._is_multiclass() else 2))
or _raise(ValueError(f'len(validation_data) = {len(validation_data)}, but should be {3 if self._is_multiclass() else 2}')))
patch_size = self.config.train_patch_size
axes = self.config.axes.replace('C','')
b = self.config.train_completion_crop if self.config.train_shape_completion else 0
div_by = self._axes_div_by(axes)
[(p-2*b) % d == 0 or _raise(ValueError(
"'train_patch_size' - 2*'train_completion_crop' must be divisible by {d} along axis '{a}'".format(a=a,d=d) if self.config.train_shape_completion else
"'train_patch_size' must be divisible by {d} along axis '{a}'".format(a=a,d=d)
)) for p,d,a in zip(patch_size,div_by,axes)]
if not self._model_prepared:
self.prepare_for_training()
data_kwargs = dict (
n_rays = self.config.n_rays,
patch_size = self.config.train_patch_size,
grid = self.config.grid,
shape_completion = self.config.train_shape_completion,
b = self.config.train_completion_crop,
use_gpu = self.config.use_gpu,
foreground_prob = self.config.train_foreground_only,
n_classes = self.config.n_classes,
sample_ind_cache = self.config.train_sample_cache,
)
# generate validation data and store in numpy arrays
n_data_val = len(validation_data[0])
classes_val = self._parse_classes_arg(validation_data[2], n_data_val) if self._is_multiclass() else None
n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val
_data_val = StarDistData2D(validation_data[0],validation_data[1], classes=classes_val, batch_size=n_take, length=1, **data_kwargs)
data_val = _data_val[0]
# expose data generator as member for general diagnostics
self.data_train = StarDistData2D(X, Y, classes=classes, batch_size=self.config.train_batch_size,
augmenter=augmenter, length=epochs*steps_per_epoch, **data_kwargs)
if self.config.train_tensorboard:
# show dist for three rays
_n = min(3, self.config.n_rays)
channel = axes_dict(self.config.axes)['C']
output_slices = [[slice(None)]*4,[slice(None)]*4]
output_slices[1][1+channel] = slice(0,(self.config.n_rays//_n)*_n, self.config.n_rays//_n)
if self._is_multiclass():
_n = min(3, self.config.n_classes)
output_slices += [[slice(None)]*4]
output_slices[2][1+channel] = slice(1,1+(self.config.n_classes//_n)*_n, self.config.n_classes//_n)
if IS_TF_1:
for cb in self.callbacks:
if isinstance(cb,CARETensorBoard):
cb.output_slices = output_slices
# target image for dist includes dist_mask and thus has more channels than dist output
cb.output_target_shapes = [None,[None]*4,None]
cb.output_target_shapes[1][1+channel] = data_val[1][1].shape[1+channel]
elif self.basedir is not None and not any(isinstance(cb,CARETensorBoardImage) for cb in self.callbacks):
self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=data_val, log_dir=str(self.logdir/'logs'/'images'),
n_images=3, prob_out=False, output_slices=output_slices))
fit = self.keras_model.fit_generator if IS_TF_1 else self.keras_model.fit
history = fit(iter(self.data_train), validation_data=data_val,
epochs=epochs, steps_per_epoch=steps_per_epoch,
workers=workers, use_multiprocessing=workers>1,
callbacks=self.callbacks, verbose=1,
# set validation batchsize to training batchsize (only works for tf >= 2.2)
**(dict(validation_batch_size = self.config.train_batch_size) if _tf_version_at_least("2.2.0") else {}))
self._training_finished()
return history
# def _instances_from_prediction_old(self, img_shape, prob, dist,points = None, prob_class = None, prob_thresh=None, nms_thresh=None, overlap_label = None, **nms_kwargs):
# from stardist.geometry.geom2d import _polygons_to_label_old, _dist_to_coord_old
# from stardist.nms import _non_maximum_suppression_old
# if prob_thresh is None: prob_thresh = self.thresholds.prob
# if nms_thresh is None: nms_thresh = self.thresholds.nms
# if overlap_label is not None: raise NotImplementedError("overlap_label not supported for 2D yet!")
# coord = _dist_to_coord_old(dist, grid=self.config.grid)
# inds = _non_maximum_suppression_old(coord, prob, grid=self.config.grid,
# prob_thresh=prob_thresh, nms_thresh=nms_thresh, **nms_kwargs)
# labels = _polygons_to_label_old(coord, prob, inds, shape=img_shape)
# # sort 'inds' such that ids in 'labels' map to entries in polygon dictionary entries
# inds = inds[np.argsort(prob[inds[:,0],inds[:,1]])]
# # adjust for grid
# points = inds*np.array(self.config.grid)
# res_dict = dict(coord=coord[inds[:,0],inds[:,1]], points=points, prob=prob[inds[:,0],inds[:,1]])
# if prob_class is not None:
# prob_class = np.asarray(prob_class)
# res_dict.update(dict(class_prob = prob_class))
# return labels, res_dict
def _instances_from_prediction(self, img_shape, prob, dist, points=None, prob_class=None, prob_thresh=None, nms_thresh=None, overlap_label=None, return_labels=True, scale=None, **nms_kwargs):
"""
if points is None -> dense prediction
if points is not None -> sparse prediction
if prob_class is None -> single class prediction
if prob_class is not None -> multi class prediction
"""
if prob_thresh is None: prob_thresh = self.thresholds.prob
if nms_thresh is None: nms_thresh = self.thresholds.nms
if overlap_label is not None: raise NotImplementedError("overlap_label not supported for 2D yet!")
# sparse prediction
if points is not None:
points, probi, disti, indsi = non_maximum_suppression_sparse(dist, prob, points, nms_thresh=nms_thresh, **nms_kwargs)
if prob_class is not None:
prob_class = prob_class[indsi]
# dense prediction
else:
points, probi, disti = non_maximum_suppression(dist, prob, grid=self.config.grid,
prob_thresh=prob_thresh, nms_thresh=nms_thresh, **nms_kwargs)
if prob_class is not None:
inds = tuple(p//g for p,g in zip(points.T, self.config.grid))
prob_class = prob_class[inds]
if scale is not None:
# need to undo the scaling given by the scale dict, e.g. scale = dict(X=0.5,Y=0.5):
# 1. re-scale points (origins of polygons)
# 2. re-scale coordinates (computed from distances) of (zero-origin) polygons
if not (isinstance(scale,dict) and 'X' in scale and 'Y' in scale):
raise ValueError("scale must be a dictionary with entries for 'X' and 'Y'")
rescale = (1/scale['Y'],1/scale['X'])
points = points * np.array(rescale).reshape(1,2)
else:
rescale = (1,1)
if return_labels:
labels = polygons_to_label(disti, points, prob=probi, shape=img_shape, scale_dist=rescale)
else:
labels = None
coord = dist_to_coord(disti, points, scale_dist=rescale)
res_dict = dict(coord=coord, points=points, prob=probi)
# multi class prediction
if prob_class is not None:
prob_class = np.asarray(prob_class)
class_id = np.argmax(prob_class, axis=-1)
res_dict.update(dict(class_prob=prob_class, class_id=class_id))
return labels, res_dict
def _axes_div_by(self, query_axes):
self.config.backbone == 'unet' or _raise(NotImplementedError())
query_axes = axes_check_and_normalize(query_axes)
assert len(self.config.unet_pool) == len(self.config.grid)
div_by = dict(zip(
self.config.axes.replace('C',''),
tuple(p**self.config.unet_n_depth * g for p,g in zip(self.config.unet_pool,self.config.grid))
))
return tuple(div_by.get(a,1) for a in query_axes)
# def _axes_tile_overlap(self, query_axes):
# self.config.backbone == 'unet' or _raise(NotImplementedError())
# query_axes = axes_check_and_normalize(query_axes)
# assert len(self.config.unet_pool) == len(self.config.grid) == len(self.config.unet_kernel_size)
# # TODO: compute this properly when any value of grid > 1
# # all(g==1 for g in self.config.grid) or warnings.warn('FIXME')
# overlap = dict(zip(
# self.config.axes.replace('C',''),
# tuple(tile_overlap(self.config.unet_n_depth + int(np.log2(g)), k, p)
# for p,k,g in zip(self.config.unet_pool,self.config.unet_kernel_size,self.config.grid))
# ))
# return tuple(overlap.get(a,0) for a in query_axes)
@property
def _config_class(self):
return Config2D