Showcases of tfplot

This guide shows a quick tour of the tfplot library. Please skip the setup section of this document.

[5]:
import tfplot
tfplot.__version__
[5]:
'0.3.0.dev0'

Setup: Utilities and Data

In order to see the images generated from the plot ops, we introduce a simple utility function which takes a Tensor as an input and displays the resulting image after executing it in a TensorFlow session.

You may want to skip this section to have the showcase started.

[6]:
import tensorflow as tf
sess = tf.InteractiveSession()
[7]:
def execute_op_as_image(op):
    """
    Evaluate the given `op` and return the content PNG image as `PIL.Image`.

    - If op is a plot op (e.g. RGBA Tensor) the image or
      a list of images will be returned
    - If op is summary proto (i.e. `op` was a summary op),
      the image content will be extracted from the proto object.
    """
    print ("Executing: " + str(op))
    ret = sess.run(op)
    plt.close()

    if isinstance(ret, np.ndarray):
        if len(ret.shape) == 3:
            # single image
            return Image.fromarray(ret)
        elif len(ret.shape) == 4:
            return [Image.fromarray(r) for r in ret]
        else:
            raise ValueError("Invalid rank : %d" % len(ret.shape))

    elif isinstance(ret, (str, bytes)):
        from io import BytesIO
        s = tf.Summary()
        s.ParseFromString(ret)
        ims = []
        for i in range(len(s.value)):
            png_string = s.value[i].image.encoded_image_string
            im = Image.open(BytesIO(png_string))
            ims.append(im)
        plt.close()
        if len(ims) == 1: return ims[0]
        else: return ims

    else:
        raise TypeError("Unknown type: " + str(ret))

and some data:

[8]:
def fake_attention():
    import scipy.ndimage
    attention = np.zeros([16, 16], dtype=np.float32)
    attention[(11, 8)] = 1.0
    attention[(9, 9)] = 1.0
    attention = scipy.ndimage.filters.gaussian_filter(attention, sigma=1.5)
    return attention

sample_image = scipy.misc.face()
attention_map = fake_attention()

# display the data
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].imshow(sample_image); axs[0].set_title('image')
axs[1].imshow(attention_map, cmap='jet'); axs[1].set_title('attention')
plt.show()
../_images/guide_showcases_12_0.png

And we finally wrap these numpy values into TensorFlow ops:

[9]:
# the input to plot_op
image_tensor = tf.constant(sample_image, name='image')
attention_tensor = tf.constant(attention_map, name='attention')
print(image_tensor)
print(attention_tensor)
Tensor("image:0", shape=(768, 1024, 3), dtype=uint8)
Tensor("attention:0", shape=(16, 16), dtype=float32)

1. tfplot.autowrap: The Main End-User API

Use tfplot.autowrap to design a custom plot function of your own.

Decorator to define a TF op that draws plot

With tfplot.autowrap, you can wrap a python function that returns matplotlib.Figure (or AxesSubPlot) into TensorFlow ops, similar as in tf.py_func.

[10]:
@tfplot.autowrap
def plot_scatter(x, y):
    # NEVER use plt.XXX, or matplotlib.pyplot.
    # Use tfplot.subplots() instead of plt.subplots() to avoid thread-safety issues.
    fig, ax = tfplot.subplots(figsize=(3, 3))
    ax.scatter(x, y, color='green')
    return fig

x = tf.constant(np.arange(10), dtype=tf.float32)
y = tf.constant(np.arange(10) ** 2, dtype=tf.float32)
execute_op_as_image(plot_scatter(x, y))
Executing: Tensor("plot_scatter:0", shape=(?, ?, 4), dtype=uint8)
[10]:
../_images/guide_showcases_19_1.png

We can create subplots as well. Also, note that additional arguments (i.e. kwargs) other than Tensor arguments (i.e. positional arguments) can be passed.

[11]:
@tfplot.autowrap
def plot_image_and_attention(im, att, cmap=None):
    fig, axes = tfplot.subplots(1, 2, figsize=(7, 4))
    fig.suptitle('Image and Heatmap')
    axes[0].imshow(im)
    axes[1].imshow(att, cmap=cmap)
    return fig

op = plot_image_and_attention(sample_image, attention_map, cmap='jet')
execute_op_as_image(op)
Executing: Tensor("plot_image_and_attention:0", shape=(?, ?, 4), dtype=uint8)
[11]:
../_images/guide_showcases_21_1.png

Sometimes, it can be cumbersome to create instances of fig and ax. If you want to have them automatically created and injected, use a keyword argument named fig and/or ax:

[12]:
@tfplot.autowrap(figsize=(2, 2))
def plot_scatter(x, y, *, ax, color='red'):
    ax.set_title('x^2')
    ax.scatter(x, y, color=color)

x = tf.constant(np.arange(10), dtype=tf.float32)
y = tf.constant(np.arange(10) ** 2, dtype=tf.float32)
execute_op_as_image(plot_scatter(x, y))
Executing: Tensor("plot_scatter_1:0", shape=(?, ?, 4), dtype=uint8)
[12]:
../_images/guide_showcases_23_1.png

2. Wrapping Matplotlib’s AxesPlot or Seaborn Plot

You can use tfplot.autowrap (or raw APIs such as tfplot.plot, etc.) to plot anything by writing a customized plotting function on your own, but sometimes we may want to convert already existing plot functions from common libraries such as matplotlib and seaborn.

To do this, you can still use tfplot.autowrap.

