#ifndef _RSTL_RED_BLACK_TREE #define _RSTL_RED_BLACK_TREE #include "types.h" #include "rstl/pair.hpp" #include "rstl/rmemory_allocator.hpp" namespace rstl { template < typename P > struct select1st { const P& operator()(const P& it) const { return it; } }; template < typename K, typename V > struct select1st< pair< K, V > > { const K& operator()(const pair< K, V >& it) const { return it.first; } }; template < typename T > struct less { bool operator()(const T& a, const T& b) const { return a < b; } }; enum node_color { kNC_Red, kNC_Black, }; void rbtree_rebalance(void*, void*); void* rbtree_traverse_forward(const void*, void*); void* rbtree_rebalance_for_erase(void* header_void, void* node_void); template < typename T, typename P, int U, typename S = select1st< P >, typename Cmp = less< T >, typename Alloc = rmemory_allocator > class red_black_tree { private: struct node { node* mLeft; node* mRight; node* mParent; node_color mColor; uchar mValue[sizeof(P)]; node(node* left, node* right, node* parent, node_color color, const P& value) : mLeft(left), mRight(right), mParent(parent), mColor(color) { construct(get_value(), value); } ~node() { get_value()->~P(); } P* get_value() { return reinterpret_cast< P* >(&mValue); } const P* get_value() const { return reinterpret_cast< const P* >(&mValue); } node* get_left() { return mLeft; } void set_left(node* n) { mLeft = n; } node* get_right() { return mRight; } void set_right(node* n) { mRight = n; } }; class header { public: header() : mLeftmost(nullptr), mRightmost(nullptr), mRootNode(nullptr) {} void set_root(node* n) { mRootNode = n; } void set_leftmost(node* n) { mLeftmost = n; } void set_rightmost(node* n) { mRightmost = n; } node* get_root() const { return mRootNode; } node* get_leftmost() const { return mLeftmost; } node* get_rightmost() const { return mRightmost; } private: node* mLeftmost; node* mRightmost; node* mRootNode; }; public: struct const_iterator { node* mNode; const header* mHeader; // bool x8_; // TODO why is this bool here? const_iterator(node* node, const header* header, bool b) : mNode(node), mHeader(header)/*, x8_(b)*/ {} const P* operator->() const { return mNode->get_value(); } const P* operator*() const { return mNode->get_value(); } bool operator==(const const_iterator& other) const { return mNode == other.mNode && mHeader == other.mHeader; } bool operator!=(const const_iterator& other) const { // return !(*this == other); return mNode != other.mNode || mHeader != other.mHeader; } const_iterator& operator++() { mNode = static_cast< node* >(rbtree_traverse_forward(static_cast< const void* >(mHeader), static_cast< void* >(mNode))); return *this; } const_iterator operator++(int) { const_iterator result = *this; mNode = static_cast< node* >(rbtree_traverse_forward(static_cast< const void* >(mHeader), static_cast< void* >(mNode))); return result; } }; struct iterator : public const_iterator { iterator(node* node, const header* header, bool b) : const_iterator(node, header, b) {} P* operator->() { return mNode->get_value(); } P* operator*() { return mNode->get_value(); } node* get_node() { return mNode; } }; red_black_tree() : x0_(0), x1_(0), x4_count(0) {} ~red_black_tree() { clear(); } iterator insert_into(node* n, const P& item); iterator insert(const P& item) { return insert_into(x8_header.get_root(), item); } const_iterator begin() const { // TODO return const_iterator(x8_header.get_leftmost(), &x8_header, false); } const_iterator end() const { // TODO return const_iterator(nullptr, &x8_header, false); } const_iterator find(const T& key) const { node* n = x8_header.get_root(); node* needle = nullptr; while (n != nullptr) { if (!x2_cmp(x3_selector(*n->get_value()), key)) { needle = n; n = n->get_left(); } else { n = n->get_right(); } } bool noResult = false; if (needle == nullptr || x2_cmp(key, x3_selector(*needle->get_value()))) { noResult = true; } if (noResult) { needle = nullptr; } return const_iterator(needle, &x8_header, false); } iterator find(const T& key) { node* n = x8_header.get_root(); node* needle = nullptr; while (n != nullptr) { if (!x2_cmp(x3_selector(*n->get_value()), key)) { needle = n; n = n->get_left(); } else { n = n->get_right(); } } bool noResult = false; if (needle == nullptr || x2_cmp(key, x3_selector(*needle->get_value()))) { noResult = true; } if (noResult) { needle = nullptr; } return iterator(needle, &x8_header, false); } iterator erase(iterator it) { node* node = it.get_node(); ++it; free_node(rebalance_for_erase(node)); x4_count--; return it; } void clear() { node* root = x8_header.get_root(); if (root != nullptr) { free_node_and_sub_nodes(root); } x8_header.set_root(nullptr); x8_header.set_leftmost(nullptr); x8_header.set_rightmost(nullptr); x4_count = 0; } private: uchar x0_; uchar x1_; Cmp x2_cmp; S x3_selector; int x4_count; header x8_header; node* create_node(node* left, node* right, node* parent, node_color color, const P& value) { node* n; Alloc::allocate(n, 1); new (n) node(left, right, parent, color, value); return n; } void free_node_and_sub_nodes(node* n) { if (node* left = n->get_left()) { free_node_and_sub_nodes(left); } if (node* right = n->get_right()) { free_node_and_sub_nodes(right); } free_node(n); } void free_node(node* n) { n->~node(); Alloc::deallocate(n); } void rebalance(node* n) { rbtree_rebalance(&x8_header, n); } node* rebalance_for_erase(node* n) { return static_cast<node*>(rbtree_rebalance_for_erase(&x8_header, n)); } }; static bool kUnknownValueNewRoot = true; static bool kUnknownValueEqualKey = false; static bool kUnknownValueNewItem = true; template < typename T, typename P, int U, typename S, typename Cmp, typename Alloc > typename red_black_tree< T, P, U, S, Cmp, Alloc >::iterator red_black_tree< T, P, U, S, Cmp, Alloc >::insert_into(node* n, const P& item) { if (n == nullptr) { x8_header.set_root(create_node(nullptr, nullptr, nullptr, kNC_Red, item)); x4_count += 1; x8_header.set_leftmost(x8_header.get_root()); x8_header.set_rightmost(x8_header.get_root()); return iterator(x8_header.get_root(), &x8_header, kUnknownValueNewRoot); } else { node* newNode = nullptr; while (newNode == nullptr) { bool firstComp = x2_cmp(x3_selector(*n->get_value()), x3_selector(item)); if (!firstComp && !x2_cmp(x3_selector(item), x3_selector(*n->get_value()))) { return iterator(n, &x8_header, kUnknownValueEqualKey); } if (firstComp) { if (n->get_left() == nullptr) { newNode = create_node(nullptr, nullptr, n, kNC_Red, item); n->set_left(newNode); if (n == x8_header.get_leftmost()) { x8_header.set_leftmost(newNode); } } else { n = n->get_left(); } } else { if (n->get_right() == nullptr) { newNode = create_node(nullptr, nullptr, n, kNC_Black, item); n->set_right(newNode); if (n == x8_header.get_rightmost()) { x8_header.set_rightmost(newNode); } } else { n = n->get_right(); } } } x4_count += 1; rebalance(newNode); return iterator(newNode, &x8_header, kUnknownValueNewItem); } } }; // namespace rstl #endif // _RSTL_RED_BLACK_TREE