resolver: Validate increment/decrement statements

These can only be applied to scalar integer references.

These currently cannot be used in a for-loop initializer.

Bug: tint:1488
Change-Id: I218c438c573ff3f5917d058718d12603f9b4057f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/86002
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2022-04-07 13:42:45 +00:00
parent ebe9741d0c
commit 2f9e31cefb
7 changed files with 309 additions and 0 deletions

View File

@ -747,6 +747,7 @@ if(TINT_BUILD_TESTS)
resolver/entry_point_validation_test.cc resolver/entry_point_validation_test.cc
resolver/function_validation_test.cc resolver/function_validation_test.cc
resolver/host_shareable_validation_test.cc resolver/host_shareable_validation_test.cc
resolver/increment_decrement_validation_test.cc
resolver/inferred_type_test.cc resolver/inferred_type_test.cc
resolver/is_host_shareable_test.cc resolver/is_host_shareable_test.cc
resolver/is_storeable_test.cc resolver/is_storeable_test.cc

View File

@ -234,6 +234,9 @@ class DependencyScanner {
TraverseStatement(l->continuing); TraverseStatement(l->continuing);
TraverseStatement(l->body); TraverseStatement(l->body);
}, },
[&](const ast::IncrementDecrementStatement* i) {
TraverseExpression(i->lhs);
},
[&](const ast::LoopStatement* l) { [&](const ast::LoopStatement* l) {
scope_stack_.Push(); scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop()); TINT_DEFER(scope_stack_.Pop());

View File

@ -0,0 +1,234 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/resolver/resolver.h"
#include "gmock/gmock.h"
#include "src/tint/resolver/resolver_test_helper.h"
namespace tint::resolver {
namespace {
using ResolverIncrementDecrementValidationTest = ResolverTest;
TEST_F(ResolverIncrementDecrementValidationTest, Increment_Signed) {
// var a : i32 = 2;
// a++;
auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
WrapInFunction(var, Increment(Source{{12, 34}}, "a"));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverIncrementDecrementValidationTest, Decrement_Signed) {
// var a : i32 = 2;
// a--;
auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
WrapInFunction(var, Decrement(Source{{12, 34}}, "a"));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverIncrementDecrementValidationTest, Increment_Unsigned) {
// var a : u32 = 2u;
// a++;
auto* var = Var("a", ty.u32(), ast::StorageClass::kNone, Expr(2u));
WrapInFunction(var, Increment(Source{{12, 34}}, "a"));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverIncrementDecrementValidationTest, Decrement_Unsigned) {
// var a : u32 = 2u;
// a--;
auto* var = Var("a", ty.u32(), ast::StorageClass::kNone, Expr(2u));
WrapInFunction(var, Decrement(Source{{12, 34}}, "a"));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverIncrementDecrementValidationTest, ThroughPointer) {
// var a : i32;
// let b : ptr<function,i32> = &a;
// *b++;
auto* var_a = Var("a", ty.i32(), ast::StorageClass::kFunction);
auto* var_b = Const("b", ty.pointer<int>(ast::StorageClass::kFunction),
AddressOf(Expr("a")));
WrapInFunction(var_a, var_b, Increment(Source{{12, 34}}, Deref("b")));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverIncrementDecrementValidationTest, ThroughArray) {
// var a : array<i32, 4>;
// a[1]++;
auto* var_a = Var("a", ty.array(ty.i32(), 4), ast::StorageClass::kNone);
WrapInFunction(var_a, Increment(Source{{12, 34}}, IndexAccessor("a", 1)));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverIncrementDecrementValidationTest, ThroughVector_Index) {
// var a : vec4<i32>;
// a.y++;
auto* var_a = Var("a", ty.vec4(ty.i32()), ast::StorageClass::kNone);
WrapInFunction(var_a, Increment(Source{{12, 34}}, IndexAccessor("a", 1)));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverIncrementDecrementValidationTest, ThroughVector_Member) {
// var a : vec4<i32>;
// a.y++;
auto* var_a = Var("a", ty.vec4(ty.i32()), ast::StorageClass::kNone);
WrapInFunction(var_a, Increment(Source{{12, 34}}, MemberAccessor("a", "y")));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverIncrementDecrementValidationTest, Float) {
// var a : f32 = 2.0;
// a++;
auto* var = Var("a", ty.f32(), ast::StorageClass::kNone, Expr(2.f));
auto* inc = Increment(Expr(Source{{12, 34}}, "a"));
WrapInFunction(var, inc);
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: increment statement can only be applied to an "
"integer scalar");
}
TEST_F(ResolverIncrementDecrementValidationTest, Vector) {
// var a : vec4<f32>;
// a++;
auto* var = Var("a", ty.vec4<i32>(), ast::StorageClass::kNone);
auto* inc = Increment(Expr(Source{{12, 34}}, "a"));
WrapInFunction(var, inc);
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: increment statement can only be applied to an "
"integer scalar");
}
TEST_F(ResolverIncrementDecrementValidationTest, Atomic) {
// var<workgroup> a : atomic<i32>;
// a++;
Global(Source{{12, 34}}, "a", ty.atomic(ty.i32()),
ast::StorageClass::kWorkgroup);
WrapInFunction(Increment(Expr(Source{{56, 78}}, "a")));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"56:78 error: increment statement can only be applied to an "
"integer scalar");
}
TEST_F(ResolverIncrementDecrementValidationTest, Literal) {
// 1++;
WrapInFunction(Increment(Expr(Source{{56, 78}}, 1)));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "56:78 error: cannot modify value of type 'i32'");
}
TEST_F(ResolverIncrementDecrementValidationTest, Constant) {
// let a = 1;
// a++;
auto* a = Const(Source{{12, 34}}, "a", nullptr, Expr(1));
WrapInFunction(a, Increment(Expr(Source{{56, 78}}, "a")));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(56:78 error: cannot modify constant value
12:34 note: 'a' is declared here:)");
}
TEST_F(ResolverIncrementDecrementValidationTest, Parameter) {
// fn func(a : i32)
// {
// a++;
// }
auto* a = Param(Source{{12, 34}}, "a", ty.i32());
Func("func", {a}, ty.void_(), {Increment(Expr(Source{{56, 78}}, "a"))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(56:78 error: cannot modify function parameter
12:34 note: 'a' is declared here:)");
}
TEST_F(ResolverIncrementDecrementValidationTest, ReturnValue) {
// fn func() -> i32 {
// return 0;
// }
// {
// a++;
// }
Func("func", {}, ty.i32(), {Return(0)});
WrapInFunction(Increment(Call(Source{{56, 78}}, "func")));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(56:78 error: cannot modify value of type 'i32')");
}
TEST_F(ResolverIncrementDecrementValidationTest, ReadOnlyBuffer) {
// @group(0) @binding(0) var<storage,read> a : i32;
// {
// a++;
// }
Global(Source{{12, 34}}, "a", ty.i32(), ast::StorageClass::kStorage,
ast::Access::kRead, GroupAndBinding(0, 0));
WrapInFunction(Increment(Source{{56, 78}}, "a"));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"56:78 error: cannot modify read-only type 'ref<storage, i32, read>'");
}
TEST_F(ResolverIncrementDecrementValidationTest, Phony) {
// _++;
WrapInFunction(Increment(Phony(Source{{56, 78}})));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "56:78 error: cannot modify value of type 'void'");
}
TEST_F(ResolverIncrementDecrementValidationTest, InForLoopInit) {
// var a : i32 = 2;
// for (a++; ; ) {
// break;
// }
auto* a = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
auto* loop =
For(Increment(Source{{56, 78}}, "a"), nullptr, nullptr, Block(Break()));
WrapInFunction(a, loop);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverIncrementDecrementValidationTest, InForLoopCont) {
// var a : i32 = 2;
// for (; ; a++) {
// break;
// }
auto* a = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
auto* loop =
For(nullptr, nullptr, Increment(Source{{56, 78}}, "a"), Block(Break()));
WrapInFunction(a, loop);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
} // namespace
} // namespace tint::resolver

