231-10 Help with BST tree

Hi, I need help figuring out an issue with the code on BST tree (https://app.dataquest.io/m/231/working-with-binary-search-trees/10/range-querying-a-csv).

I copy-pasted the solution to Jupyter notebook and kept getting error messages. I got an error on “amount_rows = list(r[0], float(r[1])) for r in reader)” so I changed r[0] to an integer, which got ride of the error. Then I was getting another error about None type has no key for the rotations, so I added two if statements to make sure the the nodes are not empty when the following evaluations are executed:

 if difference > 1: 
            if self.node.right.node:  # added code to avoid errors in Jupyter notebook
                if key > self.node.right.node.key: # Left-right case.
                    self.node.left.left_rotate()
                self.right_rotate()  # Left-left case.

if difference < -1: 
            if self.node.left.node:  # added code
                if key <= self.node.left.node.key: # Right-left case.
                    self.node.left.right_rotate()
                self.left_rotate()

After I did that, the error messages are gong, but the insert function is still running after a long time. Note, To make it clean and easy to debug, I have moved all the functions into one cell under a Node and BST classes (see below). It would be great if someone can tell me why the code dosen’t finish running. It ran much faster on DQ platform. Thanks in advance,

Here is the whole thing,

class Node:
    def __init__(self, key=None, value=None):
        self.left = None
        self.right = None
        self.key = key
        self.value = value

    def __str__(self):
        return "<Node: {}>".format(self.value)
    
class BST():
    def __init__(self, index=None):
        self.node = None
        self.index = index
    
    def insert_multiple(self, values):
        for value in values:
            self.insert(value)
    
    def depth(self, node):
        if not node:
            return 0
        if not node.left and not node.right:
            return 1
        
        return max(self.depth(node.left.node), self.depth(node.right.node)) + 1
    
    def insert(self, value = None):  # value = list
        key = value
        if self.index:
            key = value[self.index]
        node = Node(key=key, value=value)
        
        if not self.node:
            self.node = node
            self.node.left = BST(index = self.index)
            self.node.right = BST(index = self.index)
            return
        if key > self.node.key:
            if self.node.right:
                self.node.right.insert(value = value)
            else:
                self.node.right.node = node
        else:
            if self.node.left:
                self.node.left.insert(value=value)
            else:
                self.node.left.node = node
                
        difference = self.depth(self.node.left.node) - self.depth(self.node.right.node)
        
        if difference > 1: # Left side case.
            if self.node.right.node:
                if key > self.node.right.node.key: # Left-right case.
                    self.node.left.left_rotate()
                self.right_rotate()  # Left-left case.
            
        if difference < -1: # Right side case.
            if self.node.left.node:
                if key <= self.node.left.node.key: # Right-left case.
                    self.node.left.right_rotate()
                self.left_rotate()
    
    def search(self, key):
        if not self.node:
            return False
        if key == self.node.key:
            return True
        
        result = False
        if self.node.left:
            result = self.node.left.search(key)
        if self.node.right:
            result = self.node.right.search(key)
        return result
    
    def inorder(self, tree):
        if not tree or not tree.node:
            return []
        return (
            self.inorder(tree.node.left) +
            [tree.node.value] +
            self.inorder(tree.node.right)
        )

    def left_rotate(self):
        old_node = self.node
        new_node = self.node.right.node
        if not new_node:
            return
        
        new_right_sub = new_node.left.node
        self.node = new_node
        old_node.right.node = new_right_sub
        new_node.left.node = old_node
    
    def right_rotate(self):
        old_node = self.node
        new_node = self.node.left.node
        if not new_node:
            return
    
    def greater_than(self, key):
        if not self.node:
            return []
        
        values = []
        if self.node.left:
            values += self.node.left.greater_than(key)
        if self.node.right:
            values += self.node.right.greater_than(key)
        if self.node.key > key:
            values.append(self.node.value)
        return values
    
import csv
with open('amounts.csv', 'r') as f:
    reader = csv.reader(f)
    next(reader)
    # query on the second column, the amounts field to find rows
    amount_rows = list((int(float(r[0])), float(r[1])) for r in reader)

bst = BST(index=1)
bst.insert_multiple(amount_rows)
print('finished insert')
csv_query = bst.greater_than(10)
print(csv_query)
1 Like

Just to make it clear, without modifying the code, run the original code in Jupyter produced the following error message. Having a hard time figuring out how to fix this. Any help is appreciated.

<ipython-input-16-680849201e3c> in insert(self, value)
     58         if difference < -1:
     59             # Right-left case.
---> 60             if value <= self.node.left.node.value:
     61                 self.node.left.right_rotate()
     62             self.left_rotate()

AttributeError: 'NoneType' object has no attribute 'value'

Hi @xuehong.liu.pdx,

The right_rotate() method in your BST class is incomplete. It is missing the following lines:

new_left_sub = new_node.right.node
self.node = new_node
old_node.left.node = new_left_sub
new_node.right.node = old_node

