"""Convolutional/Variational autoencoder, including demonstration of
training such a network on MNIST, CelebNet and the film, "Sita Sings The Blues"
using an image pipeline.
"""
"""
Copyright 2017 Parag K. Mital. See also NOTICE.md.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import tensorflow as tf
import numpy as np
import os
from cadl.dataset_utils import create_input_pipeline
from cadl.datasets import CELEB
from cadl import utils
[docs]def encoder(x,
n_hidden=None,
dimensions=[],
filter_sizes=[],
convolutional=False,
activation=tf.nn.relu,
output_activation=tf.nn.sigmoid):
"""Summary
Parameters
----------
x : TYPE
Description
n_hidden : None, optional
Description
dimensions : list, optional
Description
filter_sizes : list, optional
Description
convolutional : bool, optional
Description
activation : TYPE, optional
Description
output_activation : TYPE, optional
Description
Returns
-------
name : TYPE
Description
"""
if convolutional:
x_tensor = utils.to_tensor(x)
else:
x_tensor = tf.reshape(tensor=x, shape=[-1, dimensions[0]])
dimensions = dimensions[1:]
current_input = x_tensor
Ws = []
hs = []
shapes = []
for layer_i, n_output in enumerate(dimensions):
with tf.variable_scope(str(layer_i)):
shapes.append(current_input.get_shape().as_list())
if convolutional:
h, W = utils.conv2d(
x=current_input,
n_output=n_output,
k_h=filter_sizes[layer_i],
k_w=filter_sizes[layer_i],
padding='SAME')
else:
h, W = utils.linear(x=current_input, n_output=n_output)
h = activation(h)
Ws.append(W)
hs.append(h)
current_input = h
shapes.append(h.get_shape().as_list())
with tf.variable_scope('flatten'):
flattened = utils.flatten(current_input)
with tf.variable_scope('hidden'):
if n_hidden:
h, W = utils.linear(flattened, n_hidden, name='linear')
h = activation(h)
else:
h = flattened
return {'z': h, 'Ws': Ws, 'hs': hs, 'shapes': shapes}
[docs]def decoder(z,
shapes,
n_hidden=None,
dimensions=[],
filter_sizes=[],
convolutional=False,
activation=tf.nn.relu,
output_activation=tf.nn.relu):
"""Summary
Parameters
----------
z : TYPE
Description
shapes : TYPE
Description
n_hidden : None, optional
Description
dimensions : list, optional
Description
filter_sizes : list, optional
Description
convolutional : bool, optional
Description
activation : TYPE, optional
Description
output_activation : TYPE, optional
Description
Returns
-------
name : TYPE
Description
"""
with tf.variable_scope('hidden/1'):
if n_hidden:
h = utils.linear(z, n_hidden, name='linear')[0]
h = activation(h)
else:
h = z
with tf.variable_scope('hidden/2'):
dims = shapes[0]
size = dims[1] * dims[2] * dims[3] if convolutional else dims[1]
h = utils.linear(h, size, name='linear')[0]
current_input = activation(h)
if convolutional:
current_input = tf.reshape(
current_input,
tf.stack(
[tf.shape(current_input)[0], dims[1], dims[2], dims[3]]))
Ws = []
hs = []
for layer_i, n_output in enumerate(dimensions[1:]):
with tf.variable_scope('decoder/{}'.format(layer_i)):
if convolutional:
shape = shapes[layer_i + 1]
h, W = utils.deconv2d(
x=current_input,
n_output_h=shape[1],
n_output_w=shape[2],
n_output_ch=shape[3],
n_input_ch=shapes[layer_i][3],
k_h=filter_sizes[layer_i],
k_w=filter_sizes[layer_i])
else:
h, W = utils.linear(x=current_input, n_output=n_output)
if (layer_i + 1) < len(dimensions):
h = activation(h)
else:
h = output_activation(h)
Ws.append(W)
hs.append(h)
current_input = h
z = tf.identity(current_input, name="x_tilde")
return {'x_tilde': current_input, 'Ws': Ws, 'hs': hs}
[docs]def variational_bayes(h, n_code):
"""Summary
Parameters
----------
h : TYPE
Description
n_code : TYPE
Description
Returns
-------
name : TYPE
Description
"""
z_mu = tf.nn.tanh(utils.linear(h, n_code, name='mu')[0])
z_log_sigma = 0.5 * tf.nn.tanh(utils.linear(h, n_code, name='log_sigma')[0])
# Sample from noise distribution p(eps) ~ N(0, 1)
epsilon = tf.random_normal(tf.stack([tf.shape(h)[0], n_code]))
# Sample from posterior
z = tf.add(z_mu, tf.multiply(epsilon, tf.exp(z_log_sigma)), name='z')
# -log(p(z)/q(z|x)), bits by coding.
# variational bound coding costs kl(p(z|x)||q(z|x))
# d_kl(q(z|x)||p(z))
loss_z = -0.5 * tf.reduce_sum(1.0 + 2.0 * z_log_sigma - tf.square(z_mu) -
tf.exp(2.0 * z_log_sigma), 1)
return z, z_mu, z_log_sigma, loss_z
[docs]def discriminator(x,
convolutional=True,
filter_sizes=[5, 5, 5, 5],
activation=tf.nn.relu,
n_filters=[100, 100, 100, 100]):
"""Summary
Parameters
----------
x : TYPE
Description
convolutional : bool, optional
Description
filter_sizes : list, optional
Description
activation : TYPE, optional
Description
n_filters : list, optional
Description
Returns
-------
name : TYPE
Description
"""
encoding = encoder(
x=x,
convolutional=convolutional,
dimensions=n_filters,
filter_sizes=filter_sizes,
activation=activation)
# flatten, then linear to 1 value
res = utils.flatten(encoding['z'], name='flatten')
if res.get_shape().as_list()[-1] > 1:
res = utils.linear(res, 1)[0]
return {
'logits': res,
'probs': tf.nn.sigmoid(res),
'Ws': encoding['Ws'],
'hs': encoding['hs']
}
[docs]def VAE(input_shape=[None, 784],
n_filters=[64, 64, 64],
filter_sizes=[4, 4, 4],
n_hidden=32,
n_code=2,
activation=tf.nn.tanh,
convolutional=False,
variational=False):
"""Summary
Parameters
----------
input_shape : list, optional
Description
n_filters : list, optional
Description
filter_sizes : list, optional
Description
n_hidden : int, optional
Description
n_code : int, optional
Description
activation : TYPE, optional
Description
convolutional : bool, optional
Description
variational : bool, optional
Description
Returns
-------
name : TYPE
Description
"""
# network input / placeholders for train (bn)
x = tf.placeholder(tf.float32, input_shape, 'x')
with tf.variable_scope('encoder'):
encoding = encoder(
x=x,
n_hidden=n_hidden,
convolutional=convolutional,
dimensions=n_filters,
filter_sizes=filter_sizes,
activation=activation)
if variational:
with tf.variable_scope('variational'):
z, z_mu, z_log_sigma, loss_z = variational_bayes(
h=encoding['z'], n_code=n_code)
else:
z = encoding['z']
loss_z = None
shapes = encoding['shapes'].copy()
shapes.reverse()
n_filters = n_filters.copy()
n_filters.reverse()
n_filters += [input_shape[-1]]
with tf.variable_scope('generator'):
decoding = decoder(
z=z,
shapes=shapes,
n_hidden=n_hidden,
dimensions=n_filters,
filter_sizes=filter_sizes,
convolutional=convolutional,
activation=activation)
x_tilde = decoding['x_tilde']
x_flat = utils.flatten(x)
x_tilde_flat = utils.flatten(x_tilde)
# -log(p(x|z))
loss_x = tf.reduce_sum(tf.squared_difference(x_flat, x_tilde_flat), 1)
return {
'loss_x': loss_x,
'loss_z': loss_z,
'x': x,
'z': z,
'Ws': encoding['Ws'],
'hs': decoding['hs'],
'x_tilde': x_tilde
}
[docs]def VAEGAN(input_shape=[None, 784],
n_filters=[64, 64, 64],
filter_sizes=[4, 4, 4],
n_hidden=32,
n_code=2,
activation=tf.nn.tanh,
convolutional=False,
variational=False):
"""Summary
Parameters
----------
input_shape : list, optional
Description
n_filters : list, optional
Description
filter_sizes : list, optional
Description
n_hidden : int, optional
Description
n_code : int, optional
Description
activation : TYPE, optional
Description
convolutional : bool, optional
Description
variational : bool, optional
Description
Returns
-------
name : TYPE
Description
"""
# network input / placeholders for train (bn)
x = tf.placeholder(tf.float32, input_shape, 'x')
z_samp = tf.placeholder(tf.float32, [None, n_code], 'z_samp')
with tf.variable_scope('encoder'):
encoding = encoder(
x=x,
n_hidden=n_hidden,
convolutional=convolutional,
dimensions=n_filters,
filter_sizes=filter_sizes,
activation=activation)
with tf.variable_scope('variational'):
z, z_mu, z_log_sigma, loss_z = variational_bayes(
h=encoding['z'], n_code=n_code)
shapes = encoding['shapes'].copy()
shapes.reverse()
n_filters_decoder = n_filters.copy()
n_filters_decoder.reverse()
n_filters_decoder += [input_shape[-1]]
with tf.variable_scope('generator'):
decoding_actual = decoder(
z=z,
shapes=shapes,
n_hidden=n_hidden,
convolutional=convolutional,
dimensions=n_filters_decoder,
filter_sizes=filter_sizes,
activation=activation)
with tf.variable_scope('generator', reuse=True):
decoding_sampled = decoder(
z=z_samp,
shapes=shapes,
n_hidden=n_hidden,
convolutional=convolutional,
dimensions=n_filters_decoder,
filter_sizes=filter_sizes,
activation=activation)
with tf.variable_scope('discriminator'):
D_real = discriminator(
x,
filter_sizes=filter_sizes,
n_filters=n_filters,
activation=activation)
with tf.variable_scope('discriminator', reuse=True):
D_fake = discriminator(
decoding_actual['x_tilde'],
filter_sizes=filter_sizes,
n_filters=n_filters,
activation=activation)
with tf.variable_scope('discriminator', reuse=True):
D_samp = discriminator(
decoding_sampled['x_tilde'],
filter_sizes=filter_sizes,
n_filters=n_filters,
activation=activation)
with tf.variable_scope('loss'):
# Weights influence of content/style of decoder
gamma = tf.placeholder(tf.float32, name='gamma')
# Discriminator_l Log Likelihood Loss
loss_D_llike = 0
for h_fake, h_real in zip(D_fake['hs'][3:], D_real['hs'][3:]):
loss_D_llike += tf.reduce_sum(0.5 * tf.squared_difference(
utils.flatten(h_fake), utils.flatten(h_real)), 1)
# GAN Loss
eps = 1e-12
loss_real = tf.reduce_sum(tf.log(D_real['probs'] + eps), 1)
loss_fake = tf.reduce_sum(tf.log(1 - D_fake['probs'] + eps), 1)
loss_samp = tf.reduce_sum(tf.log(1 - D_samp['probs'] + eps), 1)
loss_GAN = (loss_real + loss_fake + loss_samp) / 3.0
loss_enc = tf.reduce_mean(loss_z + loss_D_llike)
loss_gen = tf.reduce_mean(gamma * loss_D_llike - loss_GAN)
loss_dis = -tf.reduce_mean(loss_GAN)
return {
'x': x,
'z': z,
'x_tilde': decoding_actual['x_tilde'],
'z_samp': z_samp,
'x_tilde_samp': decoding_sampled['x_tilde'],
'loss_real': loss_real,
'loss_fake': loss_fake,
'loss_samp': loss_samp,
'loss_GAN': loss_GAN,
'loss_D_llike': loss_D_llike,
'loss_enc': loss_enc,
'loss_gen': loss_gen,
'loss_dis': loss_dis,
'gamma': gamma
}
[docs]def train_vaegan(files,
learning_rate=0.00001,
batch_size=64,
n_epochs=250,
n_examples=10,
input_shape=[218, 178, 3],
crop_shape=[64, 64, 3],
crop_factor=0.8,
n_filters=[100, 100, 100, 100],
n_hidden=None,
n_code=128,
convolutional=True,
variational=True,
filter_sizes=[3, 3, 3, 3],
activation=tf.nn.elu,
ckpt_name="vaegan.ckpt"):
"""Summary
Parameters
----------
files : TYPE
Description
learning_rate : float, optional
Description
batch_size : int, optional
Description
n_epochs : int, optional
Description
n_examples : int, optional
Description
input_shape : list, optional
Description
crop_shape : list, optional
Description
crop_factor : float, optional
Description
n_filters : list, optional
Description
n_hidden : int, optional
Description
n_code : int, optional
Description
convolutional : bool, optional
Description
variational : bool, optional
Description
filter_sizes : list, optional
Description
activation : TYPE, optional
Description
ckpt_name : str, optional
Description
No Longer Returned
------------------
name : TYPE
Description
"""
ae = VAEGAN(
input_shape=[None] + crop_shape,
convolutional=convolutional,
variational=variational,
n_filters=n_filters,
n_hidden=n_hidden,
n_code=n_code,
filter_sizes=filter_sizes,
activation=activation)
batch = create_input_pipeline(
files=files,
batch_size=batch_size,
n_epochs=n_epochs,
crop_shape=crop_shape,
crop_factor=crop_factor,
shape=input_shape)
zs = np.random.randn(4, n_code).astype(np.float32)
zs = utils.make_latent_manifold(zs, n_examples)
opt_enc = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
ae['loss_enc'],
var_list=[
var_i for var_i in tf.trainable_variables()
if var_i.name.startswith('encoder')
])
opt_gen = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
ae['loss_gen'],
var_list=[
var_i for var_i in tf.trainable_variables()
if var_i.name.startswith('generator')
])
opt_dis = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
ae['loss_dis'],
var_list=[
var_i for var_i in tf.trainable_variables()
if var_i.name.startswith('discriminator')
])
sess = tf.Session()
saver = tf.train.Saver()
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
tf.get_default_graph().finalize()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
saver.restore(sess, ckpt_name)
print("VAE model restored.")
t_i = 0
batch_i = 0
epoch_i = 0
equilibrium = 0.693
margin = 0.4
n_files = len(files)
test_xs = sess.run(batch) / 255.0
utils.montage(test_xs, 'test_xs.png')
try:
while not coord.should_stop() and epoch_i < n_epochs:
if batch_i % (n_files // batch_size) == 0:
batch_i = 0
epoch_i += 1
print('---------- EPOCH:', epoch_i)
batch_i += 1
batch_xs = sess.run(batch) / 255.0
batch_zs = np.random.randn(batch_size, n_code).astype(np.float32)
real_cost, fake_cost, _ = sess.run(
[ae['loss_real'], ae['loss_fake'], opt_enc],
feed_dict={ae['x']: batch_xs,
ae['gamma']: 0.5})
real_cost = -np.mean(real_cost)
fake_cost = -np.mean(fake_cost)
print('real:', real_cost, '/ fake:', fake_cost)
gen_update = True
dis_update = True
if real_cost > (equilibrium + margin) or \
fake_cost > (equilibrium + margin):
gen_update = False
if real_cost < (equilibrium - margin) or \
fake_cost < (equilibrium - margin):
dis_update = False
if not (gen_update or dis_update):
gen_update = True
dis_update = True
if gen_update:
sess.run(
opt_gen,
feed_dict={
ae['x']: batch_xs,
ae['z_samp']: batch_zs,
ae['gamma']: 0.5
})
if dis_update:
sess.run(
opt_dis,
feed_dict={
ae['x']: batch_xs,
ae['z_samp']: batch_zs,
ae['gamma']: 0.5
})
if batch_i % 50 == 0:
# Plot example reconstructions from latent layer
recon = sess.run(ae['x_tilde'], feed_dict={ae['z']: zs})
print('recon:', recon.min(), recon.max())
recon = np.clip(recon / recon.max(), 0, 1)
utils.montage(
recon.reshape([-1] + crop_shape),
'imgs/manifold_%08d.png' % t_i)
# Plot example reconstructions
recon = sess.run(ae['x_tilde'], feed_dict={ae['x']: test_xs})
print('recon:', recon.min(), recon.max())
recon = np.clip(recon / recon.max(), 0, 1)
utils.montage(
recon.reshape([-1] + crop_shape),
'imgs/reconstruction_%08d.png' % t_i)
t_i += 1
if batch_i % 100 == 0:
# Save the variables to disk.
save_path = saver.save(
sess,
ckpt_name,
global_step=batch_i,
write_meta_graph=False)
print("Model saved in file: %s" % save_path)
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
# One of the threads has issued an exception. So let's tell all the
# threads to shutdown.
coord.request_stop()
# Wait until all threads have finished.
coord.join(threads)
# Clean up the session.
sess.close()
[docs]def test_celeb(n_epochs=100,
filter_sizes=[3, 3, 3, 3],
n_filters=[100, 100, 100, 100],
crop_shape=[100, 100, 3]):
"""Summary
Parameters
----------
n_epochs : int, optional
Description
No Longer Returned
------------------
name : TYPE
Description
"""
files = CELEB()
train_vaegan(
files=files,
batch_size=64,
n_epochs=n_epochs,
crop_shape=crop_shape,
crop_factor=0.8,
input_shape=[218, 178, 3],
convolutional=True,
variational=True,
n_filters=n_filters,
n_hidden=None,
n_code=64,
filter_sizes=filter_sizes,
activation=tf.nn.elu,
ckpt_name='./celeb.ckpt')
[docs]def test_sita(n_epochs=100):
"""Summary
Parameters
----------
n_epochs : int, optional
Description
No Longer Returned
------------------
name : TYPE
Description
"""
if not os.path.exists('sita'):
os.system(
'wget http://ossguy.com/sita/Sita_Sings_the_Blues_640x360_XviD.avi')
os.mkdir('sita')
os.system('ffmpeg -i Sita_Sings_the_Blues_640x360_XviD.avi -r 60 -f' +
' image2 -s 160x90 sita/sita-%08d.jpg')
files = [os.path.join('sita', f) for f in os.listdir('sita')]
train_vaegan(
files=files,
batch_size=64,
n_epochs=n_epochs,
crop_shape=[90, 160, 3],
crop_factor=1.0,
input_shape=[218, 178, 3],
convolutional=True,
variational=True,
n_filters=[100, 100, 100, 100, 100],
n_hidden=250,
n_code=100,
filter_sizes=[3, 3, 3, 3, 2],
activation=tf.nn.elu,
ckpt_name='./sita.ckpt')
if __name__ == '__main__':
test_celeb()