cosmopolitan/ctl/set.h
Justine Tunney c9152b6f14
Release Cosmopolitan v3.8.0
This change switches c++ exception handling from sjlj to standard dwarf.
It's needed because clang for aarch64 doesn't support sjlj. It turns out
that libunwind had a bare-metal configuration that made this easy to do.

This change gets the new experimental cosmocc -mclang flag in a state of
working so well that it can now be used to build all of llamafile and it
goes 3x faster in terms of build latency, without trading away any perf.

The int_fast16_t and int_fast32_t types are now always defined as 32-bit
in the interest of having more abi consistency between cosmocc -mgcc and
-mclang mode.
2024-08-30 20:14:07 -07:00

967 lines
25 KiB
C++

// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#ifndef CTL_SET_H_
#define CTL_SET_H_
#include "initializer_list.h"
#include "less.h"
#include "pair.h"
namespace ctl {
template<typename Key, typename Compare = ctl::less<Key>>
class set
{
struct rbtree
{
uintptr_t left_;
rbtree* right;
rbtree* parent;
Key value;
rbtree* left() const
{
return (rbtree*)(left_ & -2);
}
void left(rbtree* val)
{
left_ = (uintptr_t)val | (left_ & 1);
}
bool is_red() const
{
return left_ & 1;
}
void is_red(bool val)
{
left_ &= -2;
left_ |= val;
}
rbtree(const Key& val)
: left_(1), right(nullptr), parent(nullptr), value(val)
{
}
};
public:
using key_type = Key;
using value_type = Key;
using size_type = size_t;
using node_type = rbtree;
using key_compare = Compare;
using value_compare = Compare;
using difference_type = ptrdiff_t;
using reference = value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using const_reference = const value_type&;
class iterator
{
public:
using value_type = Key;
using difference_type = ptrdiff_t;
using pointer = Key*;
using reference = Key&;
iterator() : node_(nullptr)
{
}
reference operator*() const
{
return node_->value;
}
pointer operator->() const
{
return &(node_->value);
}
iterator& operator++()
{
if (node_ == nullptr)
return *this;
if (node_->right) {
node_ = leftmost(node_->right);
} else {
node_type* parent = node_->parent;
while (parent && node_ == parent->right) {
node_ = parent;
parent = parent->parent;
}
node_ = parent;
}
return *this;
}
iterator& operator--()
{
if (node_ == nullptr)
__builtin_trap();
if (node_->left()) {
node_ = rightmost(node_->left());
} else {
node_type* parent = node_->parent;
for (;;) {
if (parent == nullptr)
break;
if (node_ == parent->left()) {
node_ = parent;
parent = parent->parent;
} else {
break;
}
}
node_ = parent;
}
return *this;
}
iterator operator++(int)
{
iterator tmp = *this;
++(*this);
return tmp;
}
iterator operator--(int)
{
iterator tmp = *this;
--(*this);
return tmp;
}
bool operator==(const iterator& other) const
{
return node_ == other.node_;
}
bool operator!=(const iterator& other) const
{
return !(*this == other);
}
private:
friend class set;
node_type* node_;
explicit iterator(node_type* node) : node_(node)
{
}
};
class reverse_iterator
{
public:
using value_type = Key;
using difference_type = ptrdiff_t;
using pointer = Key*;
using reference = Key&;
reverse_iterator() : node_(nullptr)
{
}
reference operator*() const
{
return node_->value;
}
pointer operator->() const
{
return &(node_->value);
}
reverse_iterator& operator++()
{
if (node_->left()) {
node_ = rightmost(node_->left());
} else {
node_type* parent = node_->parent;
while (parent && node_ == parent->left()) {
node_ = parent;
parent = parent->parent;
}
node_ = parent;
}
return *this;
}
reverse_iterator operator++(int)
{
reverse_iterator tmp = *this;
++(*this);
return tmp;
}
reverse_iterator& operator--()
{
if (node_->right) {
node_ = leftmost(node_->right);
} else {
node_type* parent = node_->parent;
while (parent && node_ == parent->right) {
node_ = parent;
parent = parent->parent;
}
node_ = parent;
}
return *this;
}
reverse_iterator operator--(int)
{
reverse_iterator tmp = *this;
--(*this);
return tmp;
}
bool operator==(const reverse_iterator& other) const
{
return node_ == other.node_;
}
bool operator!=(const reverse_iterator& other) const
{
return !(*this == other);
}
iterator base() const
{
if (node_ == nullptr)
return iterator(leftmost(root_));
reverse_iterator tmp = *this;
++tmp;
return iterator(tmp.node_);
}
private:
friend class set;
node_type* node_;
node_type* root_;
explicit reverse_iterator(node_type* node, node_type* root) : node_(node), root_(root)
{
}
};
using const_iterator = iterator;
using const_reverse_iterator = reverse_iterator;
set() : root_(nullptr), size_(0), comp_(Compare())
{
}
explicit set(const Compare& comp) : root_(nullptr), size_(0), comp_(comp)
{
}
template<class InputIt>
set(InputIt first, InputIt last, const Compare& comp = Compare())
: root_(nullptr), size_(0), comp_(comp)
{
for (; first != last; ++first)
insert(*first);
}
set(const set& other) : root_(nullptr), size_(0)
{
if (other.root_) {
root_ = copier(other.root_);
size_ = other.size_;
}
}
set(set&& other) noexcept : root_(other.root_), size_(other.size_)
{
other.root_ = nullptr;
other.size_ = 0;
}
set(std::initializer_list<value_type> init, const Compare& comp = Compare())
: root_(nullptr), size_(0), comp_(comp)
{
for (const auto& value : init)
insert(value);
}
~set()
{
clear();
}
set& operator=(const set& other)
{
if (this != &other) {
clear();
if (other.root_) {
root_ = copier(other.root_);
size_ = other.size_;
}
}
return *this;
}
set& operator=(set&& other) noexcept
{
if (this != &other) {
clear();
root_ = other.root_;
size_ = other.size_;
other.root_ = nullptr;
other.size_ = 0;
}
return *this;
}
bool empty() const noexcept
{
return size_ == 0;
}
size_type size() const noexcept
{
return size_;
}
iterator begin() noexcept
{
if (root_)
return iterator(leftmost(root_));
return iterator(nullptr);
}
const_iterator begin() const noexcept
{
if (root_)
return const_iterator(leftmost(root_));
return const_iterator(nullptr);
}
const_iterator cbegin() const noexcept
{
return begin();
}
reverse_iterator rbegin()
{
return reverse_iterator(rightmost(root_), root_);
}
const_reverse_iterator rbegin() const
{
return const_reverse_iterator(rightmost(root_), root_);
}
const_reverse_iterator crbegin() const
{
return const_reverse_iterator(rightmost(root_), root_);
}
iterator end() noexcept
{
return iterator(nullptr);
}
const_iterator end() const noexcept
{
return const_iterator(nullptr);
}
const_iterator cend() const noexcept
{
return end();
}
reverse_iterator rend()
{
return reverse_iterator(nullptr, root_);
}
const_reverse_iterator rend() const
{
return const_reverse_iterator(nullptr, root_);
}
const_reverse_iterator crend() const
{
return const_reverse_iterator(nullptr, root_);
}
void clear() noexcept
{
clearer(root_);
root_ = nullptr;
size_ = 0;
}
ctl::pair<iterator, bool> insert(value_type&& value)
{
return insert_node(new node_type(ctl::move(value)));
}
ctl::pair<iterator, bool> insert(const value_type& value)
{
return insert_node(new node_type(value));
}
iterator insert(const_iterator hint, const value_type& value)
{
return insert(value).first;
}
iterator insert(const_iterator hint, value_type&& value)
{
return insert(ctl::move(value)).first;
}
template<class InputIt>
void insert(InputIt first, InputIt last)
{
for (; first != last; ++first)
insert(*first);
}
void insert(std::initializer_list<value_type> ilist)
{
for (const auto& value : ilist)
insert(value);
}
template<class... Args>
ctl::pair<iterator, bool> emplace(Args&&... args)
{
value_type value(ctl::forward<Args>(args)...);
return insert(ctl::move(value));
}
iterator erase(iterator pos)
{
node_type* node = pos.node_;
iterator res = ++pos;
erase_node(node);
return res;
}
size_type erase(const key_type& key)
{
node_type* node = get_element(key);
if (node == nullptr)
return 0;
erase_node(node);
return 1;
}
iterator erase(const_iterator first, const_iterator last)
{
while (first != last)
first = erase(first);
return iterator(const_cast<node_type*>(last.node_));
}
void swap(set& other) noexcept
{
ctl::swap(root_, other.root_);
ctl::swap(size_, other.size_);
}
ctl::pair<iterator, iterator> equal_range(const key_type& key)
{
return { iterator(get_floor(key)), iterator(get_ceiling(key)) };
}
ctl::pair<const_iterator, const_iterator> equal_range(
const key_type& key) const
{
return { const_iterator(get_floor(key)),
const_iterator(get_ceiling(key)) };
}
iterator lower_bound(const key_type& key)
{
return iterator(get_floor(key));
}
const_iterator lower_bound(const key_type& key) const
{
return const_iterator(get_floor(key));
}
iterator upper_bound(const key_type& key)
{
return iterator(get_ceiling(key));
}
const_iterator upper_bound(const key_type& key) const
{
return const_iterator(get_ceiling(key));
}
size_type count(const key_type& key) const
{
return !!get_element(key);
}
iterator find(const key_type& key)
{
return iterator(get_element(key));
}
const_iterator find(const key_type& key) const
{
return const_iterator(get_element(key));
}
template<class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args)
{
return emplace(ctl::forward<Args>(args)...).first;
}
void check() const
{
size_type count = 0;
if (root_ != nullptr) {
if (root_->is_red())
// ILLEGAL TREE: root node must be black
__builtin_trap();
int black_height = -1;
checker(root_, nullptr, 0, black_height);
count = tally(root_);
}
if (count != size_)
// ILLEGAL TREE: unexpected number of nodes
__builtin_trap();
}
private:
static node_type* leftmost(node_type* node) noexcept
{
while (node && node->left())
node = node->left();
return node;
}
static node_type* rightmost(node_type* node) noexcept
{
while (node && node->right)
node = node->right;
return node;
}
static optimizesize void clearer(node_type* node) noexcept
{
node_type* right;
for (; node; node = right) {
right = node->right;
clearer(node->left());
delete node;
}
}
static optimizesize node_type* copier(const node_type* node)
{
if (node == nullptr)
return nullptr;
node_type* new_node = new node_type(node->value);
new_node->left(copier(node->left()));
new_node->right = copier(node->right);
if (new_node->left())
new_node->left()->parent = new_node;
if (new_node->right)
new_node->right->parent = new_node;
return new_node;
}
static optimizesize size_type tally(const node_type* node)
{
if (node == nullptr)
return 0;
return 1 + tally(node->left()) + tally(node->right);
}
template<typename K>
node_type* get_element(const K& key) const
{
node_type* current = root_;
while (current != nullptr) {
if (comp_(key, current->value)) {
current = current->left();
} else if (comp_(current->value, key)) {
current = current->right;
} else {
return current;
}
}
return nullptr;
}
template<typename K>
node_type* get_floor(const K& key) const
{
node_type* result = nullptr;
node_type* current = root_;
while (current != nullptr) {
if (!comp_(current->value, key)) {
result = current;
current = current->left();
} else {
current = current->right;
}
}
return result;
}
template<typename K>
node_type* get_ceiling(const K& key) const
{
node_type* result = nullptr;
node_type* current = root_;
while (current != nullptr) {
if (comp_(key, current->value)) {
result = current;
current = current->left();
} else {
current = current->right;
}
}
return result;
}
optimizesize ctl::pair<iterator, bool> insert_node(node_type* node)
{
if (root_ == nullptr) {
root_ = node;
root_->is_red(false);
size_++;
return { iterator(root_), true };
}
node_type* current = root_;
node_type* parent = nullptr;
while (current != nullptr) {
parent = current;
if (comp_(node->value, current->value)) {
current = current->left();
} else if (comp_(current->value, node->value)) {
current = current->right;
} else {
delete node; // already exists
return { iterator(current), false };
}
}
if (comp_(node->value, parent->value)) {
parent->left(node);
} else {
parent->right = node;
}
node->parent = parent;
size_++;
rebalance_after_insert(node);
return { iterator(node), true };
}
optimizesize void erase_node(node_type* node)
{
node_type* y = node;
node_type* x = nullptr;
node_type* x_parent = nullptr;
bool y_original_color = y->is_red();
if (node->left() == nullptr) {
x = node->right;
transplant(node, node->right);
x_parent = node->parent;
} else if (node->right == nullptr) {
x = node->left();
transplant(node, node->left());
x_parent = node->parent;
} else {
y = leftmost(node->right);
y_original_color = y->is_red();
x = y->right;
if (y->parent == node) {
if (x)
x->parent = y;
x_parent = y;
} else {
transplant(y, y->right);
y->right = node->right;
y->right->parent = y;
x_parent = y->parent;
}
transplant(node, y);
y->left(node->left());
y->left()->parent = y;
y->is_red(node->is_red());
}
if (!y_original_color)
rebalance_after_erase(x, x_parent);
delete node;
--size_;
}
optimizesize void left_rotate(node_type* x)
{
node_type* y = x->right;
x->right = y->left();
if (y->left() != nullptr)
y->left()->parent = x;
y->parent = x->parent;
if (x->parent == nullptr) {
root_ = y;
} else if (x == x->parent->left()) {
x->parent->left(y);
} else {
x->parent->right = y;
}
y->left(x);
x->parent = y;
}
optimizesize void right_rotate(node_type* y)
{
node_type* x = y->left();
y->left(x->right);
if (x->right != nullptr)
x->right->parent = y;
x->parent = y->parent;
if (y->parent == nullptr) {
root_ = x;
} else if (y == y->parent->right) {
y->parent->right = x;
} else {
y->parent->left(x);
}
x->right = y;
y->parent = x;
}
optimizesize void transplant(node_type* u, node_type* v)
{
if (u->parent == nullptr) {
root_ = v;
} else if (u == u->parent->left()) {
u->parent->left(v);
} else {
u->parent->right = v;
}
if (v != nullptr)
v->parent = u->parent;
}
optimizesize void checker(const node_type* node,
const node_type* parent,
int black_count,
int& black_height) const
{
if (node == nullptr) {
// Leaf nodes are considered black
if (black_height == -1) {
black_height = black_count;
} else if (black_count != black_height) {
// ILLEGAL TREE: Black height mismatch
__builtin_trap();
}
return;
}
if (node->parent != parent)
// ILLEGAL TREE: Parent link is incorrect
__builtin_trap();
if (parent) {
if (parent->left() == node && !comp_(node->value, parent->value))
// ILLEGAL TREE: Binary search property violated on left child
__builtin_trap();
if (parent->right == node && !comp_(parent->value, node->value))
// ILLEGAL TREE: Binary search property violated on right child
__builtin_trap();
}
if (!node->is_red()) {
black_count++;
} else if (parent != nullptr && parent->is_red()) {
// ILLEGAL TREE: Red node has red child
__builtin_trap();
}
checker(node->left(), node, black_count, black_height);
checker(node->right, node, black_count, black_height);
}
optimizesize void rebalance_after_insert(node_type* node)
{
node->is_red(true);
while (node != root_ && node->parent->is_red()) {
if (node->parent == node->parent->parent->left()) {
node_type* uncle = node->parent->parent->right;
if (uncle && uncle->is_red()) {
node->parent->is_red(false);
uncle->is_red(false);
node->parent->parent->is_red(true);
node = node->parent->parent;
} else {
if (node == node->parent->right) {
node = node->parent;
left_rotate(node);
}
node->parent->is_red(false);
node->parent->parent->is_red(true);
right_rotate(node->parent->parent);
}
} else {
node_type* uncle = node->parent->parent->left();
if (uncle && uncle->is_red()) {
node->parent->is_red(false);
uncle->is_red(false);
node->parent->parent->is_red(true);
node = node->parent->parent;
} else {
if (node == node->parent->left()) {
node = node->parent;
right_rotate(node);
}
node->parent->is_red(false);
node->parent->parent->is_red(true);
left_rotate(node->parent->parent);
}
}
}
root_->is_red(false);
}
optimizesize void rebalance_after_erase(node_type* node, node_type* parent)
{
while (node != root_ && (node == nullptr || !node->is_red())) {
if (node == parent->left()) {
node_type* sibling = parent->right;
if (sibling->is_red()) {
sibling->is_red(false);
parent->is_red(true);
left_rotate(parent);
sibling = parent->right;
}
if ((sibling->left() == nullptr ||
!sibling->left()->is_red()) &&
(sibling->right == nullptr || !sibling->right->is_red())) {
sibling->is_red(true);
node = parent;
parent = node->parent;
} else {
if (sibling->right == nullptr ||
!sibling->right->is_red()) {
sibling->left()->is_red(false);
sibling->is_red(true);
right_rotate(sibling);
sibling = parent->right;
}
sibling->is_red(parent->is_red());
parent->is_red(false);
sibling->right->is_red(false);
left_rotate(parent);
node = root_;
break;
}
} else {
node_type* sibling = parent->left();
if (sibling->is_red()) {
sibling->is_red(false);
parent->is_red(true);
right_rotate(parent);
sibling = parent->left();
}
if ((sibling->right == nullptr || !sibling->right->is_red()) &&
(sibling->left() == nullptr ||
!sibling->left()->is_red())) {
sibling->is_red(true);
node = parent;
parent = node->parent;
} else {
if (sibling->left() == nullptr ||
!sibling->left()->is_red()) {
sibling->right->is_red(false);
sibling->is_red(true);
left_rotate(sibling);
sibling = parent->left();
}
sibling->is_red(parent->is_red());
parent->is_red(false);
sibling->left()->is_red(false);
right_rotate(parent);
node = root_;
break;
}
}
}
if (node != nullptr)
node->is_red(false);
}
node_type* root_;
size_type size_;
Compare comp_;
};
template<class Key, typename Compare = ctl::less<Key>>
bool
operator==(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
if (lhs.size() != rhs.size())
return false;
auto i = lhs.cbegin();
auto j = rhs.cbegin();
for (; i != lhs.end(); ++i, ++j) {
if (!(*i == *j))
return false;
}
return true;
}
template<class Key, typename Compare = ctl::less<Key>>
bool
operator<(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
auto i = lhs.cbegin();
auto j = rhs.cbegin();
for (; i != lhs.end() && j != rhs.end(); ++i, ++j) {
if (Compare()(*i, *j))
return true;
if (Compare()(*j, *i))
return false;
}
return i == lhs.end() && j != rhs.end();
}
template<class Key, typename Compare = ctl::less<Key>>
bool
operator!=(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return !(lhs == rhs);
}
template<class Key, typename Compare = ctl::less<Key>>
bool
operator<=(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return !(rhs < lhs);
}
template<class Key, typename Compare = ctl::less<Key>>
bool
operator>(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return rhs < lhs;
}
template<class Key, typename Compare = ctl::less<Key>>
bool
operator>=(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return !(lhs < rhs);
}
template<class Key, typename Compare = ctl::less<Key>>
void
swap(set<Key, Compare>& lhs, set<Key, Compare>& rhs) noexcept;
} // namespace ctl
#endif // CTL_SET_H_