After adding them to your code I obtain the same performance as I did using the DQ implementation. It is normal that without them you cannot achieve fast performance because the tree is not being balanced.

Here is the full DQ implementation for reference:

class Node:
    def __init__(self, key=None, value=None):
        self.left = None
        self.right = None
        self.value = value
        self.key = key
    
    def __str__(self):
        return "<Node: {}>".format(self.value)

class BST:
    def __init__(self, index=None):
        self.node = None
        self.index = index
    
    def insert_multiple(self, values):
        for value in values:
            self.insert(value)
    
    def insert(self, value=None):
        key = value
        if self.index:
            key = value[self.index]
        node = Node(key=key, value=value)
        
        if not self.node:
            self.node = node
            self.node.left = BST(index=self.index)
            self.node.right = BST(index=self.index)
            return
        
        if key > self.node.key:
            if self.node.right:
                self.node.right.insert(value=value)
            else:
                self.node.right.node = node
        else:
            if self.node.left:
                self.node.left.insert(value=value)
            else:
                self.node.left.node = node
        
        difference = self.depth(self.node.left.node) - self.depth(self.node.right.node)
        
        if difference > 1:
            if self.node.right.node and key > self.node.right.node.key:
                self.node.left.left_rotate()
            self.right_rotate()
            
        if difference < -1:
            if self.node.left.node and key <= self.node.left.node.key:
                self.node.left.right_rotate()
            self.left_rotate()
    
    def inorder(self, tree):
        if not tree or not tree.node:
            return []
        return (
            self.inorder(tree.node.left) +
            [tree.node.value] +
            self.inorder(tree.node.right)
        )
    
    def search(self, key):
        if not self.node:
            return False
        if key == self.node.key:
            return True
        
        result = False
        if self.node.left:
            result = self.node.left.search(key)
        if self.node.right:
            result = self.node.right.search(key)
        return result

    def depth(self, node):
        if not node:
            return 0
        if not node.left and not node.right:
            return 1
        
        return max(self.depth(node.left.node), self.depth(node.right.node)) + 1
    
    def left_rotate(self):
        old_node = self.node
        new_node = self.node.right.node
        if not new_node:
            return
        
        new_right_sub = new_node.left.node
        self.node = new_node
        old_node.right.node = new_right_sub
        new_node.left.node = old_node
    
    def right_rotate(self):
        old_node = self.node
        new_node = self.node.left.node
        if not new_node:
            return
        
        new_left_sub = new_node.right.node
        self.node = new_node
        old_node.left.node = new_left_sub
        new_node.right.node = old_node

    def is_balanced(self):
        if not self.node:
            return True
        
        left_subtree = self.depth(self.node.left.node)
        right_subtree = self.depth(self.node.right.node)
        
        return abs(left_subtree - right_subtree) < 2
    
    def greater_than(self, key):
        if not self.node:
            return []
        values = []
        if self.node.left:
            values += self.node.left.greater_than(key)
        if self.node.right:
            values += self.node.right.greater_than(key)
        if self.node.key > key:
            values.append(self.node.value)
        return values

Regarding the other error, since I only have the version of the code that you fixed I don’t know what the problem was.

I hope it helps :slight_smile: Let me know if you need anything else.

1 Like

Hi @Francois,

Thank you so much for responding. I copy pasted your code to a new cell then ran it. It has been 20 min and the insertion part is still not over. It seems to be running since the difference variable is changing. But it took way too long to run.

I decided to reduced the amount of data and check the relationship between the number of records for insertion and the run time, it turns out to be exponential (see graph below). That might explain why it could not finish at a reasonable time on my computer when I included the whole file. Does this make sense to you? The funny thing is that most of the time in the past, when I ran the same code in Jupyter notebook, they tended to run faster than on DQ platform. This one is an exception for me.

Could you show how you would figure out the big O notation for insertion? It took a lot longer than searching. Could you also explain why it ran so much faster on the DQ platform? Sorry about all the questions but it will really help.

Best,
Xuehong

image

@xuehong.liu.pdx,

This is very strange. I just ran the code and it takes 4 second to insert all. This is the code I ran to insert all:


import csv
with open('amounts.csv', 'r') as f:
    reader = csv.reader(f)
    next(reader)
    amount_rows = list((r[0], float(r[1])) for r in reader)
    
bst = BST(index=1)
import time
start = time.time()
bst.insert_multiple(amount_rows)
end = time.time()
print('finished inserting, time = {} seconds'.format(end - start))
finished inserting, time = 4.002045154571533 seconds

What did you run that took more than 20 seconds?

I am not sure that the plot you showed is showing exponential growth. I would guess it is quadratic. This will happen if the i-th insertion is taking O(i) making a total runtime for N insertions be:

1 + 2 + \ldots + N = O(N^2)

This behavior is what we would get when the tree is failing to self-balance because at each new insertion it will visit all previously inserted values.

