From 2f9a98870eae19d1a69d67d8c74ce69f727f5c3d Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Sat, 17 Dec 2022 02:20:04 +0000 Subject: [PATCH] tint: Implement sem::Load The resolver now wraps sem::Expression objects with a sem::Load object anywhere that the load rule is invoked. sem::Expression provides an `UnwrapLoad()` method that returns the inner expression (or passthrough, if no load is present), which is analaguous to Type::UnwrapRef(). The logic for alias analysis in `RegisterLoadIfNeeded` has been folded into the new `Resolver::Load` method. Fixed up many transforms and tests. The only difference in output is for a single SPIR-V backend test, where some IDs have changed due to slight re-ordering of when expressions are generated. There may be further clean-ups possible (e.g. removing unnecessary calls to `UnwrapRef`, and simplifying places in the SPIR-V writer or transforms that deal with memory accesses), but these can be addressed in future patches. Fixed: tint:1654 Change-Id: I69adecfe9251faae46546b64d0cdc29eea26cd4e Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99706 Commit-Queue: James Price Kokoro: Kokoro Reviewed-by: Antonio Maiorano Reviewed-by: Ben Clayton --- src/dawn/tests/end2end/SamplerTests.cpp | 86 ++-- src/tint/BUILD.gn | 1 + src/tint/CMakeLists.txt | 1 + src/tint/resolver/alias_analysis_test.cc | 31 ++ src/tint/resolver/array_accessor_test.cc | 65 ++- src/tint/resolver/builtin_validation_test.cc | 2 +- src/tint/resolver/load_test.cc | 370 ++++++++++++++++++ src/tint/resolver/ptr_ref_test.cc | 11 +- src/tint/resolver/resolver.cc | 175 +++++---- src/tint/resolver/resolver.h | 19 +- src/tint/resolver/resolver_test.cc | 31 +- src/tint/resolver/resolver_test_helper.h | 2 +- src/tint/resolver/side_effects_test.cc | 8 +- src/tint/resolver/uniformity.cc | 6 +- src/tint/resolver/validator.cc | 2 +- src/tint/resolver/variable_test.cc | 12 +- src/tint/sem/expression.cc | 16 + src/tint/sem/expression.h | 6 + src/tint/transform/combine_samplers.cc | 11 +- src/tint/transform/decompose_memory_access.cc | 35 +- .../transform/decompose_strided_matrix.cc | 2 +- src/tint/transform/first_index_offset.cc | 2 +- .../transform/multiplanar_external_texture.cc | 5 +- src/tint/transform/packed_vec3.cc | 12 +- .../transform/promote_side_effects_to_decl.cc | 2 +- src/tint/transform/renamer.cc | 2 +- src/tint/transform/renamer_test.cc | 12 +- src/tint/transform/robustness.cc | 4 +- src/tint/transform/spirv_atomic.cc | 6 +- src/tint/transform/std140.cc | 2 +- src/tint/transform/unshadow.cc | 8 +- src/tint/transform/unshadow_test.cc | 26 ++ src/tint/writer/glsl/generator_impl.cc | 2 +- src/tint/writer/hlsl/generator_impl.cc | 2 +- src/tint/writer/msl/generator_impl.cc | 2 +- src/tint/writer/spirv/builder.cc | 111 +++--- src/tint/writer/spirv/builder.h | 30 +- .../spirv/builder_accessor_expression_test.cc | 6 +- .../access/var/vector.wgsl.expected.spvasm | 4 +- 39 files changed, 808 insertions(+), 322 deletions(-) create mode 100644 src/tint/resolver/load_test.cc diff --git a/src/dawn/tests/end2end/SamplerTests.cpp b/src/dawn/tests/end2end/SamplerTests.cpp index e6de979fbf..822c1136db 100644 --- a/src/dawn/tests/end2end/SamplerTests.cpp +++ b/src/dawn/tests/end2end/SamplerTests.cpp @@ -55,36 +55,6 @@ class SamplerTest : public DawnTest { DawnTest::SetUp(); mRenderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize); - auto vsModule = utils::CreateShaderModule(device, R"( - @vertex - fn main(@builtin(vertex_index) VertexIndex : u32) -> @builtin(position) vec4 { - var pos = array, 6>( - vec2(-2.0, -2.0), - vec2(-2.0, 2.0), - vec2( 2.0, -2.0), - vec2(-2.0, 2.0), - vec2( 2.0, -2.0), - vec2( 2.0, 2.0)); - return vec4(pos[VertexIndex], 0.0, 1.0); - } - )"); - auto fsModule = utils::CreateShaderModule(device, R"( - @group(0) @binding(0) var sampler0 : sampler; - @group(0) @binding(1) var texture0 : texture_2d; - - @fragment - fn main(@builtin(position) FragCoord : vec4) -> @location(0) vec4 { - return textureSample(texture0, sampler0, FragCoord.xy / vec2(2.0, 2.0)); - })"); - - utils::ComboRenderPipelineDescriptor pipelineDescriptor; - pipelineDescriptor.vertex.module = vsModule; - pipelineDescriptor.cFragment.module = fsModule; - pipelineDescriptor.cTargets[0].format = mRenderPass.colorFormat; - - mPipeline = device.CreateRenderPipeline(&pipelineDescriptor); - mBindGroupLayout = mPipeline.GetBindGroupLayout(0); - wgpu::TextureDescriptor descriptor; descriptor.dimension = wgpu::TextureDimension::e2D; descriptor.size.width = 2; @@ -119,6 +89,31 @@ class SamplerTest : public DawnTest { mTextureView = texture.CreateView(); } + void InitShaders(const char* frag_shader) { + auto vsModule = utils::CreateShaderModule(device, R"( + @vertex + fn main(@builtin(vertex_index) VertexIndex : u32) -> @builtin(position) vec4 { + var pos = array, 6>( + vec2(-2.0, -2.0), + vec2(-2.0, 2.0), + vec2( 2.0, -2.0), + vec2(-2.0, 2.0), + vec2( 2.0, -2.0), + vec2( 2.0, 2.0)); + return vec4(pos[VertexIndex], 0.0, 1.0); + } + )"); + auto fsModule = utils::CreateShaderModule(device, frag_shader); + + utils::ComboRenderPipelineDescriptor pipelineDescriptor; + pipelineDescriptor.vertex.module = vsModule; + pipelineDescriptor.cFragment.module = fsModule; + pipelineDescriptor.cTargets[0].format = mRenderPass.colorFormat; + + mPipeline = device.CreateRenderPipeline(&pipelineDescriptor); + mBindGroupLayout = mPipeline.GetBindGroupLayout(0); + } + void TestAddressModes(AddressModeTestCase u, AddressModeTestCase v, AddressModeTestCase w) { wgpu::Sampler sampler; { @@ -169,6 +164,37 @@ class SamplerTest : public DawnTest { // Test drawing a rect with a checkerboard texture with different address modes. TEST_P(SamplerTest, AddressMode) { + InitShaders(R"( + @group(0) @binding(0) var sampler0 : sampler; + @group(0) @binding(1) var texture0 : texture_2d; + + @fragment + fn main(@builtin(position) FragCoord : vec4) -> @location(0) vec4 { + return textureSample(texture0, sampler0, FragCoord.xy / vec2(2.0, 2.0)); + })"); + for (auto u : addressModes) { + for (auto v : addressModes) { + for (auto w : addressModes) { + TestAddressModes(u, v, w); + } + } + } +} + +// Test that passing texture and sampler objects through user-defined functions works correctly. +TEST_P(SamplerTest, PassThroughUserFunctionParameters) { + InitShaders(R"( + @group(0) @binding(0) var sampler0 : sampler; + @group(0) @binding(1) var texture0 : texture_2d; + + fn foo(t : texture_2d, s : sampler, FragCoord : vec4) -> vec4 { + return textureSample(t, s, FragCoord.xy / vec2(2.0, 2.0)); + } + + @fragment + fn main(@builtin(position) FragCoord : vec4) -> @location(0) vec4 { + return foo(texture0, sampler0, FragCoord); + })"); for (auto u : addressModes) { for (auto v : addressModes) { for (auto w : addressModes) { diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 484c6f9ba3..acb05e00e0 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -1205,6 +1205,7 @@ if (tint_build_unittests) { "resolver/intrinsic_table_test.cc", "resolver/is_host_shareable_test.cc", "resolver/is_storeable_test.cc", + "resolver/load_test.cc", "resolver/materialize_test.cc", "resolver/override_test.cc", "resolver/ptr_ref_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index ee7c17c9b3..077774654d 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -913,6 +913,7 @@ if(TINT_BUILD_TESTS) resolver/intrinsic_table_test.cc resolver/is_host_shareable_test.cc resolver/is_storeable_test.cc + resolver/load_test.cc resolver/materialize_test.cc resolver/override_test.cc resolver/ptr_ref_test.cc diff --git a/src/tint/resolver/alias_analysis_test.cc b/src/tint/resolver/alias_analysis_test.cc index 38710b6678..490f54cac0 100644 --- a/src/tint/resolver/alias_analysis_test.cc +++ b/src/tint/resolver/alias_analysis_test.cc @@ -804,6 +804,37 @@ TEST_F(ResolverAliasAnalysisTest, Write_MemberAccessor) { 12:34 note: aliases with another argument passed here)"); } +TEST_F(ResolverAliasAnalysisTest, Read_MultiComponentSwizzle) { + // fn f2(p1 : ptr, p2 : ptr) { + // _ = (*p2).zy; + // *p1 = vec4(); + // } + // fn f1() { + // var v : vec4; + // f2(&v, &v); + // } + Structure("S", utils::Vector{Member("a", ty.i32())}); + Func("f2", + utils::Vector{ + Param("p1", ty.pointer(ty.vec4(), ast::AddressSpace::kFunction)), + Param("p2", ty.pointer(ty.vec4(), ast::AddressSpace::kFunction)), + }, + ty.void_(), + utils::Vector{ + Assign(Phony(), MemberAccessor(Deref("p2"), "zy")), + Assign(Deref("p1"), Construct(ty.vec4())), + }); + Func("f1", utils::Empty, ty.void_(), + utils::Vector{ + Decl(Var("v", ty.vec4())), + CallStmt( + Call("f2", AddressOf(Source{{12, 34}}, "v"), AddressOf(Source{{56, 76}}, "v"))), + }); + EXPECT_TRUE(r()->Resolve()) << r()->error(); + EXPECT_EQ(r()->error(), R"(56:76 warning: invalid aliased pointer argument +12:34 note: aliases with another argument passed here)"); +} + TEST_F(ResolverAliasAnalysisTest, SinglePointerReadWrite) { // Test that we can both read and write from a single pointer parameter. // diff --git a/src/tint/resolver/array_accessor_test.cc b/src/tint/resolver/array_accessor_test.cc index 9c4678969d..65fbb8313c 100644 --- a/src/tint/resolver/array_accessor_test.cc +++ b/src/tint/resolver/array_accessor_test.cc @@ -43,7 +43,7 @@ TEST_F(ResolverIndexAccessorTest, Matrix_Dynamic_Ref) { EXPECT_TRUE(r()->Resolve()) << r()->error(); - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -58,7 +58,7 @@ TEST_F(ResolverIndexAccessorTest, Matrix_BothDimensions_Dynamic_Ref) { EXPECT_TRUE(r()->Resolve()) << r()->error(); - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -73,7 +73,7 @@ TEST_F(ResolverIndexAccessorTest, Matrix_Dynamic) { EXPECT_TRUE(r()->Resolve()); EXPECT_EQ(r()->error(), ""); - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -108,13 +108,10 @@ TEST_F(ResolverIndexAccessorTest, Matrix) { EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(acc), nullptr); - ASSERT_TRUE(TypeOf(acc)->Is()); + ASSERT_TRUE(TypeOf(acc)->Is()); + EXPECT_EQ(TypeOf(acc)->As()->Width(), 3u); - auto* ref = TypeOf(acc)->As(); - ASSERT_TRUE(ref->StoreType()->Is()); - EXPECT_EQ(ref->StoreType()->As()->Width(), 3u); - - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -129,12 +126,9 @@ TEST_F(ResolverIndexAccessorTest, Matrix_BothDimensions) { EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(acc), nullptr); - ASSERT_TRUE(TypeOf(acc)->Is()); + EXPECT_TRUE(TypeOf(acc)->Is()); - auto* ref = TypeOf(acc)->As(); - EXPECT_TRUE(ref->StoreType()->Is()); - - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -157,7 +151,7 @@ TEST_F(ResolverIndexAccessorTest, Vector_Dynamic_Ref) { EXPECT_TRUE(r()->Resolve()); - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -181,12 +175,9 @@ TEST_F(ResolverIndexAccessorTest, Vector) { EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(acc), nullptr); - ASSERT_TRUE(TypeOf(acc)->Is()); + EXPECT_TRUE(TypeOf(acc)->Is()); - auto* ref = TypeOf(acc)->As(); - EXPECT_TRUE(ref->StoreType()->Is()); - - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -197,12 +188,9 @@ TEST_F(ResolverIndexAccessorTest, Array_Literal_i32) { auto* acc = IndexAccessor("my_var", 2_i); WrapInFunction(acc); EXPECT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(acc), nullptr); - auto* ref = TypeOf(acc)->As(); - ASSERT_NE(ref, nullptr); - EXPECT_TRUE(ref->StoreType()->Is()); + EXPECT_TRUE(TypeOf(acc)->Is()); - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -213,12 +201,9 @@ TEST_F(ResolverIndexAccessorTest, Array_Literal_u32) { auto* acc = IndexAccessor("my_var", 2_u); WrapInFunction(acc); EXPECT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(acc), nullptr); - auto* ref = TypeOf(acc)->As(); - ASSERT_NE(ref, nullptr); - EXPECT_TRUE(ref->StoreType()->Is()); + EXPECT_TRUE(TypeOf(acc)->Is()); - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -229,12 +214,9 @@ TEST_F(ResolverIndexAccessorTest, Array_Literal_AInt) { auto* acc = IndexAccessor("my_var", 2_a); WrapInFunction(acc); EXPECT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(acc), nullptr); - auto* ref = TypeOf(acc)->As(); - ASSERT_NE(ref, nullptr); - EXPECT_TRUE(ref->StoreType()->Is()); + EXPECT_TRUE(TypeOf(acc)->Is()); - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -251,12 +233,9 @@ TEST_F(ResolverIndexAccessorTest, Alias_Array) { EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(acc), nullptr); - ASSERT_TRUE(TypeOf(acc)->Is()); + EXPECT_TRUE(TypeOf(acc)->Is()); - auto* ref = TypeOf(acc)->As(); - EXPECT_TRUE(ref->StoreType()->Is()); - - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -292,7 +271,7 @@ TEST_F(ResolverIndexAccessorTest, Array_Dynamic_I32) { EXPECT_TRUE(r()->Resolve()); EXPECT_EQ(r()->error(), ""); - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -325,7 +304,7 @@ TEST_F(ResolverIndexAccessorTest, Array_Literal_I32) { }); EXPECT_TRUE(r()->Resolve()) << r()->error(); - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); @@ -346,7 +325,7 @@ TEST_F(ResolverIndexAccessorTest, Expr_Deref_FuncGoodParent) { EXPECT_TRUE(r()->Resolve()) << r()->error(); - auto idx_sem = Sem().Get(acc); + auto idx_sem = Sem().Get(acc)->UnwrapLoad()->As(); ASSERT_NE(idx_sem, nullptr); EXPECT_EQ(idx_sem->Index()->Declaration(), acc->index); EXPECT_EQ(idx_sem->Object()->Declaration(), acc->object); diff --git a/src/tint/resolver/builtin_validation_test.cc b/src/tint/resolver/builtin_validation_test.cc index 7340387693..6bda08aefd 100644 --- a/src/tint/resolver/builtin_validation_test.cc +++ b/src/tint/resolver/builtin_validation_test.cc @@ -181,7 +181,7 @@ TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalVarUsedAsVariable WrapInFunction(Decl(Var("v", use))); ASSERT_TRUE(r()->Resolve()) << r()->error(); - auto* sem = Sem().Get(use); + auto* sem = Sem().Get(use)->UnwrapLoad()->As(); ASSERT_NE(sem, nullptr); EXPECT_EQ(sem->Variable(), Sem().Get(mix)); } diff --git a/src/tint/resolver/load_test.cc b/src/tint/resolver/load_test.cc new file mode 100644 index 0000000000..fea60ff513 --- /dev/null +++ b/src/tint/resolver/load_test.cc @@ -0,0 +1,370 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/resolver/resolver.h" +#include "src/tint/resolver/resolver_test_helper.h" +#include "src/tint/sem/test_helper.h" + +#include "src/tint/sem/load.h" +#include "src/tint/type/reference.h" + +#include "gmock/gmock.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::resolver { +namespace { + +using ResolverLoadTest = ResolverTest; + +TEST_F(ResolverLoadTest, VarInitializer) { + // var ref = 1i; + // var v = ref; + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_i)), // + Var("v", ident)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, LetInitializer) { + // var ref = 1i; + // let l = ref; + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_i)), // + Let("l", ident)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, Assignment) { + // var ref = 1i; + // var v : i32; + // v = ref; + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_i)), // + Var("v", ty.i32()), // + Assign("v", ident)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, CompoundAssignment) { + // var ref = 1i; + // var v : i32; + // v += ref; + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_i)), // + Var("v", ty.i32()), // + CompoundAssign("v", ident, ast::BinaryOp::kAdd)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, UnaryOp) { + // var ref = 1i; + // var v = -ref; + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_i)), // + Var("v", Negation(ident))); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, UnaryOp_NoLoad) { + // var ref = 1i; + // let v = &ref; + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_i)), // + Let("v", AddressOf(ident))); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* var_user = Sem().Get(ident); + ASSERT_NE(var_user, nullptr); + EXPECT_TRUE(var_user->Type()->Is()); + EXPECT_TRUE(var_user->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, BinaryOp) { + // var ref = 1i; + // var v = ref * 1i; + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_i)), // + Var("v", Mul(ident, 1_i))); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, Index) { + // var ref = 1i; + // var v = array(1i, 2i, 3i)[ref]; + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_i)), // + IndexAccessor(array(1_i, 2_i, 3_i), ident)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, MultiComponentSwizzle) { + // var ref = vec4(1); + // var v = ref.xyz; + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Construct(ty.vec4(), 1_i)), // + Var("v", MemberAccessor(ident, "xyz"))); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, Bitcast) { + // var ref = 1f; + // var v = bitcast(ref); + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_f)), // + Bitcast(ident)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, BuiltinArg) { + // var ref = 1f; + // var v = abs(ref); + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_f)), // + Call("abs", ident)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, FunctionArg) { + // fn f(x : f32) {} + // var ref = 1f; + // f(ref); + Func("f", utils::Vector{Param("x", ty.f32())}, ty.void_(), utils::Empty); + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_f)), // + CallStmt(Call("f", ident))); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, FunctionArg_Handles) { + // @group(0) @binding(0) var t : texture_2d; + // @group(0) @binding(1) var s : sampler; + // fn f(tp : texture_2d, sp : sampler) -> vec4 { + // return textureSampleLevel(tp, sp, vec2(), 0); + // } + // f(t, s); + GlobalVar("t", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), + utils::Vector{Group(0_a), Binding(0_a)}); + GlobalVar("s", ty.sampler(ast::SamplerKind::kSampler), utils::Vector{Group(0_a), Binding(1_a)}); + Func("f", + utils::Vector{ + Param("tp", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32())), + Param("sp", ty.sampler(ast::SamplerKind::kSampler)), + }, + ty.vec4(), + utils::Vector{ + Return(Call("textureSampleLevel", "tp", "sp", Construct(ty.vec2()), 0_a)), + }); + auto* t_ident = Expr("t"); + auto* s_ident = Expr("s"); + WrapInFunction(CallStmt(Call("f", t_ident, s_ident))); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + { + auto* load = Sem().Get(t_ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); + } + { + auto* load = Sem().Get(s_ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); + } +} + +TEST_F(ResolverLoadTest, FunctionReturn) { + // var ref = 1f; + // return ref; + auto* ident = Expr("ref"); + Func("f", utils::Empty, ty.f32(), + utils::Vector{ + Decl(Var("ref", Expr(1_f))), + Return(ident), + }); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, IfCond) { + // var ref = false; + // if (ref) {} + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(false)), // + If(ident, Block())); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, Switch) { + // var ref = 1i; + // switch (ref) { + // default: + // } + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_i)), // + Switch(ident, DefaultCase())); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, BreakIfCond) { + // var ref = false; + // loop { + // continuing { + // break if (ref); + // } + // } + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(false)), // + Loop(Block(), Block(BreakIf(ident)))); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, ForCond) { + // var ref = false; + // for (; ref; ) {} + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(false)), // + For(nullptr, ident, nullptr, Block())); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, WhileCond) { + // var ref = false; + // while (ref) {} + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(false)), // + While(ident, Block())); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* load = Sem().Get(ident); + ASSERT_NE(load, nullptr); + EXPECT_TRUE(load->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->Is()); + EXPECT_TRUE(load->Reference()->Type()->UnwrapRef()->Is()); +} + +TEST_F(ResolverLoadTest, AddressOf) { + // var ref = 1i; + // let l = &ref; + auto* ident = Expr("ref"); + WrapInFunction(Var("ref", Expr(1_i)), // + Let("l", AddressOf(ident))); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* no_load = Sem().Get(ident); + ASSERT_NE(no_load, nullptr); + EXPECT_TRUE(no_load->Type()->Is()); // No load +} + +} // namespace +} // namespace tint::resolver diff --git a/src/tint/resolver/ptr_ref_test.cc b/src/tint/resolver/ptr_ref_test.cc index ecaafa7eb6..417642369b 100644 --- a/src/tint/resolver/ptr_ref_test.cc +++ b/src/tint/resolver/ptr_ref_test.cc @@ -14,6 +14,7 @@ #include "src/tint/resolver/resolver.h" #include "src/tint/resolver/resolver_test_helper.h" +#include "src/tint/sem/load.h" #include "src/tint/type/reference.h" #include "gmock/gmock.h" @@ -52,8 +53,14 @@ TEST_F(ResolverPtrRefTest, AddressOfThenDeref) { EXPECT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_TRUE(TypeOf(expr)->Is()); - EXPECT_TRUE(TypeOf(expr)->As()->StoreType()->Is()); + auto* load = Sem().Get(expr); + ASSERT_NE(load, nullptr); + + auto* ref = load->Reference(); + ASSERT_NE(ref, nullptr); + + ASSERT_TRUE(ref->Type()->Is()); + EXPECT_TRUE(ref->Type()->As()->StoreType()->Is()); } TEST_F(ResolverPtrRefTest, DefaultPtrAddressSpace) { diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index f0cbada3a1..5d9dfca4ed 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -58,6 +58,7 @@ #include "src/tint/sem/function.h" #include "src/tint/sem/if_statement.h" #include "src/tint/sem/index_accessor_expression.h" +#include "src/tint/sem/load.h" #include "src/tint/sem/loop_statement.h" #include "src/tint/sem/materialize.h" #include "src/tint/sem/member_accessor_expression.h" @@ -379,13 +380,11 @@ sem::Variable* Resolver::Let(const ast::Let* v, bool is_global) { return nullptr; } - auto* rhs = Materialize(Expression(v->initializer), ty); + auto* rhs = Load(Materialize(Expression(v->initializer), ty)); if (!rhs) { return nullptr; } - RegisterLoadIfNeeded(rhs); - // If the variable has no declared type, infer it from the RHS if (!ty) { ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS @@ -432,8 +431,11 @@ sem::Variable* Resolver::Override(const ast::Override* v) { const sem::Expression* rhs = nullptr; - // Does the variable have a initializer? + // Does the variable have an initializer? if (v->initializer) { + // Note: RHS must be a const or override expression, which excludes references. + // So there's no need to load or unwrap references here. + ExprEvalStageConstraint constraint{sem::EvaluationStage::kOverride, "override initializer"}; TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); rhs = Materialize(Expression(v->initializer), ty); @@ -443,7 +445,7 @@ sem::Variable* Resolver::Override(const ast::Override* v) { // If the variable has no declared type, infer it from the RHS if (!ty) { - ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS + ty = rhs->Type(); } } else if (!ty) { AddError("override declaration requires a type or initializer", v->source); @@ -529,6 +531,9 @@ sem::Variable* Resolver::Const(const ast::Const* c, bool is_global) { } } + // Note: RHS must be a const expression, which excludes references. + // So there's no need to load or unwrap references here. + if (ty) { // If an explicit type was specified, materialize to that type rhs = Materialize(rhs, ty); @@ -584,16 +589,14 @@ sem::Variable* Resolver::Var(const ast::Var* var, bool is_global) { }; TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - rhs = Materialize(Expression(var->initializer), storage_ty); + rhs = Load(Materialize(Expression(var->initializer), storage_ty)); if (!rhs) { return nullptr; } // If the variable has no declared type, infer it from the RHS if (!storage_ty) { - storage_ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS + storage_ty = rhs->Type(); } - - RegisterLoadIfNeeded(rhs); } if (!storage_ty) { @@ -1315,7 +1318,7 @@ sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) { auto* sem = builder_->create(stmt, current_compound_statement_, current_function_); return StatementScope(stmt, sem, [&] { - auto* cond = Expression(stmt->condition); + auto* cond = Load(Expression(stmt->condition)); if (!cond) { return false; } @@ -1323,8 +1326,6 @@ sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) { sem->Behaviors() = cond->Behaviors(); sem->Behaviors().Remove(sem::Behavior::kNext); - RegisterLoadIfNeeded(cond); - Mark(stmt->body); auto* body = builder_->create(stmt->body, current_compound_statement_, current_function_); @@ -1412,14 +1413,12 @@ sem::ForLoopStatement* Resolver::ForLoopStatement(const ast::ForLoopStatement* s } if (auto* cond_expr = stmt->condition) { - auto* cond = Expression(cond_expr); + auto* cond = Load(Expression(cond_expr)); if (!cond) { return false; } sem->SetCondition(cond); behaviors.Add(cond->Behaviors()); - - RegisterLoadIfNeeded(cond); } if (auto* continuing = stmt->continuing) { @@ -1457,15 +1456,13 @@ sem::WhileStatement* Resolver::WhileStatement(const ast::WhileStatement* stmt) { return StatementScope(stmt, sem, [&] { auto& behaviors = sem->Behaviors(); - auto* cond = Expression(stmt->condition); + auto* cond = Load(Expression(stmt->condition)); if (!cond) { return false; } sem->SetCondition(cond); behaviors.Add(cond->Behaviors()); - RegisterLoadIfNeeded(cond); - Mark(stmt->body); auto* body = builder_->create( @@ -1592,26 +1589,6 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) { return nullptr; } -void Resolver::RegisterLoadIfNeeded(const sem::Expression* expr) { - if (!expr) { - return; - } - if (!expr->Type()->Is()) { - return; - } - if (!current_function_) { - // There is currently no situation where the Load Rule can be invoked outside of a function. - return; - } - auto& info = alias_analysis_infos_[current_function_]; - Switch( - expr->RootIdentifier(), - [&](const sem::GlobalVariable* global) { - info.module_scope_reads.insert({global, expr}); - }, - [&](const sem::Parameter* param) { info.parameter_reads.insert(param); }); -} - void Resolver::RegisterStore(const sem::Expression* expr) { auto& info = alias_analysis_infos_[current_function_]; Switch( @@ -1778,6 +1755,33 @@ const type::Type* Resolver::ConcreteType(const type::Type* ty, }); } +const sem::Expression* Resolver::Load(const sem::Expression* expr) { + if (!expr) { + // Allow for Load(Expression(blah)), where failures pass through Load() + return nullptr; + } + + if (!expr->Type()->Is()) { + // Expression is not a reference type, so cannot be loaded. Just return expr. + return expr; + } + + auto* load = builder_->create(expr, current_statement_); + load->Behaviors() = expr->Behaviors(); + builder_->Sem().Replace(expr->Declaration(), load); + + // Track the load for the alias analysis. + auto& alias_info = alias_analysis_infos_[current_function_]; + Switch( + expr->RootIdentifier(), + [&](const sem::GlobalVariable* global) { + alias_info.module_scope_reads.insert({global, expr}); + }, + [&](const sem::Parameter* param) { alias_info.parameter_reads.insert(param); }); + + return load; +} + const sem::Expression* Resolver::Materialize(const sem::Expression* expr, const type::Type* target_type /* = nullptr */) { if (!expr) { @@ -1829,8 +1833,8 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr, } template -bool Resolver::MaybeMaterializeArguments(utils::Vector& args, - const sem::CallTarget* target) { +bool Resolver::MaybeMaterializeAndLoadArguments(utils::Vector& args, + const sem::CallTarget* target) { for (size_t i = 0, n = std::min(args.Length(), target->Parameters().Length()); i < n; i++) { const auto* param_ty = target->Parameters()[i]->Type(); if (ShouldMaterializeArgument(param_ty)) { @@ -1840,6 +1844,13 @@ bool Resolver::MaybeMaterializeArguments(utils::VectorIs()) { + auto* load = Load(args[i]); + if (!load) { + return false; + } + args[i] = load; + } } return true; } @@ -1875,7 +1886,7 @@ utils::Result> Resolver::ConvertArgumen } sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* expr) { - auto* idx = Materialize(sem_.Get(expr->index)); + auto* idx = Load(Materialize(sem_.Get(expr->index))); if (!idx) { return nullptr; } @@ -1886,7 +1897,6 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp // vec2(1, 2)[runtime-index] obj = Materialize(obj); } - RegisterLoadIfNeeded(idx); if (!obj) { return nullptr; } @@ -1939,7 +1949,7 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp } sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) { - auto* inner = Materialize(sem_.Get(expr->expr)); + auto* inner = Load(Materialize(sem_.Get(expr->expr))); if (!inner) { return nullptr; } @@ -1948,8 +1958,6 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) { return nullptr; } - RegisterLoadIfNeeded(inner); - const constant::Value* val = nullptr; // TODO(crbug.com/tint/1582): short circuit 'expr' once const eval of Bitcast is implemented. if (auto r = const_eval_.Bitcast(ty, inner)) { @@ -1990,8 +1998,6 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { args.Push(arg); args_stage = sem::EarliestStage(args_stage, arg->Stage()); arg_behaviors.Add(arg->Behaviors()); - - RegisterLoadIfNeeded(arg); } arg_behaviors.Remove(sem::Behavior::kNext); @@ -2008,7 +2014,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { if (!ctor_or_conv.target) { return nullptr; } - if (!MaybeMaterializeArguments(args, ctor_or_conv.target)) { + if (!MaybeMaterializeAndLoadArguments(args, ctor_or_conv.target)) { return nullptr; } @@ -2037,7 +2043,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { // initializer call target. auto arr_or_str_init = [&](const type::Type* ty, const sem::CallTarget* call_target) -> sem::Call* { - if (!MaybeMaterializeArguments(args, call_target)) { + if (!MaybeMaterializeAndLoadArguments(args, call_target)) { return nullptr; } @@ -2325,7 +2331,7 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, } } else { // Materialize arguments if the parameter type is not abstract - if (!MaybeMaterializeArguments(args, builtin.sem)) { + if (!MaybeMaterializeAndLoadArguments(args, builtin.sem)) { return nullptr; } } @@ -2476,14 +2482,17 @@ void Resolver::CollectTextureSamplerPairs(const sem::Builtin* builtin, if (texture_index == -1) { TINT_ICE(Resolver, diagnostics_) << "texture builtin without texture parameter"; } - if (auto* user = args[static_cast(texture_index)]->As()) { + if (auto* user = + args[static_cast(texture_index)]->UnwrapLoad()->As()) { auto* texture = user->Variable(); if (!texture->Type()->UnwrapRef()->Is()) { int sampler_index = signature.IndexOf(sem::ParameterUsage::kSampler); - const sem::Variable* sampler = - sampler_index != -1 - ? args[static_cast(sampler_index)]->As()->Variable() - : nullptr; + const sem::Variable* sampler = sampler_index != -1 + ? args[static_cast(sampler_index)] + ->UnwrapLoad() + ->As() + ->Variable() + : nullptr; current_function_->AddTextureSamplerPair(texture, sampler); } } @@ -2497,7 +2506,7 @@ sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr, auto sym = expr->target.name->symbol; auto name = builder_->Symbols().NameFor(sym); - if (!MaybeMaterializeArguments(args, target)) { + if (!MaybeMaterializeAndLoadArguments(args, target)) { return nullptr; } @@ -2554,11 +2563,11 @@ void Resolver::CollectTextureSamplerPairs(sem::Function* func, const sem::Variable* texture = pair.first; const sem::Variable* sampler = pair.second; if (auto* param = texture->As()) { - texture = args[param->Index()]->As()->Variable(); + texture = args[param->Index()]->UnwrapLoad()->As()->Variable(); } if (sampler) { if (auto* param = sampler->As()) { - sampler = args[param->Index()]->As()->Variable(); + sampler = args[param->Index()]->UnwrapLoad()->As()->Variable(); } } current_function_->AddTextureSamplerPair(texture, sampler); @@ -2820,6 +2829,7 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e return nullptr; } + const sem::Expression* obj_expr = object; if (size == 1) { // A single element swizzle is just the type of the vector. ty = vec->type(); @@ -2831,12 +2841,15 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e // The vector will have a number of components equal to the length of // the swizzle. ty = builder_->create(vec->type(), static_cast(size)); + + // The load rule is invoked before the swizzle, if necessary. + obj_expr = Load(object); } auto val = const_eval_.Swizzle(ty, object, swizzle); if (!val) { return nullptr; } - return builder_->create(expr, ty, current_statement_, val.Get(), object, + return builder_->create(expr, ty, current_statement_, val.Get(), obj_expr, std::move(swizzle), has_side_effects, root_ident); }, @@ -2872,8 +2885,15 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { } } - RegisterLoadIfNeeded(lhs); - RegisterLoadIfNeeded(rhs); + // Load arguments if they are references + lhs = Load(lhs); + if (!lhs) { + return nullptr; + } + rhs = Load(rhs); + if (!rhs) { + return nullptr; + } const constant::Value* value = nullptr; if (stage == sem::EvaluationStage::kConstant) { @@ -2975,6 +2995,14 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { return nullptr; } } + + // Load expr if it is a reference + expr = Load(expr); + if (!expr) { + return nullptr; + } + + stage = expr->Stage(); if (stage == sem::EvaluationStage::kConstant) { if (op.const_eval_fn) { if (auto r = (const_eval_.*op.const_eval_fn)( @@ -2988,7 +3016,6 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { stage = sem::EvaluationStage::kRuntime; } } - RegisterLoadIfNeeded(expr); break; } } @@ -3437,7 +3464,7 @@ sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) { const type::Type* value_ty = nullptr; if (auto* value = stmt->value) { - const auto* expr = Expression(value); + const auto* expr = Load(Expression(value)); if (!expr) { return false; } @@ -3448,9 +3475,8 @@ sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) { } } behaviors.Add(expr->Behaviors() - sem::Behavior::kNext); - value_ty = expr->Type()->UnwrapRef(); - RegisterLoadIfNeeded(expr); + value_ty = expr->Type(); } else { value_ty = builder_->create(); } @@ -3468,15 +3494,13 @@ sem::SwitchStatement* Resolver::SwitchStatement(const ast::SwitchStatement* stmt return StatementScope(stmt, sem, [&] { auto& behaviors = sem->Behaviors(); - const auto* cond = Expression(stmt->condition); + const auto* cond = Load(Expression(stmt->condition)); if (!cond) { return false; } behaviors = cond->Behaviors() - sem::Behavior::kNext; - RegisterLoadIfNeeded(cond); - - auto* cond_ty = cond->Type()->UnwrapRef(); + auto* cond_ty = cond->Type(); // Determine the common type across all selectors and the switch expression // This must materialize to an integer scalar (non-abstract). @@ -3579,7 +3603,10 @@ sem::Statement* Resolver::AssignmentStatement(const ast::AssignmentStatement* st } } - RegisterLoadIfNeeded(rhs); + rhs = Load(rhs); + if (!rhs) { + return false; + } auto& behaviors = sem->Behaviors(); behaviors = rhs->Behaviors(); @@ -3609,7 +3636,7 @@ sem::Statement* Resolver::BreakIfStatement(const ast::BreakIfStatement* stmt) { auto* sem = builder_->create(stmt, current_compound_statement_, current_function_); return StatementScope(stmt, sem, [&] { - auto* cond = Expression(stmt->condition); + auto* cond = Load(Expression(stmt->condition)); if (!cond) { return false; } @@ -3617,8 +3644,6 @@ sem::Statement* Resolver::BreakIfStatement(const ast::BreakIfStatement* stmt) { sem->Behaviors() = cond->Behaviors(); sem->Behaviors().Add(sem::Behavior::kBreak); - RegisterLoadIfNeeded(cond); - return validator_.BreakIfStatement(sem, current_statement_); }); } @@ -3645,12 +3670,11 @@ sem::Statement* Resolver::CompoundAssignmentStatement( return false; } - auto* rhs = Expression(stmt->rhs); + auto* rhs = Load(Expression(stmt->rhs)); if (!rhs) { return false; } - RegisterLoadIfNeeded(rhs); RegisterStore(lhs); sem->Behaviors() = rhs->Behaviors() + lhs->Behaviors(); @@ -3705,7 +3729,6 @@ sem::Statement* Resolver::IncrementDecrementStatement( } sem->Behaviors() = lhs->Behaviors(); - RegisterLoadIfNeeded(lhs); RegisterStore(lhs); return validator_.IncrementDecrementStatement(stmt); diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 8b16abf862..1a29a75452 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -157,10 +157,6 @@ class Resolver { sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*); sem::Expression* UnaryOp(const ast::UnaryOpExpression*); - /// Register a memory load from an expression, to track accesses to root identifiers in order to - /// perform alias analysis. - void RegisterLoadIfNeeded(const sem::Expression* expr); - /// Register a memory store to an expression, to track accesses to root identifiers in order to /// perform alias analysis. void RegisterStore(const sem::Expression* expr); @@ -169,8 +165,11 @@ class Resolver { /// @returns true is the call arguments are free from aliasing issues, false otherwise. bool AliasAnalysis(const sem::Call* call); + /// If `expr` is of a reference type, then Load will create and return a sem::Load node wrapping + /// `expr`. If `expr` is not of a reference type, then Load will just return `expr`. + const sem::Expression* Load(const sem::Expression* expr); + /// If `expr` is not of an abstract-numeric type, then Materialize() will just return `expr`. - /// If `expr` is of an abstract-numeric type: /// * Materialize will create and return a sem::Materialize node wrapping `expr`. /// * The AST -> Sem binding will be updated to point to the new sem::Materialize node. /// * The sem::Materialize node will have a new concrete type, which will be `target_type` if @@ -181,15 +180,19 @@ class Resolver { /// if `expr` has a element type of abstract-float. /// * The sem::Materialize constant value will be the value of `expr` value-converted to the /// materialized type. + /// If `expr` is not of an abstract-numeric type, then Materialize() will just return `expr`. /// If `expr` is nullptr, then Materialize() will also return nullptr. const sem::Expression* Materialize(const sem::Expression* expr, const type::Type* target_type = nullptr); - /// Materializes all the arguments in `args` to the parameter types of `target`. + /// For each argument in `args`: + /// * Calls Materialize() passing the argument and the corresponding parameter type. + /// * Calls Load() passing the argument, iff the corresponding parameter type is not a + /// reference type. /// @returns true on success, false on failure. template - bool MaybeMaterializeArguments(utils::Vector& args, - const sem::CallTarget* target); + bool MaybeMaterializeAndLoadArguments(utils::Vector& args, + const sem::CallTarget* target); /// @returns true if an argument of an abstract numeric type, passed to a parameter of type /// `parameter_ty` should be materialized. diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc index 3c8e3a4044..e5538c558e 100644 --- a/src/tint/resolver/resolver_test.cc +++ b/src/tint/resolver/resolver_test.cc @@ -707,8 +707,7 @@ TEST_F(ResolverTest, Expr_Identifier_GlobalVariable) { EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(ident), nullptr); - ASSERT_TRUE(TypeOf(ident)->Is()); - EXPECT_TRUE(TypeOf(ident)->UnwrapRef()->Is()); + EXPECT_TRUE(TypeOf(ident)->Is()); EXPECT_TRUE(CheckVarUsers(my_var, utils::Vector{ident})); ASSERT_NE(VarOf(ident), nullptr); EXPECT_EQ(VarOf(ident)->Declaration(), my_var); @@ -788,8 +787,7 @@ TEST_F(ResolverTest, Expr_Identifier_FunctionVariable) { EXPECT_TRUE(TypeOf(my_var_a)->UnwrapRef()->Is()); EXPECT_EQ(StmtOf(my_var_a), assign); ASSERT_NE(TypeOf(my_var_b), nullptr); - ASSERT_TRUE(TypeOf(my_var_b)->Is()); - EXPECT_TRUE(TypeOf(my_var_b)->UnwrapRef()->Is()); + EXPECT_TRUE(TypeOf(my_var_b)->Is()); EXPECT_EQ(StmtOf(my_var_b), assign); EXPECT_TRUE(CheckVarUsers(var, utils::Vector{my_var_a, my_var_b})); ASSERT_NE(VarOf(my_var_a), nullptr); @@ -1250,11 +1248,8 @@ TEST_F(ResolverTest, Expr_MemberAccessor_Struct) { EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(mem), nullptr); - ASSERT_TRUE(TypeOf(mem)->Is()); - - auto* ref = TypeOf(mem)->As(); - EXPECT_TRUE(ref->StoreType()->Is()); - auto* sma = Sem().Get(mem)->As(); + EXPECT_TRUE(TypeOf(mem)->Is()); + auto* sma = Sem().Get(mem)->UnwrapLoad()->As(); ASSERT_NE(sma, nullptr); EXPECT_TRUE(sma->Member()->Type()->Is()); EXPECT_EQ(sma->Object()->Declaration(), mem->structure); @@ -1274,11 +1269,8 @@ TEST_F(ResolverTest, Expr_MemberAccessor_Struct_Alias) { EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(mem), nullptr); - ASSERT_TRUE(TypeOf(mem)->Is()); - - auto* ref = TypeOf(mem)->As(); - EXPECT_TRUE(ref->StoreType()->Is()); - auto* sma = Sem().Get(mem)->As(); + EXPECT_TRUE(TypeOf(mem)->Is()); + auto* sma = Sem().Get(mem)->UnwrapLoad()->As(); ASSERT_NE(sma, nullptr); EXPECT_EQ(sma->Object()->Declaration(), mem->structure); EXPECT_TRUE(sma->Member()->Type()->Is()); @@ -1300,7 +1292,7 @@ TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) { auto* sma = Sem().Get(mem)->As(); ASSERT_NE(sma, nullptr); EXPECT_EQ(sma->Object()->Declaration(), mem->structure); - EXPECT_THAT(sma->As()->Indices(), ElementsAre(0, 2, 1, 3)); + EXPECT_THAT(sma->Indices(), ElementsAre(0, 2, 1, 3)); } TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) { @@ -1312,14 +1304,11 @@ TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) { EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(mem), nullptr); - ASSERT_TRUE(TypeOf(mem)->Is()); - - auto* ref = TypeOf(mem)->As(); - ASSERT_TRUE(ref->StoreType()->Is()); - auto* sma = Sem().Get(mem)->As(); + ASSERT_TRUE(TypeOf(mem)->Is()); + auto* sma = Sem().Get(mem)->UnwrapLoad()->As(); ASSERT_NE(sma, nullptr); EXPECT_EQ(sma->Object()->Declaration(), mem->structure); - EXPECT_THAT(Sem().Get(mem)->As()->Indices(), ElementsAre(2)); + EXPECT_THAT(sma->Indices(), ElementsAre(2)); } TEST_F(ResolverTest, Expr_Accessor_MultiLevel) { diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h index 2515a028d6..86bc5d5506 100644 --- a/src/tint/resolver/resolver_test_helper.h +++ b/src/tint/resolver/resolver_test_helper.h @@ -83,7 +83,7 @@ class TestHelper : public ProgramBuilder { /// @return the resolved sem::Variable of the identifier, or nullptr if /// the expression did not resolve to a variable. const sem::Variable* VarOf(const ast::Expression* expr) { - auto* sem_ident = Sem().Get(expr); + auto* sem_ident = Sem().Get(expr)->UnwrapLoad(); auto* var_user = sem_ident ? sem_ident->As() : nullptr; return var_user ? var_user->Variable() : nullptr; } diff --git a/src/tint/resolver/side_effects_test.cc b/src/tint/resolver/side_effects_test.cc index 46727e2977..53df777840 100644 --- a/src/tint/resolver/side_effects_test.cc +++ b/src/tint/resolver/side_effects_test.cc @@ -83,7 +83,7 @@ TEST_F(SideEffectsTest, VariableUser) { EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(expr); ASSERT_NE(sem, nullptr); - EXPECT_TRUE(sem->Is()); + EXPECT_TRUE(sem->UnwrapLoad()->Is()); EXPECT_FALSE(sem->HasSideEffects()); } @@ -438,8 +438,8 @@ TEST_F(SideEffectsTest, MemberAccessor_Vector) { EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(expr); - EXPECT_TRUE(sem->Is()); ASSERT_NE(sem, nullptr); + EXPECT_TRUE(sem->UnwrapLoad()->Is()); EXPECT_FALSE(sem->HasSideEffects()); } @@ -450,8 +450,8 @@ TEST_F(SideEffectsTest, MemberAccessor_VectorSwizzleNoSE) { EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(expr); - EXPECT_TRUE(sem->Is()); ASSERT_NE(sem, nullptr); + EXPECT_TRUE(sem->Is()); EXPECT_FALSE(sem->HasSideEffects()); } @@ -462,8 +462,8 @@ TEST_F(SideEffectsTest, MemberAccessor_VectorSwizzleSE) { EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(expr); - EXPECT_TRUE(sem->Is()); ASSERT_NE(sem, nullptr); + EXPECT_TRUE(sem->Is()); EXPECT_TRUE(sem->HasSideEffects()); } diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc index db1578fff5..30dd30f71b 100644 --- a/src/tint/resolver/uniformity.cc +++ b/src/tint/resolver/uniformity.cc @@ -1035,7 +1035,7 @@ class UniformityGraph { }; auto name = builder_->Symbols().NameFor(ident->symbol); - auto* sem = sem_.Get(ident)->UnwrapMaterialize()->As()->Variable(); + auto* sem = sem_.Get(ident)->Unwrap()->As()->Variable(); auto* node = CreateNode(name + "_ident_expr", ident); return Switch( sem, @@ -1203,7 +1203,7 @@ class UniformityGraph { [&](const ast::IdentifierExpression* i) { auto name = builder_->Symbols().NameFor(i->symbol); - auto* sem = sem_.Get(i); + auto* sem = sem_.Get(i)->UnwrapLoad()->As(); if (sem->Variable()->Is()) { return std::make_pair(cf, current_function_->may_be_non_uniform); } else if (auto* local = sem->Variable()->As()) { @@ -1536,7 +1536,7 @@ class UniformityGraph { Switch( non_uniform_source->ast, [&](const ast::IdentifierExpression* ident) { - auto* var = sem_.Get(ident)->Variable(); + auto* var = sem_.Get(ident)->UnwrapLoad()->As()->Variable(); std::string var_type = get_var_type(var); diagnostics_.add_note(diag::System::Resolver, "reading from " + var_type + "'" + diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index 479cb9c1b6..3172e5834c 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -2205,7 +2205,7 @@ bool Validator::Return(const ast::ReturnStatement* ret, } bool Validator::SwitchStatement(const ast::SwitchStatement* s) { - auto* cond_ty = sem_.TypeOf(s->condition)->UnwrapRef(); + auto* cond_ty = sem_.TypeOf(s->condition); if (!cond_ty->is_integer_scalar()) { AddError("switch statement selector expression must be of a scalar integer type", s->condition->source); diff --git a/src/tint/resolver/variable_test.cc b/src/tint/resolver/variable_test.cc index 196151679e..552ad08658 100644 --- a/src/tint/resolver/variable_test.cc +++ b/src/tint/resolver/variable_test.cc @@ -249,7 +249,8 @@ TEST_F(ResolverVariableTest, LocalVar_ShadowsGlobalVar) { ASSERT_NE(local, nullptr); EXPECT_EQ(local->Shadows(), global); - auto* user_v = Sem().Get(local->Declaration()->initializer); + auto* user_v = + Sem().Get(local->Declaration()->initializer)->UnwrapLoad()->As(); ASSERT_NE(user_v, nullptr); EXPECT_EQ(user_v->Variable(), global); } @@ -298,7 +299,8 @@ TEST_F(ResolverVariableTest, LocalVar_ShadowsLocalVar) { ASSERT_NE(local_y, nullptr); EXPECT_EQ(local_y->Shadows(), local_x); - auto* user_y = Sem().Get(local_y->Declaration()->initializer); + auto* user_y = + Sem().Get(local_y->Declaration()->initializer)->UnwrapLoad()->As(); ASSERT_NE(user_y, nullptr); EXPECT_EQ(user_y->Variable(), local_x); } @@ -563,7 +565,8 @@ TEST_F(ResolverVariableTest, LocalLet_ShadowsGlobalVar) { ASSERT_NE(local, nullptr); EXPECT_EQ(local->Shadows(), global); - auto* user = Sem().Get(local->Declaration()->initializer); + auto* user = + Sem().Get(local->Declaration()->initializer)->UnwrapLoad()->As(); ASSERT_NE(user, nullptr); EXPECT_EQ(user->Variable(), global); } @@ -612,7 +615,8 @@ TEST_F(ResolverVariableTest, LocalLet_ShadowsLocalVar) { ASSERT_NE(local_l, nullptr); EXPECT_EQ(local_l->Shadows(), local_v); - auto* user = Sem().Get(local_l->Declaration()->initializer); + auto* user = + Sem().Get(local_l->Declaration()->initializer)->UnwrapLoad()->As(); ASSERT_NE(user, nullptr); EXPECT_EQ(user->Variable(), local_v); } diff --git a/src/tint/sem/expression.cc b/src/tint/sem/expression.cc index 9d231be728..cf338c4665 100644 --- a/src/tint/sem/expression.cc +++ b/src/tint/sem/expression.cc @@ -16,6 +16,7 @@ #include +#include "src/tint/sem/load.h" #include "src/tint/sem/materialize.h" TINT_INSTANTIATE_TYPEINFO(tint::sem::Expression); @@ -49,4 +50,19 @@ const Expression* Expression::UnwrapMaterialize() const { return this; } +const Expression* Expression::UnwrapLoad() const { + if (auto* l = As()) { + return l->Reference(); + } + return this; +} + +const Expression* Expression::Unwrap() const { + return Switch( + this, // note: An expression can only be wrapped by a Load or Materialize, not both. + [&](const Load* load) { return load->Reference(); }, + [&](const Materialize* materialize) { return materialize->Expr(); }, + [&](Default) { return this; }); +} + } // namespace tint::sem diff --git a/src/tint/sem/expression.h b/src/tint/sem/expression.h index e39e84483c..575f3f5cdb 100644 --- a/src/tint/sem/expression.h +++ b/src/tint/sem/expression.h @@ -85,6 +85,12 @@ class Expression : public Castable { /// @return the inner expression node if this is a Materialize, otherwise this. const Expression* UnwrapMaterialize() const; + /// @return the inner reference expression if this is a Load, otherwise this. + const Expression* UnwrapLoad() const; + + /// @return the inner expression node if this is a Materialize or Load, otherwise this. + const Expression* Unwrap() const; + protected: /// The AST expression node for this semantic expression const ast::Expression* const declaration_; diff --git a/src/tint/transform/combine_samplers.cc b/src/tint/transform/combine_samplers.cc index 4f6838808e..ca42335e90 100644 --- a/src/tint/transform/combine_samplers.cc +++ b/src/tint/transform/combine_samplers.cc @@ -250,9 +250,10 @@ struct CombineSamplers::State { const sem::Expression* sampler = sampler_index != -1 ? call->Arguments()[static_cast(sampler_index)] : nullptr; - auto* texture_var = texture->As()->Variable(); + auto* texture_var = texture->UnwrapLoad()->As()->Variable(); auto* sampler_var = - sampler ? sampler->As()->Variable() : nullptr; + sampler ? sampler->UnwrapLoad()->As()->Variable() + : nullptr; sem::VariablePair new_pair(texture_var, sampler_var); for (auto* arg : expr->args) { auto* type = ctx.src->TypeOf(arg)->UnwrapRef(); @@ -296,12 +297,14 @@ struct CombineSamplers::State { const sem::Variable* sampler_var = pair.second; if (auto* param = texture_var->As()) { const sem::Expression* texture = call->Arguments()[param->Index()]; - texture_var = texture->As()->Variable(); + texture_var = + texture->UnwrapLoad()->As()->Variable(); } if (sampler_var) { if (auto* param = sampler_var->As()) { const sem::Expression* sampler = call->Arguments()[param->Index()]; - sampler_var = sampler->As()->Variable(); + sampler_var = + sampler->UnwrapLoad()->As()->Variable(); } } sem::VariablePair new_pair(texture_var, sampler_var); diff --git a/src/tint/transform/decompose_memory_access.cc b/src/tint/transform/decompose_memory_access.cc index df43788764..7c968753ad 100644 --- a/src/tint/transform/decompose_memory_access.cc +++ b/src/tint/transform/decompose_memory_access.cc @@ -126,10 +126,10 @@ struct LoadStoreKey { /// AtomicKey is the unordered map key to an atomic intrinsic. struct AtomicKey { - ast::Access const access; // buffer access + ast::Access const access; // buffer access type::Type const* buf_ty = nullptr; // buffer type type::Type const* el_ty = nullptr; // element type - sem::BuiltinType const op; // atomic op + sem::BuiltinType const op; // atomic op bool operator==(const AtomicKey& rhs) const { return access == rhs.access && buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op; } @@ -881,15 +881,17 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src, for (auto* node : src->ASTNodes().Objects()) { if (auto* ident = node->As()) { // X - if (auto* var = sem.Get(ident)) { - if (var->Variable()->AddressSpace() == ast::AddressSpace::kStorage || - var->Variable()->AddressSpace() == ast::AddressSpace::kUniform) { - // Variable to a storage or uniform buffer - state.AddAccess(ident, { - var, - state.ToOffset(0u), - var->Type()->UnwrapRef(), - }); + if (auto* sem_ident = sem.Get(ident)) { + if (auto* var = sem_ident->UnwrapLoad()->As()) { + if (var->Variable()->AddressSpace() == ast::AddressSpace::kStorage || + var->Variable()->AddressSpace() == ast::AddressSpace::kUniform) { + // Variable to a storage or uniform buffer + state.AddAccess(ident, { + var, + state.ToOffset(0u), + var->Type()->UnwrapRef(), + }); + } } } continue; @@ -897,7 +899,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src, if (auto* accessor = node->As()) { // X.Y - auto* accessor_sem = sem.Get(accessor); + auto* accessor_sem = sem.Get(accessor)->UnwrapLoad(); if (auto* swizzle = accessor_sem->As()) { if (swizzle->Indices().Length() == 1) { if (auto access = state.TakeAccess(accessor->structure)) { @@ -906,7 +908,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src, state.AddAccess(accessor, { access.var, state.Add(access.offset, offset), - vec_ty->type()->UnwrapRef(), + vec_ty->type(), }); } } @@ -918,7 +920,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src, state.AddAccess(accessor, { access.var, state.Add(access.offset, offset), - member->Type()->UnwrapRef(), + member->Type(), }); } } @@ -933,7 +935,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src, state.AddAccess(accessor, { access.var, state.Add(access.offset, offset), - arr->ElemType()->UnwrapRef(), + arr->ElemType(), }); continue; } @@ -942,7 +944,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src, state.AddAccess(accessor, { access.var, state.Add(access.offset, offset), - vec_ty->type()->UnwrapRef(), + vec_ty->type(), }); continue; } @@ -1014,6 +1016,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src, // All remaining accesses are loads, transform these into calls to the // corresponding load function + // TODO(crbug.com/tint/1784): Use `sem::Load`s instead of maintaining `state.expression_order`. for (auto* expr : state.expression_order) { auto access_it = state.accesses.find(expr); if (access_it == state.accesses.end()) { diff --git a/src/tint/transform/decompose_strided_matrix.cc b/src/tint/transform/decompose_strided_matrix.cc index a874e4e82c..4f2042a8b8 100644 --- a/src/tint/transform/decompose_strided_matrix.cc +++ b/src/tint/transform/decompose_strided_matrix.cc @@ -167,7 +167,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src, // m = arr_to_mat(ssbo.mat) std::unordered_map arr_to_mat; ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* { - if (auto* access = src->Sem().Get(expr)) { + if (auto* access = src->Sem().Get(expr)->UnwrapLoad()->As()) { if (auto info = decomposed.Find(access->Member()->Declaration())) { auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] { auto name = diff --git a/src/tint/transform/first_index_offset.cc b/src/tint/transform/first_index_offset.cc index eb698be572..da1587de40 100644 --- a/src/tint/transform/first_index_offset.cc +++ b/src/tint/transform/first_index_offset.cc @@ -139,7 +139,7 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src, // Fix up all references to the builtins with the offsets ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* { if (auto* sem = ctx.src->Sem().Get(expr)) { - if (auto* user = sem->As()) { + if (auto* user = sem->UnwrapLoad()->As()) { auto it = builtin_vars.find(user->Variable()); if (it != builtin_vars.end()) { return ctx.dst->Add(ctx.CloneWithoutTransform(expr), diff --git a/src/tint/transform/multiplanar_external_texture.cc b/src/tint/transform/multiplanar_external_texture.cc index e7d8327cce..f8c927b174 100644 --- a/src/tint/transform/multiplanar_external_texture.cc +++ b/src/tint/transform/multiplanar_external_texture.cc @@ -197,7 +197,8 @@ struct MultiplanarExternalTexture::State { if (builtin && !builtin->Parameters().IsEmpty() && builtin->Parameters()[0]->Type()->Is() && builtin->Type() != sem::BuiltinType::kTextureDimensions) { - if (auto* var_user = sem.Get(expr->args[0])) { + if (auto* var_user = + sem.Get(expr->args[0])->UnwrapLoad()->As()) { auto it = new_binding_symbols.find(var_user->Variable()); if (it == new_binding_symbols.end()) { // If valid new binding locations were not provided earlier, we would have @@ -222,7 +223,7 @@ struct MultiplanarExternalTexture::State { // texture_external parameter. These need to be expanded out to multiple plane // textures and the texture parameters structure. for (auto* arg : expr->args) { - if (auto* var_user = sem.Get(arg)) { + if (auto* var_user = sem.Get(arg)->UnwrapLoad()->As()) { // Check if a parameter is a texture_external by trying to find // it in the transform state. auto it = new_binding_symbols.find(var_user->Variable()); diff --git a/src/tint/transform/packed_vec3.cc b/src/tint/transform/packed_vec3.cc index 5b5fdf4c9a..af9516137a 100644 --- a/src/tint/transform/packed_vec3.cc +++ b/src/tint/transform/packed_vec3.cc @@ -74,8 +74,14 @@ struct PackedVec3::State { // that load a whole packed vector (not a scalar / swizzle of the vector). utils::Hashset refs; for (auto* node : ctx.src->ASTNodes().Objects()) { + auto* sem_node = sem.Get(node); + if (sem_node) { + if (auto* expr = sem_node->As()) { + sem_node = expr->UnwrapLoad(); + } + } Switch( - sem.Get(node), // + sem_node, // [&](const sem::StructMemberAccess* access) { if (members.Contains(access->Member())) { // Access to a packed vector member. Seed the expression tracking. @@ -84,11 +90,11 @@ struct PackedVec3::State { }, [&](const sem::IndexAccessorExpression* access) { // Not loading a whole packed vector. Ignore. - refs.Remove(access->Object()); + refs.Remove(access->Object()->UnwrapLoad()); }, [&](const sem::Swizzle* access) { // Not loading a whole packed vector. Ignore. - refs.Remove(access->Object()); + refs.Remove(access->Object()->UnwrapLoad()); }, [&](const sem::VariableUser* user) { auto* v = user->Variable(); diff --git a/src/tint/transform/promote_side_effects_to_decl.cc b/src/tint/transform/promote_side_effects_to_decl.cc index 7a5254dc11..ff1c18f9d4 100644 --- a/src/tint/transform/promote_side_effects_to_decl.cc +++ b/src/tint/transform/promote_side_effects_to_decl.cc @@ -276,7 +276,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase { }, [&](const ast::IdentifierExpression* e) { if (auto* sem_e = sem.Get(e)) { - if (auto* var_user = sem_e->As()) { + if (auto* var_user = sem_e->UnwrapLoad()->As()) { // Don't hoist constants. if (var_user->ConstantValue()) { return false; diff --git a/src/tint/transform/renamer.cc b/src/tint/transform/renamer.cc index a84b004605..8d9872dedf 100644 --- a/src/tint/transform/renamer.cc +++ b/src/tint/transform/renamer.cc @@ -1285,7 +1285,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src, Switch( node, [&](const ast::MemberAccessorExpression* accessor) { - auto* sem = src->Sem().Get(accessor); + auto* sem = src->Sem().Get(accessor)->UnwrapLoad(); if (sem->Is()) { preserved_identifiers.Add(accessor->member); } else if (auto* str_expr = src->Sem().Get(accessor->structure)) { diff --git a/src/tint/transform/renamer_test.cc b/src/tint/transform/renamer_test.cc index 635bf4ea8e..06280638d1 100644 --- a/src/tint/transform/renamer_test.cc +++ b/src/tint/transform/renamer_test.cc @@ -93,7 +93,8 @@ fn entry() -> @builtin(position) vec4 { var v : vec4; var rgba : f32; var xyzw : f32; - return v.zyxw + v.rgab; + var z : f32; + return v.zyxw + v.rgab * v.z; } )"; @@ -103,7 +104,8 @@ fn tint_symbol() -> @builtin(position) vec4 { var tint_symbol_1 : vec4; var tint_symbol_2 : f32; var tint_symbol_3 : f32; - return (tint_symbol_1.zyxw + tint_symbol_1.rgab); + var tint_symbol_4 : f32; + return (tint_symbol_1.zyxw + (tint_symbol_1.rgab * tint_symbol_1.z)); } )"; @@ -115,10 +117,8 @@ fn tint_symbol() -> @builtin(position) vec4 { ASSERT_NE(data, nullptr); Renamer::Data::Remappings expected_remappings = { - {"entry", "tint_symbol"}, - {"v", "tint_symbol_1"}, - {"rgba", "tint_symbol_2"}, - {"xyzw", "tint_symbol_3"}, + {"entry", "tint_symbol"}, {"v", "tint_symbol_1"}, {"rgba", "tint_symbol_2"}, + {"xyzw", "tint_symbol_3"}, {"z", "tint_symbol_4"}, }; EXPECT_THAT(data->remappings, ContainerEq(expected_remappings)); } diff --git a/src/tint/transform/robustness.cc b/src/tint/transform/robustness.cc index 45b138e770..1ef19c96f2 100644 --- a/src/tint/transform/robustness.cc +++ b/src/tint/transform/robustness.cc @@ -67,7 +67,7 @@ struct Robustness::State { /// @return the clamped replacement expression, or nullptr if `expr` should be cloned without /// changes. const ast::IndexAccessorExpression* Transform(const ast::IndexAccessorExpression* expr) { - auto* sem = src->Sem().Get(expr)->UnwrapMaterialize()->As(); + auto* sem = src->Sem().Get(expr)->Unwrap()->As(); auto* ret_type = sem->Type(); auto* ref = ret_type->As(); @@ -78,7 +78,7 @@ struct Robustness::State { // idx return the cloned index expression, as a u32. auto idx = [&]() -> const ast::Expression* { auto* i = ctx.Clone(expr->index); - if (sem->Index()->Type()->UnwrapRef()->is_signed_integer_scalar()) { + if (sem->Index()->Type()->is_signed_integer_scalar()) { return b.Construct(b.ty.u32(), i); // u32(idx) } return i; diff --git a/src/tint/transform/spirv_atomic.cc b/src/tint/transform/spirv_atomic.cc index 6b55e5d966..2471b89d1c 100644 --- a/src/tint/transform/spirv_atomic.cc +++ b/src/tint/transform/spirv_atomic.cc @@ -164,7 +164,7 @@ struct SpirvAtomic::State { void ProcessAtomicExpressions() { for (size_t i = 0; i < atomic_expressions.Length(); i++) { Switch( - atomic_expressions[i], // + atomic_expressions[i]->UnwrapLoad(), // [&](const sem::VariableUser* user) { auto* v = user->Variable()->Declaration(); if (v->type && atomic_variables.emplace(user->Variable()).second) { @@ -262,7 +262,7 @@ struct SpirvAtomic::State { } auto sem_rhs = ctx.src->Sem().Get(assign->rhs); - if (is_ref_to_atomic_var(sem_rhs)) { + if (is_ref_to_atomic_var(sem_rhs->UnwrapLoad())) { ctx.Replace(assign->rhs, [=] { auto* rhs = ctx.CloneWithoutTransform(assign->rhs); return b.Call(sem::str(sem::BuiltinType::kAtomicLoad), @@ -274,7 +274,7 @@ struct SpirvAtomic::State { [&](const ast::VariableDeclStatement* decl) { auto* var = decl->variable; if (auto* sem_init = ctx.src->Sem().Get(var->initializer)) { - if (is_ref_to_atomic_var(sem_init)) { + if (is_ref_to_atomic_var(sem_init->UnwrapLoad())) { ctx.Replace(var->initializer, [=] { auto* rhs = ctx.CloneWithoutTransform(var->initializer); return b.Call(sem::str(sem::BuiltinType::kAtomicLoad), diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc index 3f9401865d..0b8665a464 100644 --- a/src/tint/transform/std140.cc +++ b/src/tint/transform/std140.cc @@ -511,7 +511,7 @@ struct Std140::State { while (true) { enum class Action { kStop, kContinue, kError }; Action action = Switch( - expr, // + expr->Unwrap(), // [&](const sem::VariableUser* user) { if (user->Variable() == access.var) { // Walked all the way to the root identifier. We're done traversing. diff --git a/src/tint/transform/unshadow.cc b/src/tint/transform/unshadow.cc index 8d2b876381..3ec5f3053e 100644 --- a/src/tint/transform/unshadow.cc +++ b/src/tint/transform/unshadow.cc @@ -96,9 +96,11 @@ struct Unshadow::State { }); ctx.ReplaceAll( [&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* { - if (auto* user = sem.Get(ident)) { - if (auto renamed = renamed_to.Find(user->Variable())) { - return b.Expr(*renamed); + if (auto* sem_ident = sem.Get(ident)) { + if (auto* user = sem_ident->UnwrapLoad()->As()) { + if (auto renamed = renamed_to.Find(user->Variable())) { + return b.Expr(*renamed); + } } } return nullptr; diff --git a/src/tint/transform/unshadow_test.cc b/src/tint/transform/unshadow_test.cc index c731e76956..f9e976b361 100644 --- a/src/tint/transform/unshadow_test.cc +++ b/src/tint/transform/unshadow_test.cc @@ -760,5 +760,31 @@ type a = i32; EXPECT_EQ(expect, str(got)); } +TEST_F(UnshadowTest, RenamedVarHasUsers) { + auto* src = R"( +fn F() { + var a : bool; + { + var a : i32; + var b = a + 1; + } +} +)"; + + auto* expect = R"( +fn F() { + var a : bool; + { + var a_1 : i32; + var b = (a_1 + 1); + } +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + } // namespace } // namespace tint::transform diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc index 2da5c8c918..41705cfbb0 100644 --- a/src/tint/writer/glsl/generator_impl.cc +++ b/src/tint/writer/glsl/generator_impl.cc @@ -2706,7 +2706,7 @@ bool GeneratorImpl::EmitMemberAccessor(std::ostream& out, } out << "."; - auto* sem = builder_.Sem().Get(expr); + auto* sem = builder_.Sem().Get(expr)->UnwrapLoad(); return Switch( sem, diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index ece37231b4..93589e5fe0 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -3724,7 +3724,7 @@ bool GeneratorImpl::EmitMemberAccessor(std::ostream& out, } out << "."; - auto* sem = builder_.Sem().Get(expr); + auto* sem = builder_.Sem().Get(expr)->UnwrapLoad(); return Switch( sem, diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index c3fbe4efda..ab18316e6e 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -2339,7 +2339,7 @@ bool GeneratorImpl::EmitMemberAccessor(std::ostream& out, return true; }; - auto* sem = builder_.Sem().Get(expr); + auto* sem = builder_.Sem().Get(expr)->UnwrapLoad(); return Switch( sem, diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc index dab6869f45..9d1092e0b4 100644 --- a/src/tint/writer/spirv/builder.cc +++ b/src/tint/writer/spirv/builder.cc @@ -26,6 +26,7 @@ #include "src/tint/sem/builtin.h" #include "src/tint/sem/call.h" #include "src/tint/sem/function.h" +#include "src/tint/sem/load.h" #include "src/tint/sem/materialize.h" #include "src/tint/sem/member_accessor_expression.h" #include "src/tint/sem/module.h" @@ -436,7 +437,7 @@ bool Builder::GenerateAssignStatement(const ast::AssignmentStatement* assign) { if (lhs_id == 0) { return false; } - auto rhs_id = GenerateExpressionWithLoadIfNeeded(assign->rhs); + auto rhs_id = GenerateExpression(assign->rhs); if (rhs_id == 0) { return false; } @@ -457,7 +458,7 @@ bool Builder::GenerateBreakStatement(const ast::BreakStatement*) { bool Builder::GenerateBreakIfStatement(const ast::BreakIfStatement* stmt) { TINT_ASSERT(Writer, !backedge_stack_.empty()); - const auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition); + const auto cond_id = GenerateExpression(stmt->condition); if (!cond_id) { return false; } @@ -555,14 +556,19 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) { return true; } -uint32_t Builder::GenerateExpression(const ast::Expression* expr) { - if (auto* sem = builder_.Sem().Get(expr)) { - if (auto constant = sem->ConstantValue()) { - return GenerateConstantIfNeeded(constant); +uint32_t Builder::GenerateExpression(const sem::Expression* expr) { + if (auto* constant = expr->ConstantValue()) { + return GenerateConstantIfNeeded(constant); + } + if (auto* load = expr->As()) { + auto ref_id = GenerateExpression(load->Reference()); + if (ref_id == 0) { + return 0; } + return GenerateLoad(load->ReferenceType(), ref_id); } return Switch( - expr, // + expr->Declaration(), // [&](const ast::IndexAccessorExpression* a) { return GenerateAccessorExpression(a); }, [&](const ast::BinaryExpression* b) { return GenerateBinaryExpression(b); }, [&](const ast::BitcastExpression* b) { return GenerateBitcastExpression(b); }, @@ -577,6 +583,10 @@ uint32_t Builder::GenerateExpression(const ast::Expression* expr) { }); } +uint32_t Builder::GenerateExpression(const ast::Expression* expr) { + return GenerateExpression(builder_.Sem().Get(expr)); +} + bool Builder::GenerateFunction(const ast::Function* func_ast) { auto* func = builder_.Sem().Get(func_ast); @@ -686,7 +696,7 @@ bool Builder::GenerateFunctionVariable(const ast::Variable* v) { uint32_t init_id = 0; if (v->initializer) { - init_id = GenerateExpressionWithLoadIfNeeded(v->initializer); + init_id = GenerateExpression(v->initializer); if (init_id == 0) { return false; } @@ -874,7 +884,7 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* v) { } bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr, AccessorInfo* info) { - auto idx_id = GenerateExpressionWithLoadIfNeeded(expr->index); + auto idx_id = GenerateExpression(expr->index); if (idx_id == 0) { return 0; } @@ -884,7 +894,7 @@ bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr, Ac // See https://github.com/gpuweb/gpuweb/pull/1580 if (info->source_type->Is()) { info->access_chain_indices.push_back(idx_id); - info->source_type = TypeOf(expr); + info->source_type = builder_.Sem().Get(expr)->UnwrapLoad()->Type(); return true; } @@ -936,7 +946,7 @@ bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr, Ac bool Builder::GenerateMemberAccessor(const ast::MemberAccessorExpression* expr, AccessorInfo* info) { - auto* expr_sem = builder_.Sem().Get(expr); + auto* expr_sem = builder_.Sem().Get(expr)->UnwrapLoad(); auto* expr_type = expr_sem->Type(); if (auto* access = expr_sem->As()) { @@ -1108,7 +1118,7 @@ uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) { } if (!info.access_chain_indices.empty()) { - auto* type = TypeOf(expr); + auto* type = builder_.Sem().Get(expr)->UnwrapLoad()->Type(); auto result_type_id = GenerateTypeIfNeeded(type); if (result_type_id == 0) { return 0; @@ -1133,38 +1143,18 @@ uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) { uint32_t Builder::GenerateIdentifierExpression(const ast::IdentifierExpression* expr) { auto* sem = builder_.Sem().Get(expr); - if (auto* user = sem->As()) { - return LookupVariableID(user->Variable()); + if (sem) { + if (auto* user = sem->UnwrapLoad()->As()) { + return LookupVariableID(user->Variable()); + } } error_ = "identifier '" + builder_.Symbols().NameFor(expr->symbol) + "' does not resolve to a variable"; return 0; } -uint32_t Builder::GenerateExpressionWithLoadIfNeeded(const sem::Expression* expr) { - // The semantic node directly knows both the AST node and the resolved type. - if (const auto id = GenerateExpression(expr->Declaration())) { - return GenerateLoadIfNeeded(expr->Type(), id); - } - return 0; -} - -uint32_t Builder::GenerateExpressionWithLoadIfNeeded(const ast::Expression* expr) { - if (const auto id = GenerateExpression(expr)) { - // Perform a lookup to get the resolved type. - return GenerateLoadIfNeeded(TypeOf(expr), id); - } - return 0; -} - -uint32_t Builder::GenerateLoadIfNeeded(const type::Type* type, uint32_t id) { - if (auto* ref = type->As()) { - type = ref->StoreType(); - } else { - return id; - } - - auto type_id = GenerateTypeIfNeeded(type); +uint32_t Builder::GenerateLoad(const type::Reference* type, uint32_t id) { + auto type_id = GenerateTypeIfNeeded(type->StoreType()); auto result = result_op(); auto result_id = std::get(result); if (!push_function_inst(spv::Op::OpLoad, {Operand(type_id), result, Operand(id)})) { @@ -1173,6 +1163,13 @@ uint32_t Builder::GenerateLoadIfNeeded(const type::Type* type, uint32_t id) { return result_id; } +uint32_t Builder::GenerateLoadIfNeeded(const type::Type* type, uint32_t id) { + if (auto* ref = type->As()) { + return GenerateLoad(ref, id); + } + return id; +} + uint32_t Builder::GenerateUnaryOpExpression(const ast::UnaryOpExpression* expr) { auto result = result_op(); auto result_id = std::get(result); @@ -1200,7 +1197,7 @@ uint32_t Builder::GenerateUnaryOpExpression(const ast::UnaryOpExpression* expr) return GenerateExpression(expr->expr); } - auto val_id = GenerateExpressionWithLoadIfNeeded(expr->expr); + auto val_id = GenerateExpression(expr->expr); if (val_id == 0) { return 0; } @@ -1338,7 +1335,7 @@ uint32_t Builder::GenerateTypeInitializerOrConversion(const sem::Call* call, for (auto* e : args) { uint32_t id = 0; - id = GenerateExpressionWithLoadIfNeeded(e); + id = GenerateExpression(e); if (id == 0) { return 0; } @@ -1476,7 +1473,7 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough(const type::Type* to_type, return 0; } - auto val_id = GenerateExpressionWithLoadIfNeeded(from_expr); + auto val_id = GenerateExpression(from_expr); if (val_id == 0) { return 0; } @@ -1828,7 +1825,7 @@ uint32_t Builder::GenerateConstantVectorSplatIfNeeded(const type::Vector* type, } uint32_t Builder::GenerateShortCircuitBinaryExpression(const ast::BinaryExpression* expr) { - auto lhs_id = GenerateExpressionWithLoadIfNeeded(expr->lhs); + auto lhs_id = GenerateExpression(expr->lhs); if (lhs_id == 0) { return false; } @@ -1869,7 +1866,7 @@ uint32_t Builder::GenerateShortCircuitBinaryExpression(const ast::BinaryExpressi if (!GenerateLabel(block_id)) { return 0; } - auto rhs_id = GenerateExpressionWithLoadIfNeeded(expr->rhs); + auto rhs_id = GenerateExpression(expr->rhs); if (rhs_id == 0) { return 0; } @@ -1989,12 +1986,12 @@ uint32_t Builder::GenerateBinaryExpression(const ast::BinaryExpression* expr) { return GenerateShortCircuitBinaryExpression(expr); } - auto lhs_id = GenerateExpressionWithLoadIfNeeded(expr->lhs); + auto lhs_id = GenerateExpression(expr->lhs); if (lhs_id == 0) { return 0; } - auto rhs_id = GenerateExpressionWithLoadIfNeeded(expr->rhs); + auto rhs_id = GenerateExpression(expr->rhs); if (rhs_id == 0) { return 0; } @@ -2267,7 +2264,7 @@ uint32_t Builder::GenerateFunctionCall(const sem::Call* call, const sem::Functio ops.push_back(Operand(func_id)); for (auto* arg : expr->args) { - auto id = GenerateExpressionWithLoadIfNeeded(arg); + auto id = GenerateExpression(arg); if (id == 0) { return 0; } @@ -2623,10 +2620,7 @@ bool Builder::GenerateTextureBuiltin(const sem::Call* call, auto& arguments = call->Arguments(); // Generates the given expression, returning the operand ID - auto gen = [&](const sem::Expression* expr) { - const auto val_id = GenerateExpressionWithLoadIfNeeded(expr); - return Operand(val_id); - }; + auto gen = [&](const sem::Expression* expr) { return Operand(GenerateExpression(expr)); }; // Returns the argument with the given usage auto arg = [&](Usage usage) { @@ -2754,7 +2748,7 @@ bool Builder::GenerateTextureBuiltin(const sem::Call* call, // Array index needs to be appended to the coordinates. auto* packed = AppendVector(&builder_, arg(Usage::kCoords)->Declaration(), array_index->Declaration()); - auto param = GenerateExpression(packed->Declaration()); + auto param = GenerateExpression(packed); if (param == 0) { return false; } @@ -3097,14 +3091,14 @@ bool Builder::GenerateAtomicBuiltin(const sem::Call* call, return false; } - uint32_t pointer_id = GenerateExpression(call->Arguments()[0]->Declaration()); + uint32_t pointer_id = GenerateExpression(call->Arguments()[0]); if (pointer_id == 0) { return false; } uint32_t value_id = 0; if (call->Arguments().Length() > 1) { - value_id = GenerateExpressionWithLoadIfNeeded(call->Arguments().Back()); + value_id = GenerateExpression(call->Arguments().Back()); if (value_id == 0) { return false; } @@ -3206,8 +3200,7 @@ bool Builder::GenerateAtomicBuiltin(const sem::Call* call, value, }); case sem::BuiltinType::kAtomicCompareExchangeWeak: { - auto comparator = - GenerateExpressionWithLoadIfNeeded(call->Arguments()[1]->Declaration()); + auto comparator = GenerateExpression(call->Arguments()[1]); if (comparator == 0) { return false; } @@ -3312,7 +3305,7 @@ uint32_t Builder::GenerateBitcastExpression(const ast::BitcastExpression* expr) return 0; } - auto val_id = GenerateExpressionWithLoadIfNeeded(expr->expr); + auto val_id = GenerateExpression(expr->expr); if (val_id == 0) { return 0; } @@ -3339,7 +3332,7 @@ uint32_t Builder::GenerateBitcastExpression(const ast::BitcastExpression* expr) bool Builder::GenerateConditionalBlock(const ast::Expression* cond, const ast::BlockStatement* true_body, const ast::Statement* else_stmt) { - auto cond_id = GenerateExpressionWithLoadIfNeeded(cond); + auto cond_id = GenerateExpression(cond); if (cond_id == 0) { return false; } @@ -3420,7 +3413,7 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) { merge_stack_.push_back(merge_block_id); - auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition); + auto cond_id = GenerateExpression(stmt->condition); if (cond_id == 0) { return false; } @@ -3504,7 +3497,7 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) { bool Builder::GenerateReturnStatement(const ast::ReturnStatement* stmt) { if (stmt->value) { - auto val_id = GenerateExpressionWithLoadIfNeeded(stmt->value); + auto val_id = GenerateExpression(stmt->value); if (val_id == 0) { return false; } diff --git a/src/tint/writer/spirv/builder.h b/src/tint/writer/spirv/builder.h index 68c8f25deb..6ac749e883 100644 --- a/src/tint/writer/spirv/builder.h +++ b/src/tint/writer/spirv/builder.h @@ -43,6 +43,7 @@ // Forward declarations namespace tint::sem { class Call; +class Load; class TypeInitializer; class TypeConversion; } // namespace tint::sem @@ -274,6 +275,10 @@ class Builder { /// Generates an expression /// @param expr the expression to generate /// @returns the resulting ID of the expression or 0 on error + uint32_t GenerateExpression(const sem::Expression* expr); + /// Generates an expression + /// @param expr the expression to generate + /// @returns the resulting ID of the expression or 0 on error uint32_t GenerateExpression(const ast::Expression* expr); /// Generates the instructions for a function /// @param func the function to generate @@ -440,24 +445,15 @@ class Builder { /// @param stmt the statement to generate /// @returns true if the statement was generated bool GenerateStatement(const ast::Statement* stmt); - /// Generates an expression. If the WGSL expression does not have reference - /// type, then return the SPIR-V ID for the expression. Otherwise implement - /// the WGSL Load Rule: generate an OpLoad and return the ID of the result. - /// Returns 0 if the expression could not be generated. - /// @param expr the semantic expression node to be generated - /// @returns the the ID of the expression, or loaded expression - uint32_t GenerateExpressionWithLoadIfNeeded(const sem::Expression* expr); - /// Generates an expression. If the WGSL expression does not have reference - /// type, then return the SPIR-V ID for the expression. Otherwise implement - /// the WGSL Load Rule: generate an OpLoad and return the ID of the result. - /// Returns 0 if the expression could not be generated. - /// @param expr the AST expression to be generated - /// @returns the the ID of the expression, or loaded expression - uint32_t GenerateExpressionWithLoadIfNeeded(const ast::Expression* expr); - /// Generates an OpLoad on the given ID if it has reference type in WGSL, - /// othewrise return the ID itself. + /// Generates an OpLoad of the given expression type + /// @param type the reference type of the expression + /// @param id the SPIR-V id of the expression + /// @returns the ID of the loaded value or 0 on failure. + uint32_t GenerateLoad(const type::Reference* type, uint32_t id); + /// Generates an OpLoad on the given ID if it has reference type in WGSL, otherwise return the + /// ID itself. /// @param type the type of the expression - /// @param id the SPIR-V id of the experssion + /// @param id the SPIR-V id of the expression /// @returns the ID of the loaded value or `id` if type is not a reference uint32_t GenerateLoadIfNeeded(const type::Type* type, uint32_t id); /// Generates an OpStore. Emits an error and returns false if we're diff --git a/src/tint/writer/spirv/builder_accessor_expression_test.cc b/src/tint/writer/spirv/builder_accessor_expression_test.cc index 5c0671792b..925406a51b 100644 --- a/src/tint/writer/spirv/builder_accessor_expression_test.cc +++ b/src/tint/writer/spirv/builder_accessor_expression_test.cc @@ -1263,12 +1263,12 @@ TEST_F(BuilderTest, MemberAccessor_Swizzle_MultipleNames) { %7 = OpTypeVector %8 3 %6 = OpTypePointer Function %7 %9 = OpConstantNull %7 -%10 = OpTypeVector %8 2 +%11 = OpTypeVector %8 2 )"); EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), R"(%5 = OpVariable %6 Function %9 )"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(%11 = OpLoad %7 %5 -%12 = OpVectorShuffle %10 %11 %11 1 0 + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(%10 = OpLoad %7 %5 +%12 = OpVectorShuffle %11 %10 %10 1 0 OpReturn )"); diff --git a/test/tint/access/var/vector.wgsl.expected.spvasm b/test/tint/access/var/vector.wgsl.expected.spvasm index 11a010c06a..74418bb50d 100644 --- a/test/tint/access/var/vector.wgsl.expected.spvasm +++ b/test/tint/access/var/vector.wgsl.expected.spvasm @@ -24,8 +24,8 @@ %v = OpVariable %_ptr_Function_v3float Function %9 %13 = OpAccessChain %_ptr_Function_float %v %uint_1 %14 = OpLoad %float %13 - %16 = OpLoad %v3float %v - %17 = OpVectorShuffle %v2float %16 %16 0 2 + %15 = OpLoad %v3float %v + %17 = OpVectorShuffle %v2float %15 %15 0 2 %18 = OpLoad %v3float %v %19 = OpVectorShuffle %v3float %18 %18 0 2 1 OpReturn