Matplotlib

Matplotlib provides a variety of plot methods defined in the class AxesPlot (usually, ax).

[13]:
rs = np.random.RandomState(42)
x = rs.randn(100)
y = 2 * x + rs.randn(100)

fig, ax = plt.subplots()
ax.scatter(x, y)
ax.set_title("Created from matplotlib API")
plt.show()
../_images/guide_showcases_27_0.png

We can wrap the Axes.scatter() method as TensorFlow op as follows:

[14]:
from matplotlib.axes import Axes
tf_scatter = tfplot.autowrap(Axes.scatter, figsize=(4, 4))

plot_op = tf_scatter(x, y)
execute_op_as_image(plot_op)
Executing: Tensor("scatter:0", shape=(?, ?, 4), dtype=uint8)
[14]:
../_images/guide_showcases_29_1.png

Seaborn

Seaborn provides many useful axis plot functions that can be used out-of-box. Most of functions for drawing an AxesPlot will have the ax=... parameter.

See seaborn’s example gallery for interesting features seaborn provides.

[15]:
import seaborn as sns

assert sns.__version__ >= '0.8', \
    'Use seaborn >= v0.8.0, otherwise `import seaborn as sns` will affect the default matplotlib style.'

barplot: (Discrete) Probability Distribution

[16]:
# https://seaborn.pydata.org/generated/seaborn.barplot.html

y = np.random.RandomState(42).normal(size=[18])
y = np.exp(y) / np.exp(y).sum() # softmax
y = tf.constant(y, dtype=tf.float32)

ATARI_ACTIONS = [
    '⠀', '●', '↑', '→', '←', '↓', '↗', '↖', '↘', '↙',
    '⇑', '⇒', '⇐', '⇓', '⇗', '⇖', '⇘', '⇙' ]
x = tf.constant(ATARI_ACTIONS)

op = tfplot.autowrap(sns.barplot, palette='Blues_d')(x, y)
execute_op_as_image(op)
Executing: Tensor("barplot:0", shape=(?, ?, 4), dtype=uint8)
[16]:
../_images/guide_showcases_33_1.png
[17]:
y = np.random.RandomState(42).normal(size=[3, 18])
y = np.exp(y) / np.exp(y).sum(axis=1).reshape([-1, 1]) # softmax example-wise
y = tf.constant(y, dtype=tf.float32)

ATARI_ACTIONS = [
    '⠀', '●', '↑', '→', '←', '↓', '↗', '↖', '↘', '↙',
    '⇑', '⇒', '⇐', '⇓', '⇗', '⇖', '⇘', '⇙' ]
x = tf.broadcast_to(tf.constant(ATARI_ACTIONS), y.shape)

op = tfplot.autowrap(sns.barplot, palette='Blues_d', batch=True)(x, y)
for im in execute_op_as_image(op):
    display(im)
Executing: Tensor("barplot_1/PlotImages:0", shape=(3, ?, ?, 4), dtype=uint8)
../_images/guide_showcases_34_1.png
../_images/guide_showcases_34_2.png
../_images/guide_showcases_34_3.png

Heatmap

Let’s wrap seaborn’s heatmap function, as TensorFlow operation, with some additional default kwargs. This is very useful for visualization.

[18]:
# @seealso https://seaborn.pydata.org/examples/heatmap_annotation.html
tf_heatmap = tfplot.autowrap(sns.heatmap, figsize=(9, 6))

op = tf_heatmap(attention_map, cbar=True, annot=True, fmt=".2f")
execute_op_as_image(op)
Executing: Tensor("heatmap:0", shape=(?, ?, 4), dtype=uint8)
[18]:
../_images/guide_showcases_37_1.png

What if we don’t want axes and colorbars, but only the map itself? Compare to plain tf.summary.image, which just gives a grayscale image.

[19]:
# print only heatmap figures other than axis, colorbar, etc.
tf_heatmap = tfplot.autowrap(sns.heatmap, figsize=(4, 4), tight_layout=True,
                             cmap='jet', cbar=False, xticklabels=False, yticklabels=False)

op = tf_heatmap(attention_map, name='HeatmapImage')
execute_op_as_image(op)
Executing: Tensor("HeatmapImage:0", shape=(?, ?, 4), dtype=uint8)
[19]:
../_images/guide_showcases_39_1.png

And Many More!

This document has covered a basic usage of tfplot, but there are a few more:

  • tfplot.contrib: contains some off-the-shelf functions for creating plot operations that can be useful in practice, in few lines (without a hassle of writing function body). See [contrib.ipynb] for more tour of available APIs.
  • tfplot.plot(), tfplot.plot_many(), etc.: Low-level APIs.
  • tfplot.summary: One-liner APIs for creating TF summary operations.
[21]:
import tfplot.contrib

For example, probmap and probmap_simple create an image Tensor that visualizes a probability map:

[22]:
op = tfplot.contrib.probmap(attention_map, figsize=(4, 3))
execute_op_as_image(op)
Executing: Tensor("probmap:0", shape=(?, ?, 4), dtype=uint8)
[22]:
../_images/guide_showcases_45_1.png
[23]:
op = tfplot.contrib.probmap_simple(attention_map, figsize=(3, 3), vmin=0, vmax=1)
execute_op_as_image(op)
Executing: Tensor("probmap_1:0", shape=(?, ?, 4), dtype=uint8)
[23]:
../_images/guide_showcases_46_1.png

That’s all! Please take a look at API documentations and more examples if you are interested.