Basic CTL shared_ptr implementation (#1267)

This commit is contained in:
Steven Dee (Jōshin) 2024-08-31 14:00:56 -04:00 committed by GitHub
parent a6fe62cf13
commit e1528a71e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 708 additions and 0 deletions

View file

@ -17,6 +17,9 @@ struct conditional<false, T, F>
typedef F type;
};
template<bool B, typename T, typename F>
using conditional_t = typename conditional<B, T, F>::type;
} // namespace ctl
#endif // CTL_CONDITIONAL_H_

View file

@ -19,6 +19,9 @@ template<typename _Tp>
struct is_void : public is_void_<typename ctl::remove_cv<_Tp>::type>::type
{};
template<typename T>
inline constexpr bool is_void_v = is_void<T>::value;
} // namespace ctl
#endif // CTL_IS_VOID_H_

454
ctl/shared_ptr.h Normal file
View file

@ -0,0 +1,454 @@
// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#ifndef CTL_SHARED_PTR_H_
#define CTL_SHARED_PTR_H_
#include "exception.h"
#include "is_convertible.h"
#include "remove_extent.h"
#include "unique_ptr.h"
// XXX currently needed to use placement-new syntax (move to cxx.inc?)
void*
operator new(size_t, void*) noexcept;
namespace ctl {
class bad_weak_ptr : public exception
{
public:
const char* what() const noexcept override
{
return "ctl::bad_weak_ptr";
}
};
namespace __ {
template<typename T>
struct ptr_ref
{
using type = T&;
};
template<>
struct ptr_ref<void>
{
using type = void;
};
static inline __attribute__((always_inline)) void
incref(size_t* r) noexcept
{
#ifdef NDEBUG
__atomic_fetch_add(r, 1, __ATOMIC_RELAXED);
#else
size_t refs = __atomic_fetch_add(r, 1, __ATOMIC_RELAXED);
if (refs > ((size_t)-1) >> 1)
__builtin_trap();
#endif
}
static inline __attribute__((always_inline)) bool
decref(size_t* r) noexcept
{
if (!__atomic_fetch_sub(r, 1, __ATOMIC_RELEASE)) {
__atomic_thread_fence(__ATOMIC_ACQUIRE);
return true;
}
return false;
}
class shared_ref
{
public:
constexpr shared_ref() noexcept = default;
shared_ref(const shared_ref&) = delete;
shared_ref& operator=(const shared_ref&) = delete;
virtual ~shared_ref() = default;
void keep_shared() noexcept
{
incref(&shared);
}
void drop_shared() noexcept
{
if (decref(&shared)) {
dispose();
drop_weak();
}
}
void keep_weak() noexcept
{
incref(&weak);
}
void drop_weak() noexcept
{
if (decref(&weak)) {
delete this;
}
}
size_t use_count() const noexcept
{
return shared + 1;
}
size_t weak_count() const noexcept
{
return weak;
}
private:
virtual void dispose() noexcept = 0;
size_t shared = 0;
size_t weak = 0;
};
template<typename T, typename D>
class shared_pointer : public shared_ref
{
public:
static shared_pointer* make(T* const p, D d)
{
return make(unique_ptr<T, D>(p, move(d)));
}
static shared_pointer* make(unique_ptr<T, D> p)
{
return new shared_pointer(p.release(), move(p.get_deleter()));
}
private:
shared_pointer(T* const p, D d) noexcept : p(p), d(move(d))
{
}
void dispose() noexcept override
{
move(d)(p);
}
T* const p;
[[no_unique_address]] D d;
};
template<typename T>
class shared_emplace : public shared_ref
{
public:
union
{
T t;
};
template<typename... Args>
void construct(Args&&... args)
{
::new (&t) T(forward<Args>(args)...);
}
static unique_ptr<shared_emplace> make()
{
return unique_ptr(new shared_emplace());
}
private:
explicit constexpr shared_emplace() noexcept = default;
void dispose() noexcept override
{
t.~T();
}
};
template<typename T, typename U>
concept shared_ptr_compatible = is_convertible_v<U*, T*>;
} // namespace __
template<typename T>
class weak_ptr;
template<typename T>
class shared_ptr
{
public:
using element_type = remove_extent_t<T>;
using weak_type = weak_ptr<T>;
constexpr shared_ptr() noexcept = default;
constexpr shared_ptr(nullptr_t) noexcept
{
}
template<typename U>
requires __::shared_ptr_compatible<T, U>
explicit shared_ptr(U* const p) : shared_ptr(p, default_delete<U>())
{
}
template<typename U, typename D>
requires __::shared_ptr_compatible<T, U>
shared_ptr(U* const p, D d)
: p(p), rc(__::shared_pointer<U, D>::make(p, move(d)))
{
}
template<typename U>
shared_ptr(const shared_ptr<U>& r, element_type* p) noexcept
: p(p), rc(r.rc)
{
if (rc)
rc->keep_shared();
}
template<typename U>
shared_ptr(shared_ptr<U>&& r, element_type* p) noexcept : p(p), rc(r.rc)
{
r.p = nullptr;
r.rc = nullptr;
}
template<typename U>
requires __::shared_ptr_compatible<T, U>
shared_ptr(const shared_ptr<U>& r) noexcept : p(r.p), rc(r.rc)
{
if (rc)
rc->keep_shared();
}
template<typename U>
requires __::shared_ptr_compatible<T, U>
shared_ptr(shared_ptr<U>&& r) noexcept : p(r.p), rc(r.rc)
{
r.p = nullptr;
r.rc = nullptr;
}
shared_ptr(const shared_ptr& r) noexcept : p(r.p), rc(r.rc)
{
if (rc)
rc->keep_shared();
}
shared_ptr(shared_ptr&& r) noexcept : p(r.p), rc(r.rc)
{
r.p = nullptr;
r.rc = nullptr;
}
template<typename U>
requires __::shared_ptr_compatible<T, U>
explicit shared_ptr(const weak_ptr<U>& r) : p(r.p), rc(r.rc)
{
if (r.expired()) {
throw bad_weak_ptr();
}
rc->keep_shared();
}
template<typename U, typename D>
requires __::shared_ptr_compatible<T, U>
shared_ptr(unique_ptr<U, D>&& r)
: p(r.p), rc(__::shared_pointer<U, D>::make(move(r)))
{
}
~shared_ptr()
{
if (rc)
rc->drop_shared();
}
shared_ptr& operator=(shared_ptr r) noexcept
{
swap(r);
return *this;
}
template<typename U>
requires __::shared_ptr_compatible<T, U>
shared_ptr& operator=(shared_ptr<U> r) noexcept
{
shared_ptr<T>(move(r)).swap(*this);
return *this;
}
void reset() noexcept
{
shared_ptr().swap(*this);
}
template<typename U>
requires __::shared_ptr_compatible<T, U>
void reset(U* const p2)
{
shared_ptr<T>(p2).swap(*this);
}
template<typename U, typename D>
requires __::shared_ptr_compatible<T, U>
void reset(U* const p2, D d)
{
shared_ptr<T>(p2, d).swap(*this);
}
void swap(shared_ptr& r) noexcept
{
using ctl::swap;
swap(p, r.p);
swap(rc, r.rc);
}
element_type* get() const noexcept
{
return p;
}
typename __::ptr_ref<T>::type operator*() const noexcept
{
if (!p)
__builtin_trap();
return *p;
}
T* operator->() const noexcept
{
if (!p)
__builtin_trap();
return p;
}
long use_count() const noexcept
{
return rc ? rc->use_count() : 0;
}
explicit operator bool() const noexcept
{
return p;
}
template<typename U>
bool owner_before(const shared_ptr<U>& r) const noexcept
{
return p < r.p;
}
template<typename U>
bool owner_before(const weak_ptr<U>& r) const noexcept
{
return !r.owner_before(*this);
}
private:
template<typename U>
friend class weak_ptr;
template<typename U>
friend class shared_ptr;
template<typename U, typename... Args>
friend shared_ptr<U> make_shared(Args&&... args);
element_type* p = nullptr;
__::shared_ref* rc = nullptr;
};
template<typename T>
class weak_ptr
{
public:
using element_type = remove_extent_t<T>;
constexpr weak_ptr() noexcept = default;
template<typename U>
requires __::shared_ptr_compatible<T, U>
weak_ptr(const shared_ptr<U>& r) noexcept : p(r.p), rc(r.rc)
{
if (rc)
rc->keep_weak();
}
~weak_ptr()
{
if (rc)
rc->drop_weak();
}
long use_count() const noexcept
{
return rc ? rc->use_count() : 0;
}
bool expired() const noexcept
{
return !use_count();
}
void reset() noexcept
{
weak_ptr().swap(*this);
}
void swap(weak_ptr& r) noexcept
{
using ctl::swap;
swap(p, r.p);
swap(rc, r.rc);
}
shared_ptr<T> lock() const noexcept
{
if (expired())
return nullptr;
shared_ptr<T> r;
r.p = p;
r.rc = rc;
if (rc)
rc->keep_shared();
return r;
}
template<typename U>
bool owner_before(const weak_ptr<U>& r) const noexcept
{
return p < r.p;
}
template<typename U>
bool owner_before(const shared_ptr<U>& r) const noexcept
{
return p < r.p;
}
private:
template<typename U>
friend class shared_ptr;
element_type* p = nullptr;
__::shared_ref* rc = nullptr;
};
template<typename T, typename... Args>
shared_ptr<T>
make_shared(Args&&... args)
{
auto rc = __::shared_emplace<T>::make();
rc->construct(forward<Args>(args)...);
shared_ptr<T> r;
r.p = &rc->t;
r.rc = rc.release();
return r;
}
} // namespace ctl
#endif // CTL_SHARED_PTR_H_

