From 188ed1793a20602bba04792f4e7b426f329bff4a Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 14 Dec 2022 11:09:47 +0000 Subject: [PATCH] tint/resolver: Correctly validate f16 usage There was limited validation for this. Validate all ways to use a f16. Change-Id: Ibdcde1f304e704790da3db379c79fcc0844cad67 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114140 Kokoro: Kokoro Reviewed-by: James Price Commit-Queue: Ben Clayton --- src/tint/BUILD.gn | 1 + src/tint/CMakeLists.txt | 1 + src/tint/resolver/f16_extension_test.cc | 144 ++++++++++++++++++++++ src/tint/resolver/resolver.cc | 76 ++++++------ src/tint/resolver/resolver.h | 5 +- src/tint/resolver/resolver_test.cc | 17 --- src/tint/resolver/type_validation_test.cc | 18 --- src/tint/resolver/validator.cc | 9 ++ src/tint/resolver/validator.h | 5 + 9 files changed, 203 insertions(+), 73 deletions(-) create mode 100644 src/tint/resolver/f16_extension_test.cc diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 45a965d8fe..5ac2ea9efe 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -1185,6 +1185,7 @@ if (tint_build_unittests) { "resolver/dependency_graph_test.cc", "resolver/entry_point_validation_test.cc", "resolver/evaluation_stage_test.cc", + "resolver/f16_extension_test.cc", "resolver/function_validation_test.cc", "resolver/host_shareable_validation_test.cc", "resolver/increment_decrement_validation_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 630f9988d2..d2f2a19872 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -897,6 +897,7 @@ if(TINT_BUILD_TESTS) resolver/dependency_graph_test.cc resolver/entry_point_validation_test.cc resolver/evaluation_stage_test.cc + resolver/f16_extension_test.cc resolver/function_validation_test.cc resolver/host_shareable_validation_test.cc resolver/increment_decrement_validation_test.cc diff --git a/src/tint/resolver/f16_extension_test.cc b/src/tint/resolver/f16_extension_test.cc new file mode 100644 index 0000000000..9161410bc7 --- /dev/null +++ b/src/tint/resolver/f16_extension_test.cc @@ -0,0 +1,144 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/resolver/resolver.h" +#include "src/tint/resolver/resolver_test_helper.h" + +#include "gmock/gmock.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::resolver { +namespace { + +using ResolverF16ExtensionTest = ResolverTest; + +TEST_F(ResolverF16ExtensionTest, TypeUsedWithExtension) { + // enable f16; + // var v : f16; + Enable(ast::Extension::kF16); + + GlobalVar("v", ty.f16(), ast::AddressSpace::kPrivate); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverF16ExtensionTest, TypeUsedWithoutExtension) { + // var v : f16; + GlobalVar("v", ty.f16(Source{{12, 34}}), ast::AddressSpace::kPrivate); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled"); +} + +TEST_F(ResolverF16ExtensionTest, Vec2TypeUsedWithExtension) { + // enable f16; + // var v : vec2; + Enable(ast::Extension::kF16); + + GlobalVar("v", ty.vec2(), ast::AddressSpace::kPrivate); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverF16ExtensionTest, Vec2TypeUsedWithoutExtension) { + // var v : vec2; + GlobalVar("v", ty.vec2(ty.f16(Source{{12, 34}})), ast::AddressSpace::kPrivate); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled"); +} + +TEST_F(ResolverF16ExtensionTest, Vec2TypeInitUsedWithExtension) { + // enable f16; + // var v = vec2(); + Enable(ast::Extension::kF16); + + GlobalVar("v", Construct(ty.vec2()), ast::AddressSpace::kPrivate); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverF16ExtensionTest, Vec2TypeInitUsedWithoutExtension) { + // var v = vec2(); + GlobalVar("v", Construct(ty.vec2(ty.f16(Source{{12, 34}}))), ast::AddressSpace::kPrivate); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled"); +} + +TEST_F(ResolverF16ExtensionTest, Vec2TypeConvUsedWithExtension) { + // enable f16; + // var v = vec2(vec2()); + Enable(ast::Extension::kF16); + + GlobalVar("v", Construct(ty.vec2(), Construct(ty.vec2())), + ast::AddressSpace::kPrivate); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverF16ExtensionTest, Vec2TypeConvUsedWithoutExtension) { + // var v = vec2(vec2()); + GlobalVar("v", Construct(ty.vec2(ty.f16(Source{{12, 34}})), Construct(ty.vec2())), + ast::AddressSpace::kPrivate); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled"); +} + +TEST_F(ResolverF16ExtensionTest, F16LiteralUsedWithExtension) { + // enable f16; + // var v = 16h; + Enable(ast::Extension::kF16); + + GlobalVar("v", Expr(16_h), ast::AddressSpace::kPrivate); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverF16ExtensionTest, F16LiteralUsedWithoutExtension) { + // var v = 16h; + GlobalVar("v", Expr(Source{{12, 34}}, 16_h), ast::AddressSpace::kPrivate); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled"); +} + +using ResolverF16ExtensionShortNameTest = ResolverTestWithParam; + +TEST_P(ResolverF16ExtensionShortNameTest, Vec2hTypeUsedWithExtension) { + // enable f16; + // var v : vec2h; + Enable(ast::Extension::kF16); + + GlobalVar("v", ty.type_name(Source{{12, 34}}, GetParam()), ast::AddressSpace::kPrivate); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_P(ResolverF16ExtensionShortNameTest, Vec2hTypeUsedWithoutExtension) { + // var v : vec2h; + GlobalVar("v", ty.type_name(Source{{12, 34}}, GetParam()), ast::AddressSpace::kPrivate); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled"); +} + +INSTANTIATE_TEST_SUITE_P(ResolverF16ExtensionShortNameTest, + ResolverF16ExtensionShortNameTest, + testing::Values("vec2h", "vec3h", "vec4h")); + +} // namespace +} // namespace tint::resolver diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 281fa94605..5087472b99 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -208,12 +208,7 @@ type::Type* Resolver::Type(const ast::Type* ty) { [&](const ast::I32*) { return builder_->create(); }, [&](const ast::U32*) { return builder_->create(); }, [&](const ast::F16* t) -> type::F16* { - // Validate if f16 type is allowed. - if (!enabled_extensions_.Contains(ast::Extension::kF16)) { - AddError("f16 used without 'f16' extension enabled", t->source); - return nullptr; - } - return builder_->create(); + return validator_.CheckF16Enabled(t->source) ? builder_->create() : nullptr; }, [&](const ast::F32*) { return builder_->create(); }, [&](const ast::Vector* t) -> type::Vector* { @@ -337,9 +332,7 @@ type::Type* Resolver::Type(const ast::Type* ty) { AddError("cannot use builtin '" + name + "' as type", ty->source); return nullptr; } - if (auto* t = BuiltinTypeAlias(tn->name)) { - return t; - } + return ShortName(tn->name, tn->source); } TINT_UNREACHABLE(Resolver, diagnostics_) << "Unhandled resolved type '" @@ -2048,7 +2041,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { return nullptr; } - auto stage = args_stage; // The evaluation stage of the call + auto stage = args_stage; // The evaluation stage of the call const constant::Constant* value = nullptr; // The constant value for the call if (stage == sem::EvaluationStage::kConstant) { if (auto r = const_eval_.ArrayOrStructInit(ty, args)) { @@ -2083,7 +2076,11 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { }, [&](const type::I32*) { return ct_init_or_conv(InitConvIntrinsic::kI32, nullptr); }, [&](const type::U32*) { return ct_init_or_conv(InitConvIntrinsic::kU32, nullptr); }, - [&](const type::F16*) { return ct_init_or_conv(InitConvIntrinsic::kF16, nullptr); }, + [&](const type::F16*) { + return validator_.CheckF16Enabled(expr->source) + ? ct_init_or_conv(InitConvIntrinsic::kF16, nullptr) + : nullptr; + }, [&](const type::F32*) { return ct_init_or_conv(InitConvIntrinsic::kF32, nullptr); }, [&](const type::Bool*) { return ct_init_or_conv(InitConvIntrinsic::kBool, nullptr); }, [&](const type::Array* arr) -> sem::Call* { @@ -2285,17 +2282,13 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { }, [&](Default) -> sem::Call* { auto name = builder_->Symbols().NameFor(ident->symbol); - if (auto* alias = BuiltinTypeAlias(ident->symbol)) { - return ty_init_or_conv(alias); - } if (auto builtin_type = sem::ParseBuiltinType(name); builtin_type != sem::BuiltinType::kNone) { return BuiltinCall(expr, builtin_type, args); } - TINT_ICE(Resolver, diagnostics_) - << expr->source << " unhandled CallExpression target:\n" - << "resolved: " << (resolved ? resolved->TypeInfo().name : "") << "\n" - << "name: " << builder_->Symbols().NameFor(ident->symbol); + if (auto* alias = ShortName(ident->symbol, ident->source)) { + return ty_init_or_conv(alias); + } return nullptr; }); } @@ -2391,7 +2384,7 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, return call; } -type::Type* Resolver::BuiltinTypeAlias(Symbol sym) const { +type::Type* Resolver::ShortName(Symbol sym, const Source& source) const { auto name = builder_->Symbols().NameFor(sym); auto& b = *builder_; switch (type::ParseShortName(name)) { @@ -2402,11 +2395,17 @@ type::Type* Resolver::BuiltinTypeAlias(Symbol sym) const { case type::ShortName::kVec4F: return b.create(b.create(), 4u); case type::ShortName::kVec2H: - return b.create(b.create(), 2u); + return validator_.CheckF16Enabled(source) + ? b.create(b.create(), 2u) + : nullptr; case type::ShortName::kVec3H: - return b.create(b.create(), 3u); + return validator_.CheckF16Enabled(source) + ? b.create(b.create(), 3u) + : nullptr; case type::ShortName::kVec4H: - return b.create(b.create(), 4u); + return validator_.CheckF16Enabled(source) + ? b.create(b.create(), 4u) + : nullptr; case type::ShortName::kVec2I: return b.create(b.create(), 2u); case type::ShortName::kVec3I: @@ -2422,6 +2421,8 @@ type::Type* Resolver::BuiltinTypeAlias(Symbol sym) const { case type::ShortName::kUndefined: break; } + + TINT_ICE(Resolver, diagnostics_) << source << " unhandled type short name '" << name << "'"; return nullptr; } @@ -2534,6 +2535,8 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) { case ast::IntLiteralExpression::Suffix::kU: return builder_->create(); } + TINT_UNREACHABLE(Resolver, builder_->Diagnostics()) + << "Unhandled integer literal suffix: " << i->suffix; return nullptr; }, [&](const ast::FloatLiteralExpression* f) -> type::Type* { @@ -2543,21 +2546,22 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) { case ast::FloatLiteralExpression::Suffix::kF: return builder_->create(); case ast::FloatLiteralExpression::Suffix::kH: - return builder_->create(); + return validator_.CheckF16Enabled(literal->source) + ? builder_->create() + : nullptr; } + TINT_UNREACHABLE(Resolver, builder_->Diagnostics()) + << "Unhandled float literal suffix: " << f->suffix; return nullptr; }, [&](const ast::BoolLiteralExpression*) { return builder_->create(); }, - [&](Default) { return nullptr; }); + [&](Default) { + TINT_UNREACHABLE(Resolver, builder_->Diagnostics()) + << "Unhandled literal type: " << literal->TypeInfo().name; + return nullptr; + }); if (ty == nullptr) { - TINT_UNREACHABLE(Resolver, builder_->Diagnostics()) - << "Unhandled literal type: " << literal->TypeInfo().name; - return nullptr; - } - - if ((ty->Is()) && (!enabled_extensions_.Contains(tint::ast::Extension::kF16))) { - AddError("f16 literal used without 'f16' extension enabled", literal->source); return nullptr; } @@ -3161,11 +3165,11 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { sem_members.Reserve(str->members.Length()); // Calculate the effective size and alignment of each field, and the overall size of the - // structure. For size, use the size attribute if provided, otherwise use the default size for - // the type. For alignment, use the alignment attribute if provided, otherwise use the default - // alignment for the member type. Diagnostic errors are raised if a basic rule is violated. - // Validation of storage-class rules requires analyzing the actual variable usage of the - // structure, and so is performed as part of the variable validation. + // structure. For size, use the size attribute if provided, otherwise use the default size + // for the type. For alignment, use the alignment attribute if provided, otherwise use the + // default alignment for the member type. Diagnostic errors are raised if a basic rule is + // violated. Validation of storage-class rules requires analyzing the actual variable usage + // of the structure, and so is performed as part of the variable validation. uint64_t struct_size = 0; uint64_t struct_align = 1; utils::Hashmap member_map; diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 3e160b9286..186c653443 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -420,8 +420,9 @@ class Resolver { /// @returns true if the symbol is the name of a builtin function. bool IsBuiltin(Symbol) const; - /// @returns the builtin type alias for the given symbol - type::Type* BuiltinTypeAlias(Symbol) const; + /// @returns the type short-name alias for the symbol @p symbol at @p source + /// @note: Will raise an ICE if @p symbol is not a short-name type. + type::Type* ShortName(Symbol symbol, const Source& source) const; // ArrayInitializerSig represents a unique array initializer signature. // It is a tuple of the array type, number of arguments provided and earliest evaluation stage. diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc index a59887515a..3c8e3a4044 100644 --- a/src/tint/resolver/resolver_test.cc +++ b/src/tint/resolver/resolver_test.cc @@ -2361,23 +2361,6 @@ TEST_F(ResolverTest, MaxExpressionDepth_Fail) { std::to_string(kMaxExpressionDepth))); } -TEST_F(ResolverTest, Literal_F16WithoutExtension) { - // fn test() {_ = 1.23h;} - WrapInFunction(Ignore(Expr(f16(1.23f)))); - - EXPECT_FALSE(r()->Resolve()); - EXPECT_THAT(r()->error(), HasSubstr("error: f16 literal used without 'f16' extension enabled")); -} - -TEST_F(ResolverTest, Literal_F16WithExtension) { - // enable f16; - // fn test() {_ = 1.23h;} - Enable(ast::Extension::kF16); - WrapInFunction(Ignore(Expr(f16(1.23f)))); - - EXPECT_TRUE(r()->Resolve()); -} - // Windows debug builds have significantly smaller stack than other builds, and these tests will // stack overflow. #if !defined(NDEBUG) diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc index 9c37b68c01..ed6fa82810 100644 --- a/src/tint/resolver/type_validation_test.cc +++ b/src/tint/resolver/type_validation_test.cc @@ -893,24 +893,6 @@ TEST_F(ResolverTypeValidationTest, BuiltinAsType) { EXPECT_EQ(r()->error(), "error: cannot use builtin 'max' as type"); } -TEST_F(ResolverTypeValidationTest, F16TypeUsedWithExtension) { - // enable f16; - // var v : f16; - Enable(ast::Extension::kF16); - - GlobalVar("v", ty.f16(), ast::AddressSpace::kPrivate); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); -} - -TEST_F(ResolverTypeValidationTest, F16TypeUsedWithoutExtension) { - // var v : f16; - GlobalVar("v", ty.f16(), ast::AddressSpace::kPrivate); - - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "error: f16 used without 'f16' extension enabled"); -} - namespace GetCanonicalTests { struct Params { builder::ast_type_func_ptr create_ast_type; diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index a3bc938ab4..24b2e7b529 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -1625,6 +1625,15 @@ bool Validator::RequiredExtensionForBuiltinFunction(const sem::Call* call) const return true; } +bool Validator::CheckF16Enabled(const Source& source) const { + // Validate if f16 type is allowed. + if (!enabled_extensions_.Contains(ast::Extension::kF16)) { + AddError("f16 type used without 'f16' extension enabled", source); + return false; + } + return true; +} + bool Validator::FunctionCall(const sem::Call* call, sem::Statement* current_statement) const { auto* decl = call->Declaration(); auto* target = call->Target()->As(); diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h index 0d6d3432d0..d7587ed852 100644 --- a/src/tint/resolver/validator.h +++ b/src/tint/resolver/validator.h @@ -444,6 +444,11 @@ class Validator { /// @returns true on success, false otherwise bool RequiredExtensionForBuiltinFunction(const sem::Call* call) const; + /// Validates that 'f16' extension is enabled for f16 usage at @p source + /// @param source the source of the f16 usage + /// @returns true on success, false otherwise + bool CheckF16Enabled(const Source& source) const; + /// Validates there are no duplicate attributes /// @param attributes the list of attributes to validate /// @returns true on success, false otherwise.