diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 6b4866dbe8..c84338642b 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -355,7 +355,7 @@ void Builder::GenerateLabel(uint32_t id) { uint32_t Builder::GenerateU32Literal(uint32_t val) { ast::type::U32Type u32; ast::SintLiteral lit(&u32, val); - return GenerateLiteralIfNeeded(&lit); + return GenerateLiteralIfNeeded(nullptr, &lit); } bool Builder::GenerateAssignStatement(ast::AssignmentStatement* assign) { @@ -465,7 +465,7 @@ uint32_t Builder::GenerateExpression(ast::Expression* expr) { return GenerateCallExpression(expr->AsCall()); } if (expr->IsConstructor()) { - return GenerateConstructorExpression(expr->AsConstructor(), false); + return GenerateConstructorExpression(nullptr, expr->AsConstructor(), false); } if (expr->IsIdentifier()) { return GenerateIdentifierExpression(expr->AsIdentifier()); @@ -611,7 +611,7 @@ bool Builder::GenerateFunctionVariable(ast::Variable* var) { // TODO(dsinclair) We could detect if the constructor is fully const and emit // an initializer value for the variable instead of doing the OpLoad. ast::NullLiteral nl(var->type()->UnwrapPtrIfNeeded()); - auto null_id = GenerateLiteralIfNeeded(&nl); + auto null_id = GenerateLiteralIfNeeded(var, &nl); if (null_id == 0) { return 0; } @@ -642,8 +642,8 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { return false; } - init_id = GenerateConstructorExpression(var->constructor()->AsConstructor(), - true); + init_id = GenerateConstructorExpression( + var, var->constructor()->AsConstructor(), true); if (init_id == 0) { return false; } @@ -689,7 +689,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { var->storage_class() == ast::StorageClass::kNone || var->storage_class() == ast::StorageClass::kOutput) { ast::NullLiteral nl(var->type()->UnwrapPtrIfNeeded()); - init_id = GenerateLiteralIfNeeded(&nl); + init_id = GenerateLiteralIfNeeded(var, &nl); if (init_id == 0) { return 0; } @@ -718,6 +718,8 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { spv::Op::OpDecorate, {Operand::Int(var_id), Operand::Int(SpvDecorationDescriptorSet), Operand::Int(deco->AsSet()->value())}); + } else if (deco->IsConstantId()) { + // Spec constants are handled elsewhere } else { error_ = "unknown decoration"; return false; @@ -1033,10 +1035,11 @@ void Builder::GenerateGLSLstd450Import() { } uint32_t Builder::GenerateConstructorExpression( + ast::Variable* var, ast::ConstructorExpression* expr, bool is_global_init) { if (expr->IsScalarConstructor()) { - return GenerateLiteralIfNeeded(expr->AsScalarConstructor()->literal()); + return GenerateLiteralIfNeeded(var, expr->AsScalarConstructor()->literal()); } if (expr->IsTypeConstructor()) { return GenerateTypeConstructorExpression(expr->AsTypeConstructor(), @@ -1055,7 +1058,7 @@ uint32_t Builder::GenerateTypeConstructorExpression( // Generate the zero initializer if there are no values provided. if (values.empty()) { ast::NullLiteral nl(init->type()->UnwrapPtrIfNeeded()); - return GenerateLiteralIfNeeded(&nl); + return GenerateLiteralIfNeeded(nullptr, &nl); } std::ostringstream out; @@ -1102,7 +1105,8 @@ uint32_t Builder::GenerateTypeConstructorExpression( for (const auto& e : values) { uint32_t id = 0; if (constructor_is_const) { - id = GenerateConstructorExpression(e->AsConstructor(), is_global_init); + id = GenerateConstructorExpression(nullptr, e->AsConstructor(), + is_global_init); } else { id = GenerateExpression(e.get()); id = GenerateLoadIfNeeded(e->result_type(), id); @@ -1268,12 +1272,21 @@ uint32_t Builder::GenerateCastOrCopy(ast::type::Type* to_type, return result_id; } -uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) { +uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var, + ast::Literal* lit) { auto type_id = GenerateTypeIfNeeded(lit->type()); if (type_id == 0) { return 0; } + auto name = lit->name(); + bool is_spec_constant = false; + if (var && var->IsDecorated() && + var->AsDecorated()->HasConstantIdDecoration()) { + name = "__spec" + name; + is_spec_constant = true; + } + auto val = const_to_id_.find(name); if (val != const_to_id_.end()) { return val->second; @@ -1282,21 +1295,34 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) { auto result = result_op(); auto result_id = result.to_i(); + if (is_spec_constant) { + push_annot(spv::Op::OpDecorate, + {Operand::Int(result_id), Operand::Int(SpvDecorationSpecId), + Operand::Int(var->AsDecorated()->constant_id())}); + } + if (lit->IsBool()) { if (lit->AsBool()->IsTrue()) { - push_type(spv::Op::OpConstantTrue, {Operand::Int(type_id), result}); + push_type(is_spec_constant ? spv::Op::OpSpecConstantTrue + : spv::Op::OpConstantTrue, + {Operand::Int(type_id), result}); } else { - push_type(spv::Op::OpConstantFalse, {Operand::Int(type_id), result}); + push_type(is_spec_constant ? spv::Op::OpSpecConstantFalse + : spv::Op::OpConstantFalse, + {Operand::Int(type_id), result}); } } else if (lit->IsSint()) { - push_type(spv::Op::OpConstant, {Operand::Int(type_id), result, - Operand::Int(lit->AsSint()->value())}); + push_type( + is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant, + {Operand::Int(type_id), result, Operand::Int(lit->AsSint()->value())}); } else if (lit->IsUint()) { - push_type(spv::Op::OpConstant, {Operand::Int(type_id), result, - Operand::Int(lit->AsUint()->value())}); + push_type( + is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant, + {Operand::Int(type_id), result, Operand::Int(lit->AsUint()->value())}); } else if (lit->IsFloat()) { - push_type(spv::Op::OpConstant, {Operand::Int(type_id), result, - Operand::Float(lit->AsFloat()->value())}); + push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant, + {Operand::Int(type_id), result, + Operand::Float(lit->AsFloat()->value())}); } else if (lit->IsNull()) { push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result}); } else { @@ -1730,7 +1756,8 @@ uint32_t Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident, spirv_params.push_back(Operand::Int(SpvImageOperandsLodMask)); ast::type::F32Type f32; ast::FloatLiteral float_0(&f32, 0.0); - spirv_params.push_back(Operand::Int(GenerateLiteralIfNeeded(&float_0))); + spirv_params.push_back( + Operand::Int(GenerateLiteralIfNeeded(nullptr, &float_0))); } if (op == spv::Op::OpNop) { error_ = "unable to determine operator for: " + ident->name(); diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index ccfce3912a..b6e3f7ce79 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -249,10 +249,12 @@ class Builder { /// Generates an import instruction void GenerateGLSLstd450Import(); /// Generates a constructor expression + /// @param var the variable generated for, nullptr if no variable associated. /// @param expr the expression to generate /// @param is_global_init set true if this is a global variable constructor /// @returns the ID of the expression or 0 on failure. - uint32_t GenerateConstructorExpression(ast::ConstructorExpression* expr, + uint32_t GenerateConstructorExpression(ast::Variable* var, + ast::ConstructorExpression* expr, bool is_global_init); /// Generates a type constructor expression /// @param init the expression to generate @@ -262,9 +264,10 @@ class Builder { ast::TypeConstructorExpression* init, bool is_global_init); /// Generates a literal constant if needed + /// @param var the variable generated for, nullptr if no variable associated. /// @param lit the literal to generate /// @returns the ID on success or 0 on failure - uint32_t GenerateLiteralIfNeeded(ast::Literal* lit); + uint32_t GenerateLiteralIfNeeded(ast::Variable* var, ast::Literal* lit); /// Generates a binary expression /// @param expr the expression to generate /// @returns the expression ID on success or 0 otherwise diff --git a/src/writer/spirv/builder_constructor_expression_test.cc b/src/writer/spirv/builder_constructor_expression_test.cc index 73f07ad55c..12ac999f5d 100644 --- a/src/writer/spirv/builder_constructor_expression_test.cc +++ b/src/writer/spirv/builder_constructor_expression_test.cc @@ -56,7 +56,7 @@ TEST_F(BuilderTest, Constructor_Const) { ast::Module mod; Builder b(&mod); - EXPECT_EQ(b.GenerateConstructorExpression(&c, true), 2u); + EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &c, true), 2u); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32 @@ -84,7 +84,7 @@ TEST_F(BuilderTest, Constructor_Type) { EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); Builder b(&mod); - EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 5u); + EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &t, true), 5u); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 @@ -190,7 +190,7 @@ TEST_F(BuilderTest, Constructor_Type_NonConst_Value_Fails) { ast::Module mod; Builder b(&mod); - EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 0u); + EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &t, true), 0u); EXPECT_TRUE(b.has_error()); EXPECT_EQ(b.error(), R"(constructor must be a constant expression)"); } @@ -765,7 +765,7 @@ TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec3_With_F32_Vec2) { Builder b(&mod); b.push_function(Function{}); - EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u); + EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 %1 = OpTypeVector %2 3 @@ -807,7 +807,7 @@ TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec3_With_Vec2_F32) { Builder b(&mod); b.push_function(Function{}); - EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u); + EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 %1 = OpTypeVector %2 3 @@ -851,7 +851,7 @@ TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_F32_F32_Vec2) { Builder b(&mod); b.push_function(Function{}); - EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u); + EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 %1 = OpTypeVector %2 4 @@ -895,7 +895,7 @@ TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_F32_Vec2_F32) { Builder b(&mod); b.push_function(Function{}); - EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u); + EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 %1 = OpTypeVector %2 4 @@ -939,7 +939,7 @@ TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_Vec2_F32_F32) { Builder b(&mod); b.push_function(Function{}); - EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u); + EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 %1 = OpTypeVector %2 4 @@ -987,7 +987,7 @@ TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_Vec2_Vec2) { Builder b(&mod); b.push_function(Function{}); - EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u); + EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 13u); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 %1 = OpTypeVector %2 4 @@ -1033,7 +1033,7 @@ TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_F32_Vec3) { Builder b(&mod); b.push_function(Function{}); - EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u); + EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 13u); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 %1 = OpTypeVector %2 4 @@ -1079,7 +1079,7 @@ TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_Vec3_F32) { Builder b(&mod); b.push_function(Function{}); - EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u); + EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 13u); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 %1 = OpTypeVector %2 4 diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc index 1e60c0fcfc..73ae3035a0 100644 --- a/src/writer/spirv/builder_global_variable_test.cc +++ b/src/writer/spirv/builder_global_variable_test.cc @@ -16,14 +16,18 @@ #include "gtest/gtest.h" #include "src/ast/binding_decoration.h" +#include "src/ast/bool_literal.h" #include "src/ast/builtin.h" #include "src/ast/builtin_decoration.h" +#include "src/ast/constant_id_decoration.h" #include "src/ast/decorated_variable.h" #include "src/ast/float_literal.h" #include "src/ast/location_decoration.h" +#include "src/ast/module.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/set_decoration.h" #include "src/ast/storage_class.h" +#include "src/ast/type/bool_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" @@ -359,6 +363,58 @@ TEST_F(BuilderTest, GlobalVar_WithBuiltin) { )"); } +TEST_F(BuilderTest, GlobalVar_ConstantId_Bool) { + ast::type::BoolType bool_type; + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(1200)); + + ast::DecoratedVariable v(std::make_unique( + "var", ast::StorageClass::kNone, &bool_type)); + v.set_decorations(std::move(decos)); + v.set_constructor(std::make_unique( + std::make_unique(&bool_type, true))); + + ast::Module mod; + Builder b(&mod); + EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "var" +)"); + EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 1200 +)"); + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool +%2 = OpSpecConstantTrue %1 +%4 = OpTypePointer Private %1 +%3 = OpVariable %4 Private %2 +)"); +} + +TEST_F(BuilderTest, GlobalVar_ConstantId_Scalar) { + ast::type::F32Type f32; + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + + ast::DecoratedVariable v( + std::make_unique("var", ast::StorageClass::kNone, &f32)); + v.set_decorations(std::move(decos)); + v.set_constructor(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::Module mod; + Builder b(&mod); + EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "var" +)"); + EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0 +)"); + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32 +%2 = OpSpecConstant %1 2 +%4 = OpTypePointer Private %1 +%3 = OpVariable %4 Private %2 +)"); +} + struct BuiltinData { ast::Builtin builtin; SpvBuiltIn result; diff --git a/src/writer/spirv/builder_literal_test.cc b/src/writer/spirv/builder_literal_test.cc index 1c96661b73..294c6b6c15 100644 --- a/src/writer/spirv/builder_literal_test.cc +++ b/src/writer/spirv/builder_literal_test.cc @@ -38,7 +38,7 @@ TEST_F(BuilderTest, Literal_Bool_True) { ast::Module mod; Builder b(&mod); - auto id = b.GenerateLiteralIfNeeded(&b_true); + auto id = b.GenerateLiteralIfNeeded(nullptr, &b_true); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(2u, id); @@ -53,7 +53,7 @@ TEST_F(BuilderTest, Literal_Bool_False) { ast::Module mod; Builder b(&mod); - auto id = b.GenerateLiteralIfNeeded(&b_false); + auto id = b.GenerateLiteralIfNeeded(nullptr, &b_false); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(2u, id); @@ -69,11 +69,11 @@ TEST_F(BuilderTest, Literal_Bool_Dedup) { ast::Module mod; Builder b(&mod); - ASSERT_NE(b.GenerateLiteralIfNeeded(&b_true), 0u); + ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &b_true), 0u); ASSERT_FALSE(b.has_error()) << b.error(); - ASSERT_NE(b.GenerateLiteralIfNeeded(&b_false), 0u); + ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &b_false), 0u); ASSERT_FALSE(b.has_error()) << b.error(); - ASSERT_NE(b.GenerateLiteralIfNeeded(&b_true), 0u); + ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &b_true), 0u); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool @@ -88,7 +88,7 @@ TEST_F(BuilderTest, Literal_I32) { ast::Module mod; Builder b(&mod); - auto id = b.GenerateLiteralIfNeeded(&i); + auto id = b.GenerateLiteralIfNeeded(nullptr, &i); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(2u, id); @@ -104,8 +104,8 @@ TEST_F(BuilderTest, Literal_I32_Dedup) { ast::Module mod; Builder b(&mod); - ASSERT_NE(b.GenerateLiteralIfNeeded(&i1), 0u); - ASSERT_NE(b.GenerateLiteralIfNeeded(&i2), 0u); + ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i1), 0u); + ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i2), 0u); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1 @@ -119,7 +119,7 @@ TEST_F(BuilderTest, Literal_U32) { ast::Module mod; Builder b(&mod); - auto id = b.GenerateLiteralIfNeeded(&i); + auto id = b.GenerateLiteralIfNeeded(nullptr, &i); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(2u, id); @@ -135,8 +135,8 @@ TEST_F(BuilderTest, Literal_U32_Dedup) { ast::Module mod; Builder b(&mod); - ASSERT_NE(b.GenerateLiteralIfNeeded(&i1), 0u); - ASSERT_NE(b.GenerateLiteralIfNeeded(&i2), 0u); + ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i1), 0u); + ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i2), 0u); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0 @@ -150,7 +150,7 @@ TEST_F(BuilderTest, Literal_F32) { ast::Module mod; Builder b(&mod); - auto id = b.GenerateLiteralIfNeeded(&i); + auto id = b.GenerateLiteralIfNeeded(nullptr, &i); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(2u, id); @@ -166,8 +166,8 @@ TEST_F(BuilderTest, Literal_F32_Dedup) { ast::Module mod; Builder b(&mod); - ASSERT_NE(b.GenerateLiteralIfNeeded(&i1), 0u); - ASSERT_NE(b.GenerateLiteralIfNeeded(&i2), 0u); + ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i1), 0u); + ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i2), 0u); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32