248
test/ctl/shared_ptr_test.cc Normal file
View file

@ -0,0 +1,248 @@
// -*- mode:c++; indent-tabs-mode:nil; c-basic-offset:4; coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Justine Alexandra Roberts Tunney
//
// Permission to use, copy, modify, and/or distribute this software for
// any purpose with or without fee is hereby granted, provided that the
// above copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL
// WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE
// AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL
// DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR
// PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER
// TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
// PERFORMANCE OF THIS SOFTWARE.
#include "ctl/shared_ptr.h"
#include "ctl/vector.h"
#include "libc/mem/leaks.h"
// #include <memory>
// #include <vector>
// #define ctl std
using ctl::bad_weak_ptr;
using ctl::make_shared;
using ctl::move;
using ctl::shared_ptr;
using ctl::unique_ptr;
using ctl::vector;
using ctl::weak_ptr;
#undef ctl
static int g = 0;
struct ConstructG
{
ConstructG()
{
++g;
}
};
struct DestructG
{
~DestructG()
{
++g;
}
};
struct CallG
{
void operator()(auto*) const noexcept
{
++g;
}
};
struct Base
{};
struct Derived : Base
{};
int
main()
{
int a, b;
{
// Shouldn't cause memory leaks.
shared_ptr<int> x(new int(5));
}
{
// Objects get destroyed when the last shared_ptr is reset.
shared_ptr<int> x(&a, CallG());
shared_ptr<int> y(x);
x.reset();
if (g)
return 1;
y.reset();
if (g != 1)
return 2;
}
{
g = 0;
// Weak pointers don't prevent object destruction.
shared_ptr<int> x(&a, CallG());
weak_ptr<int> y(x);
x.reset();
if (g != 1)
return 3;
}
{
g = 0;
// Weak pointers can be promoted to shared pointers.
shared_ptr<int> x(&a, CallG());
weak_ptr<int> y(x);
auto z = y.lock();
x.reset();
if (g)
return 4;
y.reset();
if (g)
return 5;
z.reset();
if (g != 1)
return 6;
}
{
// Shared null pointers are falsey.
shared_ptr<int> x;
if (x)
return 7;
x.reset(new int);
if (!x)
return 8;
}
{
// You can cast a shared pointer validly.
shared_ptr<Derived> x(new Derived);
shared_ptr<Base> y(x);
// But not invalidly:
// shared_ptr<Base> x(new Derived);
// shared_ptr<Derived> y(x);
}
{
// You can cast a shared pointer to void to retain a reference.
shared_ptr<int> x(new int);
shared_ptr<void> y(x);
}
{
// You can also create a shared pointer to void in the first place.
shared_ptr<void> x(new int);
}
{
// You can take a shared pointer to a subobject, and it will free the
// base object.
shared_ptr<vector<int>> x(new vector<int>);
x->push_back(5);
shared_ptr<int> y(x, &x->at(0));
x.reset();
if (*y != 5)
return 9;
}
{
g = 0;
// You can create a shared_ptr from a unique_ptr.
unique_ptr<int, CallG> x(&a, CallG());
shared_ptr<int> y(move(x));
if (x)
return 10;
y.reset();
if (g != 1)
return 11;
}
{
g = 0;
// You can reassign shared_ptrs.
shared_ptr<int> x(&a, CallG());
shared_ptr<int> y;
y = x;
x.reset();
if (g)
return 12;
y.reset();
if (g != 1)
return 13;
}
{
// owner_before works across shared and weak pointers.
shared_ptr<int> x(&a, CallG());
shared_ptr<int> y(&b, CallG());
if (!x.owner_before(y))
return 14;
if (!x.owner_before(weak_ptr<int>(y)))
return 15;
}
{
// Use counts work like you'd expect
shared_ptr<int> x(new int);
if (x.use_count() != 1)
return 16;
shared_ptr<int> y(x);
if (x.use_count() != 2 || y.use_count() != 2)
return 17;
x.reset();
if (x.use_count() != 0 || y.use_count() != 1)
return 18;
}
{
// There is a make_shared that will allocate an object for you safely.
auto x = make_shared<int>(5);
if (!x)
return 19;
if (*x != 5)
return 20;
}
{
// Expired weak pointers lock to nullptr, and throw when promoted to
// shared pointer by constructor.
auto x = make_shared<int>();
weak_ptr<int> y(x);
x.reset();
if (y.lock())
return 21;
int caught = 0;
try {
shared_ptr<int> z(y);
} catch (bad_weak_ptr& e) {
caught = 1;
}
if (!caught)
return 22;
}
{
// nullptr is always expired.
shared_ptr<int> x(nullptr);
weak_ptr<int> y(x);
if (!y.expired())
return 23;
}
// TODO(mrdomino): exercise threads / races. The reference count should be
// atomically maintained.
CheckForMemoryLeaks();
return 0;
}