resolver: Fixes for bitcasts

Fix dependency graph traversal for bitcasts. These were not being traversed, leading to an ICE if the bitcast type was an alias, as the symbol was not resolved for later use by the resolver.

Add missing validation for bitcasts. We were permitting any bitcast that wasn't a being cast to a pointer type, when the spec only allows:
 * numeric_scalar to numeric_scalar
 * vecN<numeric_scalar> to vecN<numeric_scalar>

Add lots of tests.

Fixed: chromium:1276320
Change-Id: I9e5487ec7649ac543f73fc878e7e282bf932d8cb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/71681
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2021-12-03 21:29:13 +00:00 committed by Tint LUCI CQ
parent b9e8a0b87d
commit c830130bb8
12 changed files with 312 additions and 78 deletions

View File

@ -673,6 +673,7 @@ if(${TINT_BUILD_TESTS})
resolver/assignment_validation_test.cc
resolver/atomics_test.cc
resolver/atomics_validation_test.cc
resolver/bitcast_validation_test.cc
resolver/builtins_validation_test.cc
resolver/call_test.cc
resolver/call_validation_test.cc

View File

@ -0,0 +1,228 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/ast/bitcast_expression.h"
#include "src/resolver/resolver.h"
#include "src/resolver/resolver_test_helper.h"
#include "gmock/gmock.h"
namespace tint {
namespace resolver {
namespace {
struct Type {
template <typename T>
static constexpr Type Create() {
return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem,
builder::DataType<T>::Expr};
}
builder::ast_type_func_ptr ast;
builder::sem_type_func_ptr sem;
builder::ast_expr_func_ptr expr;
};
static constexpr Type kNumericScalars[] = {
Type::Create<builder::f32>(),
Type::Create<builder::i32>(),
Type::Create<builder::u32>(),
};
static constexpr Type kVec2NumericScalars[] = {
Type::Create<builder::vec2<builder::f32>>(),
Type::Create<builder::vec2<builder::i32>>(),
Type::Create<builder::vec2<builder::u32>>(),
};
static constexpr Type kVec3NumericScalars[] = {
Type::Create<builder::vec3<builder::f32>>(),
Type::Create<builder::vec3<builder::i32>>(),
Type::Create<builder::vec3<builder::u32>>(),
};
static constexpr Type kVec4NumericScalars[] = {
Type::Create<builder::vec4<builder::f32>>(),
Type::Create<builder::vec4<builder::i32>>(),
Type::Create<builder::vec4<builder::u32>>(),
};
static constexpr Type kInvalid[] = {
// A non-exhaustive selection of uncastable types
Type::Create<bool>(),
Type::Create<builder::vec2<bool>>(),
Type::Create<builder::vec3<bool>>(),
Type::Create<builder::vec4<bool>>(),
Type::Create<builder::array<2, builder::i32>>(),
Type::Create<builder::array<3, builder::u32>>(),
Type::Create<builder::array<4, builder::f32>>(),
Type::Create<builder::array<5, bool>>(),
Type::Create<builder::mat2x2<builder::f32>>(),
Type::Create<builder::mat3x3<builder::f32>>(),
Type::Create<builder::mat4x4<builder::f32>>(),
Type::Create<builder::ptr<builder::i32>>(),
Type::Create<builder::ptr<builder::array<2, builder::i32>>>(),
Type::Create<builder::ptr<builder::mat2x2<builder::f32>>>(),
};
using ResolverBitcastValidationTest =
ResolverTestWithParam<std::tuple<Type, Type>>;
////////////////////////////////////////////////////////////////////////////////
// Valid bitcasts
////////////////////////////////////////////////////////////////////////////////
using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestPass, Test) {
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0));
WrapInFunction(cast);
ASSERT_TRUE(r()->Resolve()) << r()->error();
EXPECT_EQ(TypeOf(cast), dst.sem(*this));
}
INSTANTIATE_TEST_SUITE_P(Scalars,
ResolverBitcastValidationTestPass,
testing::Combine(testing::ValuesIn(kNumericScalars),
testing::ValuesIn(kNumericScalars)));
INSTANTIATE_TEST_SUITE_P(
Vec2,
ResolverBitcastValidationTestPass,
testing::Combine(testing::ValuesIn(kVec2NumericScalars),
testing::ValuesIn(kVec2NumericScalars)));
INSTANTIATE_TEST_SUITE_P(
Vec3,
ResolverBitcastValidationTestPass,
testing::Combine(testing::ValuesIn(kVec3NumericScalars),
testing::ValuesIn(kVec3NumericScalars)));
INSTANTIATE_TEST_SUITE_P(
Vec4,
ResolverBitcastValidationTestPass,
testing::Combine(testing::ValuesIn(kVec4NumericScalars),
testing::ValuesIn(kVec4NumericScalars)));
////////////////////////////////////////////////////////////////////////////////
// Invalid source type for bitcasts
////////////////////////////////////////////////////////////////////////////////
using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) {
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src"));
WrapInFunction(Const("src", nullptr, src.expr(*this, 0)), cast);
auto expected = "12:34 error: '" + src.sem(*this)->FriendlyName(Symbols()) +
"' cannot be bitcast";
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), expected);
}
INSTANTIATE_TEST_SUITE_P(Scalars,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
testing::ValuesIn(kNumericScalars)));
INSTANTIATE_TEST_SUITE_P(
Vec2,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
testing::ValuesIn(kVec2NumericScalars)));
INSTANTIATE_TEST_SUITE_P(
Vec3,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
testing::ValuesIn(kVec3NumericScalars)));
INSTANTIATE_TEST_SUITE_P(
Vec4,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
testing::ValuesIn(kVec4NumericScalars)));
////////////////////////////////////////////////////////////////////////////////
// Invalid target type for bitcasts
////////////////////////////////////////////////////////////////////////////////
using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) {
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
// Use an alias so we can put a Source on the bitcast type
Alias("T", dst.ast(*this));
WrapInFunction(
Bitcast(ty.type_name(Source{{12, 34}}, "T"), src.expr(*this, 0)));
auto expected = "12:34 error: cannot bitcast to '" +
dst.sem(*this)->FriendlyName(Symbols()) + "'";
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), expected);
}
INSTANTIATE_TEST_SUITE_P(Scalars,
ResolverBitcastValidationTestInvalidDstTy,
testing::Combine(testing::ValuesIn(kNumericScalars),
testing::ValuesIn(kInvalid)));
INSTANTIATE_TEST_SUITE_P(
Vec2,
ResolverBitcastValidationTestInvalidDstTy,
testing::Combine(testing::ValuesIn(kVec2NumericScalars),
testing::ValuesIn(kInvalid)));
INSTANTIATE_TEST_SUITE_P(
Vec3,
ResolverBitcastValidationTestInvalidDstTy,
testing::Combine(testing::ValuesIn(kVec3NumericScalars),
testing::ValuesIn(kInvalid)));
INSTANTIATE_TEST_SUITE_P(
Vec4,
ResolverBitcastValidationTestInvalidDstTy,
testing::Combine(testing::ValuesIn(kVec4NumericScalars),
testing::ValuesIn(kInvalid)));
////////////////////////////////////////////////////////////////////////////////
// Incompatible bitcast, but both src and dst types are valid
////////////////////////////////////////////////////////////////////////////////
using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestIncompatible, Test) {
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0)));
auto expected = "12:34 error: cannot bitcast from '" +
src.sem(*this)->FriendlyName(Symbols()) + "' to '" +
dst.sem(*this)->FriendlyName(Symbols()) + "'";
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), expected);
}
INSTANTIATE_TEST_SUITE_P(
ScalarToVec2,
ResolverBitcastValidationTestIncompatible,
testing::Combine(testing::ValuesIn(kNumericScalars),
testing::ValuesIn(kVec2NumericScalars)));
INSTANTIATE_TEST_SUITE_P(
Vec2ToVec3,
ResolverBitcastValidationTestIncompatible,
testing::Combine(testing::ValuesIn(kVec2NumericScalars),
testing::ValuesIn(kVec3NumericScalars)));
INSTANTIATE_TEST_SUITE_P(
Vec3ToVec4,
ResolverBitcastValidationTestIncompatible,
testing::Combine(testing::ValuesIn(kVec3NumericScalars),
testing::ValuesIn(kVec4NumericScalars)));
INSTANTIATE_TEST_SUITE_P(
Vec4ToScalar,
ResolverBitcastValidationTestIncompatible,
testing::Combine(testing::ValuesIn(kVec4NumericScalars),
testing::ValuesIn(kNumericScalars)));
} // namespace
} // namespace resolver
} // namespace tint

