diff --git a/src/type_determiner.h b/src/type_determiner.h index 8f50a9ec00..132dd216da 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -82,6 +82,12 @@ class TypeDeterminer { /// @returns false on error bool DetermineVariableStorageClass(ast::Statement* stmt); + /// Testing method to set a given variable into the type stack + /// @param var the variable to set + void RegisterVariableForTesting(ast::Variable* var) { + variable_stack_.set(var->name(), var); + } + private: void set_error(const Source& src, const std::string& msg); diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 83277a6794..21e8f42fe7 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -207,6 +207,34 @@ uint32_t Builder::GenerateExpression(ast::Expression* expr) { return 0; } +uint32_t Builder::GenerateExpressionAndLoad(ast::Expression* expr) { + auto id = GenerateExpression(expr); + if (id == 0) { + return false; + } + + // Only need to load identifiers + if (!expr->IsIdentifier()) { + return id; + } + if (spirv_id_to_variable_.find(id) == spirv_id_to_variable_.end()) { + error_ = "missing generated ID for variable"; + return 0; + } + auto var = spirv_id_to_variable_[id]; + if (var->is_const()) { + return id; + } + + auto type_id = GenerateTypeIfNeeded(expr->result_type()); + auto result = result_op(); + auto result_id = result.to_i(); + push_function_inst(spv::Op::OpLoad, + {Operand::Int(type_id), result, Operand::Int(id)}); + + return result_id; +} + bool Builder::GenerateFunction(ast::Function* func) { uint32_t func_type_id = GenerateFunctionTypeIfNeeded(func); if (func_type_id == 0) { @@ -283,6 +311,7 @@ bool Builder::GenerateFunctionVariable(ast::Variable* var) { return false; } scope_stack_.set(var->name(), init_id); + spirv_id_to_variable_[init_id] = var; return true; } @@ -308,6 +337,7 @@ bool Builder::GenerateFunctionVariable(ast::Variable* var) { } scope_stack_.set(var->name(), var_id); + spirv_id_to_variable_[var_id] = var; return true; } @@ -337,6 +367,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { return false; } scope_stack_.set_global(var->name(), init_id); + spirv_id_to_variable_[init_id] = var; return true; } @@ -390,6 +421,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { } } scope_stack_.set_global(var->name(), var_id); + spirv_id_to_variable_[var_id] = var; return true; } @@ -516,11 +548,11 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) { uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { if (expr->IsAdd()) { - auto lhs_id = GenerateExpression(expr->lhs()); + auto lhs_id = GenerateExpressionAndLoad(expr->lhs()); if (lhs_id == 0) { return 0; } - auto rhs_id = GenerateExpression(expr->rhs()); + auto rhs_id = GenerateExpressionAndLoad(expr->rhs()); if (rhs_id == 0) { return 0; } diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index ad1805b5e2..fa765ab4ab 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -156,6 +156,10 @@ class Builder { /// @param expr the expression to generate /// @returns the resulting ID of the expression or 0 on error uint32_t GenerateExpression(ast::Expression* expr); + /// Generates an expression and emits a load if necessary + /// @param expr the expression + /// @returns the SPIR-V result id + uint32_t GenerateExpressionAndLoad(ast::Expression* expr); /// Generates the instructions for a function /// @param func the function to generate /// @returns true if the instructions were generated @@ -266,6 +270,7 @@ class Builder { std::unordered_map type_name_to_id_; std::unordered_map const_to_id_; ScopeStack scope_stack_; + std::unordered_map spirv_id_to_variable_; }; } // namespace spirv diff --git a/src/writer/spirv/builder_ident_expression_test.cc b/src/writer/spirv/builder_ident_expression_test.cc index 235c9e209f..d38a3b9e45 100644 --- a/src/writer/spirv/builder_ident_expression_test.cc +++ b/src/writer/spirv/builder_ident_expression_test.cc @@ -18,12 +18,17 @@ #include "src/ast/float_literal.h" #include "src/ast/identifier_expression.h" #include "src/ast/scalar_constructor_expression.h" +#include "src/ast/type/i32_type.h" #include "src/ast/type/f32_type.h" +#include "src/context.h" +#include "src/type_determiner.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" #include "src/ast/variable.h" +#include "src/ast/int_literal.h" #include "src/writer/spirv/builder.h" #include "src/writer/spirv/spv_dump.h" +#include "src/ast/binary_expression.h" namespace tint { namespace writer { @@ -138,7 +143,73 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionVar) { EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 1); } -TEST_F(BuilderTest, DISABLED_IdentifierExpression_MultiName) {} +TEST_F(BuilderTest, IdentifierExpression_Load) { + ast::type::I32Type i32; + + Context ctx; + TypeDeterminer td(&ctx); + + ast::Variable var("var", ast::StorageClass::kPrivate, &i32); + + td.RegisterVariableForTesting(&var); + + auto lhs = std::make_unique("var"); + auto rhs = std::make_unique("var"); + + ast::BinaryExpression expr(ast::BinaryOp::kAdd, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(&var)) << b.error(); + + ASSERT_EQ(b.GenerateBinaryExpression(&expr), 6) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1 +%2 = OpTypePointer Private %3 +%1 = OpVariable %2 Private +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%4 = OpLoad %3 %1 +%5 = OpLoad %3 %1 +%6 = OpIAdd %3 %4 %5 +)"); +} + +TEST_F(BuilderTest, IdentifierExpression_NoLoadConst) { + ast::type::I32Type i32; + + Context ctx; + TypeDeterminer td(&ctx); + + ast::Variable var("var", ast::StorageClass::kNone, &i32); + var.set_constructor(std::make_unique( + std::make_unique(&i32, 2))); + var.set_is_const(true); + + td.RegisterVariableForTesting(&var); + + auto lhs = std::make_unique("var"); + auto rhs = std::make_unique("var"); + + ast::BinaryExpression expr(ast::BinaryOp::kAdd, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(&var)) << b.error(); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 3) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1 +%2 = OpConstant %1 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%3 = OpIAdd %1 %2 %2 +)"); +} } // namespace } // namespace spirv