The BSTree data structure

Agenda

  • API
  • Implementation
    • Addition
    • Search
    • Removal
    • Iteration / Traversal

API

In [1]:
class BSTree:
    class Node:
        def __init__(self, val, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right
            
    def __init__(self):
        self.size = 0
        self.root = None
    
    def add(self, val):
        """Adds `val` to this tree while maintaining BSTree properties."""
        assert(val not in self)
        pass
    
    def __contains__(self, val):
        """Returns `True` if val is in this tree and `False` otherwise."""
        pass
    
    def __len__(self):
        return self.size
    
    def __delitem__(self, val):
        """Removes `val` from this tree while maintaining BSTree properties."""
        assert(val in self)
        pass
    
    def __iter__(self):
        """Returns an iterator over all the values in the tree, in ascending order."""
        pass
    
    def pprint(self, width=64):
        """Attempts to pretty-print this tree's contents."""
        height = self.height()
        nodes  = [(self.root, 0)]
        prev_level = 0
        repr_str = ''
        while nodes:
            n,level = nodes.pop(0)
            if prev_level != level:
                prev_level = level
                repr_str += '\n'
            if not n:
                if level < height-1:
                    nodes.extend([(None, level+1), (None, level+1)])
                repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
            elif n:
                if n.left or level < height-1:
                    nodes.append((n.left, level+1))
                if n.right or level < height-1:
                    nodes.append((n.right, level+1))
                repr_str += '{val:^{width}}'.format(val=n.val, width=width//2**level)
        print(repr_str)
    
    def height(self):
        """Returns the height of the longest branch of the tree."""
        def height_rec(t):
            if not t:
                return 0
            else:
                return max(1+height_rec(t.left), 1+height_rec(t.right))
        return height_rec(self.root)

Implementation

Addition

In [2]:
class BSTree(BSTree):
    def add(self, val):
        assert(val not in self)
        def add_rec(node):
            if not node:
                return BSTree.Node(val)
            elif val < node.val:
                node.left = add_rec(node.left)
                return node
            else:
                node.right = add_rec(node.right)
                return node
        self.root = add_rec(self.root)
        self.size += 1
In [3]:
import random
t = BSTree()
vals = list(range(5))
random.shuffle(vals)
for x in vals:
    t.add(x)
t.pprint()
                               1                                
               0                               4                
       -               -               2               -        
   -       -       -       -       -       3       -       -    
In [4]:
# alternatively ...

class BSTree(BSTree):
    def add(self, val):
        assert(val not in self)
        def add_rec(node):
            if not node:
                return BSTree.Node(val)
            elif val < node.val:
                return BSTree.Node(node.val, add_rec(node.left), node.right)
            else:
                return BSTree.Node(node.val, node.left, add_rec(node.right))
        self.root = add_rec(self.root)
        self.size += 1
In [5]:
import random
t = BSTree()
vals = list(range(5))
random.shuffle(vals)
for x in vals:
    t.add(x)
t.pprint()
                               2                                
               1                               3                
       0               -               -               4        
In [6]:
class BSTree(BSTree):
    def __contains__(self, val):
        def contains_rec(node):
            if not node:
                return False
            elif val < node.val:
                return contains_rec(node.left)
            elif val > node.val:
                return contains_rec(node.right)
            else:
                return True
        return contains_rec(self.root)
In [7]:
import random
t = BSTree()
vals = list(range(1, 10, 2))
random.shuffle(vals)
for x in vals:
    t.add(x)

assert(all(x in t for x in range(1, 10, 2)))
assert(all(x not in t for x in range(0, 12, 2)))

Removal

In [8]:
class BSTree(BSTree):
    def __delitem__(self, val):
        assert(val in self)
        def delitem_rec(node):
            if val < node.val:
                node.left = delitem_rec(node.left)
                return node
            elif val > node.val:
                node.right = delitem_rec(node.right)
                return node
            else:
                # only deal with simple cases first
                if not node.left and not node.right:
                    return None
                elif node.left and not node.right:
                    return node.left
                elif node.right and not node.left:
                    return node.right
                else:
                    pass #?!
                
        self.root = delitem_rec(self.root)
        self.size -= 1
In [9]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
t.pprint()
                               10                               
               5                               15               
       2               -               -               17       
In [10]:
del t[2]
t.pprint()
                               10                               
               5                               15               
       -               -               -               17       
In [11]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
del t[5]
t.pprint()
                               10                               
               2                               15               
       -               -               -               17       
In [12]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
del t[15]
t.pprint()
                               10                               
               5                               17               
       2               -               -               -        
In [13]:
class BSTree(BSTree):
    def __delitem__(self, val):
        assert(val in self)
        def delitem_rec(node):
            if val < node.val:
                node.left = delitem_rec(node.left)
                return node
            elif val > node.val:
                node.right = delitem_rec(node.right)
                return node
            else:
                if not node.left and not node.right:
                    return None
                elif node.left and not node.right:
                    return node.left
                elif node.right and not node.left:
                    return node.right
                else:
                    # remove the largest value from the left subtree as a replacement
                    # for the root value of this tree
                    t = node.left # refers to the candidate for removal
                    if not t.right:
                        node.val = t.val
                        node.left = t.left
                    else:
                        n = t
                        while n.right.right:
                            n = n.right
                        t = n.right
                        n.right = t.left
                        node.val = t.val
                    return node
                        
        self.root = delitem_rec(self.root)
        self.size -= 1
In [14]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
t.pprint()
                               10                               
               5                               15               
       2               7               12              18       
   1       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  
In [15]:
del t[15]
t.pprint()
                               10                               
               5                               12               
       2               7               -               18       
   1       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  
In [16]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
del t[5]
t.pprint()
                               10                               
               2                               15               
       1               7               12              18       
   -       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  
In [17]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
del t[10]
t.pprint()
                               9                                
               5                               15               
       2               7               12              18       
   1       -       -       8       -       -       -       -    

Iteration / Traversal

In [28]:
class BSTree(BSTree):
    def __iter__(self):
        def iter_rec(node):
            if node:
                for x in iter_rec(node.left):
                    yield x
                yield node.val
                for x in iter_rec(node.right):
                    yield x
        for x in iter_rec(self.root):
            yield x
In [29]:
import random
t = BSTree()
vals = list(range(20))
random.shuffle(vals)
for x in vals:
    t.add(x)
for x in t:
    print(x)
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
In [30]:
class BSTree(BSTree):
    def __iter__(self):
        def iter_rec(node):
            if node:
                yield from iter_rec(node.left)
                yield node.val
                yield from iter_rec(node.right)
        yield from iter_rec(self.root)
In [31]:
import random
t = BSTree()
vals = list(range(20))
random.shuffle(vals)
for x in vals:
    t.add(x)
for x in t:
    print(x)
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19