Decision Trees

Hasan Shahriar
8 min readMar 27, 2024

--

Decision Trees (DTs) are a non-parametric supervised learning method used for classification and regression. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features. A tree can be seen as a piecewise constant approximation.

Concepts

To build a decision tree from a dataset the following steps are needed to be considered.

  • start with all examples at the root node.
  • Calculate information gain for splitting on all possible features, and pick the one with the highest information gain
  • Split the dataset according to the selected feature, and create left and right branches of the tree
  • Keep repeating the splitting process until the stopping criteria are met

In this tutorial, we’ll implement the following functions, which will let us split a node into left and right branches using the feature with the highest information gain. The functions are:

  • Calculate the entropy at a node
  • Split the dataset at a node into left and right branches based on a given feature
  • Calculate the information gain from splitting on a given feature
  • Choose the feature that maximizes information gain.

We’ll then use the helper functions we’ve implemented to build a decision tree by repeating the splitting process until the stopping criteria is met.

Problem Statement and Dataset

Suppose someone is starting a company that grows and sells wild mushrooms.

  • Since not all mushrooms are edible, one would like to be able to tell whether a given mushroom is edible or poisonous based on its physical attributes
  • We have some existing data that we can use for this task.

We have 10 examples of mushrooms. For each example, we have

Three features

  • Cap Color (Brown or Red),
  • Stalk Shape (Tapering (as in \/) or Enlarging (as in /\)), and
  • Solitary (Yes or No)

And Label

  • Edible (1 indicating yes or 0 indicating poisonous)

Implementation

import the required packages, load data and check their shapes.

import numpy as np
import matplotlib.pyplot as plt

X_train = np.array([[1,1,1],[1,0,1],[1,0,0],[1,0,0],[1,1,1],[0,1,1],[0,0,0],[1,0,1],[0,1,0],[1,0,0]])
y_train = np.array([1,1,0,0,1,0,0,1,1,0])

print("First few elements of X_train:\n", X_train[:5])
print("Type of X_train:",type(X_train))

First few elements of X_train:
[[1 1 1]
[1 0 1]
[1 0 0]
[1 0 0]
[1 1 1]]
Type of X_train: <class ‘numpy.ndarray’>

print ('The shape of X_train is:', X_train.shape)
print ('The shape of y_train is: ', y_train.shape)
print ('Number of training examples (m):', len(X_train))

The shape of X_train is: (10, 3)
The shape of y_train is: (10,)
Number of training examples (m): 10

Calculate Entropy

First, we’ll write a helper function called compute_entropy that computes the entropy (measure of impurity) at a node.

  • The function takes in a numpy array (y) that indicates whether the examples in that node are edible (1) or poisonous(0)

Complete the compute_entropy() function below to:

  • Compute P_1, which is the fraction of examples that are edible (i.e. have value = 1 in y)
  • The entropy is then calculated as

Note:

  • The log is calculated with base 2
  • For implementation purposes, 0log_2(0)=0. That is, if p_1 = 0 or p_1 = 1, set the entropy to 0
  • Make sure to check that the data at a node is not empty (i.e. len(y) != 0). Return 0 if it is.
def compute_entropy(y):
"""
Computes the entropy for

Args:
y (ndarray): Numpy array indicating whether each example at a node is
edible (`1`) or poisonous (`0`)

Returns:
entropy (float): Entropy at that node

"""
entropy = 0.

if len(y) != 0:
p1 = len(y[y == 1]) / len(y)
if p1 != 0 and p1 != 1:
entropy = -p1 * np.log2(p1) - (1 - p1) * np.log2(1 - p1)

return entropy

Split the Dataset