View File

@ -329,6 +329,9 @@ class DependencyScanner {
utils::Lookup(graph_.resolved_symbols, call->target.type));
}
}
if (auto* cast = expr->As<ast::BitcastExpression>()) {
TraverseType(cast->type);
}
return ast::TraverseAction::Descend;
});
}

View File

@ -1282,6 +1282,7 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
Block(Assign(V, V)), //
Else(V, //
Block(Assign(V, V)))), //
Ignore(Bitcast(T, V)), //
For(Decl(Var(Sym(), T, V)), //
Equal(V, V), //
Assign(V, V), //

View File

@ -171,19 +171,6 @@ TEST_F(ResolverPtrRefValidationTest, InferredPtrAccessMismatch) {
"'ptr<storage, i32, read_write>'");
}
TEST_F(ResolverTest, Expr_Bitcast_ptr) {
auto* vf = Var("vf", ty.f32());
auto* bitcast = create<ast::BitcastExpression>(
Source{{12, 34}}, ty.pointer<i32>(ast::StorageClass::kFunction),
Expr("vf"));
auto* ip =
Const("ip", ty.pointer<i32>(ast::StorageClass::kFunction), bitcast);
WrapInFunction(Decl(vf), Decl(ip));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: cannot cast to a pointer");
}
} // namespace
} // namespace resolver
} // namespace tint

