Source code for cadl.celeb_vaegan

"""Tools for downloading the celeb dataset and model, including preprocessing.
"""
"""
Copyright 2017 Parag K. Mital.  See also NOTICE.md.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
from cadl.utils import download
from skimage.transform import resize as imresize


[docs]def celeb_vaegan_download(): """Download a pretrained celeb vae/gan network. Returns ------- TYPE Description """ # Load the model and labels model = download( 'https://s3.amazonaws.com/cadl/models/celeb.vaegan.tfmodel') labels = download( 'https://s3.amazonaws.com/cadl/celeb-align/list_attr_celeba.txt') return model, labels
[docs]def get_celeb_vaegan_model(): """Get a pretrained model. Returns ------- net : dict { 'graph_def': tf.GraphDef The graph definition 'labels': list List of different possible attributes from celeb 'attributes': np.ndarray One hot encoding of the attributes per image [n_els x n_labels] 'preprocess': function Preprocess function } """ # Download the trained net model, labels = celeb_vaegan_download() # Parse the ids and synsets txt = open(labels).readlines() n_els = int(txt[0].strip()) labels = txt[1].strip().split() n_labels = len(labels) attributes = np.zeros((n_els, n_labels), dtype=bool) for i, txt_i in enumerate(txt[2:]): attributes[i] = (np.array(txt_i.strip().split()[1:]).astype(int) > 0) # Load the saved graph with gfile.GFile(model, 'rb') as f: graph_def = tf.GraphDef() try: graph_def.ParseFromString(f.read()) except: print('try adding PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python' + 'to environment. e.g.:\n' + 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python ipython\n' + 'See here for info: ' + 'https://github.com/tensorflow/tensorflow/issues/582') net = { 'graph_def': graph_def, 'labels': labels, 'attributes': attributes, 'preprocess': preprocess, } return net
[docs]def preprocess(img, crop_factor=0.8): """Replicate the preprocessing we did on the VAE/GAN. This model used a crop_factor of 0.8 and crop size of [100, 100, 3]. Parameters ---------- img : TYPE Description crop_factor : float, optional Description Returns ------- TYPE Description """ crop = np.min(img.shape[:2]) r = (img.shape[0] - crop) // 2 c = (img.shape[1] - crop) // 2 cropped = img[r:r + crop, c:c + crop] r, c, *d = cropped.shape if crop_factor < 1.0: amt = (1 - crop_factor) / 2 h, w = int(c * amt), int(r * amt) cropped = cropped[h:-h, w:-w] rsz = imresize(cropped, (100, 100), preserve_range=False) return rsz