Source code for tfplot.summary

''' Summary Op utilities. '''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import six

from . import ops
from . import wrapper


[docs]def plot(name, plot_func, in_tensors, collections=None, **kwargs): """ Create a TensorFlow op that outpus a `Summary` protocol buffer, to which a single plot operation is executed (i.e. image summary). Basically, it is a one-liner wrapper of ``tfplot.ops.plot()`` and ``tf.summary.image()`` calls. The generated `Summary` object contains single image summary value of the image of the plot drawn. Args: name: The name of scope for the generated ops and the summary op. Will also serve as a series name prefix in TensorBoard. plot_func: A python function or callable, specifying the plot operation as in :func:`tfplot.plot`. See the documentation at :func:`tfplot.plot`. in_tensors: A list of `Tensor` objects, as in :func:`~tfplot.plot`. collections: Optional list of ``ops.GraphKeys``. The collections to add the summary to. Defaults to ``[_ops.GraphKeys.SUMMARIES]``. kwargs: Optional keyword arguments passed to :func:`~tfplot.plot`. Returns: A scalar `Tensor` of type `string`. The serialized `Summary` protocol buffer (tensorflow operation). """ with tf.name_scope(name): im = ops.plot(plot_func, in_tensors, **kwargs) im = tf.expand_dims(im, axis=0) summary = tf.summary.image(name="ImageSummary", tensor=im, collections=collections) return summary
[docs]def plot_many(name, plot_func, in_tensors, max_outputs=3, collections=None, **kwargs): """ Create a TensorFlow op that outputs a `Summary` protocol buffer, where plots could be drawn in a batch manner. This is a batch version of :func:`tfplot.summary.plot`. Specifically, all the input tensors ``in_tensors`` to ``plot_func`` is assumed to have the same batch size. Tensors corresponding to a single batch element will be passed to ``plot_func`` as input. The resulting `Summary` contains multiple (up to ``max_outputs``) image summary values, each of which contains a plot rendered by ``plot_func``. Args: name: The name of scope for the generated ops and the summary op. Will also serve as a series name prefix in TensorBoard. plot_func: A python function or callable, specifying the plot operation as in :func:`tfplot.plot`. See the documentation at :func:`tfplot.plot`. in_tensors: A list of `Tensor` objects, the input to ``plot_func`` but each in a batch. max_outputs: Max number of batch elements to generate plots for. collections: Optional list of ``ops.GraphKeys``. The collections to add the sumamry to. Defaults to ``[_ops.GraphKeys.SUMMARIES]``. kwargs: Optional keyword arguments passed to :func:`~tfplot.plot`. Returns: A scalar `Tensor` of type `string`. The serialized `Summary` protocol buffer (tensorflow operation). """ with tf.name_scope(name=name) as scope: im_batch = ops.plot_many(plot_func, in_tensors, name=scope, max_outputs=max_outputs, **kwargs) summary = tf.summary.image(name="ImageSummary", tensor=im_batch, max_outputs=max_outputs, collections=collections) return summary
[docs]def wrap(plot_func, _sentinel=None, batch=False, name=None, **kwargs): ''' Wrap a plot function as a TensorFlow summary builder. It will return a python function that creates a TensorFlow op which evaluates to ``Summary`` protocol buffer with image. The resulting function (say ``summary_wrapped``) will have the following signature: .. code-block:: python summary_wrapped(name, tensor, # [more input tensors ...], max_outputs=3, collections=None) Examples: Given a plot function which returns a matplotlib `Figure`, >>> def figure_heatmap(data, cmap='jet'): >>> fig, ax = tfplot.subplots() >>> ax.imshow(data, cmap=cmap) >>> return fig we can wrap it as a summary builder function: >>> summary_heatmap = tfplot.summary.wrap(figure_heatmap, batch=True) Now, when building your computation graph, call it to build summary ops like ``tf.summary.image``: >>> heatmap_tensor <tf.Tensor 'heatmap_tensor:0' shape=(16, 128, 128) dtype=float32> >>> >>> summary_heatmap("heatmap/original", heatmap_tensor) >>> summary_heatmap("heatmap/cmap_gray", heatmap_tensor, cmap=gray) >>> summary_heatmap("heatmap/no_default_collections", heatmap_tensor, collections=[]) Args: plot_func: A python function or callable to wrap. See the documentation of :func:`tfplot.plot` for details. batch: If True, all the tensors passed as argument will be assumed to be batched. Default value is False. name: A default name for the plot op (optional). If not given, the name of ``plot_func`` will be used. kwargs: Optional keyword arguments that will be passed by default to :func:`~tfplot.plot`. Returns: A python function that will create a TensorFlow summary operation, passing the provided arguments into plot op. ''' if _sentinel is not None: raise RuntimeError("Invalid call: it can have only one unnamed argument, " + "please pass named arguments for batch, name, etc.") factory_fn = wrapper.autowrap(plot_func, batch=batch, name=name, **kwargs) def _summary_fn(summary_name, *args, **kwargs_call): if not isinstance(summary_name, six.string_types): raise TypeError("summary_name should be a string") plot_op = factory_fn(*args, **kwargs_call) if not batch: # add batch dimension expected by tf.summary.image plot_op = tf.expand_dims(plot_op, axis=0) return tf.summary.image(summary_name, plot_op, max_outputs=kwargs_call.pop('max_outputs', 3), collections=kwargs_call.pop('collections', None), ) _summary_fn.__name__ = 'summary_fn[%s]' % plot_func return _summary_fn
__all__ = ( 'wrap', 'plot', 'plot_many', )