These tutorials have focused on the mechanics of TensorFlow, but the real use case is for machine learning. TensorFlow has a number of methods for building machine learning models, many of which can be found on the official API page. These functions allow you to build your models from the ground up, including customising aspects, such as how layers in a neural network are built.
In this tutorial, we are going to look at the TensorFlow Learn, which is the new name for a package called skflow. TensorFlow Learn (hereafter: Learn) is a machine learning wrapper, based to the scikit-learn API, allowing you to perform data mining with ease. What does all that mean? Lets go through it step-by-step:
Machine Learning is the idea that you build algorithms that learn from data, in order to perform actions on new data. In this context, what this means is that we have some input training data, and the expected outcome - the training targets. We are going to look at the famous digits dataset, which is a bunch of images of handdrawn numbers. Our input training data is a few thousand of these images, and our training targets are the expected number.
The task is to learn a model that can answer “what number is this?” for input such as this:
This is a classification task, one of the most common applications of data mining. There are also variants called regression and clustering (as well as many others), but we won’t touch on them in this lesson.
If you want to read more about data mining, check out my book Learning Data Mining with Python.
Scikit-learn is a Python package for data mining and analysis, and it is incredibly popular. This is due to its wide support for different algorithms, its amazing documentation, and its large and active community. One of the other factors is its consistent interface, its API, that allows people to build models that can be trained with scikit-learn’s helper functions, and allow people to test different models very easily.
Let’s have a look at scikit-learn’s API in practice, but first we need some data.
The following code loads a bunch of digit images that can be shown with
We can show one of these images using
Here I set
interpolation='none' to see the data exactly as it is, but if you remove this attribute, it becomes a little clearer to see (also try reducing the figure size).
In scikit-learn, we can build a simple classifier, train it, and then use it to predict the number of an image, using just four lines of code:
The first line simply imports the Support Vector Machine model, which is a popular machine learning method.
The second line builds a “blank” classifier, with the gamma value set to 0.001.
The third line uses the data to train the model. In this line (which is the bulk of the “work” for this code), the internal state of the SVM model is adjusted to best suit the training data.
We also pass
digits.data, as this is a flat array, the accepted input from this algorithm. The
digits.images above is
Finally, the last line uses this trained classifier to predict the class of some data, in this case the original dataset again.
To see how accurate this is, we can compute the accuracy using NumPy:
The results are pretty impressive (nearly perfect), but these are a little misleading. In data mining, you should never evaluate your data on the same data you used to train. The potential problem is called “overfitting”, where the model learns exactly what it needs to for the training data, but is unable to predict well on new unseen data. To address this, we need to split our training and testing data:
The result is still very good at around 98%, but this dataset is well known in data mining, and its features are well documented. Regardless, we now know what we are trying to do, let’s do it in TensorFlow Learn!
The TensorFlow Learn interface is only a small step away from scikit-learn’s interface:
The only real changes are the import statement, and the model, which comes from a different list of available algorithms.
One difference is that the classifier needs to know how many classes it will be predicting over, which can be found using
len(set(y_train)), or in other words “how many unique values are in the training data.
Another difference is that the classifier needs to be told what types of features to expect.
For this example, we have real valued continuous features, so we can simply state that as the
feature_columns value (it needs to be in a list though).
If you are using categorical features, you’ll need to state this separately.
For more information on this, checkout the documentation on TensorFlow Learn’s example.
The results can be evaluated as before, to compute the accuracy, but scikit-learn has
classification_report, which offers a much more in-depth look:
The result shows you to the recall and precision for each class, as well as the overall values and f-measure. These are more reliable scores than the accuracy, and for more information see this page on Wikipedia.
That’s the high level overview of TensorFlow Learn. You can define custom classifiers, which you’ll see in exercise 3, and combine classifiers into pipelines (the support for this is small, but improving). This package has the potential to be a widely used package for data mining in industry and academia.
Stuck? Looking for more content?
If you are looking for solutions on the exercises, or just want to see how I solved them, then our solutions bundle is what you are after. Buying the bundle gives you free updates for life - meaning when we add a new lesson, you get an updated bundle with the solutions. It's just $7, and it also helps us to keep running the site with free lessons.
1) Change the classifier to
DNNClassifier and rerun. Feel free to tell all your friends that you now perform data analysis with deep learning.
2) The default parameters to
DNNClassifier are good, but not perfect. Try changing the parameters to get a higher score.
3) Review this example from TensorFlow Learn’s documentation and download the CIFAR 10 dataset. Build a classifer that predicts what images are using a Convolutional Neural Network. You can use this code to load the data: