Source code for cadl.pixelrnn

"""Basic PixelRNN i.e. CharRNN style, none of the fancy ones (i.e. Row, Diag, BiDiag).
Copyright 2017 Parag K. Mital.  See also

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License.

B : int
C : int
ckpt_name : str
H : int
n_epochs : int
n_units : int
W : int
import os
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from cadl import dataset_utils as dsu

# Parameters for training
ckpt_name = 'pixelrnn.ckpt'
n_epochs = 10
n_units = 100
B = 50
H = 32
W = 32
C = 3

[docs]def build_pixel_rnn_basic_model(B=50, H=32, W=32, C=32, n_units=100, n_layers=2): """Summary Parameters ---------- B : int, optional Description H : int, optional Description W : int, optional Description C : int, optional Description n_units : int, optional Description n_layers : int, optional Description Returns ------- TYPE Description """ # Input to the network, a batch of images X = tf.placeholder(tf.float32, shape=[B, H, W, C], name='X') keep_prob = tf.placeholder(tf.float32, shape=1, name='keep_prob') # Flatten to 2 dimensions X_2d = tf.reshape(X, [-1, H * W * C]) # Turn each pixel value into a vector of one-hot values X_onehot = tf.one_hot(tf.cast(X_2d, tf.uint8), depth=256, axis=2) # Split each pixel into its own tensor resulting in H * W * C number of # Tensors each shaped as B x 256 pixels = [ tf.squeeze(p, axis=1) for p in tf.split(X_onehot, H * W * C, axis=1) ] # Create a GRU recurrent layer cells = tf.contrib.rnn.GRUCell(n_units) initial_state = cells.zero_state( batch_size=tf.shape(X)[0], dtype=tf.float32) if n_layers > 1: cells = tf.contrib.rnn.MultiRNNCell( [cells] * n_layers, state_is_tuple=True) initial_state = cells.zero_state(tf.shape(X)[0], tf.float32) cells = tf.contrib.rnn.DropoutWrapper(cells, output_keep_prob=keep_prob) # Connect our pixel distributions (onehots) to an rnn, this will return us a # list of tensors, one for each of our pixels. hs, final_state = tf.contrib.rnn.static_rnn( cells, pixels, initial_state=initial_state) # Concat N pixels result back into a Tensor, B x N x n_units stacked = tf.concat([tf.expand_dims(h_i, axis=1) for h_i in hs], axis=1) # And now to 2d so we can connect to FC layer stacked = tf.reshape(stacked, [-1, n_units]) # And now connect to FC layer prediction = slim.linear(stacked, 256, scope='linear') if B * H * W * C > 1: prediction = tf.slice(prediction, [0, 0], [int(prediction.shape[0] - 1), -1]) X_onehot_flat = tf.slice( tf.reshape(X_onehot, [-1, 256]), [1, 0], [-1, -1]) loss = tf.nn.softmax_cross_entropy_with_logits( labels=X_onehot_flat, logits=prediction) cost = tf.reduce_mean(loss) else: cost = None return { 'X': X, 'recon': prediction, 'cost': cost, 'initial_state': initial_state, 'final_state': final_state }
[docs]def infer(sess, net, H, W, C, pixel_value=128, state=None): """Summary Parameters ---------- sess : TYPE Description net : TYPE Description H : TYPE Description W : TYPE Description C : TYPE Description pixel_value : int, optional Description state : None, optional Description Returns ------- TYPE Description """ X = np.reshape(pixel_value, [1, 1, 1, 1]) synthesis = [pixel_value] if state is None: state =['initial_state']) for pixel_i in range(H * W * C - 1): next, state = [net['recon'], net['final_state']], feed_dict={net['X']: X, net['initial_state']: state}) synthesis.append(np.argmax(next)) return synthesis
[docs]def train_tiny_imagenet(): """Summary """ net = build_pixel_rnn_basic_model() # build the optimizer (this will take a while!) optimizer = tf.train.AdamOptimizer( learning_rate=0.001).minimize(net['cost']) # Load a list of files for tiny imagenet, downloading if necessary imagenet_files = dsu.tiny_imagenet_load() # Create a threaded image pipeline which will load/shuffle/crop/resize batch = dsu.create_input_pipeline( imagenet_files, batch_size=B, n_epochs=n_epochs, shape=[64, 64, 3], crop_shape=[32, 32, 3], crop_factor=0.5, n_threads=8) sess = tf.Session() saver = tf.train.Saver() init_op =, tf.local_variables_initializer()) # This will handle our threaded image pipeline coord = tf.train.Coordinator() # Ensure no more changes to graph tf.get_default_graph().finalize() # Start up the queues for handling the image pipeline threads = tf.train.start_queue_runners(sess=sess, coord=coord) if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name): saver.restore(sess, ckpt_name) saver.restore(sess, tf.train.latest_checkpoint('./')) epoch_i = 0 batch_i = 0 save_step = 100 try: while not coord.should_stop() and epoch_i < n_epochs: batch_i += 1 batch_xs = train_cost = [net['cost'], optimizer], feed_dict={net['X']: batch_xs})[0] print(batch_i, train_cost) if batch_i % save_step == 0: # Save the variables to disk. Don't write the meta graph # since we can use the code to create it, and it takes a long # time to create the graph since it is so deep sess, ckpt_name, global_step=batch_i, write_meta_graph=False) except tf.errors.OutOfRangeError: print('Done.') finally: # One of the threads has issued an exception. So let's tell all the # threads to shutdown. coord.request_stop() # Wait until all threads have finished. coord.join(threads) # Clean up the session. sess.close()
if __name__ == '__main__': train_tiny_imagenet()