"""Deep Recurrent Attentive Writer.
"""
"""
Copyright 2017 Parag K. Mital. See also NOTICE.md.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import matplotlib.pyplot as plt
import tensorflow as tf
from cadl.datasets import MNIST, CIFAR10
from cadl.dataset_utils import create_input_pipeline
from cadl import utils, gif
import numpy as np
[docs]def linear(x, n_output):
"""Summary
Parameters
----------
x : TYPE
Description
n_output : TYPE
Description
Returns
-------
TYPE
Description
"""
w = tf.get_variable(
"w", [x.get_shape()[1], n_output],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable(
"b", [n_output], initializer=tf.constant_initializer(0.0))
return tf.add(tf.matmul(x, w), b)
[docs]def encoder(x, rnn, batch_size, state=None, n_enc=64, reuse=None):
"""Summary
Parameters
----------
x : TYPE
Description
rnn : TYPE
Description
batch_size : TYPE
Description
state : None, optional
Description
n_enc : int, optional
Description
reuse : None, optional
Description
Returns
-------
name : TYPE
Description
"""
with tf.variable_scope('encoder', reuse=reuse):
if state is None:
h_enc, state = rnn(x, rnn.zero_state(batch_size, tf.float32))
else:
h_enc, state = rnn(x, state)
return h_enc, state
[docs]def variational_layer(h_enc, noise, n_z=2, reuse=None):
"""Summary
Parameters
----------
h_enc : TYPE
Description
noise : TYPE
Description
n_z : int, optional
Description
reuse : None, optional
Description
Returns
-------
name : TYPE
Description
"""
with tf.variable_scope('variational', reuse=reuse):
# Equation 1: use the encoder to parameterize the mean of the approximate
# posterior distribution Q
with tf.variable_scope('mu', reuse=reuse):
h_z_mu = linear(h_enc, n_z)
# Equation 2: Similarly for the standard deviation
with tf.variable_scope('log_sigma', reuse=reuse):
h_z_log_sigma = linear(h_enc, n_z)
# sample z_t ~ q(Z_t | h_enc_t)
z_t = h_z_mu + tf.multiply(tf.exp(h_z_log_sigma), noise)
# return the sampled value from the latent distribution and its parameters
return z_t, h_z_mu, h_z_log_sigma
[docs]def decoder(z, rnn, batch_size, state=None, n_dec=64, reuse=None):
"""Summary
Parameters
----------
z : TYPE
Description
rnn : TYPE
Description
batch_size : TYPE
Description
state : None, optional
Description
n_dec : int, optional
Description
reuse : None, optional
Description
Returns
-------
name : TYPE
Description
"""
with tf.variable_scope('decoder', reuse=reuse):
if state is None:
h_dec, state = rnn(z, rnn.zero_state(batch_size, tf.float32))
else:
h_dec, state = rnn(z, state)
return h_dec, state
[docs]def create_attention_map(h_dec, reuse=None):
"""Summary
Parameters
----------
h_dec : TYPE
Description
reuse : None, optional
Description
Returns
-------
name : TYPE
Description
"""
with tf.variable_scope("attention", reuse=reuse):
p = linear(h_dec, 5)
g_tilde_x, g_tilde_y, log_sigma, log_delta_tilde, log_gamma = \
tf.split(p, 5, axis=1)
return g_tilde_x, g_tilde_y, log_sigma, log_delta_tilde, log_gamma
[docs]def create_filterbank(g_x, g_y, log_sigma_sq, log_delta, A=28, B=28, C=1, N=12):
"""summary
Parameters
----------
g_x : TYPE
Description
g_y : TYPE
Description
log_sigma_sq : TYPE
Description
log_delta : TYPE
Description
A : int, optional
Description
B : int, optional
Description
C : int, optional
Description
N : int, optional
Description
Returns
-------
name : TYPE
Description
Deleted Parameters
------------------
log_sigma : type
description
"""
with tf.name_scope("filterbank"):
# Equation 22 and 23
g_x = (A + 1) / 2 * (g_x + 1)
g_y = (B + 1) / 2 * (g_y + 1)
# The authors suggest to use a real-valued center and stride, meaning
# the center of this grid is not necessarily located directly on a
# pixel, but can be between pixels. To compute the stride, we use
# equation 24:
# Equation 24 delta = (max(A, B) - 1) / (N - 1) * tf.exp(log_delta)
delta = (max(A, B) - 1) / (N - 1) * tf.exp(log_delta)
# Note that we've converted our `log_delta` to `delta` by taking the
# exponential.
# To determine the mean location of the ith and jth filter of the N x N
# grid of filters, we can use the formulas from the paper, equations 19
# and 20. We'll create grid positions for the x and y positions
# independently. So for each observation in our mini batch, we'll have
# N number of positions for our x and our y grid positions, or 12 x 12
# = 144 grid positions in total for each observation in our mini batch.
# Equations 19 and 20
ns = tf.expand_dims(tf.cast(tf.range(N), tf.float32), 0)
mu_x = tf.reshape(g_x + (ns - N / 2 - 0.5) * delta, [-1, N, 1])
mu_y = tf.reshape(g_y + (ns - N / 2 - 0.5) * delta, [-1, N, 1])
# Finally we're ready to define the filterbank matrices `F_x` and `F_y`
# from equations 25 and 26. `F_x` and `F_x` require us to use $2 *
# \sigma^2$. So we'll calculate that first for each of our sigmas, one
# per observation in our mini batch. We take exponential of
# `log_sigma` to get $\sigma^2$ and then multiply by 2. we'll also
# reshape it to the number of observations we have in the first
# dimension, and create singleton dimensions for broadcasting them
# across our filterbanks.
sigma_sq = tf.reshape(tf.exp(log_sigma_sq), [-1, 1, 1])
# Now we'll create a range for our entire image:
xs = tf.reshape(tf.cast(tf.range(B), tf.float32), [1, 1, -1])
ys = tf.reshape(tf.cast(tf.range(A), tf.float32), [1, 1, -1])
# And then using broadcasting, we can calculate the Gaussian defining
# the filterbank:
F_x = tf.exp(-tf.square(xs - mu_x) / (2 * sigma_sq))
F_y = tf.exp(-tf.square(ys - mu_y) / (2 * sigma_sq))
# Finally we'll normalize the filterbank across each location so that
# the sum of the energy across the x and y locations sum to 1. We'll
# also ensure that we do not divide by zero by making sure the maximum
# value is at least epsilon. There will be one filterbank defining the
# horizontal filters, and another for the vertical filters. The
# horizontal filterbanks, `F_x[i, a]` will be N x B, so N filters
# across the B number of pixels. Same for the vertical ones,
# `F_y[j, b]`, there will be N filters across the A number of pixels.
# Normalize
epsilon = 1e-10
F_x = F_x / tf.maximum(tf.reduce_sum(F_x, 2, keep_dims=True), epsilon)
F_y = F_y / tf.maximum(tf.reduce_sum(F_y, 2, keep_dims=True), epsilon)
# return the filterbanks
return F_x, F_y
[docs]def filter_image(x, F_x, F_y, log_gamma, A, B, C, N, inverted=False):
"""Summary
Parameters
----------
x : TYPE
Description
F_x : TYPE
Description
F_y : TYPE
Description
log_gamma : TYPE
Description
A : TYPE
Description
B : TYPE
Description
C : TYPE
Description
N : TYPE
Description
inverted : bool, optional
Description
Returns
-------
name : TYPE
Description
Deleted Parameters
------------------
gamma : TYPE
Description
"""
with tf.name_scope("filter"):
# To filter the image, we'll want to transpose our filterbanks
# dimensions allowing to to multiply the image in the next step
# For the read operation, we transpose X (equation 27)
# For write, we transpose Y, and use inverse gamma (equation 29)
gamma = tf.exp(log_gamma)
if inverted:
F_y = tf.transpose(F_y, perm=[0, 2, 1])
gamma = 1.0 / gamma
# Now we left and right multiply the image in `x` by each filter
if C == 1:
glimpse = tf.matmul(F_y,
tf.matmul(tf.reshape(x, [-1, N, N]), F_x))
else:
x = tf.reshape(x, [-1, N, N, C])
xs = tf.split(x, C, axis=3)
glimpses = []
for x_i in xs:
glimpses.append(
tf.matmul(F_y, tf.matmul(tf.squeeze(x_i), F_x)))
glimpse = tf.concat(
[tf.expand_dims(x_i, -1) for x_i in glimpses], axis=3)
else:
F_x = tf.transpose(F_x, perm=[0, 2, 1])
# Now we left and right multiply the image in `x` by each filter
if C == 1:
glimpse = tf.matmul(F_y,
tf.matmul(tf.reshape(x, [-1, A, B]), F_x))
else:
x = tf.reshape(x, [-1, A, B, C])
xs = tf.split(x, C, axis=3)
glimpses = []
for x_i in xs:
glimpses.append(
tf.matmul(F_y, tf.matmul(tf.squeeze(x_i), F_x)))
glimpse = tf.concat(
[tf.expand_dims(x_i, -1) for x_i in glimpses], axis=3)
# Finally, we'll flatten the filtered image to a vector
glimpse = tf.reshape(glimpse,
[-1, np.prod(glimpse.get_shape().as_list()[1:])])
# And weight the filtered glimpses by gamma
return glimpse * tf.reshape(gamma, [-1, 1])
[docs]def read(x_t,
x_hat_t,
h_dec_t,
read_n=5,
A=28,
B=28,
C=1,
use_attention=True,
reuse=None):
"""Read from the input image, `x`, and reconstruction error image `x_hat`.
Optionally apply a filterbank w/ `use_attention`.
Parameters
----------
x_t : tf.Tensor
Input image to optionally filter
x_hat_t : tf.Tensor
Reconstruction error to optionally filter
h_dec_t : tf.Tensor
Output of the decoder of the network (could also be the encoder but the
authors suggest to use the decoder instead, see end of section 2.1)
read_n : int, optional
Description
A : int, optional
Description
B : int, optional
Description
C : int, optional
Description
use_attention : bool, optional
Description
reuse : None, optional
Description
Returns
-------
TYPE
Description
"""
with tf.variable_scope('read', reuse=reuse):
if use_attention:
# Use the decoder's output to create 5 measures to define the
# placement and characteristics of the filterbank
g_x_tilde, g_y_tilde, \
log_sigma_sq_tilde, log_delta_tilde, log_gamma_tilde = \
create_attention_map(h_dec_t, reuse=reuse)
# Now create the filterbank
F_x_tilde, F_y_tilde = create_filterbank(
g_x_tilde,
g_y_tilde,
log_sigma_sq_tilde,
log_delta_tilde,
N=read_n,
A=A,
B=B,
C=C)
# And apply the filterbanks to the input image
x_t = filter_image(x_t, F_x_tilde, F_y_tilde, log_gamma_tilde, A, B,
C, read_n)
# And similarly, apply the filterbanks to the error image
x_hat_t = filter_image(x_hat_t, F_x_tilde, F_y_tilde,
log_gamma_tilde, A, B, C, read_n)
# Equation 27, concat the two N x N patches from the image and the error
# image. If we aren't using attention, these are just the unfiltered
# images.
return tf.concat([x_t, x_hat_t], axis=1)
[docs]def write(h_dec_t, write_n=5, A=28, B=28, C=1, use_attention=True, reuse=None):
"""Summary
Parameters
----------
h_dec_t : TYPE
Description
write_n : int, optional
Description
A : int, optional
Description
B : int, optional
Description
C : int, optional
Description
use_attention : bool, optional
Description
reuse : None, optional
Description
Returns
-------
name : TYPE
Description
"""
# Equation 28: again, like in the read layer, we can add an additional
# nonlinearity here to enforce the characteristics of the final activation
# we expect to see. For instance, if our images are normalized 0 to 1,
# then we can use a sigmoid activation.
with tf.variable_scope("write", reuse=reuse):
# Next, we'll want to apply a few more additional operations if we're
# using attention
if use_attention:
w_t = linear(h_dec_t, write_n * write_n * C)
if C == 1:
w_t = tf.reshape(w_t, [-1, write_n, write_n])
else:
w_t = tf.reshape(w_t, [-1, write_n, write_n, C])
# Use the decoder's output to create 5 measures to define the
# placement and characteristics of the filterbank
g_x_hat, g_y_hat, log_sigma_sq_hat, log_delta_hat, log_gamma_hat = \
create_attention_map(h_dec_t, reuse=reuse)
# Now create the filterbank
F_x_hat, F_y_hat = create_filterbank(
g_x_hat,
g_y_hat,
log_sigma_sq_hat,
log_delta_hat,
N=write_n,
A=A,
B=B,
C=C)
# And apply the filterbanks to the input image, Equation 29
w_t = filter_image(
w_t,
F_x_hat,
F_y_hat,
log_gamma_hat,
A,
B,
C,
write_n,
inverted=True)
return w_t
else:
return linear(h_dec_t, A * B * C)
[docs]def binary_cross_entropy(t, o, eps=1e-10):
"""Summary
Parameters
----------
t : TYPE
Description
o : TYPE
Description
eps : float, optional
Description
Returns
-------
TYPE
Description
"""
return -(t * tf.log(o + eps) + (1.0 - t) * tf.log(1.0 - o + eps))
[docs]def create_model(
A=28, # img_h
B=28, # img_w
C=1, # img_c
T=16,
batch_size=100,
n_enc=128,
n_z=32,
n_dec=128,
read_n=12,
write_n=12):
"""<FRESHLY_INSERTED>"""
x = tf.placeholder(tf.float32, shape=[None, A * B * C], name='x')
noise = tf.placeholder(tf.float32, shape=[None, n_z], name='noise')
rnn_enc = tf.contrib.rnn.GRUCell(n_enc)
rnn_dec = tf.contrib.rnn.GRUCell(n_dec)
enc_state, dec_state = None, None
canvas = [tf.zeros([batch_size, A * B * C], name='c_0')]
h_enc_t = tf.zeros([batch_size, n_dec])
h_dec_t = tf.zeros([batch_size, n_dec])
reuse = False
z_mus, z_log_sigmas = [], []
for t_i in range(1, T):
# This assumes the input image is normalized between 0 - 1
x_hat_t = x - tf.nn.sigmoid(canvas[t_i - 1])
r_t = read(
x_t=x,
x_hat_t=x_hat_t,
h_dec_t=h_dec_t,
read_n=read_n,
A=A,
B=B,
C=C,
use_attention=True,
reuse=reuse)
h_enc_t, enc_state = encoder(
x=tf.concat([r_t, h_dec_t], axis=1),
rnn=rnn_enc,
batch_size=batch_size,
state=enc_state,
n_enc=n_enc,
reuse=reuse)
z_t, z_mu, z_log_sigma = variational_layer(
h_enc=h_enc_t, noise=noise, n_z=n_z, reuse=reuse)
z_mus.append(z_mu)
z_log_sigmas.append(z_log_sigma)
h_dec_t, dec_state = decoder(
z=z_t,
rnn=rnn_dec,
batch_size=batch_size,
state=dec_state,
n_dec=n_dec,
reuse=reuse)
w_t = write(
h_dec_t=h_dec_t,
write_n=write_n,
A=A,
B=B,
C=C,
use_attention=True,
reuse=reuse)
c_t = canvas[-1] + w_t
canvas.append(c_t)
reuse = True
x_recon = tf.nn.sigmoid(canvas[-1])
with tf.variable_scope('loss'):
loss_x = tf.reduce_mean(
tf.reduce_sum(binary_cross_entropy(x, x_recon), 1))
loss_zs = []
for z_mu, z_log_sigma in zip(z_mus, z_log_sigmas):
loss_zs.append(
tf.reduce_sum(
tf.square(z_mu) + tf.square(tf.exp(z_log_sigma)) -
2 * z_log_sigma, 1))
loss_z = tf.reduce_mean(0.5 * tf.reduce_sum(loss_zs, 0) - T * 0.5)
cost = loss_x + loss_z
return {
'x': x,
'loss_x': loss_x,
'loss_z': loss_z,
'canvas': [tf.nn.sigmoid(c_i) for c_i in canvas],
'cost': cost,
'recon': x_recon,
'noise': noise
}
[docs]def test_mnist():
A = 28 # img_h
B = 28 # img_w
C = 1
T = 10
n_enc = 256
n_z = 100
n_dec = 256
read_n = 5
write_n = 5
batch_size = 64
mnist = MNIST(split=[0.8, 0.1, 0.1])
n_examples = batch_size
zs = np.random.uniform(-1.0, 1.0, [4, n_z]).astype(np.float32)
zs = utils.make_latent_manifold(zs, n_examples)
# We create a session to use the graph
g = tf.Graph()
with tf.Session(graph=g) as sess:
draw = create_model(
A=A,
B=B,
C=C,
T=T,
batch_size=batch_size,
n_enc=n_enc,
n_z=n_z,
n_dec=n_dec,
read_n=read_n,
write_n=write_n)
opt = tf.train.AdamOptimizer(learning_rate=0.0001)
grads = opt.compute_gradients(draw['cost'])
for i, (g, v) in enumerate(grads):
if g is not None:
grads[i] = (tf.clip_by_norm(g, 5), v)
train_op = opt.apply_gradients(grads)
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
saver = tf.train.Saver()
# Fit all training data
batch_i = 0
n_epochs = 100
test_xs = mnist.test.images[:n_examples]
utils.montage(test_xs.reshape((-1, A, B)), 'test_xs.png')
for epoch_i in range(n_epochs):
for batch_xs, _ in mnist.train.next_batch(batch_size):
noise = np.random.randn(batch_size, n_z)
lx, lz = sess.run(
[draw['loss_x'], draw['loss_z'], train_op],
feed_dict={draw['x']: batch_xs,
draw['noise']: noise})[0:2]
print('x:', lx, 'z:', lz)
if batch_i % 1000 == 0:
# Plot example reconstructions
recon = sess.run(
draw['canvas'],
feed_dict={draw['x']: test_xs,
draw['noise']: noise})
recon = [utils.montage(r.reshape(-1, A, B)) for r in recon]
gif.build_gif(
recon,
cmap='gray',
saveto='manifold_%08d.gif' % batch_i)
saver.save(sess, './draw.ckpt', global_step=batch_i)
batch_i += 1
[docs]def train_dataset(ds,
A,
B,
C,
T=20,
n_enc=512,
n_z=200,
n_dec=512,
read_n=12,
write_n=12,
batch_size=100,
n_epochs=100):
if ds is None:
ds = CIFAR10(split=[0.8, 0.1, 0.1])
A, B, C = (32, 32, 3)
n_examples = batch_size
zs = np.random.uniform(-1.0, 1.0, [4, n_z]).astype(np.float32)
zs = utils.make_latent_manifold(zs, n_examples)
# We create a session to use the graph
g = tf.Graph()
with tf.Session(graph=g) as sess:
draw = create_model(
A=A,
B=B,
C=C,
T=T,
batch_size=batch_size,
n_enc=n_enc,
n_z=n_z,
n_dec=n_dec,
read_n=read_n,
write_n=write_n)
opt = tf.train.AdamOptimizer(learning_rate=0.0001)
# Clip gradients
grads = opt.compute_gradients(draw['cost'])
for i, (g, v) in enumerate(grads):
if g is not None:
grads[i] = (tf.clip_by_norm(g, 5), v)
train_op = opt.apply_gradients(grads)
# Add summary variables
tf.summary.scalar(name='cost', tensor=draw['cost'])
tf.summary.scalar(name='loss_z', tensor=draw['loss_z'])
tf.summary.scalar(name='loss_x', tensor=draw['loss_x'])
tf.summary.histogram(
name='recon_t0_histogram', values=draw['canvas'][0])
tf.summary.histogram(
name='recon_t-1_histogram', values=draw['canvas'][-1])
tf.summary.image(
name='recon_t0_image',
tensor=tf.reshape(draw['canvas'][0], (-1, A, B, C)),
max_outputs=2)
tf.summary.image(
name='recon_t-1_image',
tensor=tf.reshape(draw['canvas'][-1], (-1, A, B, C)),
max_outputs=2)
sums = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(logdir='draw/train')
valid_writer = tf.summary.FileWriter(logdir='draw/valid')
# Init
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
saver = tf.train.Saver()
# Fit all training data
batch_i = 0
test_xs = ds.test.images[:n_examples] / 255.0
utils.montage(test_xs.reshape((-1, A, B, C)), 'draw/test_xs.png')
for epoch_i in range(n_epochs):
for batch_xs, _ in ds.train.next_batch(batch_size):
noise = np.random.randn(batch_size, n_z)
cost, summary = sess.run(
[draw['cost'], sums, train_op],
feed_dict={
draw['x']: batch_xs / 255.0,
draw['noise']: noise
})[0:2]
train_writer.add_summary(summary, batch_i)
print('train cost:', cost)
if batch_i % 1000 == 0:
# Plot example reconstructions
recon = sess.run(
draw['canvas'],
feed_dict={draw['x']: test_xs,
draw['noise']: noise})
recon = [
utils.montage(r.reshape(-1, A, B, C)) for r in recon
]
gif.build_gif(
recon,
cmap='gray',
saveto='draw/manifold_%08d.gif' % batch_i)
saver.save(sess, './draw/draw.ckpt', global_step=batch_i)
batch_i += 1
# Run validation
if batch_i % 1000 == 0:
for batch_xs, _ in ds.valid.next_batch(batch_size):
noise = np.random.randn(batch_size, n_z)
cost, summary = sess.run(
[draw['cost'], sums],
feed_dict={
draw['x']: batch_xs / 255.0,
draw['noise']: noise
})[0:2]
valid_writer.add_summary(summary, batch_i)
print('valid cost:', cost)
batch_i += 1