Source code for tfplot.contrib

'''Some predefined plot functions.'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .wrapper import autowrap


__all__ = (
    'probmap',
    'probmap_simple',
    'batch',
)


[docs]@autowrap def probmap(x, cmap='jet', colorbar=True, vmin=None, vmax=None, axis=True, ax=None): ''' Display a heatmap in color. The resulting op will be a RGBA image Tensor. Args: x: A 2-D image-like tensor to draw. cmap: Matplotlib colormap. Defaults 'jet' axis: If True (default), x-axis and y-axis will appear. colorbar: If True (default), a colorbar will be placed on the right. vmin: A scalar. Minimum value of the range. See ``matplotlib.axes.Axes.imshow``. vmax: A scalar. Maximum value of the range. See ``matplotlib.axes.Axes.imshow``. Returns: A `uint8` `Tensor` of shape ``(?, ?, 4)`` containing the resulting plot. ''' assert ax is not None, "autowrap did not set ax" axim = ax.imshow(x, cmap=cmap, vmin=vmin, vmax=vmax) if colorbar: ax.figure.colorbar(axim) if not axis: ax.axis('off') if not axis and not colorbar: ax.figure.subplots_adjust(0, 0, 1, 1) else: ax.figure.tight_layout()
[docs]def probmap_simple(x, **kwargs): ''' Display a heatmap in color, but only displays the image content. The resulting op will be a RGBA image Tensor. It reduces to ``probmap`` having `colorbar` and `axis` off. See the documentation of ``probmap`` for available arguments. ''' # pylint: disable=unexpected-keyword-arg return probmap(x, colorbar=kwargs.pop('colorbar', False), axis=kwargs.pop('axis', False), figsize=kwargs.pop('figsize', (3, 3)), **kwargs)
# pylint: enable=unexpected-keyword-arg
[docs]def batch(func): ''' Make an autowrapped plot function (... -> RGBA tf.Tensor) work in a batch manner. Example: >>> p Tensor("p:0", shape=(batch_size, 16, 16, 4), dtype=uint8) >>> tfplot.contrib.batch(tfplot.contrib.probmap)(p) Tensor("probmap/PlotImages:0", shape=(batch_size, ?, ?, 4), dtype=uint8) ''' if not hasattr(func, '__unwrapped__'): raise ValueError("The given function is not wrapped with tfplot.autowrap()!") func = func.__unwrapped__ return autowrap(func, batch=True)