View File

@ -868,6 +868,9 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
[&](const ast::FallthroughStatement* f) { [&](const ast::FallthroughStatement* f) {
return FallthroughStatement(f); return FallthroughStatement(f);
}, },
[&](const ast::IncrementDecrementStatement* i) {
return IncrementDecrementStatement(i);
},
[&](const ast::ReturnStatement* r) { return ReturnStatement(r); }, [&](const ast::ReturnStatement* r) { return ReturnStatement(r); },
[&](const ast::VariableDeclStatement* v) { [&](const ast::VariableDeclStatement* v) {
return VariableDeclStatement(v); return VariableDeclStatement(v);
@ -2685,6 +2688,21 @@ sem::Statement* Resolver::FallthroughStatement(
}); });
} }
sem::Statement* Resolver::IncrementDecrementStatement(
const ast::IncrementDecrementStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
auto* lhs = Expression(stmt->lhs);
if (!lhs) {
return false;
}
sem->Behaviors() = lhs->Behaviors();
return ValidateIncrementDecrementStatement(stmt);
});
}
bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
sem::Type* ty, sem::Type* ty,
const Source& usage) { const Source& usage) {

View File

@ -224,6 +224,8 @@ class Resolver {
sem::GlobalVariable* GlobalVariable(const ast::Variable*); sem::GlobalVariable* GlobalVariable(const ast::Variable*);
sem::Statement* Parameter(const ast::Variable*); sem::Statement* Parameter(const ast::Variable*);
sem::IfStatement* IfStatement(const ast::IfStatement*); sem::IfStatement* IfStatement(const ast::IfStatement*);
sem::Statement* IncrementDecrementStatement(
const ast::IncrementDecrementStatement*);
sem::LoopStatement* LoopStatement(const ast::LoopStatement*); sem::LoopStatement* LoopStatement(const ast::LoopStatement*);
sem::Statement* ReturnStatement(const ast::ReturnStatement*); sem::Statement* ReturnStatement(const ast::ReturnStatement*);
sem::Statement* Statement(const ast::Statement*); sem::Statement* Statement(const ast::Statement*);
@ -263,6 +265,8 @@ class Resolver {
bool ValidateFunctionCall(const sem::Call* call); bool ValidateFunctionCall(const sem::Call* call);
bool ValidateGlobalVariable(const sem::Variable* var); bool ValidateGlobalVariable(const sem::Variable* var);
bool ValidateIfStatement(const sem::IfStatement* stmt); bool ValidateIfStatement(const sem::IfStatement* stmt);
bool ValidateIncrementDecrementStatement(
const ast::IncrementDecrementStatement* stmt);
bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr, bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr,
const sem::Type* storage_type); const sem::Type* storage_type);
bool ValidateBuiltinCall(const sem::Call* call); bool ValidateBuiltinCall(const sem::Call* call);

View File

@ -2347,6 +2347,54 @@ bool Resolver::ValidateAssignment(const ast::Statement* a,
return true; return true;
} }
bool Resolver::ValidateIncrementDecrementStatement(
const ast::IncrementDecrementStatement* inc) {
const ast::Expression* lhs = inc->lhs;
// https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement
if (auto* var = ResolvedSymbol<sem::Variable>(lhs)) {
auto* decl = var->Declaration();
if (var->Is<sem::Parameter>()) {
AddError("cannot modify function parameter", lhs->source);
AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
"' is declared here:",
decl->source);
return false;
}
if (decl->is_const) {
AddError("cannot modify constant value", lhs->source);
AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
"' is declared here:",
decl->source);
return false;
}
}
auto const* lhs_ty = TypeOf(lhs);
auto* lhs_ref = lhs_ty->As<sem::Reference>();
if (!lhs_ref) {
// LHS is not a reference, so it has no storage.
AddError("cannot modify value of type '" + TypeNameOf(lhs_ty) + "'",
lhs->source);
return false;
}
if (!lhs_ref->StoreType()->is_integer_scalar()) {
const std::string kind = inc->increment ? "increment" : "decrement";
AddError(kind + " statement can only be applied to an integer scalar",
lhs->source);
return false;
}
if (lhs_ref->Access() == ast::Access::kRead) {
AddError("cannot modify read-only type '" + RawTypeNameOf(lhs_ty) + "'",
inc->source);
return false;
}
return true;
}
bool Resolver::ValidateNoDuplicateAttributes( bool Resolver::ValidateNoDuplicateAttributes(
const ast::AttributeList& attributes) { const ast::AttributeList& attributes) {
std::unordered_map<const TypeInfo*, Source> seen; std::unordered_map<const TypeInfo*, Source> seen;

View File

@ -253,6 +253,7 @@ tint_unittests_source_set("tint_unittests_resolver_src") {
"../../src/tint/resolver/entry_point_validation_test.cc", "../../src/tint/resolver/entry_point_validation_test.cc",
"../../src/tint/resolver/function_validation_test.cc", "../../src/tint/resolver/function_validation_test.cc",
"../../src/tint/resolver/host_shareable_validation_test.cc", "../../src/tint/resolver/host_shareable_validation_test.cc",
"../../src/tint/resolver/increment_decrement_validation_test.cc",
"../../src/tint/resolver/is_host_shareable_test.cc", "../../src/tint/resolver/is_host_shareable_test.cc",
"../../src/tint/resolver/is_storeable_test.cc", "../../src/tint/resolver/is_storeable_test.cc",
"../../src/tint/resolver/pipeline_overridable_constant_test.cc", "../../src/tint/resolver/pipeline_overridable_constant_test.cc",