Next, we’ll write a helper function called split_dataset that takes in the data at a node and a feature to split on and splits it into left and right branches. Then we'll implement code to calculate how good the split is.

  • The function takes in the training data, the list of indices of data points at that node, along with the feature to split on.
  • It splits the data and returns the subset of indices at the left and the right branch.
  • For example, say we’re starting at the root node (so node_indices = [0,1,2,3,4,5,6,7,8,9]), and we chose to split on feature 0, which is whether or not the example has a brown cap.
  • The output of the function is then, left_indices = [0,1,2,3,4,7,9] (data points with brown cap) and right_indices = [5,6,8] (data points without a brown cap)

For each index in node_indices

  • If the value of X at that index for that feature is 1, add the index to left_indices
  • If the value of X at that index for that feature is 0, add the index to right_indices
def split_dataset(X, node_indices, feature):
"""
Splits the data at the given node into
left and right branches

Args:
X (ndarray): Data matrix of shape(n_samples, n_features)
node_indices (list): List containing the active indices. I.e, the samples being considered at this step.
feature (int): Index of feature to split on

Returns:
left_indices (list): Indices with feature value == 1
right_indices (list): Indices with feature value == 0
"""

left_indices = []
right_indices = []

for i in node_indices:
if X[i][feature] == 1:
left_indices.append(i)
else:
right_indices.append(i)

return left_indices, right_indices

Now, let’s try splitting the dataset at the root node, which contains all examples at feature 0 (Brown Cap) as we’d discussed above.


root_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# The dataset only has three features, so this value can be 0 (Brown Cap), 1 (Tapering Stalk Shape) or 2 (Solitary)
feature = 0

left_indices, right_indices = split_dataset(X_train, root_indices, feature)

print("Left indices: ", left_indices)
print("Right indices: ", right_indices)

Left indices: [0, 1, 2, 3, 4, 7, 9]
Right indices: [5, 6, 8]

root_indices_subset = [0, 2, 4, 6, 8]
left_indices, right_indices = split_dataset(X_train, root_indices_subset, feature)

print("Left indices: ", left_indices)
print("Right indices: ", right_indices)

Left indices: [0, 2, 4]
Right indices: [6, 8]

Calculate information gain

Next, we’ll write a function called information_gain that takes in the training data, the indices at a node and a feature to split on and returns the information gain from the split.

The information gain function can be computer through the following equation:

here

  • H(p_1 ^ node) is entropy at the node
  • H(p_1 ^ Left) and H(p_1 ^ Right) are the entropies at the left and the right branches resulting from the split
  • w^left and w^right are the proportion of examples at the left and right branch, respectively.

Note:

  • we can use the compute_entropy() function that we implemented above to calculate the entropy.
def compute_information_gain(X, y, node_indices, feature):

"""
Compute the information of splitting the node on a given feature

Args:
X (ndarray): Data matrix of shape(n_samples, n_features)
y (array like): list or ndarray with n_samples containing the target variable
node_indices (ndarray): List containing the active indices. I.e, the samples being considered in this step.
feature (int): Index of feature to split on

Returns:
cost (float): Cost computed

"""
# Split dataset
left_indices, right_indices = split_dataset(X, node_indices, feature)


X_node, y_node = X[node_indices], y[node_indices]
X_left, y_left = X[left_indices], y[left_indices]
X_right, y_right = X[right_indices], y[right_indices]


information_gain = 0


node_entropy = compute_entropy(y_node)
left_entropy = compute_entropy(y_left)
right_entropy = compute_entropy(y_right)

w_left =len(X_left)/len(X_node)
w_right =len(X_right)/len(X_node)

weighted_entropy = w_left*left_entropy+w_right*right_entropy

information_gain = node_entropy-weighted_entropy

return information_gain

We can now check our implementation using the following code and calculate what the information gain would be from splitting on each of the features.

info_gain0 = compute_information_gain(X_train, y_train, root_indices, feature=0)
print("Information Gain from splitting the root on brown cap: ", info_gain0)

info_gain1 = compute_information_gain(X_train, y_train, root_indices, feature=1)
print("Information Gain from splitting the root on tapering stalk shape: ", info_gain1)

