diff --git a/ext/tm/include/tm/optional.hpp b/ext/tm/include/tm/optional.hpp index 9ffdc4803a..f239d6d56b 100644 --- a/ext/tm/include/tm/optional.hpp +++ b/ext/tm/include/tm/optional.hpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include namespace TM { @@ -20,8 +20,9 @@ class Optional { * ``` */ Optional(const T &value) - : m_present { true } - , m_value { value } { } + : m_present { true } { + new (m_value) T { value }; + } /** * Constructs a new Optional with a value. @@ -33,8 +34,9 @@ class Optional { * ``` */ Optional(T &&value) - : m_present { true } - , m_value { std::move(value) } { } + : m_present { true } { + new (m_value) T { std::move(value) }; + } /** * Constructs a new Optional without a value. @@ -61,7 +63,7 @@ class Optional { Optional(const Optional &other) : m_present { other.m_present } { if (m_present) - m_value = other.m_value; + *reinterpret_cast(m_value) = *reinterpret_cast(other.m_value); } /** @@ -75,11 +77,16 @@ class Optional { * assert_eq(obj, opt2.value()); * ``` */ - Optional(Optional &&other) - : m_present { other.m_present } { + Optional(Optional &&other) { + m_present = other.m_present; if (m_present) { - m_value = std::move(other.m_value); - other.m_present = false; + if constexpr (std::is_trivially_copyable_v) { + memcpy(m_value, other.m_value, sizeof(m_value)); + other.clear_after_trivial_copy(); + } else { + *reinterpret_cast(m_value) = std::move(*reinterpret_cast(other.m_value)); + other.clear(); + } } } @@ -101,9 +108,13 @@ class Optional { * ``` */ Optional &operator=(const Optional &other) { + if (this == &other) + return *this; + if (m_present) + clear(); m_present = other.m_present; if (m_present) - m_value = other.m_value; + *reinterpret_cast(m_value) = *reinterpret_cast(other.m_value); return *this; } @@ -121,10 +132,19 @@ class Optional { * ``` */ Optional &operator=(Optional &&other) { + if (this == &other) + return *this; + if (m_present) + clear(); m_present = other.m_present; if (m_present) { - m_value = std::move(other.m_value); - other.m_present = false; + if constexpr (std::is_trivially_copyable_v) { + memcpy(m_value, other.m_value, sizeof(m_value)); + other.clear_after_trivial_copy(); + } else { + *reinterpret_cast(m_value) = std::move(*reinterpret_cast(other.m_value)); + other.clear(); + } } return *this; } @@ -141,8 +161,10 @@ class Optional { * ``` */ Optional &operator=(T &&value) { + if (m_present) + clear(); m_present = true; - m_value = std::move(value); + *reinterpret_cast(m_value) = std::move(value); return *this; } @@ -164,7 +186,7 @@ class Optional { */ T &value() { assert(m_present); - return m_value; + return *reinterpret_cast(m_value); } /** @@ -185,7 +207,7 @@ class Optional { */ T const &value() const { assert(m_present); - return m_value; + return *reinterpret_cast(m_value); } /** @@ -233,7 +255,7 @@ class Optional { */ T value_or(std::function fallback) const { if (present()) - return value(); + return *reinterpret_cast(m_value); else return fallback(); } @@ -259,7 +281,7 @@ class Optional { */ T &operator*() { assert(m_present); - return m_value; + return *reinterpret_cast(m_value); } /** @@ -280,7 +302,7 @@ class Optional { */ T const &operator*() const { assert(m_present); - return m_value; + return *reinterpret_cast(m_value); } /** @@ -301,7 +323,7 @@ class Optional { */ T *operator->() { assert(m_present); - return &m_value; + return reinterpret_cast(m_value); } /** @@ -322,7 +344,7 @@ class Optional { */ T const *operator->() const { assert(m_present); - return &m_value; + return reinterpret_cast(m_value); } /** @@ -335,7 +357,13 @@ class Optional { * assert_not(opt); * ``` */ - void clear() { m_present = false; } + void clear() { + if (m_present) { + reinterpret_cast(m_value)->~T(); + memset(m_value, 0, sizeof(m_value)); + } + m_present = false; + } /** * Returns true if the Optional contains a value. @@ -362,9 +390,14 @@ class Optional { bool present() const { return m_present; } private: + void clear_after_trivial_copy() { + m_present = false; + memset(m_value, 0, sizeof(m_value)); + } + bool m_present; - T m_value {}; + alignas(T) unsigned char m_value[sizeof(T)] {}; }; }