Copyright 2017 Google Inc.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

dSprites - Disentanglement testing Sprites dataset

Description

Procedurally generated 2D shapes dataset. This dataset uses 6 latents, controlling the color, shape, scale, rotation and position of a sprite (color isn't varying here, its value is fixed).

All possible combinations of the latents are present.

The ordering of images in the dataset (i.e. shape[0] in all ndarrays) is fixed and meaningful, see below.

We chose the smallest changes in latent values that generated different pixel outputs at our 64x64 resolution after rasterization.

No noise added, single image sample for a given latent setting.

Details about the ordering of the dataset

The dataset was generated procedurally, and its order is deterministic. For example, the image at index 0 corresponds to the latents (0, 0, 0, 0, 0, 0).

Then the image at index 1 increases the least significant "bit" of the latent: (0, 0, 0, 0, 0, 1)

And similarly, till we reach index 32, where we get (0, 0, 0, 0, 1, 0).

Hence the dataset is sequentially addressable using variable bases for every "bit". Using dataset['metadata']['latents_sizes'] makes this conversion trivial, see below.


In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns

# Change figure aesthetics
%matplotlib inline
sns.set_context('talk', font_scale=1.2, rc={'lines.linewidth': 1.5})

In [2]:
# Load dataset
dataset_zip = np.load('dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')

print('Keys in the dataset:', dataset_zip.keys())
imgs = dataset_zip['imgs']
latents_values = dataset_zip['latents_values']
latents_classes = dataset_zip['latents_classes']
metadata = dataset_zip['metadata'][()]

print('Metadata: \n', metadata)


Keys in the dataset: ['metadata', 'imgs', 'latents_classes', 'latents_values']
Metadata: 
 {'description': 'Disentanglement test Sprites dataset.Procedurally generated 2D shapes, from 6 disentangled latent factors.This dataset uses 6 latents, controlling the color, shape, scale, rotation and position of a sprite. All possible variations of the latents are present. Ordering along dimension 1 is fixed and can be mapped back to the exact latent values that generated that image.We made sure that the pixel outputs are different. No noise added.', 'latents_sizes': array([ 1,  3,  6, 40, 32, 32]), 'latents_names': ('color', 'shape', 'scale', 'orientation', 'posX', 'posY'), 'date': 'April 2017', 'version': 1, 'title': 'dSprites dataset', 'latents_possible_values': {'posX': array([ 0.        ,  0.03225806,  0.06451613,  0.09677419,  0.12903226,
        0.16129032,  0.19354839,  0.22580645,  0.25806452,  0.29032258,
        0.32258065,  0.35483871,  0.38709677,  0.41935484,  0.4516129 ,
        0.48387097,  0.51612903,  0.5483871 ,  0.58064516,  0.61290323,
        0.64516129,  0.67741935,  0.70967742,  0.74193548,  0.77419355,
        0.80645161,  0.83870968,  0.87096774,  0.90322581,  0.93548387,
        0.96774194,  1.        ]), 'posY': array([ 0.        ,  0.03225806,  0.06451613,  0.09677419,  0.12903226,
        0.16129032,  0.19354839,  0.22580645,  0.25806452,  0.29032258,
        0.32258065,  0.35483871,  0.38709677,  0.41935484,  0.4516129 ,
        0.48387097,  0.51612903,  0.5483871 ,  0.58064516,  0.61290323,
        0.64516129,  0.67741935,  0.70967742,  0.74193548,  0.77419355,
        0.80645161,  0.83870968,  0.87096774,  0.90322581,  0.93548387,
        0.96774194,  1.        ]), 'scale': array([ 0.5,  0.6,  0.7,  0.8,  0.9,  1. ]), 'orientation': array([ 0.        ,  0.16110732,  0.32221463,  0.48332195,  0.64442926,
        0.80553658,  0.96664389,  1.12775121,  1.28885852,  1.44996584,
        1.61107316,  1.77218047,  1.93328779,  2.0943951 ,  2.25550242,
        2.41660973,  2.57771705,  2.73882436,  2.89993168,  3.061039  ,
        3.22214631,  3.38325363,  3.54436094,  3.70546826,  3.86657557,
        4.02768289,  4.1887902 ,  4.34989752,  4.51100484,  4.67211215,
        4.83321947,  4.99432678,  5.1554341 ,  5.31654141,  5.47764873,
        5.63875604,  5.79986336,  5.96097068,  6.12207799,  6.28318531]), 'shape': array([ 1.,  2.,  3.]), 'color': array([ 1.])}, 'author': 'lmatthey@google.com'}

In [0]:
# Define number of values per latents and functions to convert to indices
latents_sizes = metadata['latents_sizes']
latents_bases = np.concatenate((latents_sizes[::-1].cumprod()[::-1][1:],
                                np.array([1,])))

def latent_to_index(latents):
  return np.dot(latents, latents_bases).astype(int)


def sample_latent(size=1):
  samples = np.zeros((size, latents_sizes.size))
  for lat_i, lat_size in enumerate(latents_sizes):
    samples[:, lat_i] = np.random.randint(lat_size, size=size)

  return samples

In [0]:
# Helper function to show images
def show_images_grid(imgs_, num_images=25):
  ncols = int(np.ceil(num_images**0.5))
  nrows = int(np.ceil(num_images / ncols))
  _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3))
  axes = axes.flatten()

  for ax_i, ax in enumerate(axes):
    if ax_i < num_images:
      ax.imshow(imgs_[ax_i], cmap='Greys_r',  interpolation='nearest')
      ax.set_xticks([])
      ax.set_yticks([])
    else:
      ax.axis('off')

def show_density(imgs):
  _, ax = plt.subplots()
  ax.imshow(imgs.mean(axis=0), interpolation='nearest', cmap='Greys_r')
  ax.grid('off')
  ax.set_xticks([])
  ax.set_yticks([])

Randomly sampling into the dataset


In [24]:
# Sample latents randomly
latents_sampled = sample_latent(size=5000)

# Select images
indices_sampled = latent_to_index(latents_sampled)
imgs_sampled = imgs[indices_sampled]

# Show images
show_images_grid(imgs_sampled)



In [25]:
# Compute the density of the data to show that no pixel ever goes out of
# the boundary. Obviously it also means that the main support of the pixels is in the center
# half. 
# Locations cover a square, which make the aligned X-Y latents more likely for
# models to discover.

show_density(imgs_sampled)


Conditional sampling of the dataset


In [27]:
## Fix posX latent to left
latents_sampled = sample_latent(size=5000)
latents_sampled[:, -2] = 0
indices_sampled = latent_to_index(latents_sampled)
imgs_sampled = imgs[indices_sampled]

# Samples
show_images_grid(imgs_sampled, 9)

# Show the density too to check
show_density(imgs_sampled)



In [29]:
## Fix orientation to 0.8 rad
latents_sampled = sample_latent(size=5000)
latents_sampled[:, 3] = 5
indices_sampled = latent_to_index(latents_sampled)
imgs_sampled = imgs[indices_sampled]

# Samples
show_images_grid(imgs_sampled, 9)

# Density should not be different than for all orientations
show_density(imgs_sampled)



In [0]: