tfplot

Wrapper functions

tfplot.autowrap(*args, **kwargs)[source]

Wrap a function as a TensorFlow operation similar to tfplot.wrap() (as a decorator or with normal function call), but provides with additional features such as auto-creating matplotlib figures.

  • (fig, ax) matplotlib objects are automatically created and injected given that plot_func has a keyword argument named fig and/or `ax. In such cases, we do not need to manually call tfplot.subplots() to create matplotlib figure/axes objects. If a manual creation of fig, ax is forced, please consider using tfplot.wrap() instead.
  • It can automatically handle return values of the provided plot_func function. If it returns nothing (None) but fig was automatically injected then the resulting figure will be drawn, or returns Axes then the associated Figure will be used.

Example

>>> @tfplot.autowrap(figsize=(3, 3))
>>> def plot_imshow(img, *, fig, ax):
>>>    ax.imshow(img)
>>>
>>> plot_imshow(an_image_tensor)
Tensor("plot_imshow:0", shape=(?, ?, 4), dtype=uint8)
Parameters:
  • plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details. Additionally, if this function has a parameter named fig and/or ax, new instances of Figure and/or AxesSubplot will be created and passed.
  • batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
  • name – A default name for the operation (optional). If not given, the name of plot_func will be used.
  • figsize – The figure size for the figure to be created.
  • tight_layout – If True, the resulting figure will have no margins for axis. Equivalent to calling fig.subplots_adjust(0, 0, 1, 1).
  • kwargs_default – An optimal kwargs that will be passed by default to plot_func when executed inside a TensorFlow graph.
tfplot.wrap(*args, **kwargs)[source]

Wrap a plot function as a TensorFlow operation. It will return a python function that creates a TensorFlow plot operation applying the arguments as input. It can be also used as a decorator.

For example:

>>> @tfplot.wrap
>>> def plot_imshow(img):
>>>    fig, ax = tfplot.subplots()
>>>    ax.imshow(img)
>>>    return fig
>>>
>>> plot_imshow(an_image_tensor)
Tensor("plot_imshow:0", shape=(?, ?, 4), dtype=uint8)

Or, if plot_func is a python function that takes numpy arrays as input and draw a plot by returning a matplotlib Figure, we can wrap this function as a Tensor factory, such as:

>>> tf_plot = tfplot.wrap(plot_func, name="MyPlot", batch=True)
>>> # x, y = get_batch_inputs(batch_size=4, ...)
>>> plot_x = tf_plot(x)
Tensor("MyPlot:0", shape=(4, ?, ?, 4), dtype=uint8)
>>> plot_y = tf_plot(y)
Tensor("MyPlot_1:0", shape=(4, ?, ?, 4), dtype=uint8)
Parameters:
  • plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details.
  • batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
  • name – A default name for the operation (optional). If not given, the name of plot_func will be used.
  • kwargs – An optional kwargs that will be passed by default to plot_func when executed inside a TensorFlow graph.
Returns:

A python function that will create a TensorFlow plot operation, passing the provided arguments.

tfplot.wrap_axesplot(axesplot_func, _sentinel=None, batch=False, name=None, figsize=None, tight_layout=False, **kwargs)[source]

DEPRECATED: Use tfplot.autowrap() instead. Will be removed in the next version.

Wrap an axesplot function as a TensorFlow operation. It will return a python function that creates a TensorFlow plot operation applying the arguments as input.

An axesplot function axesplot_func can be either:

  • an unbounded method of matplotlib Axes (or AxesSubplot) class, such as Axes.scatter() and Axes.text(), etc, or
  • a simple python function that takes the named argument ax, of type Axes or AxesSubplot, on which the plot will be drawn. Some good examples of this family includes seaborn.heatmap(ax=...).

The resulting function can be used as a Tensor factory. When the created tensorflow plot op is being executed, a new matplotlib figure which consists of a single AxesSubplot will be created, and the axes plot will be used as an argument for axesplot_func. For example,

>>> import seaborn.apionly as sns
>>> tf_heatmap = tfplot.wrap_axesplot(sns.heatmap, name="HeatmapPlot", figsize=(4, 4), cmap='jet')
>>> plot_op = tf_heatmap(attention_map, cmap)
Tensor(HeatmapPlot:0", shape=(?, ?, 4), dtype=uint8)
Parameters:
  • axesplot_func – An unbounded method of matplotlib Axes or AxesSubplot, or a python function or callable which has the ax parameter for specifying the axis to draw on.
  • batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
  • name – A default name for the operation (optional). If not given, the name of axesplot_func will be used.
  • figsize – The figure size for the figure to be created.
  • tight_layout – If True, the resulting figure will have no margins for axis. Equivalent to calling fig.subplots_adjust(0, 0, 1, 1).
  • kwargs – An optional kwargs that will be passed by default to axesplot_func.
Returns:

A python function that will create a TensorFlow plot operation, passing the provied arguments and a new instance of AxesSubplot into axesplot_func.

Raw Plot Ops

tfplot.plot(plot_func, in_tensors, name='Plot', **kwargs)[source]

Create a TensorFlow op which draws plot in an image. The resulting image is in a 3-D uint8 tensor.

Given a python function plot_func, which takes numpy arrays as its inputs (the evaluations of in_tensors) and returns a matplotlib Figure object as its outputs, wrap this function as a TensorFlow op. The returning figure will be rendered as a RGB-A image upon execution.

Parameters:
  • plot_func – a python function or callable The function which accepts numpy ndarray objects as an argument that match the corresponding tf.Tensor objects in in_tensors. It should return a new instance of matplotlib.figure.Figure, which contains the resulting plot image.
  • in_tensors – A list of tf.Tensor objects.
  • name – A name for the operation (optional).
  • kwargs – Additional keyword arguments passed to plot_func (optional).
Returns:

A single uint8 Tensor of shape (?, ?, 4), containing the plot image that plot_func computes.

tfplot.plot_many(plot_func, in_tensors, name='PlotMany', max_outputs=None, **kwargs)[source]

A batch version of plot. Create a TensorFlow op which draws a plot for each image. The resulting images are given in a 4-D uint8 Tensor of shape [batch_size, height, width, 4].

Parameters:
  • plot_func – A python function or callable, which accepts numpy ndarray objects as an argument that match the corresponding tf.Tensor objects in in_tensors. It should return a new instance of matplotlib.figure.Figure, which contains the resulting plot image. The shape (height, width) of generated figure for each plot should be same.
  • in_tensors – A list of tf.Tensor objects.
  • name – A name for the operation (optional).
  • max_outputs – Max number of batch elements to generate plots for (optional).
  • kwargs – Additional keyword arguments passed to plot_func (optional).
Returns:

A single uint8 Tensor of shape (B, ?, ?, 4), containing the B plot images, each of which is computed by plot_func, where B equals batch_size, the number of batch elements in the each tensor from in_tensors, or max_outputs (whichever is smaller).