// File:     cs2set.h
// Author:   Chuck Stewart
// Purpose:  A binary-tree based implementation of a set class that is
//   similar to std::set.  The tree used is completely hidden behind the
//   user interface.
//
//   Not all of the functionality is implemented here.  Some will be
//   added in lecture and lab.  What's missing are the increment and
//   decrement iterator operations.  These will be discussed in
//   lecture, but not implemented here.
//
//   The only member variables of the cs2set class are a pointer to
//   the root of the tree and the size of the set.  All member
//   functions start from the root pointers.  Most of the work is done
//   in private member functions called from the public member
//   functions that pass the root pointer.  The public member
//   functions therefore effectively define the interface, whereas the
//   private functions do the real implementation

#ifndef cs2set_h_
#define cs2set_h_

#include <iostream>
#include <utility>


//  Auxiliary class for nodes in the tree.

template <class T>
class TreeNode {
public:
  TreeNode() : left(0), right(0) {}
  TreeNode(const T& init) : value(init), left(0), right(0) {}
  T value;
  TreeNode* left;
  TreeNode* right;
};


//  Forward declaration of the templated cs2set class so that the
//  iterator class allows it to be named as a friend.

template <class T> class cs2set;


///////////////////////////////////////////////////////////////////
///////////////         Iterator definitions          /////////////
///////////////////////////////////////////////////////////////////

//  These are incomplete because the iterator increment and
//  decrement operators are a bit beyond the scope of this course.

template <class T>
class tree_iterator {
private:
  TreeNode<T>* ptr_;
public:
  tree_iterator() : ptr_(0) {}
  tree_iterator( TreeNode<T>* p ) : ptr_(p) {}
  tree_iterator( const tree_iterator& old ) : ptr_(old.ptr_) {}
  ~tree_iterator() {}
  
  tree_iterator& operator=( const tree_iterator& old )
  { ptr_ = old.ptr_;  return *this; }
  
  //  operator* gives constant access to the value at the pointer
  const T& operator*() const
  {
    return ptr_->value;
  }

  //  Comparions operators are straightforward
  friend bool operator== ( const tree_iterator& lft, const tree_iterator& rgt )
  { return lft.ptr_ == rgt.ptr_; }
  
  friend bool operator!= ( const tree_iterator& lft, const tree_iterator& rgt )
  { return lft.ptr_ != rgt.ptr_; }

};



///////////////////////////////////////////////////////////////////
/////////////////       Main cs2set class          ////////////////
///////////////////////////////////////////////////////////////////


template <class T>
class cs2set {
public:
  typedef tree_iterator<T> iterator;

private:
  TreeNode<T>* root_;
  int size_;

public:
  cs2set() : root_(0), size_(0) 
  {}

  cs2set( const cs2set<T>& old ) : size_(old.size_)
  {
    root_ = this -> copy_tree( old.root_ );
  }

  ~cs2set()
  {
    this -> destroy_tree( root_ );
    root_ = 0;
  }

  cs2set& operator=( const cs2set<T>& old )
  {
    if ( old != *this ) {
      this -> destroy_tree( root_ );
      root_ = this ->  copy_tree( (TreeNode<T>*) 0, old.root_ );
      size_ = old.size_;
    }
    return *this;
  }

  int size() const { return size_; }
  
  std::pair< iterator, bool > insert( T const& key_value )
  {
    return insert( key_value, root_ );
  }

  iterator find( const T& key_value )
  {
    return find( key_value, root_ );
  }

  int erase( T const& key_value )
  {
    return erase( key_value, root_ );
  }

  friend std::ostream& operator<< ( std::ostream& ostr, const cs2set<T>& s )
  {
    s.print_in_order( ostr, s.root_ );
    return ostr;
  }

  void print_as_sideways_tree( std::ostream& ostr ) const
  {
    print_as_sideways_tree( ostr, root_, 0 );
  }

  iterator begin() const
  { 
    if ( ! root_ )
      return iterator(0);
    else
      {
        TreeNode<T>* p = root_;
        while ( p->left ) p = p->left;
        return iterator(p);
      }
  }

  iterator end() const { return iterator( 0 ); }


private:

  TreeNode<T>*  copy_tree( TreeNode<T>* old_root )
  {


  }

  void destroy_tree( TreeNode<T>* p )
  {
    if ( p )
      {
	destroy_tree(p->left);
	destroy_tree(p->right);
	delete p;
	p = 0;
      }
  }

  void print_in_order( std::ostream& ostr, const TreeNode<T>* p ) const
  {
    if ( p )
      {
	print_in_order( ostr, p->left );
	ostr << p->value << "\n";
	print_in_order( ostr, p->right );
      }
  }

  std::pair<iterator,bool> insert( const T& key_value, TreeNode<T>*& p )
  {
    if ( !p )
      {
	p = new TreeNode<T>( key_value );
	this -> size_ ++ ;
	return std::pair<iterator,bool>( iterator(p), true );
      }
    else if ( key_value < p->value )
      return insert( key_value, p->left );
    else if ( key_value > p->value )
      return insert( key_value, p->right );
    else
      return std::pair<iterator,bool>( iterator(p), false );
  }

  iterator find( const T& key_value, TreeNode<T>* p )
  {
    if ( !p )
      return this -> end();
    else if ( key_value < p->value )
      return find( key_value, p->left );
    else if ( key_value > p->value )
      return find( key_value, p->right );
    else
      return iterator( p );
  }

  int erase( T const& key_value, TreeNode<T>* & root )
  {
    if ( !root ) return 0;
    
    if ( key_value < root->value )
      return erase( key_value, root->left );
    else if ( key_value > root->value )
      return erase( key_value, root->right );

    if ( !root->left )
      {
	TreeNode<T> * temp = root;
	root = root->right;
	delete temp;
	return 1;
      }
    else if ( !root->right )
      {
	TreeNode<T> * temp = root;
	root = root->left;
	delete temp;
	return 1;
      }

    //  Two children  in-order successor
    TreeNode<T>* leftmost = root->right;
    while ( leftmost->left ) leftmost = leftmost->left;
    root->value = leftmost->value;
    return erase( root->value, root->right );
  }


  void print_as_sideways_tree( std::ostream& ostr, const TreeNode<T>* p,
                               int depth  ) const
  {
    if ( p ) 
      {
        print_as_sideways_tree( ostr, p->right, depth+1 );
        for ( int i=0; i<depth; ++i ) ostr << "    ";
        ostr << p->value << "\n";
        print_as_sideways_tree( ostr, p->left, depth+1 );
      }
  }

};



#endif

