cosmopolitan/ctl/set.h

968 lines
25 KiB
C
Raw Normal View History

// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#ifndef CTL_SET_H_
#define CTL_SET_H_
#include "initializer_list.h"
#include "less.h"
#include "pair.h"
namespace ctl {
2024-07-01 10:48:28 +00:00
template<typename Key, typename Compare = ctl::less<Key>>
class set
{
struct rbtree
{
uintptr_t left_;
rbtree* right;
rbtree* parent;
Key value;
rbtree* left() const
{
return (rbtree*)(left_ & -2);
}
void left(rbtree* val)
{
left_ = (uintptr_t)val | (left_ & 1);
}
bool is_red() const
{
return left_ & 1;
}
void is_red(bool val)
{
left_ &= -2;
left_ |= val;
}
rbtree(const Key& val)
: left_(1), right(nullptr), parent(nullptr), value(val)
{
}
};
public:
using key_type = Key;
using value_type = Key;
using size_type = size_t;
using node_type = rbtree;
using key_compare = Compare;
using value_compare = Compare;
using difference_type = ptrdiff_t;
using reference = value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using const_reference = const value_type&;
class iterator
{
public:
using value_type = Key;
using difference_type = ptrdiff_t;
using pointer = Key*;
using reference = Key&;
iterator() : node_(nullptr)
{
}
reference operator*() const
{
return node_->value;
}
pointer operator->() const
{
return &(node_->value);
}
iterator& operator++()
{
if (node_ == nullptr)
return *this;
if (node_->right) {
node_ = leftmost(node_->right);
} else {
node_type* parent = node_->parent;
while (parent && node_ == parent->right) {
node_ = parent;
parent = parent->parent;
}
node_ = parent;
}
return *this;
}
iterator& operator--()
{
if (node_ == nullptr)
__builtin_trap();
if (node_->left()) {
node_ = rightmost(node_->left());
} else {
node_type* parent = node_->parent;
for (;;) {
if (parent == nullptr)
break;
if (node_ == parent->left()) {
node_ = parent;
parent = parent->parent;
} else {
break;
}
}
node_ = parent;
}
return *this;
}
iterator operator++(int)
{
iterator tmp = *this;
++(*this);
return tmp;
}
iterator operator--(int)
{
iterator tmp = *this;
--(*this);
return tmp;
}
bool operator==(const iterator& other) const
{
return node_ == other.node_;
}
bool operator!=(const iterator& other) const
{
return !(*this == other);
}
private:
friend class set;
node_type* node_;
explicit iterator(node_type* node) : node_(node)
{
}
};
class reverse_iterator
{
public:
using value_type = Key;
using difference_type = ptrdiff_t;
using pointer = Key*;
using reference = Key&;
reverse_iterator() : node_(nullptr)
{
}
reference operator*() const
{
return node_->value;
}
pointer operator->() const
{
return &(node_->value);
}
reverse_iterator& operator++()
{
if (node_->left()) {
node_ = rightmost(node_->left());
} else {
node_type* parent = node_->parent;
while (parent && node_ == parent->left()) {
node_ = parent;
parent = parent->parent;
}
node_ = parent;
}
return *this;
}
reverse_iterator operator++(int)
{
reverse_iterator tmp = *this;
++(*this);
return tmp;
}
reverse_iterator& operator--()
{
if (node_->right) {
node_ = leftmost(node_->right);
} else {
node_type* parent = node_->parent;
while (parent && node_ == parent->right) {
node_ = parent;
parent = parent->parent;
}
node_ = parent;
}
return *this;
}
reverse_iterator operator--(int)
{
reverse_iterator tmp = *this;
--(*this);
return tmp;
}
bool operator==(const reverse_iterator& other) const
{
return node_ == other.node_;
}
bool operator!=(const reverse_iterator& other) const
{
return !(*this == other);
}
iterator base() const
{
if (node_ == nullptr)
return iterator(leftmost(root_));
reverse_iterator tmp = *this;
++tmp;
return iterator(tmp.node_);
}
private:
friend class set;
node_type* node_;
node_type* root_;
explicit reverse_iterator(node_type* node, node_type* root) : node_(node), root_(root)
{
}
};
using const_iterator = iterator;
using const_reverse_iterator = reverse_iterator;
set() : root_(nullptr), size_(0), comp_(Compare())
{
}
explicit set(const Compare& comp) : root_(nullptr), size_(0), comp_(comp)
{
}
template<class InputIt>
set(InputIt first, InputIt last, const Compare& comp = Compare())
: root_(nullptr), size_(0), comp_(comp)
{
for (; first != last; ++first)
insert(*first);
}
set(const set& other) : root_(nullptr), size_(0)
{
if (other.root_) {
root_ = copier(other.root_);
size_ = other.size_;
}
}
set(set&& other) noexcept : root_(other.root_), size_(other.size_)
{
other.root_ = nullptr;
other.size_ = 0;
}
set(std::initializer_list<value_type> init, const Compare& comp = Compare())
: root_(nullptr), size_(0), comp_(comp)
{
for (const auto& value : init)
insert(value);
}
~set()
{
clear();
}
set& operator=(const set& other)
{
if (this != &other) {
clear();
if (other.root_) {
root_ = copier(other.root_);
size_ = other.size_;
}
}
return *this;
}
set& operator=(set&& other) noexcept
{
if (this != &other) {
clear();
root_ = other.root_;
size_ = other.size_;
other.root_ = nullptr;
other.size_ = 0;
}
return *this;
}
bool empty() const noexcept
{
return size_ == 0;
}
size_type size() const noexcept
{
return size_;
}
iterator begin() noexcept
{
if (root_)
return iterator(leftmost(root_));
return iterator(nullptr);
}
const_iterator begin() const noexcept
{
if (root_)
return const_iterator(leftmost(root_));
return const_iterator(nullptr);
}
const_iterator cbegin() const noexcept
{
return begin();
}
reverse_iterator rbegin()
{
return reverse_iterator(rightmost(root_), root_);
}
const_reverse_iterator rbegin() const
{
return const_reverse_iterator(rightmost(root_), root_);
}
const_reverse_iterator crbegin() const
{
return const_reverse_iterator(rightmost(root_), root_);
}
iterator end() noexcept
{
return iterator(nullptr);
}
const_iterator end() const noexcept
{
return const_iterator(nullptr);
}
const_iterator cend() const noexcept
{
return end();
}
reverse_iterator rend()
{
return reverse_iterator(nullptr, root_);
}
const_reverse_iterator rend() const
{
return const_reverse_iterator(nullptr, root_);
}
const_reverse_iterator crend() const
{
return const_reverse_iterator(nullptr, root_);
}
void clear() noexcept
{
clearer(root_);
root_ = nullptr;
size_ = 0;
}
2024-07-01 10:48:28 +00:00
ctl::pair<iterator, bool> insert(value_type&& value)
{
return insert_node(new node_type(ctl::move(value)));
}
2024-07-01 10:48:28 +00:00
ctl::pair<iterator, bool> insert(const value_type& value)
{
return insert_node(new node_type(value));
}
iterator insert(const_iterator hint, const value_type& value)
{
return insert(value).first;
}
iterator insert(const_iterator hint, value_type&& value)
{
return insert(ctl::move(value)).first;
}
template<class InputIt>
void insert(InputIt first, InputIt last)
{
for (; first != last; ++first)
insert(*first);
}
void insert(std::initializer_list<value_type> ilist)
{
for (const auto& value : ilist)
insert(value);
}
template<class... Args>
2024-07-01 10:48:28 +00:00
ctl::pair<iterator, bool> emplace(Args&&... args)
{
value_type value(ctl::forward<Args>(args)...);
return insert(ctl::move(value));
}
iterator erase(iterator pos)
{
node_type* node = pos.node_;
iterator res = ++pos;
erase_node(node);
return res;
}
size_type erase(const key_type& key)
{
node_type* node = get_element(key);
if (node == nullptr)
return 0;
erase_node(node);
return 1;
}
iterator erase(const_iterator first, const_iterator last)
{
while (first != last)
first = erase(first);
return iterator(const_cast<node_type*>(last.node_));
}
void swap(set& other) noexcept
{
ctl::swap(root_, other.root_);
ctl::swap(size_, other.size_);
}
2024-07-01 10:48:28 +00:00
ctl::pair<iterator, iterator> equal_range(const key_type& key)
{
return { iterator(get_floor(key)), iterator(get_ceiling(key)) };
}
2024-07-01 10:48:28 +00:00
ctl::pair<const_iterator, const_iterator> equal_range(
const key_type& key) const
{
return { const_iterator(get_floor(key)),
const_iterator(get_ceiling(key)) };
}
iterator lower_bound(const key_type& key)
{
return iterator(get_floor(key));
}
const_iterator lower_bound(const key_type& key) const
{
return const_iterator(get_floor(key));
}
iterator upper_bound(const key_type& key)
{
return iterator(get_ceiling(key));
}
const_iterator upper_bound(const key_type& key) const
{
return const_iterator(get_ceiling(key));
}
size_type count(const key_type& key) const
{
return !!get_element(key);
}
iterator find(const key_type& key)
{
return iterator(get_element(key));
}
const_iterator find(const key_type& key) const
{
return const_iterator(get_element(key));
}
template<class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args)
{
return emplace(ctl::forward<Args>(args)...).first;
}
void check() const
{
size_type count = 0;
if (root_ != nullptr) {
if (root_->is_red())
// ILLEGAL TREE: root node must be black
__builtin_trap();
int black_height = -1;
checker(root_, nullptr, 0, black_height);
count = tally(root_);
}
if (count != size_)
// ILLEGAL TREE: unexpected number of nodes
__builtin_trap();
}
private:
static node_type* leftmost(node_type* node) noexcept
{
while (node && node->left())
node = node->left();
return node;
}
static node_type* rightmost(node_type* node) noexcept
{
while (node && node->right)
node = node->right;
return node;
}
static optimizesize void clearer(node_type* node) noexcept
{
node_type* right;
for (; node; node = right) {
right = node->right;
clearer(node->left());
delete node;
}
}
static optimizesize node_type* copier(const node_type* node)
{
if (node == nullptr)
return nullptr;
node_type* new_node = new node_type(node->value);
new_node->left(copier(node->left()));
new_node->right = copier(node->right);
if (new_node->left())
new_node->left()->parent = new_node;
if (new_node->right)
new_node->right->parent = new_node;
return new_node;
}
static optimizesize size_type tally(const node_type* node)
{
if (node == nullptr)
return 0;
return 1 + tally(node->left()) + tally(node->right);
}
template<typename K>
node_type* get_element(const K& key) const
{
node_type* current = root_;
while (current != nullptr) {
if (comp_(key, current->value)) {
current = current->left();
} else if (comp_(current->value, key)) {
current = current->right;
} else {
return current;
}
}
return nullptr;
}
template<typename K>
node_type* get_floor(const K& key) const
{
node_type* result = nullptr;
node_type* current = root_;
while (current != nullptr) {
if (!comp_(current->value, key)) {
result = current;
current = current->left();
} else {
current = current->right;
}
}
return result;
}
template<typename K>
node_type* get_ceiling(const K& key) const
{
node_type* result = nullptr;
node_type* current = root_;
while (current != nullptr) {
if (comp_(key, current->value)) {
result = current;
current = current->left();
} else {
current = current->right;
}
}
return result;
}
optimizesize ctl::pair<iterator, bool> insert_node(node_type* node)
{
if (root_ == nullptr) {
root_ = node;
root_->is_red(false);
size_++;
return { iterator(root_), true };
}
node_type* current = root_;
node_type* parent = nullptr;
while (current != nullptr) {
parent = current;
if (comp_(node->value, current->value)) {
current = current->left();
} else if (comp_(current->value, node->value)) {
current = current->right;
} else {
delete node; // already exists
return { iterator(current), false };
}
}
if (comp_(node->value, parent->value)) {
parent->left(node);
} else {
parent->right = node;
}
node->parent = parent;
size_++;
rebalance_after_insert(node);
return { iterator(node), true };
}
optimizesize void erase_node(node_type* node)
{
node_type* y = node;
node_type* x = nullptr;
node_type* x_parent = nullptr;
bool y_original_color = y->is_red();
if (node->left() == nullptr) {
x = node->right;
transplant(node, node->right);
x_parent = node->parent;
} else if (node->right == nullptr) {
x = node->left();
transplant(node, node->left());
x_parent = node->parent;
} else {
y = leftmost(node->right);
y_original_color = y->is_red();
x = y->right;
if (y->parent == node) {
if (x)
x->parent = y;
x_parent = y;
} else {
transplant(y, y->right);
y->right = node->right;
y->right->parent = y;
x_parent = y->parent;
}
transplant(node, y);
y->left(node->left());
y->left()->parent = y;
y->is_red(node->is_red());
}
if (!y_original_color)
rebalance_after_erase(x, x_parent);
delete node;
--size_;
}
optimizesize void left_rotate(node_type* x)
{
node_type* y = x->right;
x->right = y->left();
if (y->left() != nullptr)
y->left()->parent = x;
y->parent = x->parent;
if (x->parent == nullptr) {
root_ = y;
} else if (x == x->parent->left()) {
x->parent->left(y);
} else {
x->parent->right = y;
}
y->left(x);
x->parent = y;
}
optimizesize void right_rotate(node_type* y)
{
node_type* x = y->left();
y->left(x->right);
if (x->right != nullptr)
x->right->parent = y;
x->parent = y->parent;
if (y->parent == nullptr) {
root_ = x;
} else if (y == y->parent->right) {
y->parent->right = x;
} else {
y->parent->left(x);
}
x->right = y;
y->parent = x;
}
optimizesize void transplant(node_type* u, node_type* v)
{
if (u->parent == nullptr) {
root_ = v;
} else if (u == u->parent->left()) {
u->parent->left(v);
} else {
u->parent->right = v;
}
if (v != nullptr)
v->parent = u->parent;
}
optimizesize void checker(const node_type* node,
const node_type* parent,
int black_count,
int& black_height) const
{
if (node == nullptr) {
// Leaf nodes are considered black
if (black_height == -1) {
black_height = black_count;
} else if (black_count != black_height) {
// ILLEGAL TREE: Black height mismatch
__builtin_trap();
}
return;
}
if (node->parent != parent)
// ILLEGAL TREE: Parent link is incorrect
__builtin_trap();
if (parent) {
if (parent->left() == node && !comp_(node->value, parent->value))
// ILLEGAL TREE: Binary search property violated on left child
__builtin_trap();
if (parent->right == node && !comp_(parent->value, node->value))
// ILLEGAL TREE: Binary search property violated on right child
__builtin_trap();
}
if (!node->is_red()) {
black_count++;
} else if (parent != nullptr && parent->is_red()) {
// ILLEGAL TREE: Red node has red child
__builtin_trap();
}
checker(node->left(), node, black_count, black_height);
checker(node->right, node, black_count, black_height);
}
optimizesize void rebalance_after_insert(node_type* node)
{
node->is_red(true);
while (node != root_ && node->parent->is_red()) {
if (node->parent == node->parent->parent->left()) {
node_type* uncle = node->parent->parent->right;
if (uncle && uncle->is_red()) {
node->parent->is_red(false);
uncle->is_red(false);
node->parent->parent->is_red(true);
node = node->parent->parent;
} else {
if (node == node->parent->right) {
node = node->parent;
left_rotate(node);
}
node->parent->is_red(false);
node->parent->parent->is_red(true);
right_rotate(node->parent->parent);
}
} else {
node_type* uncle = node->parent->parent->left();
if (uncle && uncle->is_red()) {
node->parent->is_red(false);
uncle->is_red(false);
node->parent->parent->is_red(true);
node = node->parent->parent;
} else {
if (node == node->parent->left()) {
node = node->parent;
right_rotate(node);
}
node->parent->is_red(false);
node->parent->parent->is_red(true);
left_rotate(node->parent->parent);
}
}
}
root_->is_red(false);
}
optimizesize void rebalance_after_erase(node_type* node, node_type* parent)
{
while (node != root_ && (node == nullptr || !node->is_red())) {
if (node == parent->left()) {
node_type* sibling = parent->right;
if (sibling->is_red()) {
sibling->is_red(false);
parent->is_red(true);
left_rotate(parent);
sibling = parent->right;
}
if ((sibling->left() == nullptr ||
!sibling->left()->is_red()) &&
(sibling->right == nullptr || !sibling->right->is_red())) {
sibling->is_red(true);
node = parent;
parent = node->parent;
} else {
if (sibling->right == nullptr ||
!sibling->right->is_red()) {
sibling->left()->is_red(false);
sibling->is_red(true);
right_rotate(sibling);
sibling = parent->right;
}
sibling->is_red(parent->is_red());
parent->is_red(false);
sibling->right->is_red(false);
left_rotate(parent);
node = root_;
break;
}
} else {
node_type* sibling = parent->left();
if (sibling->is_red()) {
sibling->is_red(false);
parent->is_red(true);
right_rotate(parent);
sibling = parent->left();
}
if ((sibling->right == nullptr || !sibling->right->is_red()) &&
(sibling->left() == nullptr ||
!sibling->left()->is_red())) {
sibling->is_red(true);
node = parent;
parent = node->parent;
} else {
if (sibling->left() == nullptr ||
!sibling->left()->is_red()) {
sibling->right->is_red(false);
sibling->is_red(true);
left_rotate(sibling);
sibling = parent->left();
}
sibling->is_red(parent->is_red());
parent->is_red(false);
sibling->left()->is_red(false);
right_rotate(parent);
node = root_;
break;
}
}
}
if (node != nullptr)
node->is_red(false);
}
node_type* root_;
size_type size_;
Compare comp_;
};
2024-07-01 10:48:28 +00:00
template<class Key, typename Compare = ctl::less<Key>>
bool
operator==(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
if (lhs.size() != rhs.size())
return false;
auto i = lhs.cbegin();
auto j = rhs.cbegin();
for (; i != lhs.end(); ++i, ++j) {
if (!(*i == *j))
return false;
}
return true;
}
2024-07-01 10:48:28 +00:00
template<class Key, typename Compare = ctl::less<Key>>
bool
operator<(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
auto i = lhs.cbegin();
auto j = rhs.cbegin();
for (; i != lhs.end() && j != rhs.end(); ++i, ++j) {
if (Compare()(*i, *j))
return true;
if (Compare()(*j, *i))
return false;
}
return i == lhs.end() && j != rhs.end();
}
2024-07-01 10:48:28 +00:00
template<class Key, typename Compare = ctl::less<Key>>
bool
operator!=(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return !(lhs == rhs);
}
2024-07-01 10:48:28 +00:00
template<class Key, typename Compare = ctl::less<Key>>
bool
operator<=(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return !(rhs < lhs);
}
2024-07-01 10:48:28 +00:00
template<class Key, typename Compare = ctl::less<Key>>
bool
operator>(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return rhs < lhs;
}
2024-07-01 10:48:28 +00:00
template<class Key, typename Compare = ctl::less<Key>>
bool
operator>=(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return !(lhs < rhs);
}
2024-07-01 10:48:28 +00:00
template<class Key, typename Compare = ctl::less<Key>>
void
swap(set<Key, Compare>& lhs, set<Key, Compare>& rhs) noexcept;
} // namespace ctl
#endif // CTL_SET_H_