// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi #ifndef CTL_SET_H_ #define CTL_SET_H_ #include "initializer_list.h" #include "less.h" #include "pair.h" namespace ctl { template> class set { struct rbtree { uintptr_t left_; rbtree* right; rbtree* parent; Key value; rbtree* left() const { return (rbtree*)(left_ & -2); } void left(rbtree* val) { left_ = (uintptr_t)val | (left_ & 1); } bool is_red() const { return left_ & 1; } void is_red(bool val) { left_ &= -2; left_ |= val; } rbtree(const Key& val) : left_(1), right(nullptr), parent(nullptr), value(val) { } }; public: using key_type = Key; using value_type = Key; using size_type = size_t; using node_type = rbtree; using key_compare = Compare; using value_compare = Compare; using difference_type = ptrdiff_t; using reference = value_type&; using pointer = value_type*; using const_pointer = const value_type*; using const_reference = const value_type&; class iterator { public: using value_type = Key; using difference_type = ptrdiff_t; using pointer = Key*; using reference = Key&; iterator() : node_(nullptr) { } reference operator*() const { return node_->value; } pointer operator->() const { return &(node_->value); } iterator& operator++() { if (node_ == nullptr) return *this; if (node_->right) { node_ = leftmost(node_->right); } else { node_type* parent = node_->parent; while (parent && node_ == parent->right) { node_ = parent; parent = parent->parent; } node_ = parent; } return *this; } iterator& operator--() { if (node_ == nullptr) __builtin_trap(); if (node_->left()) { node_ = rightmost(node_->left()); } else { node_type* parent = node_->parent; for (;;) { if (parent == nullptr) break; if (node_ == parent->left()) { node_ = parent; parent = parent->parent; } else { break; } } node_ = parent; } return *this; } iterator operator++(int) { iterator tmp = *this; ++(*this); return tmp; } iterator operator--(int) { iterator tmp = *this; --(*this); return tmp; } bool operator==(const iterator& other) const { return node_ == other.node_; } bool operator!=(const iterator& other) const { return !(*this == other); } private: friend class set; node_type* node_; explicit iterator(node_type* node) : node_(node) { } }; class reverse_iterator { public: using value_type = Key; using difference_type = ptrdiff_t; using pointer = Key*; using reference = Key&; reverse_iterator() : node_(nullptr) { } reference operator*() const { return node_->value; } pointer operator->() const { return &(node_->value); } reverse_iterator& operator++() { if (node_->left()) { node_ = rightmost(node_->left()); } else { node_type* parent = node_->parent; while (parent && node_ == parent->left()) { node_ = parent; parent = parent->parent; } node_ = parent; } return *this; } reverse_iterator operator++(int) { reverse_iterator tmp = *this; ++(*this); return tmp; } reverse_iterator& operator--() { if (node_->right) { node_ = leftmost(node_->right); } else { node_type* parent = node_->parent; while (parent && node_ == parent->right) { node_ = parent; parent = parent->parent; } node_ = parent; } return *this; } reverse_iterator operator--(int) { reverse_iterator tmp = *this; --(*this); return tmp; } bool operator==(const reverse_iterator& other) const { return node_ == other.node_; } bool operator!=(const reverse_iterator& other) const { return !(*this == other); } iterator base() const { if (node_ == nullptr) return iterator(leftmost(root_)); reverse_iterator tmp = *this; ++tmp; return iterator(tmp.node_); } private: friend class set; node_type* node_; node_type* root_; explicit reverse_iterator(node_type* node, node_type* root) : node_(node), root_(root) { } }; using const_iterator = iterator; using const_reverse_iterator = reverse_iterator; set() : root_(nullptr), size_(0), comp_(Compare()) { } explicit set(const Compare& comp) : root_(nullptr), size_(0), comp_(comp) { } template set(InputIt first, InputIt last, const Compare& comp = Compare()) : root_(nullptr), size_(0), comp_(comp) { for (; first != last; ++first) insert(*first); } set(const set& other) : root_(nullptr), size_(0) { if (other.root_) { root_ = copier(other.root_); size_ = other.size_; } } set(set&& other) noexcept : root_(other.root_), size_(other.size_) { other.root_ = nullptr; other.size_ = 0; } set(std::initializer_list init, const Compare& comp = Compare()) : root_(nullptr), size_(0), comp_(comp) { for (const auto& value : init) insert(value); } ~set() { clear(); } set& operator=(const set& other) { if (this != &other) { clear(); if (other.root_) { root_ = copier(other.root_); size_ = other.size_; } } return *this; } set& operator=(set&& other) noexcept { if (this != &other) { clear(); root_ = other.root_; size_ = other.size_; other.root_ = nullptr; other.size_ = 0; } return *this; } bool empty() const noexcept { return size_ == 0; } size_type size() const noexcept { return size_; } iterator begin() noexcept { if (root_) return iterator(leftmost(root_)); return iterator(nullptr); } const_iterator begin() const noexcept { if (root_) return const_iterator(leftmost(root_)); return const_iterator(nullptr); } const_iterator cbegin() const noexcept { return begin(); } reverse_iterator rbegin() { return reverse_iterator(rightmost(root_), root_); } const_reverse_iterator rbegin() const { return const_reverse_iterator(rightmost(root_), root_); } const_reverse_iterator crbegin() const { return const_reverse_iterator(rightmost(root_), root_); } iterator end() noexcept { return iterator(nullptr); } const_iterator end() const noexcept { return const_iterator(nullptr); } const_iterator cend() const noexcept { return end(); } reverse_iterator rend() { return reverse_iterator(nullptr, root_); } const_reverse_iterator rend() const { return const_reverse_iterator(nullptr, root_); } const_reverse_iterator crend() const { return const_reverse_iterator(nullptr, root_); } void clear() noexcept { clearer(root_); root_ = nullptr; size_ = 0; } ctl::pair insert(value_type&& value) { return insert_node(new node_type(ctl::move(value))); } ctl::pair insert(const value_type& value) { return insert_node(new node_type(value)); } iterator insert(const_iterator hint, const value_type& value) { return insert(value).first; } iterator insert(const_iterator hint, value_type&& value) { return insert(ctl::move(value)).first; } template void insert(InputIt first, InputIt last) { for (; first != last; ++first) insert(*first); } void insert(std::initializer_list ilist) { for (const auto& value : ilist) insert(value); } template ctl::pair emplace(Args&&... args) { value_type value(ctl::forward(args)...); return insert(ctl::move(value)); } iterator erase(iterator pos) { node_type* node = pos.node_; iterator res = ++pos; erase_node(node); return res; } size_type erase(const key_type& key) { node_type* node = get_element(key); if (node == nullptr) return 0; erase_node(node); return 1; } iterator erase(const_iterator first, const_iterator last) { while (first != last) first = erase(first); return iterator(const_cast(last.node_)); } void swap(set& other) noexcept { ctl::swap(root_, other.root_); ctl::swap(size_, other.size_); } ctl::pair equal_range(const key_type& key) { return { iterator(get_floor(key)), iterator(get_ceiling(key)) }; } ctl::pair equal_range( const key_type& key) const { return { const_iterator(get_floor(key)), const_iterator(get_ceiling(key)) }; } iterator lower_bound(const key_type& key) { return iterator(get_floor(key)); } const_iterator lower_bound(const key_type& key) const { return const_iterator(get_floor(key)); } iterator upper_bound(const key_type& key) { return iterator(get_ceiling(key)); } const_iterator upper_bound(const key_type& key) const { return const_iterator(get_ceiling(key)); } size_type count(const key_type& key) const { return !!get_element(key); } iterator find(const key_type& key) { return iterator(get_element(key)); } const_iterator find(const key_type& key) const { return const_iterator(get_element(key)); } template iterator emplace_hint(const_iterator hint, Args&&... args) { return emplace(ctl::forward(args)...).first; } void check() const { size_type count = 0; if (root_ != nullptr) { if (root_->is_red()) // ILLEGAL TREE: root node must be black __builtin_trap(); int black_height = -1; checker(root_, nullptr, 0, black_height); count = tally(root_); } if (count != size_) // ILLEGAL TREE: unexpected number of nodes __builtin_trap(); } private: static node_type* leftmost(node_type* node) noexcept { while (node && node->left()) node = node->left(); return node; } static node_type* rightmost(node_type* node) noexcept { while (node && node->right) node = node->right; return node; } static optimizesize void clearer(node_type* node) noexcept { node_type* right; for (; node; node = right) { right = node->right; clearer(node->left()); delete node; } } static optimizesize node_type* copier(const node_type* node) { if (node == nullptr) return nullptr; node_type* new_node = new node_type(node->value); new_node->left(copier(node->left())); new_node->right = copier(node->right); if (new_node->left()) new_node->left()->parent = new_node; if (new_node->right) new_node->right->parent = new_node; return new_node; } static optimizesize size_type tally(const node_type* node) { if (node == nullptr) return 0; return 1 + tally(node->left()) + tally(node->right); } template node_type* get_element(const K& key) const { node_type* current = root_; while (current != nullptr) { if (comp_(key, current->value)) { current = current->left(); } else if (comp_(current->value, key)) { current = current->right; } else { return current; } } return nullptr; } template node_type* get_floor(const K& key) const { node_type* result = nullptr; node_type* current = root_; while (current != nullptr) { if (!comp_(current->value, key)) { result = current; current = current->left(); } else { current = current->right; } } return result; } template node_type* get_ceiling(const K& key) const { node_type* result = nullptr; node_type* current = root_; while (current != nullptr) { if (comp_(key, current->value)) { result = current; current = current->left(); } else { current = current->right; } } return result; } optimizesize ctl::pair insert_node(node_type* node) { if (root_ == nullptr) { root_ = node; root_->is_red(false); size_++; return { iterator(root_), true }; } node_type* current = root_; node_type* parent = nullptr; while (current != nullptr) { parent = current; if (comp_(node->value, current->value)) { current = current->left(); } else if (comp_(current->value, node->value)) { current = current->right; } else { delete node; // already exists return { iterator(current), false }; } } if (comp_(node->value, parent->value)) { parent->left(node); } else { parent->right = node; } node->parent = parent; size_++; rebalance_after_insert(node); return { iterator(node), true }; } optimizesize void erase_node(node_type* node) { node_type* y = node; node_type* x = nullptr; node_type* x_parent = nullptr; bool y_original_color = y->is_red(); if (node->left() == nullptr) { x = node->right; transplant(node, node->right); x_parent = node->parent; } else if (node->right == nullptr) { x = node->left(); transplant(node, node->left()); x_parent = node->parent; } else { y = leftmost(node->right); y_original_color = y->is_red(); x = y->right; if (y->parent == node) { if (x) x->parent = y; x_parent = y; } else { transplant(y, y->right); y->right = node->right; y->right->parent = y; x_parent = y->parent; } transplant(node, y); y->left(node->left()); y->left()->parent = y; y->is_red(node->is_red()); } if (!y_original_color) rebalance_after_erase(x, x_parent); delete node; --size_; } optimizesize void left_rotate(node_type* x) { node_type* y = x->right; x->right = y->left(); if (y->left() != nullptr) y->left()->parent = x; y->parent = x->parent; if (x->parent == nullptr) { root_ = y; } else if (x == x->parent->left()) { x->parent->left(y); } else { x->parent->right = y; } y->left(x); x->parent = y; } optimizesize void right_rotate(node_type* y) { node_type* x = y->left(); y->left(x->right); if (x->right != nullptr) x->right->parent = y; x->parent = y->parent; if (y->parent == nullptr) { root_ = x; } else if (y == y->parent->right) { y->parent->right = x; } else { y->parent->left(x); } x->right = y; y->parent = x; } optimizesize void transplant(node_type* u, node_type* v) { if (u->parent == nullptr) { root_ = v; } else if (u == u->parent->left()) { u->parent->left(v); } else { u->parent->right = v; } if (v != nullptr) v->parent = u->parent; } optimizesize void checker(const node_type* node, const node_type* parent, int black_count, int& black_height) const { if (node == nullptr) { // Leaf nodes are considered black if (black_height == -1) { black_height = black_count; } else if (black_count != black_height) { // ILLEGAL TREE: Black height mismatch __builtin_trap(); } return; } if (node->parent != parent) // ILLEGAL TREE: Parent link is incorrect __builtin_trap(); if (parent) { if (parent->left() == node && !comp_(node->value, parent->value)) // ILLEGAL TREE: Binary search property violated on left child __builtin_trap(); if (parent->right == node && !comp_(parent->value, node->value)) // ILLEGAL TREE: Binary search property violated on right child __builtin_trap(); } if (!node->is_red()) { black_count++; } else if (parent != nullptr && parent->is_red()) { // ILLEGAL TREE: Red node has red child __builtin_trap(); } checker(node->left(), node, black_count, black_height); checker(node->right, node, black_count, black_height); } optimizesize void rebalance_after_insert(node_type* node) { node->is_red(true); while (node != root_ && node->parent->is_red()) { if (node->parent == node->parent->parent->left()) { node_type* uncle = node->parent->parent->right; if (uncle && uncle->is_red()) { node->parent->is_red(false); uncle->is_red(false); node->parent->parent->is_red(true); node = node->parent->parent; } else { if (node == node->parent->right) { node = node->parent; left_rotate(node); } node->parent->is_red(false); node->parent->parent->is_red(true); right_rotate(node->parent->parent); } } else { node_type* uncle = node->parent->parent->left(); if (uncle && uncle->is_red()) { node->parent->is_red(false); uncle->is_red(false); node->parent->parent->is_red(true); node = node->parent->parent; } else { if (node == node->parent->left()) { node = node->parent; right_rotate(node); } node->parent->is_red(false); node->parent->parent->is_red(true); left_rotate(node->parent->parent); } } } root_->is_red(false); } optimizesize void rebalance_after_erase(node_type* node, node_type* parent) { while (node != root_ && (node == nullptr || !node->is_red())) { if (node == parent->left()) { node_type* sibling = parent->right; if (sibling->is_red()) { sibling->is_red(false); parent->is_red(true); left_rotate(parent); sibling = parent->right; } if ((sibling->left() == nullptr || !sibling->left()->is_red()) && (sibling->right == nullptr || !sibling->right->is_red())) { sibling->is_red(true); node = parent; parent = node->parent; } else { if (sibling->right == nullptr || !sibling->right->is_red()) { sibling->left()->is_red(false); sibling->is_red(true); right_rotate(sibling); sibling = parent->right; } sibling->is_red(parent->is_red()); parent->is_red(false); sibling->right->is_red(false); left_rotate(parent); node = root_; break; } } else { node_type* sibling = parent->left(); if (sibling->is_red()) { sibling->is_red(false); parent->is_red(true); right_rotate(parent); sibling = parent->left(); } if ((sibling->right == nullptr || !sibling->right->is_red()) && (sibling->left() == nullptr || !sibling->left()->is_red())) { sibling->is_red(true); node = parent; parent = node->parent; } else { if (sibling->left() == nullptr || !sibling->left()->is_red()) { sibling->right->is_red(false); sibling->is_red(true); left_rotate(sibling); sibling = parent->left(); } sibling->is_red(parent->is_red()); parent->is_red(false); sibling->left()->is_red(false); right_rotate(parent); node = root_; break; } } } if (node != nullptr) node->is_red(false); } node_type* root_; size_type size_; Compare comp_; }; template> bool operator==(const set& lhs, const set& rhs) { if (lhs.size() != rhs.size()) return false; auto i = lhs.cbegin(); auto j = rhs.cbegin(); for (; i != lhs.end(); ++i, ++j) { if (!(*i == *j)) return false; } return true; } template> bool operator<(const set& lhs, const set& rhs) { auto i = lhs.cbegin(); auto j = rhs.cbegin(); for (; i != lhs.end() && j != rhs.end(); ++i, ++j) { if (Compare()(*i, *j)) return true; if (Compare()(*j, *i)) return false; } return i == lhs.end() && j != rhs.end(); } template> bool operator!=(const set& lhs, const set& rhs) { return !(lhs == rhs); } template> bool operator<=(const set& lhs, const set& rhs) { return !(rhs < lhs); } template> bool operator>(const set& lhs, const set& rhs) { return rhs < lhs; } template> bool operator>=(const set& lhs, const set& rhs) { return !(lhs < rhs); } template> void swap(set& lhs, set& rhs) noexcept; } // namespace ctl #endif // CTL_SET_H_