From 933d44a2c8b9148c7a052b7555765f4a115f80aa Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 7 Apr 2021 17:29:31 +0000 Subject: [PATCH] transform/hlsl: Hoist structure constructors to new var HLSL has some pecular rules around structure constructors. `S s = S(1,2,3)` is not valid, but `S s = {1,2,3}` is. This matches the quirkiness with array initializers, so adjust the array hoisting logic to also support structures. Fixed: tint:702 Change-Id: Ifdcafd98292715ae2482f72ec06c87842176d270 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46875 Commit-Queue: Ben Clayton Reviewed-by: James Price --- src/transform/hlsl.cc | 39 ++--- src/transform/hlsl.h | 8 +- src/transform/hlsl_test.cc | 136 +++++++++++++++++- src/writer/hlsl/generator_impl.cc | 62 +++++--- .../hlsl/generator_impl_constructor_test.cc | 34 +++++ .../hlsl/generator_impl_function_test.cc | 61 ++++---- .../hlsl/generator_impl_sanitizer_test.cc | 39 +++++ 7 files changed, 305 insertions(+), 74 deletions(-) diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc index 918d6e3ddb..9db68e0745 100644 --- a/src/transform/hlsl.cc +++ b/src/transform/hlsl.cc @@ -32,24 +32,24 @@ Hlsl::~Hlsl() = default; Transform::Output Hlsl::Run(const Program* in, const DataMap&) { ProgramBuilder out; CloneContext ctx(&out, in); - PromoteArrayInitializerToConstVar(ctx); + PromoteInitializersToConstVar(ctx); AddEmptyEntryPoint(ctx); ctx.Clone(); return Output{Program(std::move(out))}; } -void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const { - // Scan the AST nodes for array initializers which need to be promoted to - // their own constant declaration. +void Hlsl::PromoteInitializersToConstVar(CloneContext& ctx) const { + // Scan the AST nodes for array and structure initializers which + // need to be promoted to their own constant declaration. - // Note: Correct handling of arrays-of-arrays is guaranteed due to the + // Note: Correct handling of nested expressions is guaranteed due to the // depth-first traversal of the ast::Node::Clone() methods: // - // The inner-most array initializers are traversed first, and they are hoisted + // The inner-most initializers are traversed first, and they are hoisted // to const variables declared just above the statement of use. The outer - // array initializer will then be hoisted, inserting themselves between the - // inner array declaration and the statement of use. This pattern applies - // correctly to any nested depth. + // initializer will then be hoisted, inserting themselves between the + // inner declaration and the statement of use. This pattern applies correctly + // to any nested depth. // // Depth-first traversal of the AST is guaranteed because AST nodes are fully // immutable and require their children to be constructed first so their @@ -75,22 +75,23 @@ void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const { if (auto* src_var_decl = src_stmt->As()) { if (src_var_decl->variable()->constructor() == src_init) { - // This statement is just a variable declaration with the array - // initializer as the constructor value. This is what we're - // attempting to transform to, and so ignore. + // This statement is just a variable declaration with the initializer + // as the constructor value. This is what we're attempting to + // transform to, and so ignore. continue; } } - if (auto* src_array_ty = src_sem_expr->Type()->As()) { + auto* src_ty = src_sem_expr->Type(); + if (src_ty->IsAnyOf()) { // Create a new symbol for the constant auto dst_symbol = ctx.dst->Symbols().New(); - // Clone the array type - auto* dst_array_ty = ctx.Clone(src_array_ty); - // Clone the array initializer + // Clone the type + auto* dst_ty = ctx.Clone(src_ty); + // Clone the initializer auto* dst_init = ctx.Clone(src_init); - // Construct the constant that holds the array - auto* dst_var = ctx.dst->Const(dst_symbol, dst_array_ty, dst_init); + // Construct the constant that holds the hoisted initializer + auto* dst_var = ctx.dst->Const(dst_symbol, dst_ty, dst_init); // Construct the variable declaration statement auto* dst_var_decl = ctx.dst->create(dst_var); @@ -100,7 +101,7 @@ void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const { // Insert the constant before the usage ctx.InsertBefore(src_sem_stmt->Block()->statements(), src_stmt, dst_var_decl); - // Replace the inlined array with a reference to the constant + // Replace the inlined initializer with a reference to the constant ctx.Replace(src_init, dst_ident); } } diff --git a/src/transform/hlsl.h b/src/transform/hlsl.h index df903a75dd..ff7978f6f1 100644 --- a/src/transform/hlsl.h +++ b/src/transform/hlsl.h @@ -40,10 +40,10 @@ class Hlsl : public Transform { Output Run(const Program* program, const DataMap& data = {}) override; private: - /// Hoists the array initializer to a constant variable, declared just before - /// the array usage statement. - /// See crbug.com/tint/406 for more details - void PromoteArrayInitializerToConstVar(CloneContext& ctx) const; + /// Hoists the array and structure initializers to a constant variable, + /// declared just before the statement of usage. See crbug.com/tint/406 for + /// more details + void PromoteInitializersToConstVar(CloneContext& ctx) const; /// Add an empty shader entry point if none exist in the module. void AddEmptyEntryPoint(CloneContext& ctx) const; }; diff --git a/src/transform/hlsl_test.cc b/src/transform/hlsl_test.cc index 4ce2259350..27779bde42 100644 --- a/src/transform/hlsl_test.cc +++ b/src/transform/hlsl_test.cc @@ -51,6 +51,39 @@ fn main() -> void { EXPECT_EQ(expect, str(got)); } +TEST_F(HlslTest, PromoteStructureInitializerToConstVar_Basic) { + auto* src = R"( +struct S { + a : i32; + b : f32; + c : vec3; +}; + +[[stage(vertex)]] +fn main() -> void { + var x : f32 = S(1, 2.0, vec3()).b; +} +)"; + + auto* expect = R"( +struct S { + a : i32; + b : f32; + c : vec3; +}; + +[[stage(vertex)]] +fn main() -> void { + const tint_symbol_1 : S = S(1, 2.0, vec3()); + var x : f32 = tint_symbol_1.b; +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + TEST_F(HlslTest, PromoteArrayInitializerToConstVar_ArrayInArray) { auto* src = R"( [[stage(vertex)]] @@ -74,14 +107,115 @@ fn main() -> void { EXPECT_EQ(expect, str(got)); } -TEST_F(HlslTest, PromoteArrayInitializerToConstVar_NoChangeOnArrayVarDecl) { +TEST_F(HlslTest, PromoteStructureInitializerToConstVar_Nested) { auto* src = R"( +struct S1 { + a : i32; +}; + +struct S2 { + a : i32; + b : S1; + c : i32; +}; + +struct S3 { + a : S2; +}; + +[[stage(vertex)]] +fn main() -> void { + var x : i32 = S3(S2(1, S1(2), 3)).a.b.a; +} +)"; + + auto* expect = R"( +struct S1 { + a : i32; +}; + +struct S2 { + a : i32; + b : S1; + c : i32; +}; + +struct S3 { + a : S2; +}; + +[[stage(vertex)]] +fn main() -> void { + const tint_symbol_1 : S1 = S1(2); + const tint_symbol_4 : S2 = S2(1, tint_symbol_1, 3); + const tint_symbol_8 : S3 = S3(tint_symbol_4); + var x : i32 = tint_symbol_8.a.b.a; +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(HlslTest, PromoteInitializerToConstVar_Mixed) { + auto* src = R"( +struct S1 { + a : i32; +}; + +struct S2 { + a : array; +}; + +[[stage(vertex)]] +fn main() -> void { + var x : i32 = S2(array(S1(1), S1(2), S1(3))).a[1].a; +} +)"; + + auto* expect = R"( +struct S1 { + a : i32; +}; + +struct S2 { + a : array; +}; + +[[stage(vertex)]] +fn main() -> void { + const tint_symbol_1 : S1 = S1(1); + const tint_symbol_4 : S1 = S1(2); + const tint_symbol_5 : S1 = S1(3); + const tint_symbol_6 : array = array(tint_symbol_1, tint_symbol_4, tint_symbol_5); + const tint_symbol_7 : S2 = S2(tint_symbol_6); + var x : i32 = tint_symbol_7.a[1].a; +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(HlslTest, PromoteInitializerToConstVar_NoChangeOnVarDecl) { + auto* src = R"( +struct S { + a : i32; + b : f32; + c : i32; +}; + [[stage(vertex)]] fn main() -> void { var local_arr : array = array(0.0, 1.0, 2.0, 3.0); + var local_str : S = S(1, 2.0, 3); } const module_arr : array = array(0.0, 1.0, 2.0, 3.0); + +const module_str : S = S(1, 2.0, 3); )"; auto* expect = src; diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 90ffaf332d..4f82f6c820 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -1229,7 +1229,16 @@ bool GeneratorImpl::EmitScalarConstructor( bool GeneratorImpl::EmitTypeConstructor(std::ostream& pre, std::ostream& out, ast::TypeConstructorExpression* expr) { - if (expr->type()->Is()) { + // If the type constructor is empty then we need to construct with the zero + // value for all components. + if (expr->values().empty()) { + return EmitZeroValue(out, expr->type()); + } + + bool brackets = + expr->type()->UnwrapAliasIfNeeded()->IsAnyOf(); + + if (brackets) { out << "{"; } else { if (!EmitType(out, expr->type(), "")) { @@ -1238,31 +1247,19 @@ bool GeneratorImpl::EmitTypeConstructor(std::ostream& pre, out << "("; } - // If the type constructor is empty then we need to construct with the zero - // value for all components. - if (expr->values().empty()) { - if (!EmitZeroValue(out, expr->type())) { + bool first = true; + for (auto* e : expr->values()) { + if (!first) { + out << ", "; + } + first = false; + + if (!EmitExpression(pre, out, e)) { return false; } - } else { - bool first = true; - for (auto* e : expr->values()) { - if (!first) { - out << ", "; - } - first = false; - - if (!EmitExpression(pre, out, e)) { - return false; - } - } } - if (expr->type()->Is()) { - out << "}"; - } else { - out << ")"; - } + out << (brackets ? "}" : ")"); return true; } @@ -1994,6 +1991,10 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, type::Type* type) { } else if (type->Is()) { out << "0u"; } else if (auto* vec = type->As()) { + if (!EmitType(out, type, "")) { + return false; + } + ScopedParen sp(out); for (uint32_t i = 0; i < vec->size(); i++) { if (i != 0) { out << ", "; @@ -2003,6 +2004,10 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, type::Type* type) { } } } else if (auto* mat = type->As()) { + if (!EmitType(out, type, "")) { + return false; + } + ScopedParen sp(out); for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) { if (i != 0) { out << ", "; @@ -2011,6 +2016,19 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, type::Type* type) { return false; } } + } else if (auto* str = type->As()) { + out << "{"; + bool first = true; + for (auto* member : str->impl()->members()) { + if (!first) { + out << ", "; + } + first = false; + if (!EmitZeroValue(out, member->type())) { + return false; + } + } + out << "}"; } else { diagnostics_.add_error("Invalid type for zero emission: " + type->type_name()); diff --git a/src/writer/hlsl/generator_impl_constructor_test.cc b/src/writer/hlsl/generator_impl_constructor_test.cc index 2acbedbb68..2bee83b1cc 100644 --- a/src/writer/hlsl/generator_impl_constructor_test.cc +++ b/src/writer/hlsl/generator_impl_constructor_test.cc @@ -194,6 +194,40 @@ TEST_F(HlslGeneratorImplTest_Constructor, Validate(); } +TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Struct) { + auto* str = Structure("S", { + Member("a", ty.i32()), + Member("b", ty.f32()), + Member("c", ty.vec3()), + }); + + WrapInFunction(Construct(str, 1, 2.0f, vec3(3, 4, 5))); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + EXPECT_THAT(result(), HasSubstr("{1, 2.0f, int3(3, 4, 5)}")); + + Validate(); +} + +TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Struct_Empty) { + auto* str = Structure("S", { + Member("a", ty.i32()), + Member("b", ty.f32()), + Member("c", ty.vec3()), + }); + + WrapInFunction(Construct(str)); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + EXPECT_THAT(result(), HasSubstr("{0, 0.0f, int3(0, 0, 0)}")); + + Validate(); +} + } // namespace } // namespace hlsl } // namespace writer diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index f9b1d6adc9..7d702777a9 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -124,16 +124,17 @@ TEST_F(HlslGeneratorImplTest_Function, GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_EQ(result(), R"(struct tint_symbol_1 { + EXPECT_EQ(result(), R"(struct tint_symbol_5 { float foo : TEXCOORD0; }; -struct tint_symbol_3 { +struct tint_symbol_2 { float value : SV_Target1; }; -tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) { - const float foo = tint_symbol_6.foo; - return tint_symbol_3(foo); +tint_symbol_2 frag_main(tint_symbol_5 tint_symbol_7) { + const float foo = tint_symbol_7.foo; + const tint_symbol_2 tint_symbol_1 = {foo}; + return tint_symbol_1; } )"); @@ -157,16 +158,17 @@ TEST_F(HlslGeneratorImplTest_Function, GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_EQ(result(), R"(struct tint_symbol_1 { + EXPECT_EQ(result(), R"(struct tint_symbol_6 { float4 coord : SV_Position; }; -struct tint_symbol_3 { +struct tint_symbol_2 { float value : SV_Depth; }; -tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) { - const float4 coord = tint_symbol_6.coord; - return tint_symbol_3(coord.x); +tint_symbol_2 frag_main(tint_symbol_6 tint_symbol_8) { + const float4 coord = tint_symbol_8.coord; + const tint_symbol_2 tint_symbol_1 = {coord.x}; + return tint_symbol_1; } )"); @@ -213,22 +215,23 @@ TEST_F(HlslGeneratorImplTest_Function, float col1; float col2; }; -struct tint_symbol_4 { +struct tint_symbol_2 { float col1 : TEXCOORD1; float col2 : TEXCOORD2; }; -struct tint_symbol_7 { +struct tint_symbol_8 { float col1 : TEXCOORD1; float col2 : TEXCOORD2; }; -tint_symbol_4 vert_main() { - const Interface tint_symbol_6 = Interface(0.5f, 0.25f); - return tint_symbol_4(tint_symbol_6.col1, tint_symbol_6.col2); +tint_symbol_2 vert_main() { + const Interface tint_symbol_5 = {0.5f, 0.25f}; + const tint_symbol_2 tint_symbol_1 = {tint_symbol_5.col1, tint_symbol_5.col2}; + return tint_symbol_1; } -void frag_main(tint_symbol_7 tint_symbol_9) { - const Interface colors = Interface(tint_symbol_9.col1, tint_symbol_9.col2); +void frag_main(tint_symbol_8 tint_symbol_10) { + const Interface colors = {tint_symbol_10.col1, tint_symbol_10.col2}; const float r = colors.col1; const float g = colors.col2; return; @@ -236,8 +239,7 @@ void frag_main(tint_symbol_7 tint_symbol_9) { )"); - // TODO(crbug.com/tint/702): This is not legal HLSL - // Validate(); + Validate(); } TEST_F(HlslGeneratorImplTest_Function, @@ -281,25 +283,28 @@ TEST_F(HlslGeneratorImplTest_Function, EXPECT_EQ(result(), R"(struct VertexOutput { float4 pos; }; -struct tint_symbol_5 { +struct tint_symbol_2 { float4 pos : SV_Position; }; -struct tint_symbol_8 { +struct tint_symbol_6 { float4 pos : SV_Position; }; VertexOutput foo(float x) { - return VertexOutput(float4(x, x, x, 1.0f)); + const VertexOutput tint_symbol_8 = {float4(x, x, x, 1.0f)}; + return tint_symbol_8; } -tint_symbol_5 vert_main1() { - const VertexOutput tint_symbol_7 = VertexOutput(foo(0.5f)); - return tint_symbol_5(tint_symbol_7.pos); +tint_symbol_2 vert_main1() { + const VertexOutput tint_symbol_4 = {foo(0.5f)}; + const tint_symbol_2 tint_symbol_1 = {tint_symbol_4.pos}; + return tint_symbol_1; } -tint_symbol_8 vert_main2() { - const VertexOutput tint_symbol_10 = VertexOutput(foo(0.25f)); - return tint_symbol_8(tint_symbol_10.pos); +tint_symbol_6 vert_main2() { + const VertexOutput tint_symbol_7 = {foo(0.25f)}; + const tint_symbol_6 tint_symbol_5 = {tint_symbol_7.pos}; + return tint_symbol_5; } )"); diff --git a/src/writer/hlsl/generator_impl_sanitizer_test.cc b/src/writer/hlsl/generator_impl_sanitizer_test.cc index 53b7cd1d73..8766b91181 100644 --- a/src/writer/hlsl/generator_impl_sanitizer_test.cc +++ b/src/writer/hlsl/generator_impl_sanitizer_test.cc @@ -51,6 +51,45 @@ TEST_F(HlslSanitizerTest, PromoteArrayInitializerToConstVar) { EXPECT_EQ(expect, got); } +TEST_F(HlslSanitizerTest, PromoteStructInitializerToConstVar) { + auto* str = Structure("S", { + Member("a", ty.i32()), + Member("b", ty.vec3()), + Member("c", ty.i32()), + }); + auto* struct_init = Construct(str, 1, vec3(2.f, 3.f, 4.f), 4); + auto* struct_access = MemberAccessor(struct_init, "b"); + auto* pos = + Var("pos", ty.vec3(), ast::StorageClass::kFunction, struct_access); + + Func("main", ast::VariableList{}, ty.void_(), + ast::StatementList{ + create(pos), + }, + ast::DecorationList{ + create(ast::PipelineStage::kVertex), + }); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + + auto got = result(); + auto* expect = R"(struct S { + int a; + float3 b; + int c; +}; + +void main() { + const S tint_symbol_1 = {1, float3(2.0f, 3.0f, 4.0f), 4}; + float3 pos = tint_symbol_1.b; + return; +} + +)"; + EXPECT_EQ(expect, got); +} } // namespace } // namespace hlsl } // namespace writer