diff --git a/src/ast/transform/vertex_pulling_transform.cc b/src/ast/transform/vertex_pulling_transform.cc index fc571d5737..528c7e57f7 100644 --- a/src/ast/transform/vertex_pulling_transform.cc +++ b/src/ast/transform/vertex_pulling_transform.cc @@ -48,6 +48,7 @@ static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_"; static const char kStructBufferName[] = "data"; static const char kPullingPosVarName[] = "tint_pulling_pos"; static const char kDefaultVertexIndexName[] = "tint_pulling_vertex_index"; +static const char kDefaultInstanceIndexName[] = "tint_pulling_instance_index"; } // namespace VertexPullingTransform::VertexPullingTransform(Context* ctx, Module* mod) @@ -100,7 +101,8 @@ bool VertexPullingTransform::Run() { // TODO(idanr): Make sure we covered all error cases, to guarantee the // following stages will pass - FindOrInsertVertexIndex(); + FindOrInsertVertexIndexIfUsed(); + FindOrInsertInstanceIndexIfUsed(); ConvertVertexInputVariablesToPrivate(); AddVertexStorageBuffers(); AddVertexPullingPreamble(vertex_func); @@ -116,7 +118,19 @@ std::string VertexPullingTransform::GetVertexBufferName(uint32_t index) { return kVertexBufferNamePrefix + std::to_string(index); } -void VertexPullingTransform::FindOrInsertVertexIndex() { +void VertexPullingTransform::FindOrInsertVertexIndexIfUsed() { + bool uses_vertex_step_mode = false; + for (const VertexBufferLayoutDescriptor& buffer_layout : + vertex_state_->vertex_buffers) { + if (buffer_layout.step_mode == InputStepMode::kVertex) { + uses_vertex_step_mode = true; + break; + } + } + if (!uses_vertex_step_mode) { + return; + } + // Look for an existing vertex index builtin for (auto& v : mod_->global_variables()) { if (!v->IsDecorated() || v->storage_class() != StorageClass::kInput) { @@ -145,6 +159,47 @@ void VertexPullingTransform::FindOrInsertVertexIndex() { mod_->AddGlobalVariable(std::move(var)); } +void VertexPullingTransform::FindOrInsertInstanceIndexIfUsed() { + bool uses_instance_step_mode = false; + for (const VertexBufferLayoutDescriptor& buffer_layout : + vertex_state_->vertex_buffers) { + if (buffer_layout.step_mode == InputStepMode::kInstance) { + uses_instance_step_mode = true; + break; + } + } + if (!uses_instance_step_mode) { + return; + } + + // Look for an existing instance index builtin + for (auto& v : mod_->global_variables()) { + if (!v->IsDecorated() || v->storage_class() != StorageClass::kInput) { + continue; + } + + for (auto& d : v->AsDecorated()->decorations()) { + if (d->IsBuiltin() && d->AsBuiltin()->value() == Builtin::kInstanceIdx) { + instance_index_name_ = v->name(); + return; + } + } + } + + // We didn't find an instance index builtin, so create one + instance_index_name_ = kDefaultInstanceIndexName; + + auto var = std::make_unique(std::make_unique( + instance_index_name_, StorageClass::kInput, GetI32Type())); + + VariableDecorationList decorations; + decorations.push_back( + std::make_unique(Builtin::kInstanceIdx)); + + var->set_decorations(std::move(decorations)); + mod_->AddGlobalVariable(std::move(var)); +} + void VertexPullingTransform::ConvertVertexInputVariablesToPrivate() { for (auto& v : mod_->global_variables()) { if (!v->IsDecorated() || v->storage_class() != StorageClass::kInput) { @@ -228,12 +283,17 @@ void VertexPullingTransform::AddVertexPullingPreamble(Function* vertex_func) { } auto* v = it->second; + // Identifier to index by + auto index_identifier = std::make_unique( + buffer_layout.step_mode == InputStepMode::kVertex + ? vertex_index_name_ + : instance_index_name_); + // An expression for the start of the read in the buffer in bytes auto pos_value = std::make_unique( BinaryOp::kAdd, std::make_unique( - BinaryOp::kMultiply, - std::make_unique(vertex_index_name_), + BinaryOp::kMultiply, std::move(index_identifier), GenUint(static_cast(buffer_layout.array_stride))), GenUint(static_cast(attribute_desc.offset))); diff --git a/src/ast/transform/vertex_pulling_transform.h b/src/ast/transform/vertex_pulling_transform.h index 636ddc44fa..4ad6321381 100644 --- a/src/ast/transform/vertex_pulling_transform.h +++ b/src/ast/transform/vertex_pulling_transform.h @@ -159,7 +159,10 @@ class VertexPullingTransform { std::string GetVertexBufferName(uint32_t index); /// Inserts vertex_idx binding, or finds the existing one - void FindOrInsertVertexIndex(); + void FindOrInsertVertexIndexIfUsed(); + + /// Inserts instance_idx binding, or finds the existing one + void FindOrInsertInstanceIndexIfUsed(); /// Converts var with a location decoration to var void ConvertVertexInputVariablesToPrivate(); @@ -237,6 +240,7 @@ class VertexPullingTransform { std::string error_; std::string vertex_index_name_; + std::string instance_index_name_; std::unordered_map location_to_var_; std::unique_ptr vertex_state_; diff --git a/src/ast/transform/vertex_pulling_transform_test.cc b/src/ast/transform/vertex_pulling_transform_test.cc index c0cc84171d..cafefda2ab 100644 --- a/src/ast/transform/vertex_pulling_transform_test.cc +++ b/src/ast/transform/vertex_pulling_transform_test.cc @@ -201,27 +201,14 @@ TEST_F(VertexPullingTransformTest, OneAttribute) { mod()->to_str()); } -// We expect the transform to use an existing vertex_idx builtin variable if it -// finds one -TEST_F(VertexPullingTransformTest, ExistingVertexIndex) { +TEST_F(VertexPullingTransformTest, OneInstancedAttribute) { InitBasicModule(); type::F32Type f32; AddVertexInputVariable(0, "var_a", &f32); - type::I32Type i32; - auto vertex_index_var = - std::make_unique(std::make_unique( - "custom_vertex_index", StorageClass::kInput, &i32)); - - VariableDecorationList decorations; - decorations.push_back( - std::make_unique(Builtin::kVertexIdx)); - - vertex_index_var->set_decorations(std::move(decorations)); - mod()->AddGlobalVariable(std::move(vertex_index_var)); - - InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}); + InitTransform( + {{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}}); EXPECT_TRUE(transform()->Run()); @@ -231,6 +218,122 @@ TEST_F(VertexPullingTransformTest, ExistingVertexIndex) { private __f32 } + DecoratedVariable{ + Decorations{ + BuiltinDecoration{instance_idx} + } + tint_pulling_instance_index + in + __i32 + } + DecoratedVariable{ + Decorations{ + BindingDecoration{0} + SetDecoration{0} + } + tint_pulling_vertex_buffer_0 + storage_buffer + __struct_ + } + EntryPoint{vertex as main = vtx_main} + Function vtx_main -> __void + () + { + Block{ + VariableDeclStatement{ + Variable{ + tint_pulling_pos + function + __i32 + } + } + Assignment{ + Identifier{tint_pulling_pos} + Binary{ + Binary{ + Identifier{tint_pulling_instance_index} + multiply + ScalarConstructor{4} + } + add + ScalarConstructor{0} + } + } + Assignment{ + Identifier{var_a} + As<__f32>{ + ArrayAccessor{ + MemberAccessor{ + Identifier{tint_pulling_vertex_buffer_0} + Identifier{data} + } + Binary{ + Identifier{tint_pulling_pos} + divide + ScalarConstructor{4} + } + } + } + } + } + } +} +)", + mod()->to_str()); +} + +// We expect the transform to use an existing builtin variables if it finds them +TEST_F(VertexPullingTransformTest, ExistingVertexIndexAndInstanceIndex) { + InitBasicModule(); + + type::F32Type f32; + AddVertexInputVariable(0, "var_a", &f32); + AddVertexInputVariable(1, "var_b", &f32); + + type::I32Type i32; + { + auto vertex_index_var = + std::make_unique(std::make_unique( + "custom_vertex_index", StorageClass::kInput, &i32)); + + VariableDecorationList decorations; + decorations.push_back( + std::make_unique(Builtin::kVertexIdx)); + + vertex_index_var->set_decorations(std::move(decorations)); + mod()->AddGlobalVariable(std::move(vertex_index_var)); + } + + { + auto instance_index_var = + std::make_unique(std::make_unique( + "custom_instance_index", StorageClass::kInput, &i32)); + + VariableDecorationList decorations; + decorations.push_back( + std::make_unique(Builtin::kInstanceIdx)); + + instance_index_var->set_decorations(std::move(decorations)); + mod()->AddGlobalVariable(std::move(instance_index_var)); + } + + InitTransform( + {{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}, + {4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}}); + + EXPECT_TRUE(transform()->Run()); + + EXPECT_EQ(R"(Module{ + Variable{ + var_a + private + __f32 + } + Variable{ + var_b + private + __f32 + } DecoratedVariable{ Decorations{ BuiltinDecoration{vertex_idx} @@ -239,6 +342,14 @@ TEST_F(VertexPullingTransformTest, ExistingVertexIndex) { in __i32 } + DecoratedVariable{ + Decorations{ + BuiltinDecoration{instance_idx} + } + custom_instance_index + in + __i32 + } DecoratedVariable{ Decorations{ BindingDecoration{0} @@ -248,6 +359,15 @@ TEST_F(VertexPullingTransformTest, ExistingVertexIndex) { storage_buffer __struct_ } + DecoratedVariable{ + Decorations{ + BindingDecoration{1} + SetDecoration{0} + } + tint_pulling_vertex_buffer_1 + storage_buffer + __struct_ + } EntryPoint{vertex as main = vtx_main} Function vtx_main -> __void () @@ -288,6 +408,34 @@ TEST_F(VertexPullingTransformTest, ExistingVertexIndex) { } } } + Assignment{ + Identifier{tint_pulling_pos} + Binary{ + Binary{ + Identifier{custom_instance_index} + multiply + ScalarConstructor{4} + } + add + ScalarConstructor{0} + } + } + Assignment{ + Identifier{var_b} + As<__f32>{ + ArrayAccessor{ + MemberAccessor{ + Identifier{tint_pulling_vertex_buffer_1} + Identifier{data} + } + Binary{ + Identifier{tint_pulling_pos} + divide + ScalarConstructor{4} + } + } + } + } } } }