View File

@ -1221,15 +1221,17 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
if (!ty) {
return nullptr;
}
if (ty->Is<sem::Pointer>()) {
AddError("cannot cast to a pointer", expr->source);
return nullptr;
}
auto val = EvaluateConstantValue(expr, ty);
auto* sem =
builder_->create<sem::Expression>(expr, ty, current_statement_, val);
sem->Behaviors() = inner->Behaviors();
if (!ValidateBitcast(expr, ty)) {
return nullptr;
}
return sem;
}

View File

@ -238,6 +238,7 @@ class Resolver {
bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
bool ValidateAtomicVariable(const sem::Variable* var);
bool ValidateAssignment(const ast::AssignmentStatement* a);
bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to);
bool ValidateBreakStatement(const sem::Statement* stmt);
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type,

View File

@ -171,6 +171,9 @@ using alias2 = alias<TO, 2>;
template <typename TO>
using alias3 = alias<TO, 3>;
template <typename TO>
struct ptr {};
using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b,
int elem_value);
@ -387,6 +390,36 @@ struct DataType<alias<T, ID>> {
}
};
/// Helper for building pointer types and expressions
template <typename T>
struct DataType<ptr<T>> {
/// true if the pointer type is a composite type
static constexpr bool is_composite = false;
/// @param b the ProgramBuilder
/// @return a new AST alias type
static inline const ast::Type* AST(ProgramBuilder& b) {
return b.create<ast::Pointer>(DataType<T>::AST(b),
ast::StorageClass::kPrivate,
ast::Access::kReadWrite);
}
/// @param b the ProgramBuilder
/// @return the semantic aliased type
static inline const sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::Pointer>(DataType<T>::Sem(b),
ast::StorageClass::kPrivate,
ast::Access::kReadWrite);
}
/// @param b the ProgramBuilder
/// @return a new AST expression of the alias type
static inline const ast::Expression* Expr(ProgramBuilder& b, int /*unused*/) {
auto sym = b.Symbols().New("global_for_ptr");
b.Global(sym, DataType<T>::AST(b), ast::StorageClass::kPrivate);
return b.AddressOf(sym);
}
};
/// Helper for building array types and expressions
template <int N, typename T>
struct DataType<array<N, T>> {
@ -401,7 +434,14 @@ struct DataType<array<N, T>> {
/// @param b the ProgramBuilder
/// @return the semantic array type
static inline const sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::Array>(DataType<T>::Sem(b), N);
auto* el = DataType<T>::Sem(b);
return b.create<sem::Array>(
/* element */ el,
/* count */ N,
/* align */ el->Align(),
/* size */ el->Size(),
/* stride */ el->Align(),
/* implicit_stride */ el->Align());
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element in the array will be initialized

View File

@ -1347,6 +1347,36 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) {
return true;
}
bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
const sem::Type* to) {
auto* from = TypeOf(cast->expr)->UnwrapRef();
if (!from->is_numeric_scalar_or_vector()) {
AddError("'" + TypeNameOf(from) + "' cannot be bitcast",
cast->expr->source);
return false;
}
if (!to->is_numeric_scalar_or_vector()) {
AddError("cannot bitcast to '" + TypeNameOf(to) + "'", cast->type->source);
return false;
}
auto width = [&](const sem::Type* ty) {
if (auto* vec = ty->As<sem::Vector>()) {
return vec->Width();
}
return 1u;
};
if (width(from) != width(to)) {
AddError("cannot bitcast from '" + TypeNameOf(from) + "' to '" +
TypeNameOf(to) + "'",
cast->source);
return false;
}
return true;
}
bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
if (!stmt->FindFirstParent<sem::LoopBlockStatement, sem::CaseStatement>()) {
AddError("break statement must be in a loop or switch case",

View File

@ -452,36 +452,6 @@ bool a = (tint_tmp);
)");
}
TEST_F(GlslGeneratorImplTest_Binary, Bitcast_WithLogical) {
// as<i32>(a && (b || c))
Global("a", ty.bool_(), ast::StorageClass::kPrivate);
Global("b", ty.bool_(), ast::StorageClass::kPrivate);
Global("c", ty.bool_(), ast::StorageClass::kPrivate);
auto* expr = create<ast::BitcastExpression>(
ty.i32(), create<ast::BinaryExpression>(
ast::BinaryOp::kLogicalAnd, Expr("a"),
create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
Expr("b"), Expr("c"))));
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
if (tint_tmp) {
bool tint_tmp_1 = b;
if (!tint_tmp_1) {
tint_tmp_1 = c;
}
tint_tmp = (tint_tmp_1);
}
)");
EXPECT_EQ(out.str(), R"(int((tint_tmp)))");
}
TEST_F(GlslGeneratorImplTest_Binary, Call_WithLogical) {
// foo(a && b, c || d, (a || c) && (b || d))

View File

@ -452,36 +452,6 @@ bool a = (tint_tmp);
)");
}
TEST_F(HlslGeneratorImplTest_Binary, Bitcast_WithLogical) {
// as<i32>(a && (b || c))
Global("a", ty.bool_(), ast::StorageClass::kPrivate);
Global("b", ty.bool_(), ast::StorageClass::kPrivate);
Global("c", ty.bool_(), ast::StorageClass::kPrivate);
auto* expr = create<ast::BitcastExpression>(
ty.i32(), create<ast::BinaryExpression>(
ast::BinaryOp::kLogicalAnd, Expr("a"),
create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
Expr("b"), Expr("c"))));
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
if (tint_tmp) {
bool tint_tmp_1 = b;
if (!tint_tmp_1) {
tint_tmp_1 = c;
}
tint_tmp = (tint_tmp_1);
}
)");
EXPECT_EQ(out.str(), R"(asint((tint_tmp)))");
}
TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) {
// foo(a && b, c || d, (a || c) && (b || d))

View File

@ -236,6 +236,7 @@ tint_unittests_source_set("tint_unittests_resolver_src") {
"../src/resolver/assignment_validation_test.cc",
"../src/resolver/atomics_test.cc",
"../src/resolver/atomics_validation_test.cc",
"../src/resolver/bitcast_validation_test.cc",
"../src/resolver/builtins_validation_test.cc",
"../src/resolver/call_test.cc",
"../src/resolver/call_validation_test.cc",