Introduce ctl::set and ctl::map

We now have a C++ red-black tree implementation that implements standard
template library compatible APIs while compiling 10x faster than libcxx.
It's not as beautiful as the red-black tree implementation in Plinko but
this will get the job done and the test proves it upholds all invariants

This change also restores CheckForMemoryLeaks() support and fixes a real
actual bug I discovered with Doug Lea's dlmalloc_inspect_all() function.
This commit is contained in:
Justine Tunney 2024-06-23 10:08:48 -07:00
parent 388e236360
commit c4c812c154
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
45 changed files with 2358 additions and 135 deletions

64
ctl/equal.h Normal file
View file

@ -0,0 +1,64 @@
// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#ifndef CTL_EQUAL_H_
#define CTL_EQUAL_H_
namespace ctl {
template<class InputIt1, class InputIt2>
bool
equal(InputIt1 first1, InputIt1 last1, InputIt2 first2)
{
for (; first1 != last1; ++first1, ++first2) {
if (!(*first1 == *first2)) {
return false;
}
}
return true;
}
// Overload that takes a predicate
template<class InputIt1, class InputIt2, class BinaryPredicate>
bool
equal(InputIt1 first1, InputIt1 last1, InputIt2 first2, BinaryPredicate pred)
{
for (; first1 != last1; ++first1, ++first2) {
if (!pred(*first1, *first2)) {
return false;
}
}
return true;
}
// Overloads that take two ranges (C++14 and later)
template<class InputIt1, class InputIt2>
bool
equal(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2)
{
for (; first1 != last1 && first2 != last2; ++first1, ++first2) {
if (!(*first1 == *first2)) {
return false;
}
}
return first1 == last1 && first2 == last2;
}
template<class InputIt1, class InputIt2, class BinaryPredicate>
bool
equal(InputIt1 first1,
InputIt1 last1,
InputIt2 first2,
InputIt2 last2,
BinaryPredicate pred)
{
for (; first1 != last1 && first2 != last2; ++first1, ++first2) {
if (!pred(*first1, *first2)) {
return false;
}
}
return first1 == last1 && first2 == last2;
}
} // namespace ctl
#endif /* CTL_EQUAL_H_ */

64
ctl/initializer_list.h Normal file
View file

@ -0,0 +1,64 @@
// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#ifndef _LIBCPP_INITIALIZER_LIST
#define _LIBCPP_INITIALIZER_LIST
namespace std {
template<class T>
class initializer_list
{
constexpr initializer_list(const T* b, size_t s) noexcept
: begin_(b), size_(s)
{
}
public:
typedef T value_type;
typedef const T& reference;
typedef const T& const_reference;
typedef size_t size_type;
typedef const T* iterator;
typedef const T* const_iterator;
constexpr initializer_list() noexcept : begin_(nullptr), size_(0)
{
}
constexpr size_t size() const noexcept
{
return size_;
}
constexpr const T* begin() const noexcept
{
return begin_;
}
constexpr const T* end() const noexcept
{
return begin_ + size_;
}
private:
const T* begin_;
size_t size_;
};
template<class T>
inline constexpr const T*
begin(initializer_list<T> i) noexcept
{
return i.begin();
}
template<class T>
inline constexpr const T*
end(initializer_list<T> i) noexcept
{
return i.end();
}
} // namespace std
#endif // _LIBCPP_INITIALIZER_LIST

38
ctl/less.h Normal file
View file

@ -0,0 +1,38 @@
// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#ifndef CTL_LESS_H_
#define CTL_LESS_H_
#include "utility.h"
namespace ctl {
template<class T = void>
struct less
{
constexpr bool operator()(const T& lhs, const T& rhs) const
{
return lhs < rhs;
}
typedef T first_argument_type;
typedef T second_argument_type;
typedef bool result_type;
};
template<>
struct less<void>
{
template<class T, class U>
constexpr auto operator()(T&& lhs,
U&& rhs) const -> decltype(ctl::forward<T>(lhs) <
ctl::forward<U>(rhs))
{
return ctl::forward<T>(lhs) < ctl::forward<U>(rhs);
}
typedef void is_transparent;
};
} // namespace ctl
#endif /* CTL_LESS_H_ */

352
ctl/map.h Normal file
View file

