Bonus content for just $7!

Get our exercise solution bundle, with loads of extra code, solutions to all exercises and updates for life - we add a new lesson, you get more code at no cost!

Buy it now, support the site, and thanks for your support!

TensorFlow Learn

Recommended reading:
Updated for new TensorFlow Learn API!

These tutorials have focused on the mechanics of TensorFlow, but the real use case is for machine learning. TensorFlow has a number of methods for building machine learning models, many of which can be found on the official API page. These functions allow you to build your models from the ground up, including customising aspects, such as how layers in a neural network are built.

In this tutorial, we are going to look at the TensorFlow Learn, which is the new name for a package called skflow. TensorFlow Learn (hereafter: Learn) is a machine learning wrapper, based to the scikit-learn API, allowing you to perform data mining with ease. What does all that mean? Lets go through it step-by-step:

Machine Learning

Machine Learning is the idea that you build algorithms that learn from data, in order to perform actions on new data. In this context, what this means is that we have some input training data, and the expected outcome - the training targets. We are going to look at the famous digits dataset, which is a bunch of images of handdrawn numbers. Our input training data is a few thousand of these images, and our training targets are the expected number.

The task is to learn a model that can answer “what number is this?” for input such as this:

This is a classification task, one of the most common applications of data mining. There are also variants called regression and clustering (as well as many others), but we won’t touch on them in this lesson.

If you want to read more about data mining, check out my book Learning Data Mining with Python.

Scikit-Learn API

Scikit-learn is a Python package for data mining and analysis, and it is incredibly popular. This is due to its wide support for different algorithms, its amazing documentation, and its large and active community. One of the other factors is its consistent interface, its API, that allows people to build models that can be trained with scikit-learn’s helper functions, and allow people to test different models very easily.

Let’s have a look at scikit-learn’s API in practice, but first we need some data. The following code loads a bunch of digit images that can be shown with matplotlib.pyplot:

from sklearn.datasets import load_digits
from matplotlib import pyplot as plt

digits = load_digits()

We can show one of these images using pyplot.imshow. Here I set interpolation='none' to see the data exactly as it is, but if you remove this attribute, it becomes a little clearer to see (also try reducing the figure size).

fig = plt.figure(figsize=(3, 3))

plt.imshow(digits['images'][66], cmap="gray", interpolation='none')

In scikit-learn, we can build a simple classifier, train it, and then use it to predict the number of an image, using just four lines of code:

from sklearn import svm

classifier = svm.SVC(gamma=0.001),
predicted = classifier.predict(

The first line simply imports the Support Vector Machine model, which is a popular machine learning method.

The second line builds a “blank” classifier, with the gamma value set to 0.001.

The third line uses the data to train the model. In this line (which is the bulk of the “work” for this code), the internal state of the SVM model is adjusted to best suit the training data. We also pass, as this is a flat array, the accepted input from this algorithm. The digits.images above is

Finally, the last line uses this trained classifier to predict the class of some data, in this case the original dataset again.

To see how accurate this is, we can compute the accuracy using NumPy:

import numpy as np
print(np.mean( == predicted))

The results are pretty impressive (nearly perfect), but these are a little misleading. In data mining, you should never evaluate your data on the same data you used to train. The potential problem is called “overfitting”, where the model learns exactly what it needs to for the training data, but is unable to predict well on new unseen data. To address this, we need to split our training and testing data:

from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(,

The result is still very good at around 98%, but this dataset is well known in data mining, and its features are well documented. Regardless, we now know what we are trying to do, let’s do it in TensorFlow Learn!

TensorFlow Learn

The TensorFlow Learn interface is only a small step away from scikit-learn’s interface:

from tensorflow.contrib import learn
n_classes = len(set(y_train))
classifier = learn.LinearClassifier(feature_columns=[tf.contrib.layers.real_valued_column("", dimension=X_train.shape[1])],
                                    n_classes=n_classes), y_train, steps=10)

y_pred = classifier.predict(X_test)

The only real changes are the import statement, and the model, which comes from a different list of available algorithms. One difference is that the classifier needs to know how many classes it will be predicting over, which can be found using len(set(y_train)), or in other words “how many unique values are in the training data.

Another difference is that the classifier needs to be told what types of features to expect. For this example, we have real valued continuous features, so we can simply state that as the feature_columns value (it needs to be in a list though). If you are using categorical features, you’ll need to state this separately. For more information on this, checkout the documentation on TensorFlow Learn’s example.

The results can be evaluated as before, to compute the accuracy, but scikit-learn has classification_report, which offers a much more in-depth look:

from sklearn import metrics
print(metrics.classification_report(y_true=y_test, y_pred=y_pred))

The result shows you to the recall and precision for each class, as well as the overall values and f-measure. These are more reliable scores than the accuracy, and for more information see this page on Wikipedia.

That’s the high level overview of TensorFlow Learn. You can define custom classifiers, which you’ll see in exercise 3, and combine classifiers into pipelines (the support for this is small, but improving). This package has the potential to be a widely used package for data mining in industry and academia.


1) Change the classifier to DNNClassifier and rerun. Feel free to tell all your friends that you now perform data analysis with deep learning.

2) The default parameters to DNNClassifier are good, but not perfect. Try changing the parameters to get a higher score.

3) Review this example from TensorFlow Learn’s documentation and download the CIFAR 10 dataset. Build a classifer that predicts what images are using a Convolutional Neural Network. You can use this code to load the data:

def load_cifar(file):
    import pickle
    import numpy as np
    with open(file, 'rb') as inf:
        cifar = pickle.load(inf, encoding='latin1')
    data = cifar['data'].reshape((10000, 3, 32, 32))
    data = np.rollaxis(data, 3, 1)
    data = np.rollaxis(data, 3, 1)
    y = np.array(cifar['labels'])

    # Just get 2s versus 9s to start
    # Remove these lines when you want to build a big model
    mask = (y == 2) | (y == 9)
    data = data[mask]
    y = y[mask]

    return data, y

Stuck? Get the exercise solutions here

If you are looking for solutions on the exercises, or just want to see how I solved them, then our solutions bundle is what you are after.

Buying the bundle gives you free updates for life - meaning when we add a new lesson, you get an updated bundle with the solutions.

It's just $7, and it also helps us to keep running the site with free lessons.

Grow your business with data analytics

Looking to improve your business through data analytics? Are you interested in implementing data mining, automation or artificial intelligence?

This book is the ultimate guide to getting started with using data in your business, with a non-technical view and focusing on achieving good outcomes. We don't get bogged down by technical detail or complex algorithms.

For additional offers, including a premium package, see this page.

Get updates

Sign up here to receive infrequent emails from us about updates to the site and when new lessons are released.

* indicates required

You can also support by becoming a patron at Patreon. If we have saved you trawling through heavy documentation, or given you a pointer on where to go next, help us to create new lessons and keep the site running.

You'll also get access to extra content and updates not available on!

Coming soon!

Note: despite the title, this book has no relationship to

Learn More!