Prerequisites:
What is a Decision Tree?
A decision tree is made up of a series of nodes and branches. A node is where we split the data based on a variable, and a branch is one side of the split. We keep splitting the data based on variables. As we do this, the tree accumulates more levels. When we make a split, some rows will go to the right, and some will go to the left. As we build the tree deeper and deeper, each node will “receive” fewer and fewer rows. The nodes at the bottom of the tree, where we decide to stop splitting, are called terminal nodes, or leaves. When we do our splits, we aren’t doing them randomly; we have an objective. Our goal is to ensure that we can make a prediction on future data. In order to do this, all rows in each leaf must have only one value for our target column.
How do we split?
We’ll use entropy [Precisely, Information Gain] to figure out which variables we should split nodes on. Post-split, we’ll have two data sets, each containing the rows from one branch of the split. Entropy refers to disorder. The more “mixed together” 1s and 0s are, the higher the entropy. A data set consisting entirely of 1s in the target column would have low entropy. Entropy, which is not to be confused with entropy from physics, comes from information theory. A key concept in information theory is the notion of a bit of information. One bit of information is one unit of information. We can represent a bit of information as a binary number because it either has the value 1 or 0. Suppose there’s an equal probability of tomorrow being sunny (1) or not sunny (0). If I tell you that it will be sunny, I’ve given you one bit of information. We can also think of entropy in terms of information. If we flip a coin where both sides are heads, we know upfront that the result will be heads. We gain no new information by flipping the coin, so entropy is 0. On the other hand, if the coin has a heads side and a tails side, there’s a 50% probability that it will land on either. Thus, flipping the coin gives us one bit of information – which side the coin landed on.
How do we calculate Entropy?
We iterate through each unique value in a single column (in our case, high_income
), and assign it to i
. We then compute the probability of that value occurring in the data (P(xi)
). Next, we do some multiplication and sum all of the values together. b
is the base of the logarithm. We commonly use the value 2 for this, but we can also set it to 10 or another value. We use 2 because it allows the result of Entropy to be interpreted in a bit of information.
What is Information Gain?
We’ll need a way to go from computing entropy to figuring out which variable to split on. We can do this using information gain, which tells us which split will reduce entropy the most. We’re computing information gain (IG
) for a given target variable (T
), as well as a given variable we want to split on (A
). To compute it, we first calculate the entropy for T
. Then, for each unique value v
in the variable A
, we compute the number of rows in which A
takes on the value v
, and divide it by the total number of rows. Next, we multiply the results by the entropy of the rows where A
is v
. We add all of these subset entropies together, then subtract from the overall entropy to get information gain. If the result is positive, we’ve lowered entropy with our split. The higher the result is, the more we’ve lowered entropy.
To simplify the calculation of information gain and make splits simpler, we won’t do it for each unique value. We’ll find the median for the variable we’re splitting on instead. Any rows where the value of the variable is below the median will go to the left branch, and the rest of the rows will go to the right branch. To compute information gain, we’ll only have to compute entropies for two subsets.
Let’s visualize the process
We will use the following code:
def id3(data, target, columns, tree):
print(data)
unique_targets = pandas.unique(data[target])
nodes.append(len(nodes) + 1)
tree["number"] = nodes[-1]
if len(unique_targets) == 1:
if 0 in unique_targets:
tree["label"] = 0
elif 1 in unique_targets:
tree["label"] = 1
return
best_column = find_best_column(data, target, columns)
column_median = data[best_column].median()
tree["column"] = best_column
tree["median"] = column_median
left_split = data[data[best_column] <= column_median]
right_split = data[data[best_column] > column_median]
split_dict = [["left", left_split], ["right", right_split]]
for name, split in split_dict:
print(name)
tree[name] = {}
id3(split, target, columns, tree[name])
id3(data, "high_income", ["age", "marital_status"], tree)
Right now our data is like what you see in the table below. After we pass it to our function, the data will be split based on age
or marital_status
column (whichever is best). Since this is a recursive algorithm, it will keep on calling itself for each split — left and right. So in this visualization, we will keep track of the best column, median value, left and right split.
high_income | age | marital_status | |
---|---|---|---|
0 | 0 | 20 | 0 |
1 | 0 | 60 | 2 |
2 | 0 | 40 | 1 |
3 | 1 | 25 | 1 |
4 | 1 | 35 | 2 |
5 | 1 | 55 | 1 |