# ID3 (Iterative Dichotomiser 3) Algorithm - How does it work?

Prerequisites:

What is a Decision Tree?
A decision tree is made up of a series of nodes and branches. A node is where we split the data based on a variable, and a branch is one side of the split. We keep splitting the data based on variables. As we do this, the tree accumulates more levels. When we make a split, some rows will go to the right, and some will go to the left. As we build the tree deeper and deeper, each node will “receive” fewer and fewer rows. The nodes at the bottom of the tree, where we decide to stop splitting, are called terminal nodes, or leaves. When we do our splits, we aren’t doing them randomly; we have an objective. Our goal is to ensure that we can make a prediction on future data. In order to do this, all rows in each leaf must have only one value for our target column.

How do we split?
We’ll use entropy [Precisely, Information Gain] to figure out which variables we should split nodes on. Post-split, we’ll have two data sets, each containing the rows from one branch of the split. Entropy refers to disorder. The more “mixed together” 1s and 0s are, the higher the entropy. A data set consisting entirely of 1s in the target column would have low entropy. Entropy, which is not to be confused with entropy from physics, comes from information theory. A key concept in information theory is the notion of a bit of information. One bit of information is one unit of information. We can represent a bit of information as a binary number because it either has the value 1 or 0. Suppose there’s an equal probability of tomorrow being sunny (1) or not sunny (0). If I tell you that it will be sunny, I’ve given you one bit of information. We can also think of entropy in terms of information. If we flip a coin where both sides are heads, we know upfront that the result will be heads. We gain no new information by flipping the coin, so entropy is 0. On the other hand, if the coin has a heads side and a tails side, there’s a 50% probability that it will land on either. Thus, flipping the coin gives us one bit of information – which side the coin landed on.

How do we calculate Entropy?

We iterate through each unique value in a single column (in our case, `high_income`), and assign it to `i`. We then compute the probability of that value occurring in the data (`P(xi)`). Next, we do some multiplication and sum all of the values together. `b` is the base of the logarithm. We commonly use the value 2 for this, but we can also set it to 10 or another value. We use 2 because it allows the result of Entropy to be interpreted in a bit of information.

What is Information Gain?

We’ll need a way to go from computing entropy to figuring out which variable to split on. We can do this using information gain, which tells us which split will reduce entropy the most. We’re computing information gain (`IG`) for a given target variable (`T`), as well as a given variable we want to split on (`A`). To compute it, we first calculate the entropy for `T`. Then, for each unique value `v` in the variable `A`, we compute the number of rows in which `A` takes on the value `v`, and divide it by the total number of rows. Next, we multiply the results by the entropy of the rows where `A` is `v`. We add all of these subset entropies together, then subtract from the overall entropy to get information gain. If the result is positive, we’ve lowered entropy with our split. The higher the result is, the more we’ve lowered entropy.

To simplify the calculation of information gain and make splits simpler, we won’t do it for each unique value. We’ll find the median for the variable we’re splitting on instead. Any rows where the value of the variable is below the median will go to the left branch, and the rest of the rows will go to the right branch. To compute information gain, we’ll only have to compute entropies for two subsets.

Let’s visualize the process
We will use the following code:

``````def id3(data, target, columns, tree):
print(data)
unique_targets = pandas.unique(data[target])
nodes.append(len(nodes) + 1)
tree["number"] = nodes[-1]

if len(unique_targets) == 1:
if 0 in unique_targets:
tree["label"] = 0
elif 1 in unique_targets:
tree["label"] = 1
return

best_column = find_best_column(data, target, columns)
column_median = data[best_column].median()

tree["column"] = best_column
tree["median"] = column_median

left_split = data[data[best_column] <= column_median]
right_split = data[data[best_column] > column_median]
split_dict = [["left", left_split], ["right", right_split]]

for name, split in split_dict:
print(name)
tree[name] = {}
id3(split, target, columns, tree[name])

id3(data, "high_income", ["age", "marital_status"], tree)
``````

