diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index ce0a78b9f2..95a77f23c5 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -2092,7 +2092,7 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, ast::Literal* lit) { return true; } -bool GeneratorImpl::EmitZeroValue(std::ostream& out, sem::Type* type) { +bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) { if (type->Is()) { out << "false"; } else if (type->Is()) { @@ -2142,6 +2142,18 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, sem::Type* type) { } } out << "}"; + } else if (auto* arr = type->As()) { + out << "{"; + auto* elem = arr->ElemType(); + for (size_t i = 0; i < arr->Count(); i++) { + if (i > 0) { + out << ", "; + } + if (!EmitZeroValue(out, elem)) { + return false; + } + } + out << "}"; } else { diagnostics_.add_error("Invalid type for zero emission: " + type->type_name()); @@ -2626,6 +2638,9 @@ bool GeneratorImpl::EmitVariable(std::ostream& out, bool skip_constructor) { make_indent(out); + auto* sem = builder_.Sem().Get(var); + auto* type = sem->Type(); + // TODO(dsinclair): Handle variable decorations if (!var->decorations().empty()) { diagnostics_.add_error("Variable decorations are not handled yet"); @@ -2633,21 +2648,25 @@ bool GeneratorImpl::EmitVariable(std::ostream& out, } std::ostringstream constructor_out; - if (!skip_constructor && var->constructor() != nullptr) { + if (!skip_constructor) { constructor_out << " = "; - std::ostringstream pre; - if (!EmitExpression(pre, constructor_out, var->constructor())) { - return false; + if (var->constructor()) { + std::ostringstream pre; + if (!EmitExpression(pre, constructor_out, var->constructor())) { + return false; + } + out << pre.str(); + } else { + if (!EmitZeroValue(constructor_out, type)) { + return false; + } } - out << pre.str(); } if (var->is_const()) { out << "const "; } - auto* sem = builder_.Sem().Get(var); - auto* type = sem->Type(); if (!EmitType(out, type, sem->StorageClass(), sem->AccessControl(), builder_.Symbols().NameFor(var->symbol()))) { return false; diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index 2fb4fedf1a..250782b426 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -315,7 +315,7 @@ class GeneratorImpl : public TextGenerator { /// @param out the output stream /// @param type the type to emit the value for /// @returns true if the zero value was successfully emitted. - bool EmitZeroValue(std::ostream& out, sem::Type* type); + bool EmitZeroValue(std::ostream& out, const sem::Type* type); /// Handles generating a variable /// @param out the output stream /// @param var the variable to generate diff --git a/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc b/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc index 62431b875c..2b6d3904b1 100644 --- a/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc +++ b/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc @@ -35,7 +35,7 @@ TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement) { gen.increment_indent(); ASSERT_TRUE(gen.EmitStatement(out, stmt)) << gen.error(); - EXPECT_EQ(result(), " float a;\n"); + EXPECT_EQ(result(), " float a = 0.0f;\n"); } TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const) { @@ -61,7 +61,8 @@ TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Array) { gen.increment_indent(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_THAT(result(), HasSubstr(" float a[5];\n")); + EXPECT_THAT(result(), + HasSubstr(" float a[5] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f};\n")); } TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Private) { diff --git a/test/access/var/matrix.wgsl.expected.hlsl b/test/access/var/matrix.wgsl.expected.hlsl index 91cb9367c1..00e5fe4baf 100644 --- a/test/access/var/matrix.wgsl.expected.hlsl +++ b/test/access/var/matrix.wgsl.expected.hlsl @@ -1,6 +1,6 @@ [numthreads(1, 1, 1)] void main() { - float3x3 m; + float3x3 m = float3x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); const float3 v = m[1]; const float f = v[1]; return; diff --git a/test/access/var/vector.wgsl.expected.hlsl b/test/access/var/vector.wgsl.expected.hlsl index d978baea7a..c13a798285 100644 --- a/test/access/var/vector.wgsl.expected.hlsl +++ b/test/access/var/vector.wgsl.expected.hlsl @@ -1,6 +1,6 @@ [numthreads(1, 1, 1)] void main() { - float3 v; + float3 v = float3(0.0f, 0.0f, 0.0f); const float scalar = v.y; const float2 swizzle2 = v.xz; const float3 swizzle3 = v.xzy;