diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc index 93852c8045..0771eeb8a3 100644 --- a/src/resolver/function_validation_test.cc +++ b/src/resolver/function_validation_test.cc @@ -479,5 +479,47 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) { "i32 module-scope constant"); } +TEST_F(ResolverFunctionValidationTest, ReturnIsAtomicFreePlain_NonPlain) { + auto* ret_type = + ty.pointer(Source{{12, 34}}, ty.i32(), ast::StorageClass::kFunction); + Func("f", {}, ret_type, {}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + "12:34 error: function return type must be an atomic-free plain type"); +} + +TEST_F(ResolverFunctionValidationTest, ReturnIsAtomicFreePlain_AtomicInt) { + auto* ret_type = ty.atomic(Source{{12, 34}}, ty.i32()); + Func("f", {}, ret_type, {}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + "12:34 error: function return type must be an atomic-free plain type"); +} + +TEST_F(ResolverFunctionValidationTest, ReturnIsAtomicFreePlain_ArrayOfAtomic) { + auto* ret_type = ty.array(Source{{12, 34}}, ty.atomic(ty.i32())); + Func("f", {}, ret_type, {}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + "12:34 error: function return type must be an atomic-free plain type"); +} + +TEST_F(ResolverFunctionValidationTest, ReturnIsAtomicFreePlain_StructOfAtomic) { + Structure("S", {Member("m", ty.atomic(ty.i32()))}); + auto* ret_type = ty.type_name(Source{{12, 34}}, "S"); + Func("f", {}, ret_type, {}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + "12:34 error: function return type must be an atomic-free plain type"); +} + } // namespace } // namespace tint diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 25b1dbf3f6..5d01d6442b 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -176,19 +176,45 @@ bool Resolver::Resolve() { } // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section -bool Resolver::IsPlain(const sem::Type* type) { +bool Resolver::IsPlain(const sem::Type* type) const { return type->is_scalar() || type->Is() || type->Is() || type->Is() || type->Is() || type->Is(); } +// https://gpuweb.github.io/gpuweb/wgsl/#atomic-free +bool Resolver::IsAtomicFreePlain(const sem::Type* type) const { + if (type->Is()) { + return false; + } + + if (type->is_scalar() || type->Is() || type->Is()) { + return true; + } + + if (auto* arr = type->As()) { + return IsAtomicFreePlain(arr->ElemType()); + } + + if (auto* str = type->As()) { + for (auto* m : str->Members()) { + if (!IsAtomicFreePlain(m->Type())) { + return false; + } + } + return true; + } + + return false; +} + // https://gpuweb.github.io/gpuweb/wgsl.html#storable-types -bool Resolver::IsStorable(const sem::Type* type) { +bool Resolver::IsStorable(const sem::Type* type) const { return IsPlain(type) || type->Is() || type->Is(); } // https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types -bool Resolver::IsHostShareable(const sem::Type* type) { +bool Resolver::IsHostShareable(const sem::Type* type) const { if (type->IsAnyOf()) { return true; } @@ -1013,6 +1039,13 @@ bool Resolver::ValidateFunction(const ast::Function* func, } if (!info->return_type->Is()) { + if (!IsAtomicFreePlain(info->return_type)) { + diagnostics_.add_error( + "function return type must be an atomic-free plain type", + func->return_type()->source()); + return false; + } + if (func->body()) { if (!func->get_last_statement() || !func->get_last_statement()->Is()) { diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index a8930ae44b..cf303c099a 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -77,15 +77,19 @@ class Resolver { /// @param type the given type /// @returns true if the given type is a plain type - bool IsPlain(const sem::Type* type); + bool IsPlain(const sem::Type* type) const; + + /// @param type the given type + /// @returns true if the given type is a atomic-free plain type + bool IsAtomicFreePlain(const sem::Type* type) const; /// @param type the given type /// @returns true if the given type is storable - bool IsStorable(const sem::Type* type); + bool IsStorable(const sem::Type* type) const; /// @param type the given type /// @returns true if the given type is host-shareable - bool IsHostShareable(const sem::Type* type); + bool IsHostShareable(const sem::Type* type) const; private: /// Describes the context in which a variable is declared