Right now our data is like what you see in the table below. After we pass it to our function, the data will be split based on `age` or `marital_status` column (whichever is best). Since this is a recursive algorithm, it will keep on calling itself for each split — left and right. So in this visualization, we will keep track of the best column, median value, left and right split.

high_income age marital_status
0 0 20 0
1 0 60 2
2 0 40 1
3 1 25 1
4 1 35 2
5 1 55 1

5 Likes

i don’t understand, I ran this code, it doesn’t even work. it just says name ‘tree’ is not defined.

Also, how is it that at the end of this function, there’s no return statement?

What was this function supposed to do without a return statement?

Hi @evayansitan,

The above post was primarily created to help students understand these missions thoroughly.

So in the above post, I only explained the working of the `id3()` function and excluded the rest of the code.

If you are interested, Here is the full code:

``````import math
import numpy
import pandas

def calc_entropy(column):
"""
Calculate entropy given a pandas series, list, or numpy array.
"""
# Compute the counts of each unique value in the column
counts = numpy.bincount(column)
# Divide by the total column length to get a probability
probabilities = counts / len(column)

# Initialize the entropy to 0
entropy = 0
# Loop through the probabilities, and add each one to the total entropy
for prob in probabilities:
if prob > 0:
entropy += prob * math.log(prob) # Base e
# Change base to 2, 6 or 10.
# Example: math.log(prob, 2)
# The result will still be same.
return -entropy

data = pandas.DataFrame([
[0,20,0],
[0,60,2],
[0,40,1],
[1,25,1],
[1,35,2],
[1,55,1]
])

data.columns = ["high_income", "age", "marital_status"]

def calc_information_gain(data, split_name, target_name):
"""
Calculate information gain given a data set, column to split on, and target
"""
# Calculate the original entropy
original_entropy = calc_entropy(data[target_name])

# Find the median of the column we're splitting
column = data[split_name]
median = column.median()

# Make two subsets of the data, based on the median
left_split = data[column <= median]
right_split = data[column > median]

# Loop through the splits and calculate the subset entropies
to_subtract = 0
for subset in [left_split, right_split]:
prob = (subset.shape[0] / data.shape[0])
to_subtract += prob * calc_entropy(subset[target_name])

# Return information gain
return original_entropy - to_subtract

def find_best_column(data, target_name, columns):
information_gains = []
# Loop through and compute information gains
for col in columns:
information_gain = calc_information_gain(data, col, "high_income")
information_gains.append(information_gain)

# Find the name of the column with the highest gain
highest_gain_index = information_gains.index(max(information_gains))
highest_gain = columns[highest_gain_index]
return highest_gain

tree = {}
nodes = []
spaces = 20
def id3(data, target, columns, tree):
print(data)
unique_targets = pandas.unique(data[target])
nodes.append(len(nodes) + 1)
tree["number"] = nodes[-1]

if len(unique_targets) == 1:
if 0 in unique_targets:
tree["label"] = 0
elif 1 in unique_targets:
tree["label"] = 1
return

best_column = find_best_column(data, target, columns)
column_median = data[best_column].median()

tree["column"] = best_column
tree["median"] = column_median

left_split = data[data[best_column] <= column_median]
right_split = data[data[best_column] > column_median]
split_dict = [["left", left_split], ["right", right_split]]

for name, split in split_dict:
print(name)
tree[name] = {}
id3(split, target, columns, tree[name])

id3(data, "high_income", ["age", "marital_status"], tree)
``````

It is not necessary for a function to have a return statement. That is, we can create a simple function that performs a certain set of actions like add 2 numbers and print their sum. In the above case, we do have a return statement, however, that is not for returning any value. The `id3()` function is a recursive function. That is, it will keep on calling itself again so we used an empty return statement to terminate the function based on a condition.

The function is supposed to create a decision tree (a dictionary named `tree`) using which we can predict the `high_income` (first column in the `data` variable) column using `age` and `marital_status` as input.