In general, the time complexity of an insert is the height of the tree. To insert a value we need to make a search to find out where to put it and then go up that path to balance the tree. We do at most one balance operation per node in the path from the root of the tree to the new node and each rotation is O(1) (or should be if the implementation is correct). Therefore the total amount of work is O(H) where H is the height of the tree.

If the balancing is correctly implemented then H = O(\log N) and so, N insertions will take O(N \log N). If the balancing is not working then H = O(N) and so, N insertions will take O(N^2) as mentioned above.

Hi @Francois,

This is indeed very strange. Is it possible that my computer is too old? I am using a Dell Inspiron 15 with a Intel Corei7 chip. Still, it took over 5 min on my end to run 5000 records. I didn’t run anything before this. I just restarted the kernal and ran it again, got the same result. Here is the code I used to generate and graph the data. I am curious to see how will it look when you run it on your computer.

Thank you so much!

import time
start = time.time()
import csv
with open('amounts.csv', 'r') as f:
    reader = csv.reader(f)
    next(reader)
    amount_rows = list((int(float(r[0])), float(r[1])) for r in reader)
    # note: I have to use int(float(r[0])) in place of r[0] otherwise I got an int error

durations_insert=[]
durations_search=[]

bst = BST(index=1)
for i in range(1000, 5000, 500):
    data = amount_rows[0:i]
    bst.insert_multiple(data)
    ins_time = time.time()
    durations_insert.append(ins_time - start)
    csv_query = bst.greater_than(10)
    durations_search.append(time.time() - ins_time)

import matplotlib.pyplot as plt
%matplotlib inline
plt.plot(range(1000, 5000, 500), durations_insert, color='blue')
plt.plot(range(1000, 5000, 500), durations_search, color='red')
plt.xlabel('number of records included')
plt.ylabel('duration for each run')
plt.legend(['insertion', 'searching'])

@xuehong.liu.pdx, I think I solved the mystery!

First of all, no, it is not a problem of your computer. With such a BST the time needed to add N elements should be almost the same as adding N elements to a list. Here the experiments are with N \leq 5000 which is nothing.

I thought the problem might be the depth as I mentioned before. To investigate that, I augmented the code so that it would keep track of the depth after each insertion. The depth was behaving as expected meaning that the tree is indeed self-balancing correctly.

So I decided to profile the code. I used this profiler which gives a line by line profile. This were the results:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    42                                               @profile
    43                                               def insert(self, value= None):
    44     61809      32172.0      0.5      0.0          key = value
    45     61809      35052.0      0.6      0.0          if self.index:
    46                                                       key = value[self.index]
    47     61809     142795.0      2.3      0.1          node = Node(key=key, value=value)
    48     61809      35993.0      0.6      0.0          if not self.node:
    49      5000       3224.0      0.6      0.0              self.node = node
    50      5000      15031.0      3.0      0.0              self.node.left = BST(index=self.index)
    51      5000      22610.0      4.5      0.0              self.node.right = BST(index=self.index)
    52      5000       2710.0      0.5      0.0              return
    53     56809      41559.0      0.7      0.0          if key > self.node.key:
    54     56809      34086.0      0.6      0.0              if self.node.right:
    55     56809      86026.0      1.5      0.0                  self.node.right.insert(value=value)
    56                                                       else:
    57                                                           self.node.right.node = node
    58                                                   else:
    59                                                       if self.node.left:
    60                                                           self.node.left.insert(value=value)
    61                                                       else:
    62                                                           self.node.left.node = node
    63                                                  
    64     56809  178237923.0   3137.5     99.6          difference = self.depth(self.node.left.node) - self.depth(self.node.right.node)
    65                                                   
    66     56809      74097.0      1.3      0.0          start = time.time()
    67     56809      39913.0      0.7      0.0          if difference > 1: 
    68                                                       if self.node.right.node and key > self.node.right.node.key:
    69                                                           self.node.left.left_rotate()
    70                                                       self.right_rotate() 
    71     56809      34305.0      0.6      0.0          if difference < -1: 
    72      4987       4725.0      0.9      0.0              if self.node.left.node and key <= self.node.left.node.key:
    73                                                           self.node.left.right_rotate()
    74      4987      67538.0     13.5      0.0              self.left_rotate()
    75     56809      37833.0      0.7      0.0          end = time.time()

As you can see 99.6% of the times is spent computing the depth difference:

    64     56809  178237923.0   3137.5     99.6          difference = self.depth(self.node.left.node) - self.depth(self.node.right.node)

Looking at the code of depth() this is not surprising! The code to compute the depth is O(N) as it needs to inspect the whole tree! Usually, implementations of BSTs will keep the depth information as a node attribute and correct it in O(1) (constant time) when performing rotations.

@Francois,

That’s good to know. Thanks. But I am still not clear why it ran so much faster on your machine than mine if it is not my computer.

When I ran the last code that you posted, I got times that were very similar to yours. Maybe we made a slightly different experiment the first time around.

Interesting @Francois . I used the code you shared to define the Node and BST class. The ‘amount.csv’ file I have has 284807 rows. But I only used a max of 5000 for testing, which took over 5 min.