Introduction to Decision Tree Learning

From Kaggle to classrooms, one of the first lessons in machine learning involves decision trees. The reason for the focus on decision trees is that they aren’t very mathematics heavy compared to other ML approaches, and at the same time, they provide reasonable accuracy on classification problems.

If you’re just getting started with machine learning, it’s very easy to pick up decision trees. In this tutorial, you’ll learn:

  • What is a decision tree?
  • How to construct a decision tree
  • Construct a decision tree using Python

What is a decision tree?

Let’s skip the formal definition and think conceptually about decision trees. Imagine you’re sitting in your office and feeling hungry. You want to go out and eat, but lunch starts at 1 PM. What do you do? Of course, you look at the time and then decide if you can go out. You can think of your logic like this:

We just made a decision tree! This is a simple one, but we can build a complicated one by including more factors like weather, cost, etc. If you want to go to lunch with your friend, Jon Snow, to a place that serves Chinese food, the logic can be summarized in this tree:

This is also a decision tree. You start at the top, follow the paths that describe the current condition, and keep doing that until you reach a decision.

Some notation

Let’s shift to the world of computers. Each box we just drew is called a node. The topmost node is called the root and all the nodes at the bottom layer are leaf nodes. Think of it as a real-world tree, but inverted.

Each node tests some property (attribute) of our world (dataset) and each branch going out from the node corresponds to a value of that attribute. Given a tree, the process of deciding will be:

  1. Start at the root
  2. Observe value of the attribute at the root
  3. Follow the path that corresponds to the observed value
  4. Repeat until we reach a leaf node, which will give us our decision

How to construct a decision tree?

You won’t ever need to construct a decision tree from scratch (unless you’re a student like me). Nonetheless, it’s a good learning experience and you’ll learn some interesting concepts along the way.

The most popular algorithm for constructing decision trees is ID3 and it’s quite simple. Here’s the algorithm pseudocode:

One detail you’ll notice is that just after the beginning of the loop, the algorithm has to pick the attribute that best classifies the examples. How will it do that? To understand that, we’ll have to dive into a little bit of math. Don’t worry, it’s not too hard, and if you get stuck, I can answer any questions in the comments.

Information Gain and Entropy

One of the commonly used and beginner friendly ways to figure out the best attribute is information gain. It’s calculated using another property called entropy.

Entropy is a concept used in physics and mathematics that refers to the randomness or the impurity of a system. In information theory, it refers to the impurity of a group of examples.

Let’s see an example to make it clear: You have 2 bags of full of chocolates. The chocolates can be either red or blue. You decide to measure the entropy of bags by counting the number of chocolates. So you sit down and start counting. After 2 minutes, you discover the first bag has 50 chocolates. 25 of them are red and 25 are blue. Second bag also has 50 chocolates, all of them blue.

In this case, the first bag has entropy 1 as the chocolates are equally distributed. The second bag has entropy zero because there is no randomness.

If you want to calculate the entropy of a system, we use this formula:

Here, c is the total number of classes or attributes and pi is number of examples belonging to the ith class. Confused? Let’s try an example to clarify.

We will go back to our chocolate boxes. We have two classes, red(R) and blue(B). For the first box, we have 25 red chocolates. The total number of chocolates is 50. So pi becomes 25 divided by 50. Same goes for blue class. Plug those values into entropy equation and we get this:

Solve the equation and here are the results:

If you’d like to verify the result or play with more examples, check Wolfram Alpha.

Go ahead and calculate entropy for the second box, which has 50 red chocolates and 0 blue ones. You will get 0 entropy.

If you understand the concept, excellent! We’ll move to information gain now. If you have any doubts, just leave a comment, and I’ll be happy to answer any questions.

Information Gain

Information gain is simply the expected reduction in entropy caused by partitioning all our examples according to a given attribute. Mathematically, it’s defined as:

This may seem like a lot, so let’s break it down. S refers to the entire set of examples that we have. A is the attribute we want to partition or split. |S| is the number of examples and |Sv| is the number of examples for the current value of attribute A.

Still very complicated, right? Let’s try the measure on an example and see how it works.

Building the Decision Tree

First, let’s take our chocolate example and add a few extra details. We already know that the box 1 has 25 red chocolates and 25 blue ones. Now, we will also consider the brand of chocolates. Among red ones, 15 are Snickers and 10 are Kit Kats. In blue ones, 20 are Kit Kats and 5 are Snickers. Let’s assume we only want to eat red Snickers. Here, red Snickers (15) become positive examples and everything else like blue Snickers and red Kit Kats are negative examples.

Now, the entropy of the dataset with respect to our classes (eat/not eat) is:

Let’s take a look back now — we have 50 chocolates. If we look at the attribute color, we have 25 red and 25 blue ones. If we look at the attribute brand, we have 20 Snickers and 30 Kit Kats.

