tfplot
¶
Wrapper functions¶
-
tfplot.
autowrap
(*args, **kwargs)[source]¶ Wrap a function as a TensorFlow operation similar to
tfplot.wrap()
(as a decorator or with normal function call), but provides with additional features such as auto-creating matplotlib figures.- (
fig
,ax
) matplotlib objects are automatically created and injected given that plot_func has a keyword argument namedfig
and/or`ax
. In such cases, we do not need to manually calltfplot.subplots()
to create matplotlib figure/axes objects. If a manual creation offig, ax
is forced, please consider usingtfplot.wrap()
instead. - It can automatically handle return values of the provided plot_func
function. If it returns nothing (None) but
fig
was automatically injected then the resulting figure will be drawn, or returnsAxes
then the associatedFigure
will be used.
Example
>>> @tfplot.autowrap(figsize=(3, 3)) >>> def plot_imshow(img, *, fig, ax): >>> ax.imshow(img) >>> >>> plot_imshow(an_image_tensor) Tensor("plot_imshow:0", shape=(?, ?, 4), dtype=uint8)
Parameters: - plot_func – A python function or callable to wrap. See the documentation
of
tfplot.plot()
for details. Additionally, if this function has a parameter namedfig
and/orax
, new instances ofFigure
and/orAxesSubplot
will be created and passed. - batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
- name – A default name for the operation (optional). If not given, the
name of
plot_func
will be used. - figsize – The figure size for the figure to be created.
- tight_layout – If True, the resulting figure will have no margins for
axis. Equivalent to calling
fig.subplots_adjust(0, 0, 1, 1)
. - kwargs_default – An optimal kwargs that will be passed by default to
plot_func
when executed inside a TensorFlow graph.
- (
-
tfplot.
wrap
(*args, **kwargs)[source]¶ Wrap a plot function as a TensorFlow operation. It will return a python function that creates a TensorFlow plot operation applying the arguments as input. It can be also used as a decorator.
For example:
>>> @tfplot.wrap >>> def plot_imshow(img): >>> fig, ax = tfplot.subplots() >>> ax.imshow(img) >>> return fig >>> >>> plot_imshow(an_image_tensor) Tensor("plot_imshow:0", shape=(?, ?, 4), dtype=uint8)
Or, if
plot_func
is a python function that takes numpy arrays as input and draw a plot by returning a matplotlib Figure, we can wrap this function as a Tensor factory, such as:>>> tf_plot = tfplot.wrap(plot_func, name="MyPlot", batch=True) >>> # x, y = get_batch_inputs(batch_size=4, ...) >>> plot_x = tf_plot(x) Tensor("MyPlot:0", shape=(4, ?, ?, 4), dtype=uint8) >>> plot_y = tf_plot(y) Tensor("MyPlot_1:0", shape=(4, ?, ?, 4), dtype=uint8)
Parameters: - plot_func – A python function or callable to wrap. See the documentation
of
tfplot.plot()
for details. - batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
- name – A default name for the operation (optional). If not given, the
name of
plot_func
will be used. - kwargs – An optional kwargs that will be passed by default to
plot_func
when executed inside a TensorFlow graph.
Returns: A python function that will create a TensorFlow plot operation, passing the provided arguments.
- plot_func – A python function or callable to wrap. See the documentation
of
-
tfplot.
wrap_axesplot
(axesplot_func, _sentinel=None, batch=False, name=None, figsize=None, tight_layout=False, **kwargs)[source]¶ DEPRECATED: Use
tfplot.autowrap()
instead. Will be removed in the next version.Wrap an axesplot function as a TensorFlow operation. It will return a python function that creates a TensorFlow plot operation applying the arguments as input.
An axesplot function
axesplot_func
can be either:- an unbounded method of matplotlib Axes (or AxesSubplot) class,
such as
Axes.scatter()
andAxes.text()
, etc, or - a simple python function that takes the named argument
ax
, of type Axes or AxesSubplot, on which the plot will be drawn. Some good examples of this family includesseaborn.heatmap(ax=...)
.
The resulting function can be used as a Tensor factory. When the created tensorflow plot op is being executed, a new matplotlib figure which consists of a single AxesSubplot will be created, and the axes plot will be used as an argument for
axesplot_func
. For example,>>> import seaborn.apionly as sns >>> tf_heatmap = tfplot.wrap_axesplot(sns.heatmap, name="HeatmapPlot", figsize=(4, 4), cmap='jet')
>>> plot_op = tf_heatmap(attention_map, cmap) Tensor(HeatmapPlot:0", shape=(?, ?, 4), dtype=uint8)
Parameters: - axesplot_func – An unbounded method of matplotlib Axes or AxesSubplot, or a python function or callable which has the ax parameter for specifying the axis to draw on.
- batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
- name – A default name for the operation (optional). If not given, the
name of
axesplot_func
will be used. - figsize – The figure size for the figure to be created.
- tight_layout – If True, the resulting figure will have no margins for
axis. Equivalent to calling
fig.subplots_adjust(0, 0, 1, 1)
. - kwargs – An optional kwargs that will be passed by default to
axesplot_func
.
Returns: A python function that will create a TensorFlow plot operation, passing the provied arguments and a new instance of AxesSubplot into
axesplot_func
.- an unbounded method of matplotlib Axes (or AxesSubplot) class,
such as
Raw Plot Ops¶
-
tfplot.
plot
(plot_func, in_tensors, name='Plot', **kwargs)[source]¶ Create a TensorFlow op which draws plot in an image. The resulting image is in a 3-D uint8 tensor.
Given a python function
plot_func
, which takes numpy arrays as its inputs (the evaluations ofin_tensors
) and returns a matplotlib Figure object as its outputs, wrap this function as a TensorFlow op. The returning figure will be rendered as a RGB-A image upon execution.Parameters: - plot_func – a python function or callable
The function which accepts numpy ndarray objects as an argument
that match the corresponding tf.Tensor objects in
in_tensors
. It should return a new instance ofmatplotlib.figure.Figure
, which contains the resulting plot image. - in_tensors – A list of tf.Tensor objects.
- name – A name for the operation (optional).
- kwargs – Additional keyword arguments passed to
plot_func
(optional).
Returns: A single uint8 Tensor of shape
(?, ?, 4)
, containing the plot image thatplot_func
computes.- plot_func – a python function or callable
The function which accepts numpy ndarray objects as an argument
that match the corresponding tf.Tensor objects in
-
tfplot.
plot_many
(plot_func, in_tensors, name='PlotMany', max_outputs=None, **kwargs)[source]¶ A batch version of
plot
. Create a TensorFlow op which draws a plot for each image. The resulting images are given in a 4-D uint8 Tensor of shape[batch_size, height, width, 4]
.Parameters: - plot_func – A python function or callable, which accepts numpy
ndarray objects as an argument that match the corresponding
tf.Tensor objects in
in_tensors
. It should return a new instance ofmatplotlib.figure.Figure
, which contains the resulting plot image. The shape (height, width) of generated figure for each plot should be same. - in_tensors – A list of tf.Tensor objects.
- name – A name for the operation (optional).
- max_outputs – Max number of batch elements to generate plots for (optional).
- kwargs – Additional keyword arguments passed to plot_func (optional).
Returns: A single uint8 Tensor of shape
(B, ?, ?, 4)
, containing the B plot images, each of which is computed byplot_func
, where B equalsbatch_size
, the number of batch elements in the each tensor fromin_tensors
, ormax_outputs
(whichever is smaller).- plot_func – A python function or callable, which accepts numpy
ndarray objects as an argument that match the corresponding
tf.Tensor objects in