tint/resolver: Correctly validate f16 usage

There was limited validation for this. Validate all ways to use a f16.

Change-Id: Ibdcde1f304e704790da3db379c79fcc0844cad67
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114140
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-12-14 11:09:47 +00:00 committed by Dawn LUCI CQ
parent 520d6002f3
commit 188ed1793a
9 changed files with 203 additions and 73 deletions

View File

@ -1185,6 +1185,7 @@ if (tint_build_unittests) {
"resolver/dependency_graph_test.cc",
"resolver/entry_point_validation_test.cc",
"resolver/evaluation_stage_test.cc",
"resolver/f16_extension_test.cc",
"resolver/function_validation_test.cc",
"resolver/host_shareable_validation_test.cc",
"resolver/increment_decrement_validation_test.cc",

View File

@ -897,6 +897,7 @@ if(TINT_BUILD_TESTS)
resolver/dependency_graph_test.cc
resolver/entry_point_validation_test.cc
resolver/evaluation_stage_test.cc
resolver/f16_extension_test.cc
resolver/function_validation_test.cc
resolver/host_shareable_validation_test.cc
resolver/increment_decrement_validation_test.cc

View File

@ -0,0 +1,144 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/resolver/resolver.h"
#include "src/tint/resolver/resolver_test_helper.h"
#include "gmock/gmock.h"
using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver {
namespace {
using ResolverF16ExtensionTest = ResolverTest;
TEST_F(ResolverF16ExtensionTest, TypeUsedWithExtension) {
// enable f16;
// var<private> v : f16;
Enable(ast::Extension::kF16);
GlobalVar("v", ty.f16(), ast::AddressSpace::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverF16ExtensionTest, TypeUsedWithoutExtension) {
// var<private> v : f16;
GlobalVar("v", ty.f16(Source{{12, 34}}), ast::AddressSpace::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
}
TEST_F(ResolverF16ExtensionTest, Vec2TypeUsedWithExtension) {
// enable f16;
// var<private> v : vec2<f16>;
Enable(ast::Extension::kF16);
GlobalVar("v", ty.vec2<f16>(), ast::AddressSpace::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverF16ExtensionTest, Vec2TypeUsedWithoutExtension) {
// var<private> v : vec2<f16>;
GlobalVar("v", ty.vec2(ty.f16(Source{{12, 34}})), ast::AddressSpace::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
}
TEST_F(ResolverF16ExtensionTest, Vec2TypeInitUsedWithExtension) {
// enable f16;
// var<private> v = vec2<f16>();
Enable(ast::Extension::kF16);
GlobalVar("v", Construct(ty.vec2<f16>()), ast::AddressSpace::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverF16ExtensionTest, Vec2TypeInitUsedWithoutExtension) {
// var<private> v = vec2<f16>();
GlobalVar("v", Construct(ty.vec2(ty.f16(Source{{12, 34}}))), ast::AddressSpace::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
}
TEST_F(ResolverF16ExtensionTest, Vec2TypeConvUsedWithExtension) {
// enable f16;
// var<private> v = vec2<f16>(vec2<f32>());
Enable(ast::Extension::kF16);
GlobalVar("v", Construct(ty.vec2<f16>(), Construct(ty.vec2<f32>())),
ast::AddressSpace::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverF16ExtensionTest, Vec2TypeConvUsedWithoutExtension) {
// var<private> v = vec2<f16>(vec2<f32>());
GlobalVar("v", Construct(ty.vec2(ty.f16(Source{{12, 34}})), Construct(ty.vec2<f32>())),
ast::AddressSpace::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
}
TEST_F(ResolverF16ExtensionTest, F16LiteralUsedWithExtension) {
// enable f16;
// var<private> v = 16h;
Enable(ast::Extension::kF16);
GlobalVar("v", Expr(16_h), ast::AddressSpace::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverF16ExtensionTest, F16LiteralUsedWithoutExtension) {
// var<private> v = 16h;
GlobalVar("v", Expr(Source{{12, 34}}, 16_h), ast::AddressSpace::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
}
using ResolverF16ExtensionShortNameTest = ResolverTestWithParam<const char*>;
TEST_P(ResolverF16ExtensionShortNameTest, Vec2hTypeUsedWithExtension) {
// enable f16;
// var<private> v : vec2h;
Enable(ast::Extension::kF16);
GlobalVar("v", ty.type_name(Source{{12, 34}}, GetParam()), ast::AddressSpace::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(ResolverF16ExtensionShortNameTest, Vec2hTypeUsedWithoutExtension) {
// var<private> v : vec2h;
GlobalVar("v", ty.type_name(Source{{12, 34}}, GetParam()), ast::AddressSpace::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
}
INSTANTIATE_TEST_SUITE_P(ResolverF16ExtensionShortNameTest,
ResolverF16ExtensionShortNameTest,
testing::Values("vec2h", "vec3h", "vec4h"));
} // namespace
} // namespace tint::resolver

View File

@ -208,12 +208,7 @@ type::Type* Resolver::Type(const ast::Type* ty) {
[&](const ast::I32*) { return builder_->create<type::I32>(); },
[&](const ast::U32*) { return builder_->create<type::U32>(); },
[&](const ast::F16* t) -> type::F16* {
// Validate if f16 type is allowed.
if (!enabled_extensions_.Contains(ast::Extension::kF16)) {
AddError("f16 used without 'f16' extension enabled", t->source);
return nullptr;
}
return builder_->create<type::F16>();
return validator_.CheckF16Enabled(t->source) ? builder_->create<type::F16>() : nullptr;
},
[&](const ast::F32*) { return builder_->create<type::F32>(); },
[&](const ast::Vector* t) -> type::Vector* {
@ -337,9 +332,7 @@ type::Type* Resolver::Type(const ast::Type* ty) {
AddError("cannot use builtin '" + name + "' as type", ty->source);
return nullptr;
}
if (auto* t = BuiltinTypeAlias(tn->name)) {
return t;
}
return ShortName(tn->name, tn->source);
}
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "Unhandled resolved type '"
@ -2048,7 +2041,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
return nullptr;
}
auto stage = args_stage; // The evaluation stage of the call
auto stage = args_stage; // The evaluation stage of the call
const constant::Constant* value = nullptr; // The constant value for the call
if (stage == sem::EvaluationStage::kConstant) {
if (auto r = const_eval_.ArrayOrStructInit(ty, args)) {
@ -2083,7 +2076,11 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
},
[&](const type::I32*) { return ct_init_or_conv(InitConvIntrinsic::kI32, nullptr); },
[&](const type::U32*) { return ct_init_or_conv(InitConvIntrinsic::kU32, nullptr); },
[&](const type::F16*) { return ct_init_or_conv(InitConvIntrinsic::kF16, nullptr); },
[&](const type::F16*) {
return validator_.CheckF16Enabled(expr->source)
? ct_init_or_conv(InitConvIntrinsic::kF16, nullptr)
: nullptr;
},
[&](const type::F32*) { return ct_init_or_conv(InitConvIntrinsic::kF32, nullptr); },
[&](const type::Bool*) { return ct_init_or_conv(InitConvIntrinsic::kBool, nullptr); },
[&](const type::Array* arr) -> sem::Call* {
@ -2285,17 +2282,13 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
},
[&](Default) -> sem::Call* {
auto name = builder_->Symbols().NameFor(ident->symbol);
if (auto* alias = BuiltinTypeAlias(ident->symbol)) {
return ty_init_or_conv(alias);
}
if (auto builtin_type = sem::ParseBuiltinType(name);
builtin_type != sem::BuiltinType::kNone) {
return BuiltinCall(expr, builtin_type, args);
}
TINT_ICE(Resolver, diagnostics_)
<< expr->source << " unhandled CallExpression target:\n"
<< "resolved: " << (resolved ? resolved->TypeInfo().name : "<null>") << "\n"
<< "name: " << builder_->Symbols().NameFor(ident->symbol);
if (auto* alias = ShortName(ident->symbol, ident->source)) {
return ty_init_or_conv(alias);
}
return nullptr;
});
}
@ -2391,7 +2384,7 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
return call;
}
type::Type* Resolver::BuiltinTypeAlias(Symbol sym) const {
type::Type* Resolver::ShortName(Symbol sym, const Source& source) const {
auto name = builder_->Symbols().NameFor(sym);
auto& b = *builder_;
switch (type::ParseShortName(name)) {
@ -2402,11 +2395,17 @@ type::Type* Resolver::BuiltinTypeAlias(Symbol sym) const {
case type::ShortName::kVec4F:
return b.create<type::Vector>(b.create<type::F32>(), 4u);
case type::ShortName::kVec2H:
return b.create<type::Vector>(b.create<type::F16>(), 2u);
return validator_.CheckF16Enabled(source)
? b.create<type::Vector>(b.create<type::F16>(), 2u)
: nullptr;
case type::ShortName::kVec3H:
return b.create<type::Vector>(b.create<type::F16>(), 3u);
return validator_.CheckF16Enabled(source)
? b.create<type::Vector>(b.create<type::F16>(), 3u)
: nullptr;
case type::ShortName::kVec4H:
return b.create<type::Vector>(b.create<type::F16>(), 4u);
return validator_.CheckF16Enabled(source)
? b.create<type::Vector>(b.create<type::F16>(), 4u)
: nullptr;
case type::ShortName::kVec2I:
return b.create<type::Vector>(b.create<type::I32>(), 2u);
case type::ShortName::kVec3I:
@ -2422,6 +2421,8 @@ type::Type* Resolver::BuiltinTypeAlias(Symbol sym) const {
case type::ShortName::kUndefined:
break;
}
TINT_ICE(Resolver, diagnostics_) << source << " unhandled type short name '" << name << "'";
return nullptr;
}
@ -2534,6 +2535,8 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
case ast::IntLiteralExpression::Suffix::kU:
return builder_->create<type::U32>();
}
TINT_UNREACHABLE(Resolver, builder_->Diagnostics())
<< "Unhandled integer literal suffix: " << i->suffix;
return nullptr;
},
[&](const ast::FloatLiteralExpression* f) -> type::Type* {
@ -2543,21 +2546,22 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
case ast::FloatLiteralExpression::Suffix::kF:
return builder_->create<type::F32>();
case ast::FloatLiteralExpression::Suffix::kH:
return builder_->create<type::F16>();
return validator_.CheckF16Enabled(literal->source)
? builder_->create<type::F16>()
: nullptr;
}
TINT_UNREACHABLE(Resolver, builder_->Diagnostics())
<< "Unhandled float literal suffix: " << f->suffix;
return nullptr;
},
[&](const ast::BoolLiteralExpression*) { return builder_->create<type::Bool>(); },
[&](Default) { return nullptr; });
[&](Default) {
TINT_UNREACHABLE(Resolver, builder_->Diagnostics())
<< "Unhandled literal type: " << literal->TypeInfo().name;
return nullptr;
});
if (ty == nullptr) {
TINT_UNREACHABLE(Resolver, builder_->Diagnostics())
<< "Unhandled literal type: " << literal->TypeInfo().name;
return nullptr;
}
if ((ty->Is<type::F16>()) && (!enabled_extensions_.Contains(tint::ast::Extension::kF16))) {
AddError("f16 literal used without 'f16' extension enabled", literal->source);
return nullptr;
}
@ -3161,11 +3165,11 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
sem_members.Reserve(str->members.Length());
// Calculate the effective size and alignment of each field, and the overall size of the
// structure. For size, use the size attribute if provided, otherwise use the default size for
// the type. For alignment, use the alignment attribute if provided, otherwise use the default
// alignment for the member type. Diagnostic errors are raised if a basic rule is violated.
// Validation of storage-class rules requires analyzing the actual variable usage of the
// structure, and so is performed as part of the variable validation.
// structure. For size, use the size attribute if provided, otherwise use the default size
// for the type. For alignment, use the alignment attribute if provided, otherwise use the
// default alignment for the member type. Diagnostic errors are raised if a basic rule is
// violated. Validation of storage-class rules requires analyzing the actual variable usage
// of the structure, and so is performed as part of the variable validation.
uint64_t struct_size = 0;
uint64_t struct_align = 1;
utils::Hashmap<Symbol, const ast::StructMember*, 8> member_map;

View File

@ -420,8 +420,9 @@ class Resolver {
/// @returns true if the symbol is the name of a builtin function.
bool IsBuiltin(Symbol) const;
/// @returns the builtin type alias for the given symbol
type::Type* BuiltinTypeAlias(Symbol) const;
/// @returns the type short-name alias for the symbol @p symbol at @p source
/// @note: Will raise an ICE if @p symbol is not a short-name type.
type::Type* ShortName(Symbol symbol, const Source& source) const;
// ArrayInitializerSig represents a unique array initializer signature.
// It is a tuple of the array type, number of arguments provided and earliest evaluation stage.

View File

@ -2361,23 +2361,6 @@ TEST_F(ResolverTest, MaxExpressionDepth_Fail) {
std::to_string(kMaxExpressionDepth)));
}
TEST_F(ResolverTest, Literal_F16WithoutExtension) {
// fn test() {_ = 1.23h;}
WrapInFunction(Ignore(Expr(f16(1.23f))));
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), HasSubstr("error: f16 literal used without 'f16' extension enabled"));
}
TEST_F(ResolverTest, Literal_F16WithExtension) {
// enable f16;
// fn test() {_ = 1.23h;}
Enable(ast::Extension::kF16);
WrapInFunction(Ignore(Expr(f16(1.23f))));
EXPECT_TRUE(r()->Resolve());
}
// Windows debug builds have significantly smaller stack than other builds, and these tests will
// stack overflow.
#if !defined(NDEBUG)

View File

@ -893,24 +893,6 @@ TEST_F(ResolverTypeValidationTest, BuiltinAsType) {
EXPECT_EQ(r()->error(), "error: cannot use builtin 'max' as type");
}
TEST_F(ResolverTypeValidationTest, F16TypeUsedWithExtension) {
// enable f16;
// var<private> v : f16;
Enable(ast::Extension::kF16);
GlobalVar("v", ty.f16(), ast::AddressSpace::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTypeValidationTest, F16TypeUsedWithoutExtension) {
// var<private> v : f16;
GlobalVar("v", ty.f16(), ast::AddressSpace::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "error: f16 used without 'f16' extension enabled");
}
namespace GetCanonicalTests {
struct Params {
builder::ast_type_func_ptr create_ast_type;

View File

@ -1625,6 +1625,15 @@ bool Validator::RequiredExtensionForBuiltinFunction(const sem::Call* call) const
return true;
}
bool Validator::CheckF16Enabled(const Source& source) const {
// Validate if f16 type is allowed.
if (!enabled_extensions_.Contains(ast::Extension::kF16)) {
AddError("f16 type used without 'f16' extension enabled", source);
return false;
}
return true;
}
bool Validator::FunctionCall(const sem::Call* call, sem::Statement* current_statement) const {
auto* decl = call->Declaration();
auto* target = call->Target()->As<sem::Function>();

View File

@ -444,6 +444,11 @@ class Validator {
/// @returns true on success, false otherwise
bool RequiredExtensionForBuiltinFunction(const sem::Call* call) const;
/// Validates that 'f16' extension is enabled for f16 usage at @p source
/// @param source the source of the f16 usage
/// @returns true on success, false otherwise
bool CheckF16Enabled(const Source& source) const;
/// Validates there are no duplicate attributes
/// @param attributes the list of attributes to validate
/// @returns true on success, false otherwise.