To build the tree, we need to pick one of these attributes for the root node. And we want to pick the one with the highest information gain. Let’s calculate information gain for attributes to see the algorithm in action.

Information gain with respect to color would be:

We just calculated the entropy of chocolates with respect to class, which is 0.8812. For entropy of red chocolates, we want to eat 15 Snickers but not 10 Kit Kats. The entropy for red chocolates is:

For blue chocolates, we don’t want to eat them at all. So entropy is 0.

Our information gain calculation now becomes:

If we split on color, information gain is 0.3958.

Let’s look at the brand now. We want to eat 15 out of 20 Snickers. We don’t want to eat any Kit Kats. The entropy for Snickers is:

We don’t want to eat Kit Kats at all, so Entropy is 0. Information gain:

Information gain for the split on brand is 0.5567.

Since information gain for brand is larger, we will split based on brand. For the next level, we only have color left. We can easily split based on color without having to do any calculations. Our decision tree will look like this:

Who thought eating chocolates would be this hard?

You should have a solid intuition about how decision trees work now. Again, if you find anything confusing or are feeling lost, feel free to ask any questions.

Implementing a Decision Tree with Python 3

Let’s go ahead and build a decision tree for our chocolates dataset. You can find the code and data on GitHub.

1. Create a new folder anywhere.

2. Download data.csv from GitHub.

3. You may need to install Scipy, Scikit-Learn and Pandas if you don’t have them. I recommend using a virtual environment, you can learn more here. For installing Pandas and Scikit-Learn, run these commands from your terminal:
pip install scikit-learn
pip install scipy
pip install pandas

4. Once you have installed them, create a new file, and add these two lines to it:
from pandas import read_csv
from sklearn import tree

5. Load the data using Pandas:
data = read_csv(“data.csv”)

6. Pandas lets you work with big datasets and has lot of visualization features. It’s used in most of the big data pipelines with Python, so it’s a good idea to get comfortable with it. You can take a quick look at loaded data using head() method in Pandas:
This will display the first 5 columns of our data.

7. I’m using theClass column to decide if we want to eat a chocolate or not. 1 means yes and 0 means no.

8. Next, we need to do some pre-processing on our data. Scikit-Learn doesn’t support text labels by default, so we will use Pandas to convert our text labels to numbers. Simply add the following two lines:
data[‘Color’] = data[‘Color’].map({‘Red’: 0, ‘Blue’: 1})
data[‘Brand’] = data[‘Brand’].map({‘Snickers’: 0, ‘Kit Kat’: 1})

9. We just changed Color attribute to reflect 0 for Red and 1 for Blue. Similarly, we substituted 0 for Snickers and 1 for Kit Kat in the column Brand.

10. If you use head() to see the dataset, you’ll see that brand and color values have changed to integers:

11. One last thing: it’s a convention to denote our training attributes by X and output class by Y, so we will do that now:
predictors = [‘Color’, ‘Brand’]
X = data[predictors]
Y = data.Class

12. Almost done. We’re ready to train our decision tree now. Add the following two lines to train the tree on our data:
decisionTreeClassifier = tree.DecisionTreeClassifier(criterion=”entropy”)
dTree =, Y)

13. Done? Let’s quickly visualize the tree. Add these lines and run the program:
dotData = tree.export_graphviz(dTree, out_file=None)

14. You’ll see output like this:

15. Copy this and head over to WebGraphviz. Paste the output and click “Generate Graph”. You’ll see a decision tree similar to one we made above:

16. This one is a bit harder to understand with all the extra information, but you can see that it split on column 1(Brand) first and then on column 1(Color).

Once you have learned the tree, future predictions are simple. Let’s see if we want to eat a blue Kit Kat.

Add following line to end of file:

print(dTree.predict([[1, 1]]))

The output will be [0] which means the classification is don’t eat. If you try a red Snickers (print(dTree.predict([[0, 0]]))), output will be [1].

Moving Forward

You just learned how decision trees work and implemented a simple one. Here are further resources that you can use to continue learning:

  1. Decision Trees page on Scikit-Learn. Discusses a bigger dataset and alternative measures for splitting data.
  2. Machine Learning tutorial on Kaggle: A deep tutorial that will teach you how to participate on Kaggle and build a Decision Tree model on housing data.
  3. Saving your Scikit Models: In this tutorial, we trained the model every time we ran. While this is OK for small datasets, it’s a much better idea to train them once and then use later. Follow this tutorial to learn how to save your model.
  4. Converting Trained Models to Core ML: If you train your decision tree for another dataset and want to run it on iOS devices, you need to convert the trained model to Core ML framework.

Avatar photo


Our team has been at the forefront of Artificial Intelligence and Machine Learning research for more than 15 years and we're using our collective intelligence to help others learn, understand and grow using these new technologies in ethical and sustainable ways.

Comments 0 Responses

Leave a Reply

Your email address will not be published. Required fields are marked *

wix banner square