Random Forest from scratch code review

I just coded random forest from scratch using OOP based on functions from https://machinelearningmastery.com/implement-resampling-methods-scratch-python/

Before i start trying to prune, hoping someone could give some advice on how it can be improved? (coding style/ functional correctness/usefulness/speed).
Such as am i using bad practices anywhere, too little/much parameter passing, lack of inheritance, tree terminating conditions, slow data structures etc… One tree alone currently takes 3 minutes to fit data of 1000 rows and 3 columns

class DecisionTree:
    
    def __init__(self, data, n_features):
        self.data = data
        self.n_features = n_features 
        self.index = None
        self.value = None
        self.prediction = None
    
    def test_split(self, dataset, index, value):
        left = right = np.array([]).reshape(0,dataset.shape[1])
        
        for row in dataset:
            if row[index] < value:          # group splitting decision is made on this inequality
                left = np.vstack([left,row])
            else:
                right = np.vstack([right,row])
        
        #print(f'left is {left}')
        #print(f'right is {right}')
        
        return left, right

    # Calculate the Gini index for a split dataset
    def gini_index(self,groups):
        """Gini of separated groups is calculated, best score is the smallest number, 
        gini before split is ignored because it's the same for any split"""

        # count all samples at split point
        n_instances = sum([len(group) for group in groups])

        gini = 0.0
        for group in groups:
            size = len(group)
            # avoid divide by zero
            if size == 0:
                continue
            score = 0.0
            # score the group based on the score for each class
            p = np.unique(group[:,-1],return_counts=True)[1]/size    # [1] to get counts
            score = sum(p**2)
            # weight the group score by its relative size
            gini += (1.0 - score) * (size / n_instances)  
        return gini
    
    def get_split(self,dataset):
        """
        Iterate though all predictor columns and 
        test each observation value in each column as a splitting point 
        """
        best_index, best_value, best_score, best_groups = 999, 999, 999, None      # initialize extreme values to be overwritten
        
        features_selected = np.random.choice(dataset.shape[1]-1,self.n_features,replace=False)
        
        #print(f'data to test:\n {dataset}')
        #print(f'features_selected: {features_selected}')
        
        for index in features_selected:     # test all predictors (all col except last label col)
            column_values = dataset[:,index]
            
            #print(f'values to test: {column_values}')
            for value in np.unique(column_values):
                groups = self.test_split(dataset, index, value)
                gini = self.gini_index(groups)
                
                #print(f'index is {index}, value is {value}, gini is {gini}')
                if gini < best_score:
                    best_index, best_value, best_score, best_groups = index, value, gini, groups
                    
       #print(f'best index: {best_index}, best_value: {best_value}')
        return best_index, best_value, best_score, best_groups
    
    def to_terminal(self,group):
            outcomes = group[:,-1].astype(int)    # np.bincount requires int not float
            return np.argmax(np.bincount(outcomes))
            
    
    def fit(self):   # initialize first tree with data passed into class, recurse with data after split
        #print("self.data is:\n ",self.data)
       
        if len(set(self.data[:,-1].flatten())) == 1: # if all classes are the same value
            self.prediction = int(self.data[0][-1]) # take the last value of any row (1st in this case)
            #print(f'self.prediction is: {self.prediction}')
            return self
            
        self.index, self.value, best_gini, best_groups = self.get_split(self.data)
        left, right = best_groups
        
        #print(f"left is \n{left}")
        #print(f"right is \n{right}")
        
        if left.size == 0:     # in cases of no further split possible, left will be empty and all information will flow to right
            self.prediction = self.to_terminal(right)
            return self
        
        if best_gini == 0:
            #print('best gini is 0 --> Terminate')
            dt_left = DecisionTree(left, self.n_features)
            dt_left.prediction = self.to_terminal(left)
            dt_right = DecisionTree(right, self.n_features)
            dt_right.prediction = self.to_terminal(right)
            
            self.left =  dt_left
            self.right =  dt_right
            
            #print(f'best_gini = 0, self.left is: {self.left.prediction}')
            #print(f'type of self.left prediction is {type(self.left.prediction)}')
            #print(f'best_gini = 0, self.right is: {self.right.prediction}')
            return self
        
        #print(f'recursing left:')
        dt = DecisionTree(left, self.n_features)
        self.left = dt.fit()
           
        #print(f'recursing right:')
        dt = DecisionTree(right, self.n_features)
        self.right = dt.fit()
                
        return self
    

            
    def print_tree(self, depth=0):
        if self.prediction != None:    
            print('%s[%s]' % ((depth*'    ', self.prediction)))
        else:
            print('%s[X%d < %.3f]' % (depth*'    ', self.index, self.value))
            self.left.print_tree(depth+1)
            self.right.print_tree(depth+1)        
            
    def predict(self,row):
            if self.prediction != None:
                return self.prediction
            else:
                if row[self.index] < self.value:
                    return self.left.predict(row)
                else:
                    return self.right.predict(row)
                