@ -0,0 +1,352 @@
// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#ifndef CTL_MAP_H_
#define CTL_MAP_H_
#include "set.h"
namespace ctl {
template<typename Key, typename Value, typename Compare = ctl::less<Key>>
class map
{
class EntryCompare
{
public:
explicit EntryCompare(Compare comp = Compare()) : comp_(comp)
{
}
bool operator()(const ctl::pair<const Key, Value>& lhs,
const ctl::pair<const Key, Value>& rhs) const
{
return comp_(lhs.first, rhs.first);
}
private:
Compare comp_;
};
ctl::set<ctl::pair<const Key, Value>, EntryCompare> data_;
public:
using key_type = Key;
using mapped_type = Value;
using value_type = ctl::pair<const Key, Value>;
using size_type = typename ctl::set<value_type, EntryCompare>::size_type;
using difference_type =
typename ctl::set<value_type, EntryCompare>::difference_type;
using key_compare = Compare;
using value_compare = EntryCompare;
using iterator = typename ctl::set<value_type, EntryCompare>::iterator;
using const_iterator =
typename ctl::set<value_type, EntryCompare>::const_iterator;
using reverse_iterator =
typename ctl::set<value_type, EntryCompare>::reverse_iterator;
using const_reverse_iterator =
typename ctl::set<value_type, EntryCompare>::const_reverse_iterator;
map() : data_(EntryCompare())
{
}
explicit map(const Compare& comp) : data_(EntryCompare(comp))
{
}
map(const map& other) = default;
map(map&& other) noexcept = default;
map(std::initializer_list<value_type> init, const Compare& comp = Compare())
: data_(init, EntryCompare(comp))
{
}
template<typename InputIt>
map(InputIt first, InputIt last, const Compare& comp = Compare())
: data_(first, last, EntryCompare(comp))
{
}
map& operator=(const map& other) = default;
map& operator=(map&& other) noexcept = default;
map& operator=(std::initializer_list<value_type> ilist)
{
data_ = ilist;
return *this;
}
iterator begin() noexcept
{
return data_.begin();
}
const_iterator begin() const noexcept
{
return data_.begin();
}
const_iterator cbegin() const noexcept
{
return data_.cbegin();
}
iterator end() noexcept
{
return data_.end();
}
const_iterator end() const noexcept
{
return data_.end();
}
const_iterator cend() const noexcept
{
return data_.cend();
}
reverse_iterator rbegin() noexcept
{
return data_.rbegin();
}
const_reverse_iterator rbegin() const noexcept
{
return data_.rbegin();
}
const_reverse_iterator crbegin() const noexcept
{
return data_.crbegin();
}
reverse_iterator rend() noexcept
{
return data_.rend();
}
const_reverse_iterator rend() const noexcept
{
return data_.rend();
}
const_reverse_iterator crend() const noexcept
{
return data_.crend();
}
bool empty() const noexcept
{
return data_.empty();
}
size_type size() const noexcept
{
return data_.size();
}
size_type max_size() const noexcept
{
return data_.max_size();
}
Value& operator[](const Key& key)
{
return ((data_.insert(make_pair(key, Value()))).first)->second;
}
Value& operator[](Key&& key)
{
return ((data_.insert(make_pair(ctl::move(key), Value()))).first)
->second;
}
Value& at(const Key& key)
{
auto it = find(key);
if (it == end())
__builtin_trap();
return it->second;
}
const Value& at(const Key& key) const
{
auto it = find(key);
if (it == end())
__builtin_trap();
return it->second;
}
ctl::pair<iterator, bool> insert(const value_type& value)
{
return data_.insert(value);
}
ctl::pair<iterator, bool> insert(value_type&& value)
{
return data_.insert(ctl::move(value));
}
template<typename P>
ctl::pair<iterator, bool> insert(P&& value)
{
return data_.insert(value_type(ctl::forward<P>(value)));
}
iterator insert(const_iterator hint, const value_type& value)
{
return data_.insert(hint, value);
}
iterator insert(const_iterator hint, value_type&& value)
{
return data_.insert(hint, ctl::move(value));
}
template<typename P>
iterator insert(const_iterator hint, P&& value)
{
return data_.insert(hint, value_type(ctl::forward<P>(value)));
}
template<typename InputIt>
void insert(InputIt first, InputIt last)
{
data_.insert(first, last);
}
void insert(std::initializer_list<value_type> ilist)
{
data_.insert(ilist);
}
template<typename... Args>
ctl::pair<iterator, bool> emplace(Args&&... args)
{
return data_.emplace(ctl::forward<Args>(args)...);
}
template<typename... Args>
iterator emplace_hint(const_iterator hint, Args&&... args)
{
return data_.emplace_hint(hint, ctl::forward<Args>(args)...);
}
iterator erase(const_iterator pos)
{
return data_.erase(pos);
}
iterator erase(const_iterator first, const_iterator last)
{
return data_.erase(first, last);
}
size_type erase(const Key& key)
{
return data_.erase(make_pair(key, Value()));
}
void swap(map& other) noexcept
{
data_.swap(other.data_);
}
void clear() noexcept
{
data_.clear();
}
iterator find(const Key& key)
{
return data_.find(make_pair(key, Value()));
}
const_iterator find(const Key& key) const
{
return data_.find(make_pair(key, Value()));
}
size_type count(const Key& key) const
{
return data_.count(make_pair(key, Value()));
}
iterator lower_bound(const Key& key)
{
return data_.lower_bound(make_pair(key, Value()));
}
const_iterator lower_bound(const Key& key) const
{
return data_.lower_bound(make_pair(key, Value()));
}
iterator upper_bound(const Key& key)
{
return data_.upper_bound(make_pair(key, Value()));
}
const_iterator upper_bound(const Key& key) const
{
return data_.upper_bound(make_pair(key, Value()));
}
ctl::pair<iterator, iterator> equal_range(const Key& key)
{
return data_.equal_range(make_pair(key, Value()));
}
ctl::pair<const_iterator, const_iterator> equal_range(const Key& key) const
{
return data_.equal_range(make_pair(key, Value()));
}
key_compare key_comp() const
{
return data_.value_comp().comp;
}
value_compare value_comp() const
{
return data_.value_comp();
}
friend bool operator==(const map& lhs, const map& rhs)
{
return lhs.data_ == rhs.data_;
}
friend bool operator!=(const map& lhs, const map& rhs)
{
return !(lhs == rhs);
}
friend bool operator<(const map& lhs, const map& rhs)
{
return lhs.data_ < rhs.data_;
}
friend bool operator<=(const map& lhs, const map& rhs)
{
return !(rhs < lhs);
}
friend bool operator>(const map& lhs, const map& rhs)
{
return rhs < lhs;
}
friend bool operator>=(const map& lhs, const map& rhs)
{
return !(lhs < rhs);
}
friend void swap(map& lhs, map& rhs) noexcept
{
lhs.swap(rhs);
}
};
} // namespace ctl
#endif // CTL_MAP_H_

