tint/resolver: Correctly validate f16 usage
There was limited validation for this. Validate all ways to use a f16. Change-Id: Ibdcde1f304e704790da3db379c79fcc0844cad67 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114140 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: James Price <jrprice@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
520d6002f3
commit
188ed1793a
|
@ -1185,6 +1185,7 @@ if (tint_build_unittests) {
|
|||
"resolver/dependency_graph_test.cc",
|
||||
"resolver/entry_point_validation_test.cc",
|
||||
"resolver/evaluation_stage_test.cc",
|
||||
"resolver/f16_extension_test.cc",
|
||||
"resolver/function_validation_test.cc",
|
||||
"resolver/host_shareable_validation_test.cc",
|
||||
"resolver/increment_decrement_validation_test.cc",
|
||||
|
|
|
@ -897,6 +897,7 @@ if(TINT_BUILD_TESTS)
|
|||
resolver/dependency_graph_test.cc
|
||||
resolver/entry_point_validation_test.cc
|
||||
resolver/evaluation_stage_test.cc
|
||||
resolver/f16_extension_test.cc
|
||||
resolver/function_validation_test.cc
|
||||
resolver/host_shareable_validation_test.cc
|
||||
resolver/increment_decrement_validation_test.cc
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/resolver/resolver.h"
|
||||
#include "src/tint/resolver/resolver_test_helper.h"
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace tint::number_suffixes; // NOLINT
|
||||
|
||||
namespace tint::resolver {
|
||||
namespace {
|
||||
|
||||
using ResolverF16ExtensionTest = ResolverTest;
|
||||
|
||||
TEST_F(ResolverF16ExtensionTest, TypeUsedWithExtension) {
|
||||
// enable f16;
|
||||
// var<private> v : f16;
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("v", ty.f16(), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverF16ExtensionTest, TypeUsedWithoutExtension) {
|
||||
// var<private> v : f16;
|
||||
GlobalVar("v", ty.f16(Source{{12, 34}}), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
|
||||
}
|
||||
|
||||
TEST_F(ResolverF16ExtensionTest, Vec2TypeUsedWithExtension) {
|
||||
// enable f16;
|
||||
// var<private> v : vec2<f16>;
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("v", ty.vec2<f16>(), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverF16ExtensionTest, Vec2TypeUsedWithoutExtension) {
|
||||
// var<private> v : vec2<f16>;
|
||||
GlobalVar("v", ty.vec2(ty.f16(Source{{12, 34}})), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
|
||||
}
|
||||
|
||||
TEST_F(ResolverF16ExtensionTest, Vec2TypeInitUsedWithExtension) {
|
||||
// enable f16;
|
||||
// var<private> v = vec2<f16>();
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("v", Construct(ty.vec2<f16>()), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverF16ExtensionTest, Vec2TypeInitUsedWithoutExtension) {
|
||||
// var<private> v = vec2<f16>();
|
||||
GlobalVar("v", Construct(ty.vec2(ty.f16(Source{{12, 34}}))), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
|
||||
}
|
||||
|
||||
TEST_F(ResolverF16ExtensionTest, Vec2TypeConvUsedWithExtension) {
|
||||
// enable f16;
|
||||
// var<private> v = vec2<f16>(vec2<f32>());
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("v", Construct(ty.vec2<f16>(), Construct(ty.vec2<f32>())),
|
||||
ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverF16ExtensionTest, Vec2TypeConvUsedWithoutExtension) {
|
||||
// var<private> v = vec2<f16>(vec2<f32>());
|
||||
GlobalVar("v", Construct(ty.vec2(ty.f16(Source{{12, 34}})), Construct(ty.vec2<f32>())),
|
||||
ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
|
||||
}
|
||||
|
||||
TEST_F(ResolverF16ExtensionTest, F16LiteralUsedWithExtension) {
|
||||
// enable f16;
|
||||
// var<private> v = 16h;
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("v", Expr(16_h), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverF16ExtensionTest, F16LiteralUsedWithoutExtension) {
|
||||
// var<private> v = 16h;
|
||||
GlobalVar("v", Expr(Source{{12, 34}}, 16_h), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
|
||||
}
|
||||
|
||||
using ResolverF16ExtensionShortNameTest = ResolverTestWithParam<const char*>;
|
||||
|
||||
TEST_P(ResolverF16ExtensionShortNameTest, Vec2hTypeUsedWithExtension) {
|
||||
// enable f16;
|
||||
// var<private> v : vec2h;
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("v", ty.type_name(Source{{12, 34}}, GetParam()), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_P(ResolverF16ExtensionShortNameTest, Vec2hTypeUsedWithoutExtension) {
|
||||
// var<private> v : vec2h;
|
||||
GlobalVar("v", ty.type_name(Source{{12, 34}}, GetParam()), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverF16ExtensionShortNameTest,
|
||||
ResolverF16ExtensionShortNameTest,
|
||||
testing::Values("vec2h", "vec3h", "vec4h"));
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::resolver
|
|
@ -208,12 +208,7 @@ type::Type* Resolver::Type(const ast::Type* ty) {
|
|||
[&](const ast::I32*) { return builder_->create<type::I32>(); },
|
||||
[&](const ast::U32*) { return builder_->create<type::U32>(); },
|
||||
[&](const ast::F16* t) -> type::F16* {
|
||||
// Validate if f16 type is allowed.
|
||||
if (!enabled_extensions_.Contains(ast::Extension::kF16)) {
|
||||
AddError("f16 used without 'f16' extension enabled", t->source);
|
||||
return nullptr;
|
||||
}
|
||||
return builder_->create<type::F16>();
|
||||
return validator_.CheckF16Enabled(t->source) ? builder_->create<type::F16>() : nullptr;
|
||||
},
|
||||
[&](const ast::F32*) { return builder_->create<type::F32>(); },
|
||||
[&](const ast::Vector* t) -> type::Vector* {
|
||||
|
@ -337,9 +332,7 @@ type::Type* Resolver::Type(const ast::Type* ty) {
|
|||
AddError("cannot use builtin '" + name + "' as type", ty->source);
|
||||
return nullptr;
|
||||
}
|
||||
if (auto* t = BuiltinTypeAlias(tn->name)) {
|
||||
return t;
|
||||
}
|
||||
return ShortName(tn->name, tn->source);
|
||||
}
|
||||
TINT_UNREACHABLE(Resolver, diagnostics_)
|
||||
<< "Unhandled resolved type '"
|
||||
|
@ -2048,7 +2041,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto stage = args_stage; // The evaluation stage of the call
|
||||
auto stage = args_stage; // The evaluation stage of the call
|
||||
const constant::Constant* value = nullptr; // The constant value for the call
|
||||
if (stage == sem::EvaluationStage::kConstant) {
|
||||
if (auto r = const_eval_.ArrayOrStructInit(ty, args)) {
|
||||
|
@ -2083,7 +2076,11 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
|
|||
},
|
||||
[&](const type::I32*) { return ct_init_or_conv(InitConvIntrinsic::kI32, nullptr); },
|
||||
[&](const type::U32*) { return ct_init_or_conv(InitConvIntrinsic::kU32, nullptr); },
|
||||
[&](const type::F16*) { return ct_init_or_conv(InitConvIntrinsic::kF16, nullptr); },
|
||||
[&](const type::F16*) {
|
||||
return validator_.CheckF16Enabled(expr->source)
|
||||
? ct_init_or_conv(InitConvIntrinsic::kF16, nullptr)
|
||||
: nullptr;
|
||||
},
|
||||
[&](const type::F32*) { return ct_init_or_conv(InitConvIntrinsic::kF32, nullptr); },
|
||||
[&](const type::Bool*) { return ct_init_or_conv(InitConvIntrinsic::kBool, nullptr); },
|
||||
[&](const type::Array* arr) -> sem::Call* {
|
||||
|
@ -2285,17 +2282,13 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
|
|||
},
|
||||
[&](Default) -> sem::Call* {
|
||||
auto name = builder_->Symbols().NameFor(ident->symbol);
|
||||
if (auto* alias = BuiltinTypeAlias(ident->symbol)) {
|
||||
return ty_init_or_conv(alias);
|
||||
}
|
||||
if (auto builtin_type = sem::ParseBuiltinType(name);
|
||||
builtin_type != sem::BuiltinType::kNone) {
|
||||
return BuiltinCall(expr, builtin_type, args);
|
||||
}
|
||||
TINT_ICE(Resolver, diagnostics_)
|
||||
<< expr->source << " unhandled CallExpression target:\n"
|
||||
<< "resolved: " << (resolved ? resolved->TypeInfo().name : "<null>") << "\n"
|
||||
<< "name: " << builder_->Symbols().NameFor(ident->symbol);
|
||||
if (auto* alias = ShortName(ident->symbol, ident->source)) {
|
||||
return ty_init_or_conv(alias);
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
}
|
||||
|
@ -2391,7 +2384,7 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
|
|||
return call;
|
||||
}
|
||||
|
||||
type::Type* Resolver::BuiltinTypeAlias(Symbol sym) const {
|
||||
type::Type* Resolver::ShortName(Symbol sym, const Source& source) const {
|
||||
auto name = builder_->Symbols().NameFor(sym);
|
||||
auto& b = *builder_;
|
||||
switch (type::ParseShortName(name)) {
|
||||
|
@ -2402,11 +2395,17 @@ type::Type* Resolver::BuiltinTypeAlias(Symbol sym) const {
|
|||
case type::ShortName::kVec4F:
|
||||
return b.create<type::Vector>(b.create<type::F32>(), 4u);
|
||||
case type::ShortName::kVec2H:
|
||||
return b.create<type::Vector>(b.create<type::F16>(), 2u);
|
||||
return validator_.CheckF16Enabled(source)
|
||||
? b.create<type::Vector>(b.create<type::F16>(), 2u)
|
||||
: nullptr;
|
||||
case type::ShortName::kVec3H:
|
||||
return b.create<type::Vector>(b.create<type::F16>(), 3u);
|
||||
return validator_.CheckF16Enabled(source)
|
||||
? b.create<type::Vector>(b.create<type::F16>(), 3u)
|
||||
: nullptr;
|
||||
case type::ShortName::kVec4H:
|
||||
return b.create<type::Vector>(b.create<type::F16>(), 4u);
|
||||
return validator_.CheckF16Enabled(source)
|
||||
? b.create<type::Vector>(b.create<type::F16>(), 4u)
|
||||
: nullptr;
|
||||
case type::ShortName::kVec2I:
|
||||
return b.create<type::Vector>(b.create<type::I32>(), 2u);
|
||||
case type::ShortName::kVec3I:
|
||||
|
@ -2422,6 +2421,8 @@ type::Type* Resolver::BuiltinTypeAlias(Symbol sym) const {
|
|||
case type::ShortName::kUndefined:
|
||||
break;
|
||||
}
|
||||
|
||||
TINT_ICE(Resolver, diagnostics_) << source << " unhandled type short name '" << name << "'";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -2534,6 +2535,8 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
|
|||
case ast::IntLiteralExpression::Suffix::kU:
|
||||
return builder_->create<type::U32>();
|
||||
}
|
||||
TINT_UNREACHABLE(Resolver, builder_->Diagnostics())
|
||||
<< "Unhandled integer literal suffix: " << i->suffix;
|
||||
return nullptr;
|
||||
},
|
||||
[&](const ast::FloatLiteralExpression* f) -> type::Type* {
|
||||
|
@ -2543,21 +2546,22 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
|
|||
case ast::FloatLiteralExpression::Suffix::kF:
|
||||
return builder_->create<type::F32>();
|
||||
case ast::FloatLiteralExpression::Suffix::kH:
|
||||
return builder_->create<type::F16>();
|
||||
return validator_.CheckF16Enabled(literal->source)
|
||||
? builder_->create<type::F16>()
|
||||
: nullptr;
|
||||
}
|
||||
TINT_UNREACHABLE(Resolver, builder_->Diagnostics())
|
||||
<< "Unhandled float literal suffix: " << f->suffix;
|
||||
return nullptr;
|
||||
},
|
||||
[&](const ast::BoolLiteralExpression*) { return builder_->create<type::Bool>(); },
|
||||
[&](Default) { return nullptr; });
|
||||
[&](Default) {
|
||||
TINT_UNREACHABLE(Resolver, builder_->Diagnostics())
|
||||
<< "Unhandled literal type: " << literal->TypeInfo().name;
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
if (ty == nullptr) {
|
||||
TINT_UNREACHABLE(Resolver, builder_->Diagnostics())
|
||||
<< "Unhandled literal type: " << literal->TypeInfo().name;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if ((ty->Is<type::F16>()) && (!enabled_extensions_.Contains(tint::ast::Extension::kF16))) {
|
||||
AddError("f16 literal used without 'f16' extension enabled", literal->source);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -3161,11 +3165,11 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
|
|||
sem_members.Reserve(str->members.Length());
|
||||
|
||||
// Calculate the effective size and alignment of each field, and the overall size of the
|
||||
// structure. For size, use the size attribute if provided, otherwise use the default size for
|
||||
// the type. For alignment, use the alignment attribute if provided, otherwise use the default
|
||||
// alignment for the member type. Diagnostic errors are raised if a basic rule is violated.
|
||||
// Validation of storage-class rules requires analyzing the actual variable usage of the
|
||||
// structure, and so is performed as part of the variable validation.
|
||||
// structure. For size, use the size attribute if provided, otherwise use the default size
|
||||
// for the type. For alignment, use the alignment attribute if provided, otherwise use the
|
||||
// default alignment for the member type. Diagnostic errors are raised if a basic rule is
|
||||
// violated. Validation of storage-class rules requires analyzing the actual variable usage
|
||||
// of the structure, and so is performed as part of the variable validation.
|
||||
uint64_t struct_size = 0;
|
||||
uint64_t struct_align = 1;
|
||||
utils::Hashmap<Symbol, const ast::StructMember*, 8> member_map;
|
||||
|
|
|
@ -420,8 +420,9 @@ class Resolver {
|
|||
/// @returns true if the symbol is the name of a builtin function.
|
||||
bool IsBuiltin(Symbol) const;
|
||||
|
||||
/// @returns the builtin type alias for the given symbol
|
||||
type::Type* BuiltinTypeAlias(Symbol) const;
|
||||
/// @returns the type short-name alias for the symbol @p symbol at @p source
|
||||
/// @note: Will raise an ICE if @p symbol is not a short-name type.
|
||||
type::Type* ShortName(Symbol symbol, const Source& source) const;
|
||||
|
||||
// ArrayInitializerSig represents a unique array initializer signature.
|
||||
// It is a tuple of the array type, number of arguments provided and earliest evaluation stage.
|
||||
|
|
|
@ -2361,23 +2361,6 @@ TEST_F(ResolverTest, MaxExpressionDepth_Fail) {
|
|||
std::to_string(kMaxExpressionDepth)));
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Literal_F16WithoutExtension) {
|
||||
// fn test() {_ = 1.23h;}
|
||||
WrapInFunction(Ignore(Expr(f16(1.23f))));
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(), HasSubstr("error: f16 literal used without 'f16' extension enabled"));
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Literal_F16WithExtension) {
|
||||
// enable f16;
|
||||
// fn test() {_ = 1.23h;}
|
||||
Enable(ast::Extension::kF16);
|
||||
WrapInFunction(Ignore(Expr(f16(1.23f))));
|
||||
|
||||
EXPECT_TRUE(r()->Resolve());
|
||||
}
|
||||
|
||||
// Windows debug builds have significantly smaller stack than other builds, and these tests will
|
||||
// stack overflow.
|
||||
#if !defined(NDEBUG)
|
||||
|
|
|
@ -893,24 +893,6 @@ TEST_F(ResolverTypeValidationTest, BuiltinAsType) {
|
|||
EXPECT_EQ(r()->error(), "error: cannot use builtin 'max' as type");
|
||||
}
|
||||
|
||||
TEST_F(ResolverTypeValidationTest, F16TypeUsedWithExtension) {
|
||||
// enable f16;
|
||||
// var<private> v : f16;
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("v", ty.f16(), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverTypeValidationTest, F16TypeUsedWithoutExtension) {
|
||||
// var<private> v : f16;
|
||||
GlobalVar("v", ty.f16(), ast::AddressSpace::kPrivate);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "error: f16 used without 'f16' extension enabled");
|
||||
}
|
||||
|
||||
namespace GetCanonicalTests {
|
||||
struct Params {
|
||||
builder::ast_type_func_ptr create_ast_type;
|
||||
|
|
|
@ -1625,6 +1625,15 @@ bool Validator::RequiredExtensionForBuiltinFunction(const sem::Call* call) const
|
|||
return true;
|
||||
}
|
||||
|
||||
bool Validator::CheckF16Enabled(const Source& source) const {
|
||||
// Validate if f16 type is allowed.
|
||||
if (!enabled_extensions_.Contains(ast::Extension::kF16)) {
|
||||
AddError("f16 type used without 'f16' extension enabled", source);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Validator::FunctionCall(const sem::Call* call, sem::Statement* current_statement) const {
|
||||
auto* decl = call->Declaration();
|
||||
auto* target = call->Target()->As<sem::Function>();
|
||||
|
|
|
@ -444,6 +444,11 @@ class Validator {
|
|||
/// @returns true on success, false otherwise
|
||||
bool RequiredExtensionForBuiltinFunction(const sem::Call* call) const;
|
||||
|
||||
/// Validates that 'f16' extension is enabled for f16 usage at @p source
|
||||
/// @param source the source of the f16 usage
|
||||
/// @returns true on success, false otherwise
|
||||
bool CheckF16Enabled(const Source& source) const;
|
||||
|
||||
/// Validates there are no duplicate attributes
|
||||
/// @param attributes the list of attributes to validate
|
||||
/// @returns true on success, false otherwise.
|
||||
|
|
Loading…
Reference in New Issue