Using Tensorflow's Dataset API

TensorFlow's new Dataset API (available from 1.8) makes creating input pipelines much easier. Using it should be painless if you have something is an iterable, one of the common formats (files in a folder, csv, numpy array) or TFRecord, life is gonna be much easier, and from_generator is perhaps the easiest to get any dataset into TensorFlow.

Usage pattern:

  1. Create the 'raw' dataset= one of:
    • tf.data.Dataset.from_generator() for some function with a yield
    • tf.data.TFRecordDataset() for reading from TFRecords
    • tf.data.Dataset.from_tensor_slices() for numpy arrays. (Sparse version available too)
    • tf.data.TextLineDataset() for text files like .csvs
  2. Apply transforms, if desired, with dataset.map(...)
  3. Randomize order with .shuffle(buffer_size=n)
  4. Set with .repeat(n). (Pass nothing for it to repeat forever)
  5. Set batch size with .batch(n)
  6. Obtain iterator with iter = dataset.make_one_shot_iterator()
  7. Elements can now be obtained with: x, y = iter.get_next()
  8. Build graph using x, y directly. No placeholders needed!

Migrating from TFRecords and QueueRunners

If you have been using TFRecords and QueueRunners, switching over to the new Dataset API will be very painless.

Your original input pipeline should have something like this

reader = tf.TFRecordReader()
_, example = reader.read(filenamequeue)
fmt = ...
features = tf.parse_single_example(example, features=fmt)
x = features['data']
y = features['label']

In the new API, we create a function will all the parsing and preprocessing we need for each example, into a function. This is then applied to the dataset using .map().

def parse_func(example):
    fmt = { <key1> : tf.FixedLenFeature( <shape>, <dtype>, <default_value(optional),
            <key2> : tf.VarLenFeature( <dtype> ), ...
          }
    parsed = tf.parse_single_example(example, fmt)
    return parsed[<key1>], ...

Full basic example with TFRecord

data_raw = tf.data.TFRecordDataset(filename) #or list of filenames

def _parse_func(example):
   #example is a Tensor of bytes. Needs to be parsed with parse_single_example
   example_fmt = { 'x': tf.FixedLenFeature((), tf.string, ''),
                   'y': tf.FixedLenFeature((), tf.string, ''),
                   }
    parsed = tf.parse_single_example(example, example_fmt)
    #parsed is a dictionary of tensors
    #Can do further processing of the tensors now, or simply return them
    return (parsed['x'], parsed['y'])

data = data_raw.map(_parse_func)

#Make random, repeatable and batched
data = data.repeat().shuffle(buffer_size=BATCH_SIZE*10).batch(BATCH_SIZE)

iter = data.make_one_shot_iterator()
x, y = iter.get_next()

#Build graph, note how x and y are used
net = tf.layers.dense(x, 512, activation=tf.relu)
net = tf.layers.dense(net, 512, activation=tf.relu)
pred = tf.layers.dense(net, 10)

loss = tf.losses.softmax_cross_entropy(pred, y)

train_op = tf.train.GradientDescentOptimizer().optimize(loss)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for ii in range(MAX_ITER):
        _, curr_loss = sess.run([train_op, loss])
        print('Iter: {}, Loss: {}'.format(ii, curr_loss))

On parsing (compressed) images

Images are parsed as FixedLenFeature, even if they might be compressed and of different sizes (byte lenght). This is because FixedLen here refers to the tensor length, not the number of bytes in the Tensor. A variable sized image is still a single element Tensor of type BytesList.

From (almost) anything else with generators

I find that the best part of the Dataset API is the from_generator(). So long as you know how to iterate thru the examples, you should be able to wrap it into the Dataset API without much difficulty.

Example with HDF5 + generators

Heres a toy example of reading a HDF5 file, with keys x and y. (Unfortunately, in H5PY's documentations, these are called datasets)

import h5py
import tensorflow as tf
import numpy as np

in_file = h5py.File('data.h5', 'r')
x_in = in_file.get('x')
y_in = in_file.get('y')

def gen():
    for x, y in zip(x_in, y_in):
        yield x, y

#Lets assume that x is 3 vector of floats and y is an int
d = tf.data.Dataset.from_generator(gen, 
                                   output_shape = ([3], None),
                                   output_types = (tf.float32, tf.int32)
                                   )

That's it! The HDF5 (or which ever esoteric reader that you might have) is now wrapped in a nice Dataset API, with all the batching, pipelined reading, shuffling, available to you!


TODO:

Planned updates:

  • [x] Notes on parse function for TFRecord
  • [ ] Fancy initializers
  • [ ] Using with graphs built with placeholders
  • [ ] Boilerplate file gist?
Show Comments