View file

@ -34,7 +34,10 @@ _ctl_alloc(size_t n, size_t a)
static void*
_ctl_alloc1(size_t n)
{
return _ctl_alloc(n, 1);
void* p;
if (!(p = malloc(n)))
__builtin_trap();
return p;
}
static void*
@ -68,10 +71,21 @@ COSMOPOLITAN_C_END_
now. If you have any brain cells left after reading this comment then go
look at the eight operator delete weak references to free in the below. */
__weak_reference(void* operator new(size_t, ctl::align_val_t), _ctl_alloc);
__weak_reference(void* operator new[](size_t, ctl::align_val_t), _ctl_alloc);
__weak_reference(void* operator new(size_t), _ctl_alloc1);
__weak_reference(void* operator new[](size_t), _ctl_alloc1);
__weak_reference(void*
operator new(size_t, ctl::align_val_t),
_ctl_alloc);
__weak_reference(void*
operator new[](size_t, ctl::align_val_t),
_ctl_alloc);
__weak_reference(void*
operator new(size_t),
_ctl_alloc1);
__weak_reference(void*
operator new[](size_t),
_ctl_alloc1);
// XXX clang-format currently mutilates these for some reason.
// clang-format off

View file

@ -3,6 +3,10 @@
#ifndef COSMOPOLITAN_CTL_NEW_H_
#define COSMOPOLITAN_CTL_NEW_H_
#ifndef TINY
__static_yoink("__demangle");
#endif
namespace ctl {
enum class align_val_t : size_t

150
ctl/pair.h Normal file
View file

@ -0,0 +1,150 @@
// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#ifndef CTL_PAIR_H_
#define CTL_PAIR_H_
#include "utility.h"
namespace ctl {
template<typename T1, typename T2>
struct pair
{
using first_type = T1;
using second_type = T2;
T1 first;
T2 second;
constexpr pair() : first(), second()
{
}
constexpr pair(const T1& x, const T2& y) : first(x), second(y)
{
}
constexpr pair(const pair<T1, T2>& p) : first(p.first), second(p.second)
{
}
template<class U1, class U2>
constexpr pair(U1&& x, U2&& y)
: first(ctl::forward<U1>(x)), second(ctl::forward<U2>(y))
{
}
template<class U1, class U2>
constexpr pair(const pair<U1, U2>& p) : first(p.first), second(p.second)
{
}
constexpr pair(pair<T1, T2>&& p)
: first(ctl::move(p.first)), second(ctl::move(p.second))
{
}
template<class U1, class U2>
constexpr pair(pair<U1, U2>&& p)
: first(ctl::move(p.first)), second(ctl::move(p.second))
{
}
pair& operator=(const pair& other)
{
first = other.first;
second = other.second;
return *this;
}
pair& operator=(pair&& other) noexcept
{
first = ctl::move(other.first);
second = ctl::move(other.second);
return *this;
}
template<class U1, class U2>
pair& operator=(const pair<U1, U2>& other)
{
first = other.first;
second = other.second;
return *this;
}
template<class U1, class U2>
pair& operator=(pair<U1, U2>&& other)
{
first = ctl::move(other.first);
second = ctl::move(other.second);
return *this;
}
void swap(pair& other) noexcept
{
using ctl::swap;
swap(first, other.first);
swap(second, other.second);
}
};
template<class T1, class T2>
constexpr pair<T1, T2>
make_pair(T1&& t, T2&& u)
{
return pair<T1, T2>(ctl::forward<T1>(t), ctl::forward<T2>(u));
}
// Comparison operators
template<class T1, class T2>
constexpr bool
operator==(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs)
{
return lhs.first == rhs.first && lhs.second == rhs.second;
}
template<class T1, class T2>
constexpr bool
operator!=(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs)
{
return !(lhs == rhs);
}
template<class T1, class T2>
constexpr bool
operator<(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs)
{
return lhs.first < rhs.first ||
(!(rhs.first < lhs.first) && lhs.second < rhs.second);
}
template<class T1, class T2>
constexpr bool
operator<=(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs)
{
return !(rhs < lhs);
}
template<class T1, class T2>
constexpr bool
operator>(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs)
{
return rhs < lhs;
}
template<class T1, class T2>
constexpr bool
operator>=(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs)
{
return !(lhs < rhs);
}
template<class T1, class T2>
void
swap(pair<T1, T2>& lhs, pair<T1, T2>& rhs) noexcept(noexcept(lhs.swap(rhs)))
{
lhs.swap(rhs);
}
} // namespace ctl
#endif // CTL_PAIR_H_

945
ctl/set.h Normal file
View file

@ -0,0 +1,945 @@
// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#ifndef CTL_SET_H_
#define CTL_SET_H_
#include "initializer_list.h"
#include "less.h"
#include "pair.h"
namespace ctl {
template<typename Key, typename Compare = less<Key>>
class set
{
struct rbtree
{
rbtree* left;
rbtree* right;
rbtree* parent;
bool is_red;
Key value;
rbtree(const Key& val)
: left(nullptr)
, right(nullptr)
, parent(nullptr)
, is_red(true)
, value(val)
{
}
};
public:
using key_type = Key;
using value_type = Key;
using size_type = size_t;
using node_type = rbtree;
using key_compare = Compare;
using value_compare = Compare;
using difference_type = ptrdiff_t;
using reference = value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using const_reference = const value_type&;
class iterator
{
public:
using value_type = Key;
using difference_type = ptrdiff_t;
using pointer = Key*;
using reference = Key&;
iterator() : node_(nullptr)
{
}
reference operator*() const
{
return node_->value;
}
pointer operator->() const
{
return &(node_->value);
}
iterator& operator++()
{
if (node_ == nullptr)
return *this;
if (node_->right) {
node_ = leftmost(node_->right);
} else {
node_type* parent = node_->parent;
while (parent && node_ == parent->right) {
node_ = parent;
parent = parent->parent;
}
node_ = parent;
}
return *this;
}
iterator& operator--()
{
if (node_ == nullptr)
__builtin_trap();
if (node_->left) {
node_ = rightmost(node_->left);
} else {
node_type* parent = node_->parent;
for (;;) {
if (parent == nullptr)
break;
if (node_ == parent->left) {
node_ = parent;
parent = parent->parent;
} else {
break;
}
}
node_ = parent;
}
return *this;
}
iterator operator++(int)
{
iterator tmp = *this;
++(*this);
return tmp;
}
iterator operator--(int)
{
iterator tmp = *this;
--(*this);
return tmp;
}
bool operator==(const iterator& other) const
{
return node_ == other.node_;
}
bool operator!=(const iterator& other) const
{
return !(*this == other);
}
private:
friend class set;
node_type* node_;
explicit iterator(node_type* node) : node_(node)
{
}
};
class reverse_iterator
{
public:
using value_type = Key;
using difference_type = ptrdiff_t;
using pointer = Key*;
using reference = Key&;
reverse_iterator() : node_(nullptr)
{
}
reference operator*() const
{
return node_->value;
}
pointer operator->() const
{
return &(node_->value);
}
reverse_iterator& operator++()
{
if (node_->left) {
node_ = rightmost(node_->left);
} else {
node_type* parent = node_->parent;
while (parent && node_ == parent->left) {
node_ = parent;
parent = parent->parent;
}
node_ = parent;
}
return *this;
}
reverse_iterator operator++(int)
{
reverse_iterator tmp = *this;
++(*this);
return tmp;
}
reverse_iterator& operator--()
{
if (node_->right) {
node_ = leftmost(node_->right);
} else {
node_type* parent = node_->parent;
while (parent && node_ == parent->right) {
node_ = parent;
parent = parent->parent;
}
node_ = parent;
}
return *this;
}
reverse_iterator operator--(int)
{
reverse_iterator tmp = *this;
--(*this);
return tmp;
}
bool operator==(const reverse_iterator& other) const
{
return node_ == other.node_;
}
bool operator!=(const reverse_iterator& other) const
{
return !(*this == other);
}
iterator base() const
{
if (node_ == nullptr)
return iterator(leftmost(root_));
reverse_iterator tmp = *this;
++tmp;
return iterator(tmp.node_);
}
private:
friend class set;
node_type* node_;
explicit reverse_iterator(node_type* node) : node_(node)
{
}
};
using const_iterator = iterator;
using const_reverse_iterator = reverse_iterator;
set() : root_(nullptr), size_(0), comp_(Compare())
{
}
explicit set(const Compare& comp) : root_(nullptr), size_(0), comp_(comp)
{
}
template<class InputIt>
set(InputIt first, InputIt last, const Compare& comp = Compare())
: root_(nullptr), size_(0), comp_(comp)
{
for (; first != last; ++first)
insert(*first);
}
set(const set& other) : root_(nullptr), size_(0)
{
if (other.root_) {
root_ = copier(other.root_);
size_ = other.size_;
}
}
set(set&& other) noexcept : root_(other.root_), size_(other.size_)
{
other.root_ = nullptr;
other.size_ = 0;
}
set(std::initializer_list<value_type> init, const Compare& comp = Compare())
: root_(nullptr), size_(0), comp_(comp)
{
for (const auto& value : init)
insert(value);
}
~set()
{
clear();
}
set& operator=(const set& other)
{
if (this != &other) {
clear();
if (other.root_) {
root_ = copier(other.root_);
size_ = other.size_;
}
}
return *this;
}
set& operator=(set&& other) noexcept
{
if (this != &other) {
clear();
root_ = other.root_;
size_ = other.size_;
other.root_ = nullptr;
other.size_ = 0;
}
return *this;
}
bool empty() const noexcept
{
return size_ == 0;
}
size_type size() const noexcept
{
return size_;
}
iterator begin() noexcept
{
if (root_)
return iterator(leftmost(root_));
return iterator(nullptr);
}
const_iterator begin() const noexcept
{
if (root_)
return const_iterator(leftmost(root_));
return const_iterator(nullptr);
}
const_iterator cbegin() const noexcept
{
return begin();
}
reverse_iterator rbegin()
{
return reverse_iterator(rightmost(root_));
}
const_reverse_iterator rbegin() const
{
return const_reverse_iterator(rightmost(root_));
}
const_reverse_iterator crbegin() const
{
return const_reverse_iterator(rightmost(root_));
}
iterator end() noexcept
{
return iterator(nullptr);
}
const_iterator end() const noexcept
{
return const_iterator(nullptr);
}
const_iterator cend() const noexcept
{
return end();
}
reverse_iterator rend()
{
return reverse_iterator(nullptr);
}
const_reverse_iterator rend() const
{
return const_reverse_iterator(nullptr);
}
const_reverse_iterator crend() const
{
return const_reverse_iterator(nullptr);
}
void clear() noexcept
{
clearer(root_);
root_ = nullptr;
size_ = 0;
}
pair<iterator, bool> insert(value_type&& value)
{
return insert_node(new node_type(ctl::move(value)));
}
pair<iterator, bool> insert(const value_type& value)
{
return insert_node(new node_type(value));
}
iterator insert(const_iterator hint, const value_type& value)
{
return insert(value).first;
}
iterator insert(const_iterator hint, value_type&& value)
{
return insert(ctl::move(value)).first;
}
template<class InputIt>
void insert(InputIt first, InputIt last)
{
for (; first != last; ++first)
insert(*first);
}
void insert(std::initializer_list<value_type> ilist)
{
for (const auto& value : ilist)
insert(value);
}
template<class... Args>
pair<iterator, bool> emplace(Args&&... args)
{
value_type value(ctl::forward<Args>(args)...);
return insert(ctl::move(value));
}
iterator erase(iterator pos)
{
node_type* node = pos.node_;
iterator res = ++pos;
erase_node(node);
return res;
}
size_type erase(const key_type& key)
{
node_type* node = get_element(key);
if (node == nullptr)
return 0;
erase_node(node);
return 1;
}
iterator erase(const_iterator first, const_iterator last)
{
while (first != last)
first = erase(first);
return iterator(const_cast<node_type*>(last.node_));
}
void swap(set& other) noexcept
{
ctl::swap(root_, other.root_);
ctl::swap(size_, other.size_);
}
pair<iterator, iterator> equal_range(const key_type& key)
{
return { iterator(get_floor(key)), iterator(get_ceiling(key)) };
}
pair<const_iterator, const_iterator> equal_range(const key_type& key) const
{
return { const_iterator(get_floor(key)),
const_iterator(get_ceiling(key)) };
}
iterator lower_bound(const key_type& key)
{
return iterator(get_floor(key));
}
const_iterator lower_bound(const key_type& key) const
{
return const_iterator(get_floor(key));
}
iterator upper_bound(const key_type& key)
{
return iterator(get_ceiling(key));
}
const_iterator upper_bound(const key_type& key) const
{
return const_iterator(get_ceiling(key));
}
size_type count(const key_type& key) const
{
return !!get_element(key);
}
iterator find(const key_type& key)
{
return iterator(get_element(key));
}
const_iterator find(const key_type& key) const
{
return const_iterator(get_element(key));
}
template<class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args)
{
return emplace(ctl::forward<Args>(args)...).first;
}
void check() const
{
size_type count = 0;
if (root_ != nullptr) {
if (root_->is_red)
// ILLEGAL TREE: root node must be black
__builtin_trap();
int black_height = -1;
checker(root_, nullptr, 0, black_height);
count = tally(root_);
}
if (count != size_)
// ILLEGAL TREE: unexpected number of nodes
__builtin_trap();
}
private:
static node_type* leftmost(node_type* node) noexcept
{
while (node && node->left)
node = node->left;
return node;
}
static node_type* rightmost(node_type* node) noexcept
{
while (node && node->right)
node = node->right;
return node;
}
static void clearer(node_type* node) noexcept
{
node_type* right;
for (; node; node = right) {
right = node->right;
clearer(node->left);
delete node;
}
}
static node_type* copier(const node_type* node)
{
if (node == nullptr)
return nullptr;
node_type* new_node = new node_type(node->value);
new_node->left = copier(node->left);
new_node->right = copier(node->right);
if (new_node->left)
new_node->left->parent = new_node;
if (new_node->right)
new_node->right->parent = new_node;
return new_node;
}
static size_type tally(const node_type* node)
{
if (node == nullptr)
return 0;
return 1 + tally(node->left) + tally(node->right);
}
template<typename K>
node_type* get_element(const K& key) const
{
node_type* current = root_;
while (current != nullptr) {
if (comp_(key, current->value)) {
current = current->left;
} else if (comp_(current->value, key)) {
current = current->right;
} else {
return current;
}
}
return nullptr;
}
template<typename K>
node_type* get_floor(const K& key) const
{
node_type* result = nullptr;
node_type* current = root_;
while (current != nullptr) {
if (!comp_(current->value, key)) {
result = current;
current = current->left;
} else {
current = current->right;
}
}
return result;
}
template<typename K>
node_type* get_ceiling(const K& key) const
{
node_type* result = nullptr;
node_type* current = root_;
while (current != nullptr) {
if (comp_(key, current->value)) {
result = current;
current = current->left;
} else {
current = current->right;
}
}
return result;
}
pair<iterator, bool> insert_node(node_type* node)
{
if (root_ == nullptr) {
root_ = node;
root_->is_red = false;
size_++;
return { iterator(root_), true };
}
node_type* current = root_;
node_type* parent = nullptr;
while (current != nullptr) {
parent = current;
if (comp_(node->value, current->value)) {
current = current->left;
} else if (comp_(current->value, node->value)) {
current = current->right;
} else {
delete node; // already exists
return { iterator(current), false };
}
}
if (comp_(node->value, parent->value)) {
parent->left = node;
} else {
parent->right = node;
}
node->parent = parent;
size_++;
rebalance_after_insert(node);
return { iterator(node), true };
}
void erase_node(node_type* node)
{
node_type* y = node;
node_type* x = nullptr;
node_type* x_parent = nullptr;
bool y_original_color = y->is_red;
if (node->left == nullptr) {
x = node->right;
transplant(node, node->right);
x_parent = node->parent;
} else if (node->right == nullptr) {
x = node->left;
transplant(node, node->left);
x_parent = node->parent;
} else {
y = leftmost(node->right);
y_original_color = y->is_red;
x = y->right;
if (y->parent == node) {
if (x)
x->parent = y;
x_parent = y;
} else {
transplant(y, y->right);
y->right = node->right;
y->right->parent = y;
x_parent = y->parent;
}
transplant(node, y);
y->left = node->left;
y->left->parent = y;
y->is_red = node->is_red;
}
if (!y_original_color)
rebalance_after_erase(x, x_parent);
delete node;
--size_;
}
void left_rotate(node_type* x)
{
node_type* y = x->right;
x->right = y->left;
if (y->left != nullptr)
y->left->parent = x;
y->parent = x->parent;
if (x->parent == nullptr) {
root_ = y;
} else if (x == x->parent->left) {
x->parent->left = y;
} else {
x->parent->right = y;
}
y->left = x;
x->parent = y;
}
void right_rotate(node_type* y)
{
node_type* x = y->left;
y->left = x->right;
if (x->right != nullptr)
x->right->parent = y;
x->parent = y->parent;
if (y->parent == nullptr) {
root_ = x;
} else if (y == y->parent->right) {
y->parent->right = x;
} else {
y->parent->left = x;
}
x->right = y;
y->parent = x;
}
void transplant(node_type* u, node_type* v)
{
if (u->parent == nullptr) {
root_ = v;
} else if (u == u->parent->left) {
u->parent->left = v;
} else {
u->parent->right = v;
}
if (v != nullptr)
v->parent = u->parent;
}
void checker(const node_type* node,
const node_type* parent,
int black_count,
int& black_height) const
{
if (node == nullptr) {
// Leaf nodes are considered black
if (black_height == -1) {
black_height = black_count;
} else if (black_count != black_height) {
// ILLEGAL TREE: Black height mismatch
__builtin_trap();
}
return;
}
if (node->parent != parent)
// ILLEGAL TREE: Parent link is incorrect
__builtin_trap();
if (parent) {
if (parent->left == node && !comp_(node->value, parent->value))
// ILLEGAL TREE: Binary search property violated on left child
__builtin_trap();
if (parent->right == node && !comp_(parent->value, node->value))
// ILLEGAL TREE: Binary search property violated on right child
__builtin_trap();
}
if (!node->is_red) {
black_count++;
} else if (parent != nullptr && parent->is_red) {
// ILLEGAL TREE: Red node has red child
__builtin_trap();
}
checker(node->left, node, black_count, black_height);
checker(node->right, node, black_count, black_height);
}
void rebalance_after_insert(node_type* node)
{
node->is_red = true;
while (node != root_ && node->parent->is_red) {
if (node->parent == node->parent->parent->left) {
node_type* uncle = node->parent->parent->right;
if (uncle && uncle->is_red) {
node->parent->is_red = false;
uncle->is_red = false;
node->parent->parent->is_red = true;
node = node->parent->parent;
} else {
if (node == node->parent->right) {
node = node->parent;
left_rotate(node);
}
node->parent->is_red = false;
node->parent->parent->is_red = true;
right_rotate(node->parent->parent);
}
} else {
node_type* uncle = node->parent->parent->left;
if (uncle && uncle->is_red) {
node->parent->is_red = false;
uncle->is_red = false;
node->parent->parent->is_red = true;
node = node->parent->parent;
} else {
if (node == node->parent->left) {
node = node->parent;
right_rotate(node);
}
node->parent->is_red = false;
node->parent->parent->is_red = true;
left_rotate(node->parent->parent);
}
}
}
root_->is_red = false;
}
void rebalance_after_erase(node_type* node, node_type* parent)
{
while (node != root_ && (node == nullptr || !node->is_red)) {
if (node == parent->left) {
node_type* sibling = parent->right;
if (sibling->is_red) {
sibling->is_red = false;
parent->is_red = true;
left_rotate(parent);
sibling = parent->right;
}
if ((sibling->left == nullptr || !sibling->left->is_red) &&
(sibling->right == nullptr || !sibling->right->is_red)) {
sibling->is_red = true;
node = parent;
parent = node->parent;
} else {
if (sibling->right == nullptr || !sibling->right->is_red) {
sibling->left->is_red = false;
sibling->is_red = true;
right_rotate(sibling);
sibling = parent->right;
}
sibling->is_red = parent->is_red;
parent->is_red = false;
sibling->right->is_red = false;
left_rotate(parent);
node = root_;
break;
}
} else {
node_type* sibling = parent->left;
if (sibling->is_red) {
sibling->is_red = false;
parent->is_red = true;
right_rotate(parent);
sibling = parent->left;
}
if ((sibling->right == nullptr || !sibling->right->is_red) &&
(sibling->left == nullptr || !sibling->left->is_red)) {
sibling->is_red = true;
node = parent;
parent = node->parent;
} else {
if (sibling->left == nullptr || !sibling->left->is_red) {
sibling->right->is_red = false;
sibling->is_red = true;
left_rotate(sibling);
sibling = parent->left;
}
sibling->is_red = parent->is_red;
parent->is_red = false;
sibling->left->is_red = false;
right_rotate(parent);
node = root_;
break;
}
}
}
if (node != nullptr)
node->is_red = false;
}
node_type* root_;
size_type size_;
Compare comp_;
};
template<class Key, typename Compare = less<Key>>
bool
operator==(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
if (lhs.size() != rhs.size())
return false;
auto i = lhs.cbegin();
auto j = rhs.cbegin();
for (; i != lhs.end(); ++i, ++j) {
if (!(*i == *j))
return false;
}
return true;
}
template<class Key, typename Compare = less<Key>>
bool
operator<(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
auto i = lhs.cbegin();
auto j = rhs.cbegin();
for (; i != lhs.end() && j != rhs.end(); ++i, ++j) {
if (Compare()(*i, *j))
return true;
if (Compare()(*j, *i))
return false;
}
return i == lhs.end() && j != rhs.end();
}
template<class Key, typename Compare = less<Key>>
bool
operator!=(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return !(lhs == rhs);
}
template<class Key, typename Compare = less<Key>>
bool
operator<=(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return !(rhs < lhs);
}
template<class Key, typename Compare = less<Key>>
bool
operator>(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return rhs < lhs;
}
template<class Key, typename Compare = less<Key>>
bool
operator>=(const set<Key, Compare>& lhs, const set<Key, Compare>& rhs)
{
return !(lhs < rhs);
}
template<class Key, typename Compare = less<Key>>
void
swap(set<Key, Compare>& lhs, set<Key, Compare>& rhs) noexcept;
} // namespace ctl
#endif // CTL_SET_H_

View file

@ -1,7 +1,7 @@
// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#ifndef COSMOPOLITAN_CTL_STRING_H_
#define COSMOPOLITAN_CTL_STRING_H_
#ifndef CTL_STRING_H_
#define CTL_STRING_H_
#include "string_view.h"
namespace ctl {
@ -397,4 +397,4 @@ operator"" s(const char* s, size_t n)
}
#pragma GCC diagnostic pop
#endif // COSMOPOLITAN_CTL_STRING_H_
#endif // CTL_STRING_H_