``````def predict(tree, row):
if "label" in tree:
return tree["label"]

column = tree["column"]
median = tree["median"]
if row[column] <= median:
return predict(tree["left"], row)
else:
return predict(tree["right"], row)

print(predict(tree, data.iloc[0])) # Prediction: 0 (Low Income)
``````

Hope this helps.

Best,
Sahil

my code:

``````data = pd.DataFrame([
[0,20,0],
[0,60,2],
[0,40,1],
[1,25,1],
[1,35,2],
[1,55,1]
])
# Assign column names to the data
data.columns = ["high_income", "age", "marital_status"]

# Call the function on our data to set the counters properly
id3(data, "high_income", ["age", "marital_status"])
label_1s = []
label_0s = []

def id3(data, target, columns):
unique_targets = pd.unique(data[target])

if len(unique_targets) == 1:
if 0 in unique_targets:
label_0s.append(0)
elif 1 in unique_targets:
label_1s.append(1)
return

best_column = find_best_column(data, target, columns)
column_median = data[best_column].median()

left_split = data[data[best_column] <= column_median]
right_split = data[data[best_column] > column_median]

for split in [left_split, right_split]:
id3(split, target, columns)

id3(data, "high_income", ["age", "marital_status"])
``````

my result: RecursionError: maximum recursion depth exceeded while calling a Python object

my expectation: an error-free result

my notebook: Decision Tree.ipynb (30.7 KB)

Click here to view the jupyter notebook file in a new tab

1 Like

That’s weird. When I ran your code, I received a function not defined error. It was fixed when I removed this code:

``````# Call the function on our data to set the counters properly
id3(data, "high_income", ["age", "marital_status"])
``````

Here is the .ipynb file:
3SS5RDED0DZnYMWKqYaIrQYRONA.ipynb (27.8 KB)

Previously was receiving recursion depth error. I modified my code to say

Blockquote
id3(split, target, data.columns, tree[name])

This fixed the error, but the decision tree outputted does not look correct

1 Like

Recursion Issue and Incorrect Tree output has been resolved. The error was in in find_best_column function.

1 Like

hi @tfu8

Can you please share what was wrong with find_best_column function?

Thanks.

Hi @Sahil

I downloaded the attached jp nb file and commented only two codes -` !wget... part and set recursion limit`

I neither get an error nor an output in the attached file. But in my own file, I get `RecursionError: maximum recursion depth exceeded while calling a Python object` error.
please scroll to the end code cells.

Hi @Rucha,

I haven’t checked the functioning of your `id3` function but based on comparison, the only difference between @photonsinnovate 's code and your code is the following:

``````def find_best_column(data, target_name, columns):
information_gains = []

for each_col in columns:
information_gains.append(calc_information_gain(preprocessed_income, each_col, "high_income"))

highest_gain = columns[information_gains.index(max(information_gains))]
return highest_gain
``````

On this line:
`information_gains.append(calc_information_gain(preprocessed_income, each_col, "high_income"))`

You are passing `preprocessed_income` instead of `data`.

Once you change it to `data`, you can see that, your project is resulting in the same output (to check result, view `label_1s` and `label_0s`) as @photonsinnovate.

However, I have observed some flaws in dataquest soltution code. If there are multiple rows in a node which cannot be split further, then it will keep on working with the same data forever because the condition to terminate the function is to have the number of `unique_targets` equal to 1. So if it never reaches one and if the further split is not possible, it will run forever.

I will explore this further next week and log it as a bug if necessary.

Best,
Sahil

1 Like

Hi @Sahil

Good catch! and for me, it’s yet again a for me.

Thank you for taking it up. I guess I will redo the mission meanwhile.

1 Like

Hi @Rucha,

I have validated the error in the solution code. The solution code is not designed to handle cases when the multiple rows in a node cannot be split further. Hence, it results in an infinite loop, as you can see here:

I have logged this issue.

Best,
Sahil