Source code for cadl.pixelcnn

"""Conditional Gated Pixel CNN.
Thanks to many reference implementations

Copyright 2017 Parag K. Mital.  See also

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License.
import tensorflow as tf
import numpy as np
import os
from cadl import dataset_utils as dsu

[docs]def gated_conv2d(X, K_h, K_w, K_c, strides=[1, 1, 1, 1], padding='SAME', mask=None, cond_h=None, vertical_h=None): """Summary Parameters ---------- X : TYPE Description K_h : TYPE Description K_w : TYPE Description K_c : TYPE Description strides : list, optional Description padding : str, optional Description mask : None, optional Description cond_h : None, optional Description vertical_h : None, optional Description Returns ------- TYPE Description """ with tf.variable_scope('masked_cnn'): W = tf.get_variable( name='W', shape=[K_h, K_w, X.shape[-1].value, K_c * 2], initializer=tf.contrib.layers.xavier_initializer_conv2d()) b = tf.get_variable( name='b', shape=[K_c * 2], initializer=tf.zeros_initializer()) if mask is not None: W = tf.multiply(mask, W) # Initial convolution with masked kernel h = tf.nn.bias_add( tf.nn.conv2d(X, W, strides=strides, padding=padding), b) # Combine the horizontal stack's pre-activations to our hidden embedding before # applying the split nonlinearities. Check Figure 2 for details. if vertical_h is not None: with tf.variable_scope('vtoh'): W_vtoh = tf.get_variable( name='W', shape=[1, 1, K_c * 2, K_c * 2], initializer=tf.contrib.layers.xavier_initializer_conv2d()) b_vtob = tf.get_variable( name='b', shape=[K_c * 2], initializer=tf.zeros_initializer()) h = tf.add(h, tf.nn.bias_add( tf.nn.conv2d( vertical_h, W_vtoh, strides=strides, padding=padding), b_vtob)) # Condition on some given data if cond_h is not None: with tf.variable_scope('conditioning'): V = tf.get_variable( name='V', shape=[cond_h.shape[1].value, K_c], initializer=tf.contrib.layers.xavier_initializer_conv2d()) b = tf.get_variable( name='b', shape=[K_c], initializer=tf.zeros_initializer()) h = tf.add( h, tf.reshape( tf.nn.bias_add(tf.matmul(cond_h, V), b), tf.shape(X)[0:3] + [K_c]), name='h') with tf.variable_scope('gated_cnn'): # Finally slice and apply gated multiplier h_f = tf.slice(h, [0, 0, 0, 0], [-1, -1, -1, K_c]) h_g = tf.slice(h, [0, 0, 0, K_c], [-1, -1, -1, K_c]) y = tf.multiply(tf.nn.tanh(h_f), tf.sigmoid(h_g)) return y, h
[docs]def build_conditional_pixel_cnn_model(B=None, H=32, W=32, C=3, n_conditionals=None): """Conditional Gated Pixel CNN Model. From the paper -------------- van den Oord, A., Kalchbrenner, N., Vinyals, O., Espeholt, L., Graves, A., & Kavukcuoglu, K. (2016). Conditional Image Generation with PixelCNN Decoders. Implements most of the paper, except for the autoencoder, triplet loss of face embeddings, and pad/crop/shift ops for convolution (as it is not as clear IMO from a pedagogical point of view). Parameters ---------- B : None, optional Description H : int, optional Description W : int, optional Description C : int, optional Description n_conditionals : None, optional Description Returns ------- TYPE Description """ n_conditionals = None X = tf.placeholder(name='X', dtype=tf.uint8, shape=[None, H, W, C]) X_ = (tf.cast(X, tf.float32) - 127.5) / 2.0 n_layers = 10 D = 256 fmaps = 64 K_hs = [7] + [3] * (n_layers - 1) K_ws = [7 * C] + [3 * C] * (n_layers - 1) K_cs = [fmaps] * n_layers if n_conditionals is not None: cond_h = tf.placeholder( name='cond_h', dtype=tf.float32, shape=[None, n_conditionals]) else: cond_h = None vertical_X = X_ horizontal_X = X_ for K_h, K_w, K_c, layer_i in zip(K_hs, K_ws, K_cs, range(n_layers)): # Create two masks: one for the first layer (a) and another for all # other layers (b). Really dumb names but am just following the paper. # See Figure 2 of Pixel Recurrent Neural Networks for more info. if layer_i == 0: mask = np.ones((K_h, K_w, 1, 1), dtype=np.float32) mask[(K_h // 2 + 1):, :, :, :] = 0.0 mask[K_h // 2, K_w // 2:, :, :] = 0.0 else: mask = np.ones((K_h, K_w, 1, 1), dtype=np.float32) mask[(K_h // 2 + 1):, :, :, :] = 0.0 mask[K_h // 2, (K_w // 2 + 1):, :, :] = 0.0 with tf.variable_scope('layer_{}'.format(layer_i)): # Vertical layer with tf.variable_scope('vertical'): vertical_Y, vertical_h = gated_conv2d( vertical_X, K_h, K_w, K_c, mask=mask, cond_h=cond_h) # Horizontal layer with tf.variable_scope('horizontal'): # Gated convolution adding in vertical stack information horizontal_Y, horizontal_h = gated_conv2d( horizontal_X, 1, K_w, K_c, mask=mask[K_h // 2, :, :, :], vertical_h=vertical_h, cond_h=cond_h) # 1x1 to reduce channels with tf.variable_scope('1x1'): W_1x1 = tf.get_variable( name='W', shape=[1, 1, K_c, D], initializer=tf.contrib.layers.xavier_initializer_conv2d()) b_1x1 = tf.get_variable( name='b', shape=[D], initializer=tf.ones_initializer()) horizontal_Y = tf.nn.bias_add( tf.nn.conv2d( horizontal_Y, W_1x1, strides=[1, 1, 1, 1], padding='SAME'), b_1x1) # Add Residual if layer_i > 0: with tf.variable_scope('residual'): horizontal_Y = tf.add(horizontal_X, horizontal_Y) vertical_X = vertical_Y horizontal_X = horizontal_Y # ReLu followed by 1x1 conv for 2 layers: # 1x1 to reduce channels Y = horizontal_X with tf.variable_scope('output/1x1_1'): W_1x1 = tf.get_variable( name='W', shape=[1, 1, D, D], initializer=tf.contrib.layers.xavier_initializer_conv2d()) b_1x1 = tf.get_variable( name='b', shape=[D], initializer=tf.ones_initializer()) Y = tf.nn.relu( tf.nn.bias_add( tf.nn.conv2d(Y, W_1x1, strides=[1, 1, 1, 1], padding='SAME'), b_1x1)) with tf.variable_scope('output/1x1_2'): W_1x1 = tf.get_variable( name='W', shape=[1, 1, D, D * C], initializer=tf.contrib.layers.xavier_initializer_conv2d()) b_1x1 = tf.get_variable( name='b', shape=[D * C], initializer=tf.ones_initializer()) Y = tf.nn.bias_add( tf.nn.conv2d(Y, W_1x1, strides=[1, 1, 1, 1], padding='SAME'), b_1x1) Y = tf.reshape(Y, [-1, D]) loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=Y, labels=tf.cast(tf.reshape(X, [-1]), tf.int32)) cost = tf.reduce_mean(loss) preds = tf.nn.softmax(Y) sampled_preds = tf.multinomial(Y, num_samples=1) tf.summary.image('actual', X) tf.summary.image('preds', tf.reshape( tf.cast(tf.argmax(Y, axis=1), tf.uint8), (-1, H, W, C))) tf.summary.histogram('loss', loss) tf.summary.scalar('cost', cost) summaries = tf.summary.merge_all() return { 'cost': cost, 'X': X, 'preds': preds, 'sampled_preds': sampled_preds, 'summaries': summaries }
[docs]def train_tiny_imagenet(ckpt_path='pixelcnn', n_epochs=1000, save_step=100, write_step=25, B=32, H=64, W=64, C=3): """Summary Parameters ---------- ckpt_path : str, optional Description n_epochs : int, optional Description save_step : int, optional Description write_step : int, optional Description B : int, optional Description H : int, optional Description W : int, optional Description C : int, optional Description """ ckpt_name = os.path.join(ckpt_path, 'pixelcnn.ckpt') with tf.Graph().as_default(), tf.Session() as sess: # Not actually conditioning on anything here just using the gated cnn model net = build_conditional_pixel_cnn_model(B=B, H=H, W=W, C=C) # build the optimizer (this will take a while!) optimizer = tf.train.AdamOptimizer( learning_rate=0.001).minimize(net['cost']) # Load a list of files for tiny imagenet, downloading if necessary imagenet_files = dsu.tiny_imagenet_load() # Create a threaded image pipeline which will load/shuffle/crop/resize batch = dsu.create_input_pipeline( imagenet_files[0], batch_size=B, n_epochs=n_epochs, shape=[64, 64, 3], crop_shape=[H, W, C], crop_factor=1.0, n_threads=8) saver = tf.train.Saver() writer = tf.summary.FileWriter(ckpt_path) init_op =, tf.local_variables_initializer()) # This will handle our threaded image pipeline coord = tf.train.Coordinator() # Ensure no more changes to graph tf.get_default_graph().finalize() # Start up the queues for handling the image pipeline threads = tf.train.start_queue_runners(sess=sess, coord=coord) if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name): saver.restore(sess, ckpt_name) saver.restore(sess, tf.train.latest_checkpoint(ckpt_path)) epoch_i = 0 batch_i = 0 try: while not coord.should_stop() and epoch_i < n_epochs: batch_i += 1 batch_xs = train_cost = [net['cost'], optimizer], feed_dict={net['X']: batch_xs})[0] print(batch_i, train_cost) if batch_i % write_step == 0: summary = net['summaries'], feed_dict={net['X']: batch_xs}) writer.add_summary(summary, batch_i) if batch_i % save_step == 0: # Save the variables to disk. Don't write the meta graph # since we can use the code to create it, and it takes a long # time to create the graph since it is so deep sess, ckpt_name, global_step=batch_i, write_meta_graph=True) except tf.errors.OutOfRangeError: print('Done.') finally: # One of the threads has issued an exception. So let's tell all the # threads to shutdown. coord.request_stop() # Wait until all threads have finished. coord.join(threads)
[docs]def generate(): """Summary """ # Parameters for generation ckpt_path = 'pixelcnn' B = None H = 64 W = 64 C = 3 with tf.Graph().as_default(), tf.Session() as sess: # Not actually conditioning on anything here just using the gated cnn model net = build_conditional_pixel_cnn_model(B=B, H=H, W=W, C=C) # Load a list of files for tiny imagenet, downloading if necessary imagenet_files = dsu.tiny_imagenet_load() saver = tf.train.Saver() init_op =, tf.local_variables_initializer()) saver.restore(sess, tf.train.latest_checkpoint(ckpt_path)) import matplotlib.pyplot as plt img = plt.imread(imagenet_files[0][1000]) from scipy.misc import imresize og_img = imresize(img, (H, W)) img = og_img.copy() # Zero out bottom half of image and let's try to synthesize it img[H // 2:, :, :] = 0 for h_i in range(H // 2, H): for w_i in range(W): for c_i in range(C): print(h_i, w_i, c_i, end='\r') X = img.copy() preds = net['sampled_preds'], feed_dict={net['X']: X[np.newaxis]}) X = preds.reshape((1, H, W, C)).astype(np.uint8) img[h_i, w_i, c_i] = X[0, h_i, w_i, c_i] fig, axs = plt.subplots(1, 2) axs[0].imshow(og_img) axs[1].imshow(img)
if __name__ == '__main__': train_tiny_imagenet()