info_gain2 = compute_information_gain(X_train, y_train, root_indices, feature=2)
print("Information Gain from splitting the root on solitary: ", info_gain2)

Information Gain from splitting the root on brown cap: 0.034851554559677034
Information Gain from splitting the root on tapering stalk shape: 0.12451124978365313
Information Gain from splitting the root on solitary: 0.2780719051126377

Splitting on “Solitary” (feature = 2) at the root node gives the maximum information gain. Therefore, it’s the best feature to split on at the root node.

Get the best Split

Now let’s write a function get_best_split() to get the best feature to split on by computing the information gain from each feature as we did above and returning the feature that gives the maximum information gain.

  • The function takes in the training data, along with the indices of datapoint at that node.
  • The output of the function is the feature that gives the maximum information gain.
  • We can use the compute_information_gain() function to iterate through the features and calculate the information for each feature.
def get_best_split(X, y, node_indices):   
"""
Returns the optimal feature and threshold value
to split the node data

Args:
X (ndarray): Data matrix of shape(n_samples, n_features)
y (array like): list or ndarray with n_samples containing the target variable
node_indices (ndarray): List containing the active indices. I.e, the samples being considered in this step.

Returns:
best_feature (int): The index of the best feature to split
"""

num_features = X.shape[1]

best_feature = -1

max_info_gain=0
for feature in range(num_features):
info_gain = compute_information_gain(X,y,node_indices,feature)
if info_gain > max_info_gain:
max_info_gain = info_gain
best_feature = feature

return best_feature

best_feature = get_best_split(X_train, y_train, root_indices)
print("Best feature to split on: %d" % best_feature)

Best feature to split on: 2

Building the tree

we will use the functions we implemented above to generate a decision tree by successively picking the best feature to split on until we reach the stopping criteria (maximum depth is 2).

tree = []

def build_tree_recursive(X, y, node_indices, branch_name, max_depth, current_depth):
"""
Build a tree using the recursive algorithm that split the dataset into 2 subgroups at each node.
This function just prints the tree.

Args:
X (ndarray): Data matrix of shape(n_samples, n_features)
y (array like): list or ndarray with n_samples containing the target variable
node_indices (ndarray): List containing the active indices. I.e, the samples being considered in this step.
branch_name (string): Name of the branch. ['Root', 'Left', 'Right']
max_depth (int): Max depth of the resulting tree.
current_depth (int): Current depth. Parameter used during recursive call.

"""

# Maximum depth reached - stop splitting
if current_depth == max_depth:
formatting = " "*current_depth + "-"*current_depth
print(formatting, "%s leaf node with indices" % branch_name, node_indices)
return

# Otherwise, get best split and split the data
# Get the best feature and threshold at this node
best_feature = get_best_split(X, y, node_indices)

formatting = "-"*current_depth
print("%s Depth %d, %s: Split on feature: %d" % (formatting, current_depth, branch_name, best_feature))

# Split the dataset at the best feature
left_indices, right_indices = split_dataset(X, node_indices, best_feature)
tree.append((left_indices, right_indices, best_feature))

# continue splitting the left and the right child. Increment current depth
build_tree_recursive(X, y, left_indices, "Left", max_depth, current_depth+1)
build_tree_recursive(X, y, right_indices, "Right", max_depth, current_depth+1)

build_tree_recursive(X_train, y_train, root_indices, "Root", max_depth=2, current_depth=0)

Depth 0, Root: Split on feature: 2
- Depth 1, Left: Split on feature: 0

  • Left leaf node with indices [0, 1, 4, 7]
  • Right leaf node with indices [5]

- Depth 1, Right: Split on feature: 1

  • Left leaf node with indices [8]
  • Right leaf node with indices [2, 3, 6, 9]

A complete implementation of this can be found in this repo at github

--

--

Hasan Shahriar
Hasan Shahriar

Written by Hasan Shahriar

Software Engineer in ML/python over 5 years. MLOps & large Scale Distributed systems Enthusiast.

No responses yet