class RandomForest:
    
    def __init__(self, data, n_trees, n_features, ratio = 1):
        self.data = data
        self.n_trees = n_trees
        self.n_features = n_features
        self.ratio = ratio
        self.trees = []
        
    def subsample(self):
        row_idx = np.random.choice(len(self.data),round(len(self.data)* self.ratio))
        return self.data[row_idx]
        
    def fit(self):
        for i in range(self.n_trees):
            sample = self.subsample()
            dt = DecisionTree(sample, self.n_features)
            tree = dt.fit()
            self.trees.append(tree)
            
    def predict(self, row):
        predictions = [tree.predict(row) for tree in self.trees]
        return max(set(predictions), key=predictions.count)
            
training_data = np.array(
    [[10, 3, 1],
    [6, 3, 0],
    [6, 3, 0],
    [6, 1, 1],
    [6, 1, 1],
])


rf = RandomForest(training_data, n_trees=3, n_features=1, ratio = 1)
rf.fit()
for index, tree in enumerate(rf.trees):
    print('{:*^80}{}'.format('tree',index))
    tree.print_tree()

This line of code

outcomes = group[:,-1].astype(int)

You should preprocess group[:,-1] as .astype(int) before you do anything. Every function call to self.to_terminal, you are converting it to integers. It is very inefficient way to convert integer multiple times.

You only need to convert it once. This is using up valuable time.

Move the conversion before doing any random forest.

max(set(predictions)
You don’t need to use a set on predictions since predictions is already an iterable

I understand you are trying to prune the data into smaller distinct values by using a set. If that is the case, then

def predict(self, row): 
    predictions = set()
    for tree in self.trees:
           predictions.add(tree.predict(row))
    return max(predictions, key=predictions.count)

Every single addition to the set, you will only be comparing it once whether to insert/add a unique data into predictions.

For the to_terminal, problem is floats are being generated from int every time np.vstack([left,row]) is called in test_split. The problem is not vstack but the empty array used for appending np.array([]).reshape(0,dataset.shape[1]) that turned the dataset into floats. I used np.array([],dtype='int').reshape(0,dataset.shape[1]) and removed astype(int) and i see that the time to run code has increased from 2.85-3 secs to 3.55-4.2 secs. Do you think vstack has become slower with the dtype='int' specification?

You want to preprocess everything once before you perform any operation on them.

If you want the data in integer format, then convert all data in that column to integer.

You don’t need to convert the data into every single split data since it is already converted.

I don’t get your statement

You don’t need to use a set on predictions since predictions is already an iterator.

You mean predictions is an iterable, rather than iterator? My understanding is a list is not an iterator, but an iterable. Also, what has using set to find unique values got to do with whether the input to set is at interator/iterable?

Yes the purpose of set was to identify the unique classes in the prediction to minimize the number of times predictions.count has to be called. I think your code will not generate the correct value? It makes the predictions set have only 4 possible cases:

  1. (1)
  2. (0)
  3. (0,1)
  4. (1,0)

For first 2 cases the predictions of the forest will be correct, that happens when all trees predict 1 or 0. For case 3 and 4, there is only a single 1 and a single 0 in the set, there is no way to know which is the majority amongst the trees. We must leave the predictions from all the trees alone, and count for each class how many of that class there are among predictions, then select the class with max count.

Yes, I meant predictions is an iterable. Therefore you can use max function.

Would you suggest a better way of numpy array appending rather than the np.array([]).reshape(0,dataset.shape[1]) and np.vstack combination? Its not a problem of data here, but the numpy library methods turning my int data into floats

Why do you need append the data? Can’t you do it more efficiency by using reference to the data instead? By reshaping, you are creating objects and using more memory.

Good idea. I will try having only 1 copy of the data in memory and only keeping track of indexes while tree building.

Hey @hanqi,

We have a solved feature that allows you the ability to mark something as the “correct” answer, which helps future students with the same question quickly find the solution they’re looking for.

Here’s an article on how to mark posts as solved - I don’t want to do this for you until I know that solution/explanation works.

Thanks for the reminder. This wasn’t a problem to be “solved” but more of an open discussion seeking advice from others too, i have marked it anyway.

I stopped further replying because i am tired of the conversation (not only in this thread) often being diverted to another direction, or met with a load of technical details without my question being addressed directly first. Some times they are either ignored outright, or met with a counter question.
The discussion then turns into an ever growing ball of loose ends.

To put it shortly, i don’t ask much questions anymore because i don’t wish to get into unpleasant discussions that are missing the empathy i’m looking for.

Thank you for your feedback.