diff --git a/src/common/Result.h b/src/common/Result.h index cd63a7a848..3e33052fb3 100644 --- a/src/common/Result.h +++ b/src/common/Result.h @@ -20,6 +20,7 @@ #include #include +#include #include // Result is the following sum type (Haskell notation): @@ -117,8 +118,11 @@ class DAWN_NO_DISCARD Result { Result(T* success); Result(E* error); - Result(Result&& other); - Result& operator=(Result&& other); + // Support returning a Result from a Result + template + Result(Result&& other); + template + Result& operator=(Result&& other); ~Result(); @@ -129,6 +133,9 @@ class DAWN_NO_DISCARD Result { E* AcquireError(); private: + template + friend class Result; + intptr_t mPayload = detail::kEmptyPayload; }; @@ -265,13 +272,17 @@ Result::Result(E* error) : mPayload(detail::MakePayload(error, detail::E } template -Result::Result(Result&& other) : mPayload(other.mPayload) { +template +Result::Result(Result&& other) : mPayload(other.mPayload) { other.mPayload = detail::kEmptyPayload; + static_assert(std::is_same::value || std::is_base_of::value, ""); } template -Result& Result::operator=(Result&& other) { +template +Result& Result::operator=(Result&& other) { ASSERT(mPayload == detail::kEmptyPayload); + static_assert(std::is_same::value || std::is_base_of::value, ""); mPayload = other.mPayload; other.mPayload = detail::kEmptyPayload; return *this; diff --git a/src/tests/unittests/ResultTests.cpp b/src/tests/unittests/ResultTests.cpp index e9e5e366c9..d991462b13 100644 --- a/src/tests/unittests/ResultTests.cpp +++ b/src/tests/unittests/ResultTests.cpp @@ -139,6 +139,31 @@ TEST(ResultBothPointer, ReturningSuccess) { TestSuccess(&result, &dummySuccess); } +// Tests converting from a Result +TEST(ResultBothPointer, ConversionFromChildClass) { + struct T { + int a; + }; + struct TChild : T {}; + + TChild child; + T* childAsT = &child; + { + Result result(&child); + TestSuccess(&result, childAsT); + } + { + Result resultChild(&child); + Result result(std::move(resultChild)); + TestSuccess(&result, childAsT); + } + { + Result resultChild(&child); + Result result = std::move(resultChild); + TestSuccess(&result, childAsT); + } +} + // Result // Test constructing an error Result