diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index ae9ef18beb..88ca0ea1ea 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -747,6 +747,7 @@ if(TINT_BUILD_TESTS) resolver/entry_point_validation_test.cc resolver/function_validation_test.cc resolver/host_shareable_validation_test.cc + resolver/increment_decrement_validation_test.cc resolver/inferred_type_test.cc resolver/is_host_shareable_test.cc resolver/is_storeable_test.cc diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc index 8a0db8bacd..be83f7391f 100644 --- a/src/tint/resolver/dependency_graph.cc +++ b/src/tint/resolver/dependency_graph.cc @@ -234,6 +234,9 @@ class DependencyScanner { TraverseStatement(l->continuing); TraverseStatement(l->body); }, + [&](const ast::IncrementDecrementStatement* i) { + TraverseExpression(i->lhs); + }, [&](const ast::LoopStatement* l) { scope_stack_.Push(); TINT_DEFER(scope_stack_.Pop()); diff --git a/src/tint/resolver/increment_decrement_validation_test.cc b/src/tint/resolver/increment_decrement_validation_test.cc new file mode 100644 index 0000000000..d97facf448 --- /dev/null +++ b/src/tint/resolver/increment_decrement_validation_test.cc @@ -0,0 +1,234 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/resolver/resolver.h" + +#include "gmock/gmock.h" +#include "src/tint/resolver/resolver_test_helper.h" + +namespace tint::resolver { +namespace { + +using ResolverIncrementDecrementValidationTest = ResolverTest; + +TEST_F(ResolverIncrementDecrementValidationTest, Increment_Signed) { + // var a : i32 = 2; + // a++; + auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2)); + WrapInFunction(var, Increment(Source{{12, 34}}, "a")); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverIncrementDecrementValidationTest, Decrement_Signed) { + // var a : i32 = 2; + // a--; + auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2)); + WrapInFunction(var, Decrement(Source{{12, 34}}, "a")); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverIncrementDecrementValidationTest, Increment_Unsigned) { + // var a : u32 = 2u; + // a++; + auto* var = Var("a", ty.u32(), ast::StorageClass::kNone, Expr(2u)); + WrapInFunction(var, Increment(Source{{12, 34}}, "a")); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverIncrementDecrementValidationTest, Decrement_Unsigned) { + // var a : u32 = 2u; + // a--; + auto* var = Var("a", ty.u32(), ast::StorageClass::kNone, Expr(2u)); + WrapInFunction(var, Decrement(Source{{12, 34}}, "a")); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverIncrementDecrementValidationTest, ThroughPointer) { + // var a : i32; + // let b : ptr = &a; + // *b++; + auto* var_a = Var("a", ty.i32(), ast::StorageClass::kFunction); + auto* var_b = Const("b", ty.pointer(ast::StorageClass::kFunction), + AddressOf(Expr("a"))); + WrapInFunction(var_a, var_b, Increment(Source{{12, 34}}, Deref("b"))); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverIncrementDecrementValidationTest, ThroughArray) { + // var a : array; + // a[1]++; + auto* var_a = Var("a", ty.array(ty.i32(), 4), ast::StorageClass::kNone); + WrapInFunction(var_a, Increment(Source{{12, 34}}, IndexAccessor("a", 1))); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverIncrementDecrementValidationTest, ThroughVector_Index) { + // var a : vec4; + // a.y++; + auto* var_a = Var("a", ty.vec4(ty.i32()), ast::StorageClass::kNone); + WrapInFunction(var_a, Increment(Source{{12, 34}}, IndexAccessor("a", 1))); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverIncrementDecrementValidationTest, ThroughVector_Member) { + // var a : vec4; + // a.y++; + auto* var_a = Var("a", ty.vec4(ty.i32()), ast::StorageClass::kNone); + WrapInFunction(var_a, Increment(Source{{12, 34}}, MemberAccessor("a", "y"))); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverIncrementDecrementValidationTest, Float) { + // var a : f32 = 2.0; + // a++; + auto* var = Var("a", ty.f32(), ast::StorageClass::kNone, Expr(2.f)); + auto* inc = Increment(Expr(Source{{12, 34}}, "a")); + WrapInFunction(var, inc); + + ASSERT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: increment statement can only be applied to an " + "integer scalar"); +} + +TEST_F(ResolverIncrementDecrementValidationTest, Vector) { + // var a : vec4; + // a++; + auto* var = Var("a", ty.vec4(), ast::StorageClass::kNone); + auto* inc = Increment(Expr(Source{{12, 34}}, "a")); + WrapInFunction(var, inc); + + ASSERT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: increment statement can only be applied to an " + "integer scalar"); +} + +TEST_F(ResolverIncrementDecrementValidationTest, Atomic) { + // var a : atomic; + // a++; + Global(Source{{12, 34}}, "a", ty.atomic(ty.i32()), + ast::StorageClass::kWorkgroup); + WrapInFunction(Increment(Expr(Source{{56, 78}}, "a"))); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "56:78 error: increment statement can only be applied to an " + "integer scalar"); +} + +TEST_F(ResolverIncrementDecrementValidationTest, Literal) { + // 1++; + WrapInFunction(Increment(Expr(Source{{56, 78}}, 1))); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "56:78 error: cannot modify value of type 'i32'"); +} + +TEST_F(ResolverIncrementDecrementValidationTest, Constant) { + // let a = 1; + // a++; + auto* a = Const(Source{{12, 34}}, "a", nullptr, Expr(1)); + WrapInFunction(a, Increment(Expr(Source{{56, 78}}, "a"))); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(56:78 error: cannot modify constant value +12:34 note: 'a' is declared here:)"); +} + +TEST_F(ResolverIncrementDecrementValidationTest, Parameter) { + // fn func(a : i32) + // { + // a++; + // } + auto* a = Param(Source{{12, 34}}, "a", ty.i32()); + Func("func", {a}, ty.void_(), {Increment(Expr(Source{{56, 78}}, "a"))}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(56:78 error: cannot modify function parameter +12:34 note: 'a' is declared here:)"); +} + +TEST_F(ResolverIncrementDecrementValidationTest, ReturnValue) { + // fn func() -> i32 { + // return 0; + // } + // { + // a++; + // } + Func("func", {}, ty.i32(), {Return(0)}); + WrapInFunction(Increment(Call(Source{{56, 78}}, "func"))); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(56:78 error: cannot modify value of type 'i32')"); +} + +TEST_F(ResolverIncrementDecrementValidationTest, ReadOnlyBuffer) { + // @group(0) @binding(0) var a : i32; + // { + // a++; + // } + Global(Source{{12, 34}}, "a", ty.i32(), ast::StorageClass::kStorage, + ast::Access::kRead, GroupAndBinding(0, 0)); + WrapInFunction(Increment(Source{{56, 78}}, "a")); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + "56:78 error: cannot modify read-only type 'ref'"); +} + +TEST_F(ResolverIncrementDecrementValidationTest, Phony) { + // _++; + WrapInFunction(Increment(Phony(Source{{56, 78}}))); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "56:78 error: cannot modify value of type 'void'"); +} + +TEST_F(ResolverIncrementDecrementValidationTest, InForLoopInit) { + // var a : i32 = 2; + // for (a++; ; ) { + // break; + // } + auto* a = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2)); + auto* loop = + For(Increment(Source{{56, 78}}, "a"), nullptr, nullptr, Block(Break())); + WrapInFunction(a, loop); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverIncrementDecrementValidationTest, InForLoopCont) { + // var a : i32 = 2; + // for (; ; a++) { + // break; + // } + auto* a = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2)); + auto* loop = + For(nullptr, nullptr, Increment(Source{{56, 78}}, "a"), Block(Break())); + WrapInFunction(a, loop); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +} // namespace +} // namespace tint::resolver diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index f6797b19c1..befbd8b99a 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -868,6 +868,9 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) { [&](const ast::FallthroughStatement* f) { return FallthroughStatement(f); }, + [&](const ast::IncrementDecrementStatement* i) { + return IncrementDecrementStatement(i); + }, [&](const ast::ReturnStatement* r) { return ReturnStatement(r); }, [&](const ast::VariableDeclStatement* v) { return VariableDeclStatement(v); @@ -2685,6 +2688,21 @@ sem::Statement* Resolver::FallthroughStatement( }); } +sem::Statement* Resolver::IncrementDecrementStatement( + const ast::IncrementDecrementStatement* stmt) { + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); + return StatementScope(stmt, sem, [&] { + auto* lhs = Expression(stmt->lhs); + if (!lhs) { + return false; + } + sem->Behaviors() = lhs->Behaviors(); + + return ValidateIncrementDecrementStatement(stmt); + }); +} + bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, sem::Type* ty, const Source& usage) { diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index d5d1632e85..e257682ff9 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -224,6 +224,8 @@ class Resolver { sem::GlobalVariable* GlobalVariable(const ast::Variable*); sem::Statement* Parameter(const ast::Variable*); sem::IfStatement* IfStatement(const ast::IfStatement*); + sem::Statement* IncrementDecrementStatement( + const ast::IncrementDecrementStatement*); sem::LoopStatement* LoopStatement(const ast::LoopStatement*); sem::Statement* ReturnStatement(const ast::ReturnStatement*); sem::Statement* Statement(const ast::Statement*); @@ -263,6 +265,8 @@ class Resolver { bool ValidateFunctionCall(const sem::Call* call); bool ValidateGlobalVariable(const sem::Variable* var); bool ValidateIfStatement(const sem::IfStatement* stmt); + bool ValidateIncrementDecrementStatement( + const ast::IncrementDecrementStatement* stmt); bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr, const sem::Type* storage_type); bool ValidateBuiltinCall(const sem::Call* call); diff --git a/src/tint/resolver/resolver_validation.cc b/src/tint/resolver/resolver_validation.cc index 02ec786872..fb018d64b0 100644 --- a/src/tint/resolver/resolver_validation.cc +++ b/src/tint/resolver/resolver_validation.cc @@ -2347,6 +2347,54 @@ bool Resolver::ValidateAssignment(const ast::Statement* a, return true; } +bool Resolver::ValidateIncrementDecrementStatement( + const ast::IncrementDecrementStatement* inc) { + const ast::Expression* lhs = inc->lhs; + + // https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement + + if (auto* var = ResolvedSymbol(lhs)) { + auto* decl = var->Declaration(); + if (var->Is()) { + AddError("cannot modify function parameter", lhs->source); + AddNote("'" + builder_->Symbols().NameFor(decl->symbol) + + "' is declared here:", + decl->source); + return false; + } + if (decl->is_const) { + AddError("cannot modify constant value", lhs->source); + AddNote("'" + builder_->Symbols().NameFor(decl->symbol) + + "' is declared here:", + decl->source); + return false; + } + } + + auto const* lhs_ty = TypeOf(lhs); + auto* lhs_ref = lhs_ty->As(); + if (!lhs_ref) { + // LHS is not a reference, so it has no storage. + AddError("cannot modify value of type '" + TypeNameOf(lhs_ty) + "'", + lhs->source); + return false; + } + + if (!lhs_ref->StoreType()->is_integer_scalar()) { + const std::string kind = inc->increment ? "increment" : "decrement"; + AddError(kind + " statement can only be applied to an integer scalar", + lhs->source); + return false; + } + + if (lhs_ref->Access() == ast::Access::kRead) { + AddError("cannot modify read-only type '" + RawTypeNameOf(lhs_ty) + "'", + inc->source); + return false; + } + return true; +} + bool Resolver::ValidateNoDuplicateAttributes( const ast::AttributeList& attributes) { std::unordered_map seen; diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn index 7cbdff782d..fa835d63e3 100644 --- a/test/tint/BUILD.gn +++ b/test/tint/BUILD.gn @@ -253,6 +253,7 @@ tint_unittests_source_set("tint_unittests_resolver_src") { "../../src/tint/resolver/entry_point_validation_test.cc", "../../src/tint/resolver/function_validation_test.cc", "../../src/tint/resolver/host_shareable_validation_test.cc", + "../../src/tint/resolver/increment_decrement_validation_test.cc", "../../src/tint/resolver/is_host_shareable_test.cc", "../../src/tint/resolver/is_storeable_test.cc", "../../src/tint/resolver/pipeline_overridable_constant_test.cc",