JOSH GORDON: Hey, everyone. Welcome back. In this episode, we'll write
a decision tree classifier from scratch in pure Python. Here's an outline
of what we'll cover. I'll start by introducing
the data set we'll work with. Next, we'll preview
the completed tree. And then, we'll build it. On the way, we'll cover concepts
like decision tree learning, Gini impurity, and
information gain. And you can find the
code for this episode in the description. And it's available
in two formats, both as a Jupiter notebook
and as a regular Python file. OK, let's get started. For this episode, I've
written a toy data set that includes both numeric
and categorical attributes. And here, our goal will be
to predict the type of fruit, like an apple or a
grape, based on features like color and size. At the end of the
episode, I encourage you to swap out this data
set for one of your own and build a tree for a
problem you care about. Let's look at the format. I've re-drawn it
here for clarity. Each row is an example. And the first two columns
provide features or attributes that describe the data. The last column gives
the label, or the class, we want to predict. And if you like, you
can modify this data set by adding additional
features or more examples, and our program will work
in exactly the same way. Now, this data set is
pretty straightforward, except for one thing. I've written it so it's
not perfectly separable. And by that I mean there's
no way to tell apart the second and fifth examples. They have the same features,
but different labels. And this is so we can see how
our tree handles this case. Towards the end of
the notebook, you'll find testing data
in the same format. Now I've written a few utility
functions that make it easier to work with this data. And below each function,
I've written a small demo to show how it works. And I've repeated this pattern
for every block of code in the notebook. Now to build the tree, we use
the decision tree learning algorithm called CART. And as it happens, there's
a whole family of algorithms used to build trees from data. At their core, they
give you a procedure to decide which questions
to ask and when. CART stands for Classification
and Regression Trees. And here's a preview
of how it works. To begin, we'll add a
root node for the tree. And all nodes receive a
list of rows as input. And the root will receive
the entire training set. Now each node will ask
a true false question about one of the features. And in response
to this question, we split, or partition,
the data into two subsets. These subsets then become
the input to two child nodes we add to the tree. And the goal of the question
is to unmix the labels as we proceed down. Or in other words, to
produce the purest possible distribution of the
labels at each node. For example, the
input to this node contains only a
single type of label, so we'd say it's
perfectly unmixed. There's no uncertainty
about the type of label. On the other hand, the labels
in this node are still mixed up, so we'd ask another question
to further narrow it down. And the trick to building
an effective tree is to understand which
questions to ask and when. And to do that, we need to
quantify how much a question helps to unmix the labels. And we can quantify the
amount of uncertainty at a single node using a
metric called Gini impurity. And we can quantify
how much a question reduces that uncertainty
using a concept called information gain. We'll use these
to select the best question to ask at each point. And given that question, we'll
recursively build the tree on each of the new nodes. We'll continue dividing
the data until there are no further questions
to ask, at which point we'll add a leaf. To implement this, first
we need to understand what type of questions
can we ask about the data. And second, we
need to understand how to decide which
question to ask when. Now each node takes a
list of rows as input. And to generate a
list of questions we'll iterate over every
value for every feature that appears in those rows. Each of these
becomes a candidate for a threshold we can
use to partition the data. And there will often
be many possibilities. In code we represent
a question by storing a column number
and a column value, or the threshold we'll
use to partition the data. For example, here's how
we'd write a question to test if the color is green. And here's an example
for a numeric attribute to test if the diameter is
greater than or equal to 3. In response to a question, we
divide, or partition, the data into two subsets. The first contains all the rows
for which the question is true. And the second contains
everything else. In code, our partition
function takes a question and a list of rows as input. For example, here's how we
partition the rows based on whether the color is red. Here, true rows contains
all the red examples. And false rows contains
everything else. The best question is the one
that reduces our uncertainty the most. And Gini impurity let's us
quantify how much uncertainty there is at a node. Information gain will
let us quantify how much a question reduces that. Let's work on impurity first. Now this is a metric that
ranges between 0 and 1 where lower values indicate
less uncertainty, or mixing, at a node. It quantifies our chance of
being incorrect if we randomly assign a label from a set
to an example in that set. Here's an example
to make that clear. Imagine we have two bowls
and one contains the examples and the other contains labels. First, we'll randomly draw an
example from the first bowl. Then we'll randomly draw
a label from the second. And now, we'll classify the
example as having that label. And Gini impurity gives us
our chance of being incorrect. In this example, we have
only apples in each bowl. There's no way to
make a mistake. So we say the impurity is zero. On the other hand, given a
bowl with five different types of fruit in equal
proportion, we'd say it has an impurity of 0.8. That's because we have a one out
of five chance of being right if we randomly assign
a label to an example. In code, this method calculates
the impurity of a data set. And I've written
a couple examples below that demonstrate
how it works. You can see the impurity
for the first set is zero because there's no mixing. And here, you can see
the impurity is 0.8. Now information gain will let us
find the question that reduces our uncertainty the most. And it's just a
number that describes how much a question helps to
unmix the labels at a node. Here's the idea. We begin by calculating
the uncertainty of our starting set. Then, for each
question we can ask, we'll try partitioning
the data and calculating the uncertainty of the
child nodes that result. We'll take a weighted
average of their uncertainty because we care more about a
large set with low uncertainty than a small set with high. Then, we'll subtract this
from our starting uncertainty. And that's our information gain. As we go, we'll keep
track of the question that produces the most gain. And that will be the best
one to ask at this node. Let's see how this
looks in code. Here, we'll iterate over
every value for the features. We'll generate a question
for that feature, then partition the data on it. Notice we discard any questions
that fail to produce a split. Then, we'll calculate
our information gain. And inside this
function, you can see we take a weighted average
and the impurity of each set. We see how much this
reduces the uncertainty from our starting set. And we keep track
of the best value. I've written a couple
of demos below as well. OK, with these concepts in hand,
we're ready to build the tree. And to put this all together I
think the most useful thing I can do is walk you
through the algorithm as it builds a tree
for our training data. This uses recursion, so seeing
it in action can be helpful. You can find the code for this
inside the Build Tree function. When we call build tree
for the first time, it receives the entire
training set as input. And as output it will
return a reference to the root node of our tree. I'll draw a placeholder
for the root here in gray. And here are the rows we're
considering at this node. And to start, that's
the entire training set. Now we find the best
question to ask at this node. And we do that by iterating
over each of these values. We'll split the data and
calculate the information gained for each one. And as we go, we'll keep
track of the question that produces the most gain. Now in this case, there's
a useful question to ask, so the gain will be
greater than zero. And we'll split the data
using that question. And now, we'll use recursion
by calling build tree again to add a node for
the true branch. The rows we're considering
now are the first half of the split. And again, we'll find the best
question to ask for this data. Once more we split and call
the build tree function to add the child node. Now for this data there are
no further questions to ask. So the information
gain will be zero. And this node becomes a leaf. It will predict that
an example is either an apple or a lemon
with 50% confidence because that's the ratio
of the labels in the data. Now we'll continue by
building the false branch. And here, this will
also become a leaf. We'll predict apple
with 100% confidence. Now the previous call
returns, and this node becomes a decision node. In code, that just means
it holds a reference to the question we asked and
the two child nodes that result. And we're nearly done. Now we return to the root node
and build the false branch. There are no further questions
to ask, so this becomes a leaf. And that predicts grape
with 100% confidence. And finally, the root node
also becomes a decision node. And our call to build tree
returns a reference to it. If you scroll down
in the code, you'll see that I've added functions
to classify data and print the tree. And these start with a
reference to the root node, so you can see how it works. OK, hope that was helpful. And you can check out the
code for more details. There's a lot more I have
to say about decision trees, but there's only so much we
can fit into a short time. Here are a couple of topics
that are good to be aware of. And you can check out the
books in the description to learn more. As a next step, I'd
recommend modifying the tree to work with your own data set. And this can be a
fun way to build a simple and interpretable
classifier for use in your projects. Thanks for watching, everyone. And I'll see you next time.