diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 27300fff03..8a74ad5014 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -688,8 +688,33 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { if (var->is_const()) { if (!var->has_constructor()) { - error_ = "missing constructor for constant"; - return false; + // Constants must have an initializer unless they have an override + // decoration. + if (!ast::HasDecoration(var->decorations())) { + error_ = "missing constructor for constant"; + return false; + } + + // SPIR-V requires specialization constants to have initializers. + if (sem->Type()->Is()) { + ast::FloatLiteral l(ProgramID(), Source{}, 0.0f); + init_id = GenerateLiteralIfNeeded(var, &l); + } else if (sem->Type()->Is()) { + ast::UintLiteral l(ProgramID(), Source{}, 0); + init_id = GenerateLiteralIfNeeded(var, &l); + } else if (sem->Type()->Is()) { + ast::SintLiteral l(ProgramID(), Source{}, 0); + init_id = GenerateLiteralIfNeeded(var, &l); + } else if (sem->Type()->Is()) { + ast::BoolLiteral l(ProgramID(), Source{}, false); + init_id = GenerateLiteralIfNeeded(var, &l); + } else { + error_ = "invalid type for pipeline constant ID, must be scalar"; + return false; + } + if (init_id == 0) { + return 0; + } } push_debug(spv::Op::OpName, {Operand::Int(init_id), @@ -743,37 +768,10 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { } } } else if (!type_no_ac->Is()) { - // Certain cases require us to generate a constructor value. - // - // 1- Pipeline constant IDs must be attached to the OpConstant, if we have a - // variable with an override attribute that doesn't have a constructor we - // make one - // 2- If we don't have a constructor and we're an Output or Private variable - // then WGSL requires an initializer. - if (ast::HasDecoration(var->decorations())) { - if (type_no_ac->Is()) { - ast::FloatLiteral l(ProgramID(), Source{}, 0.0f); - init_id = GenerateLiteralIfNeeded(var, &l); - } else if (type_no_ac->Is()) { - ast::UintLiteral l(ProgramID(), Source{}, 0); - init_id = GenerateLiteralIfNeeded(var, &l); - } else if (type_no_ac->Is()) { - ast::SintLiteral l(ProgramID(), Source{}, 0); - init_id = GenerateLiteralIfNeeded(var, &l); - } else if (type_no_ac->Is()) { - ast::BoolLiteral l(ProgramID(), Source{}, false); - init_id = GenerateLiteralIfNeeded(var, &l); - } else { - error_ = "invalid type for pipeline constant ID, must be scalar"; - return false; - } - if (init_id == 0) { - return 0; - } - ops.push_back(Operand::Int(init_id)); - } else if (sem->StorageClass() == ast::StorageClass::kPrivate || - sem->StorageClass() == ast::StorageClass::kNone || - sem->StorageClass() == ast::StorageClass::kOutput) { + // If we don't have a constructor and we're an Output or Private variable, + // then WGSL requires that we zero-initialize. + if (sem->StorageClass() == ast::StorageClass::kPrivate || + sem->StorageClass() == ast::StorageClass::kOutput) { init_id = GenerateConstantNullIfNeeded(type_no_ac); if (init_id == 0) { return 0; diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc index 51a5b9db1f..c264fc5014 100644 --- a/src/writer/spirv/builder_global_variable_test.cc +++ b/src/writer/spirv/builder_global_variable_test.cc @@ -203,123 +203,111 @@ TEST_F(BuilderTest, GlobalVar_WithBuiltin) { )"); } -TEST_F(BuilderTest, GlobalVar_ConstantId_Bool) { - auto* v = Global("var", ty.bool_(), ast::StorageClass::kInput, Expr(true), - ast::DecorationList{ - create(1200), - }); +TEST_F(BuilderTest, GlobalVar_Override_Bool) { + auto* v = GlobalConst("var", ty.bool_(), Expr(true), + ast::DecorationList{ + create(1200), + }); spirv::Builder& b = Build(); EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); - EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "var" + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var" )"); EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 1200 )"); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool %2 = OpSpecConstantTrue %1 -%4 = OpTypePointer Input %1 -%3 = OpVariable %4 Input %2 )"); } -TEST_F(BuilderTest, GlobalVar_ConstantId_Bool_NoConstructor) { - auto* v = Global("var", ty.bool_(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{ - create(1200), - }); +TEST_F(BuilderTest, GlobalVar_Override_Bool_NoConstructor) { + auto* v = GlobalConst("var", ty.bool_(), nullptr, + ast::DecorationList{ + create(1200), + }); spirv::Builder& b = Build(); EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); - EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "var" + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var" )"); - EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %4 SpecId 1200 + EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 1200 )"); - EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeBool -%2 = OpTypePointer Input %3 -%4 = OpSpecConstantFalse %3 -%1 = OpVariable %2 Input %4 + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool +%2 = OpSpecConstantFalse %1 )"); } -TEST_F(BuilderTest, GlobalVar_ConstantId_Scalar) { - auto* v = Global("var", ty.f32(), ast::StorageClass::kInput, Expr(2.f), - ast::DecorationList{ - create(0), - }); +TEST_F(BuilderTest, GlobalVar_Override_Scalar) { + auto* v = GlobalConst("var", ty.f32(), Expr(2.f), + ast::DecorationList{ + create(0), + }); spirv::Builder& b = Build(); EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); - EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "var" + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var" )"); EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0 )"); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32 %2 = OpSpecConstant %1 2 -%4 = OpTypePointer Input %1 -%3 = OpVariable %4 Input %2 )"); } -TEST_F(BuilderTest, GlobalVar_ConstantId_Scalar_F32_NoConstructor) { - auto* v = Global("var", ty.f32(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{ - create(0), - }); +TEST_F(BuilderTest, GlobalVar_Override_Scalar_F32_NoConstructor) { + auto* v = GlobalConst("var", ty.f32(), nullptr, + ast::DecorationList{ + create(0), + }); spirv::Builder& b = Build(); EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); - EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "var" + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var" )"); - EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %4 SpecId 0 + EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0 )"); - EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32 -%2 = OpTypePointer Input %3 -%4 = OpSpecConstant %3 0 -%1 = OpVariable %2 Input %4 + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32 +%2 = OpSpecConstant %1 0 )"); } -TEST_F(BuilderTest, GlobalVar_ConstantId_Scalar_I32_NoConstructor) { - auto* v = Global("var", ty.i32(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{ - create(0), - }); +TEST_F(BuilderTest, GlobalVar_Override_Scalar_I32_NoConstructor) { + auto* v = GlobalConst("var", ty.i32(), nullptr, + ast::DecorationList{ + create(0), + }); spirv::Builder& b = Build(); EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); - EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "var" + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var" )"); - EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %4 SpecId 0 + EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0 )"); - EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1 -%2 = OpTypePointer Input %3 -%4 = OpSpecConstant %3 0 -%1 = OpVariable %2 Input %4 + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1 +%2 = OpSpecConstant %1 0 )"); } -TEST_F(BuilderTest, GlobalVar_ConstantId_Scalar_U32_NoConstructor) { - auto* v = Global("var", ty.u32(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{ - create(0), - }); +TEST_F(BuilderTest, GlobalVar_Override_Scalar_U32_NoConstructor) { + auto* v = GlobalConst("var", ty.u32(), nullptr, + ast::DecorationList{ + create(0), + }); spirv::Builder& b = Build(); EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); - EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "var" + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var" )"); - EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %4 SpecId 0 + EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0 )"); - EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 0 -%2 = OpTypePointer Input %3 -%4 = OpSpecConstant %3 0 -%1 = OpVariable %2 Input %4 + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0 +%2 = OpSpecConstant %1 0 )"); }