diff --git a/src/ast/expression.h b/src/ast/expression.h index 3055bce77f..f82fe9be93 100644 --- a/src/ast/expression.h +++ b/src/ast/expression.h @@ -93,7 +93,7 @@ class Expression : public Node { /// @returns the expression as a unary op expression const UnaryOpExpression* AsUnaryOp() const; - /// @returns the expression as an array accessor + /// @returns the expression as an array accessor ArrayAccessorExpression* AsArrayAccessor(); /// @returns the expression as an as AsExpression* AsAs(); diff --git a/src/ast/type/type.cc b/src/ast/type/type.cc index 1814876efe..55d045a779 100644 --- a/src/ast/type/type.cc +++ b/src/ast/type/type.cc @@ -36,6 +36,13 @@ Type::Type() = default; Type::~Type() = default; +Type* Type::UnwrapPtrIfNeeded() { + if (IsPointer()) { + return AsPointer()->type(); + } + return this; +} + bool Type::IsAlias() const { return false; } diff --git a/src/ast/type/type.h b/src/ast/type/type.h index 4a76ccfa67..52e8c81952 100644 --- a/src/ast/type/type.h +++ b/src/ast/type/type.h @@ -66,6 +66,9 @@ class Type { /// @returns the name for this type. The |type_name| is unique over all types. virtual std::string type_name() const = 0; + /// @returns the pointee type if this is a pointer, |this| otherwise + Type* UnwrapPtrIfNeeded(); + /// @returns true if this type is a float scalar bool is_float_scalar(); /// @returns true if this type is a float matrix diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc index d86182d7c8..295f1bff4e 100644 --- a/src/reader/spirv/function_cfg_test.cc +++ b/src/reader/spirv/function_cfg_test.cc @@ -1159,8 +1159,7 @@ TEST_F(SpvParserTest, EXPECT_THAT(fe.block_order(), ElementsAre(10, 30, 20, 80, 99)); } -TEST_F(SpvParserTest, - ComputeBlockOrder_Switch_DefaultSameAsACase) { +TEST_F(SpvParserTest, ComputeBlockOrder_Switch_DefaultSameAsACase) { auto* p = parser(test::Assemble(CommonTypes() + R"( %100 = OpFunction %void None %voidfn @@ -1373,8 +1372,7 @@ TEST_F(SpvParserTest, << assembly; } -TEST_F(SpvParserTest, - ComputeBlockOrder_Nest_If_Contains_If) { +TEST_F(SpvParserTest, ComputeBlockOrder_Nest_If_Contains_If) { auto assembly = CommonTypes() + R"( %100 = OpFunction %void None %voidfn @@ -1424,8 +1422,7 @@ TEST_F(SpvParserTest, << assembly; } -TEST_F(SpvParserTest, - ComputeBlockOrder_Nest_If_In_SwitchCase) { +TEST_F(SpvParserTest, ComputeBlockOrder_Nest_If_In_SwitchCase) { auto assembly = CommonTypes() + R"( %100 = OpFunction %void None %voidfn @@ -1475,8 +1472,7 @@ TEST_F(SpvParserTest, << assembly; } -TEST_F(SpvParserTest, - ComputeBlockOrder_Nest_IfFallthrough_In_SwitchCase) { +TEST_F(SpvParserTest, ComputeBlockOrder_Nest_IfFallthrough_In_SwitchCase) { auto assembly = CommonTypes() + R"( %100 = OpFunction %void None %voidfn @@ -1526,8 +1522,7 @@ TEST_F(SpvParserTest, << assembly; } -TEST_F(SpvParserTest, - ComputeBlockOrder_Nest_IfBreak_In_SwitchCase) { +TEST_F(SpvParserTest, ComputeBlockOrder_Nest_IfBreak_In_SwitchCase) { auto assembly = CommonTypes() + R"( %100 = OpFunction %void None %voidfn diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 9b8cf73847..b081eade5b 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -38,6 +38,7 @@ #include "src/ast/type/bool_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/matrix_type.h" +#include "src/ast/type/pointer_type.h" #include "src/ast/type/struct_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" @@ -274,19 +275,30 @@ bool TypeDeterminer::DetermineArrayAccessor( if (!DetermineResultType(expr->array())) { return false; } - auto* parent_type = expr->array()->result_type(); + + auto* res = expr->array()->result_type(); + auto* parent_type = res->UnwrapPtrIfNeeded(); + ast::type::Type* ret = nullptr; if (parent_type->IsArray()) { - expr->set_result_type(parent_type->AsArray()->type()); + ret = parent_type->AsArray()->type(); } else if (parent_type->IsVector()) { - expr->set_result_type(parent_type->AsVector()->type()); + ret = parent_type->AsVector()->type(); } else if (parent_type->IsMatrix()) { auto* m = parent_type->AsMatrix(); - expr->set_result_type(ctx_.type_mgr().Get( - std::make_unique(m->type(), m->rows()))); + ret = ctx_.type_mgr().Get( + std::make_unique(m->type(), m->rows())); } else { set_error(expr->source(), "invalid parent type in array accessor"); return false; } + + // If we're extracting from a pointer, we return a pointer. + if (res->IsPointer()) { + ret = ctx_.type_mgr().Get(std::make_unique( + ret, res->AsPointer()->storage_class())); + } + expr->set_result_type(ret); + return true; } @@ -365,7 +377,15 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) { auto name = expr->name(); ast::Variable* var; if (variable_stack_.get(name, &var)) { - expr->set_result_type(var->type()); + // A constant is the type, but a variable is always a pointer so synthesize + // the pointer around the variable type. + if (var->is_const()) { + expr->set_result_type(var->type()); + } else { + expr->set_result_type( + ctx_.type_mgr().Get(std::make_unique( + var->type(), var->storage_class()))); + } return true; } @@ -384,43 +404,52 @@ bool TypeDeterminer::DetermineMemberAccessor( return false; } - auto* data_type = expr->structure()->result_type(); + auto* res = expr->structure()->result_type(); + auto* data_type = res->UnwrapPtrIfNeeded(); + ast::type::Type* ret = nullptr; if (data_type->IsStruct()) { auto* strct = data_type->AsStruct()->impl(); auto name = expr->member()->name(); for (const auto& member : strct->members()) { - if (member->name() != name) { - continue; + if (member->name() == name) { + ret = member->type(); + break; } - - expr->set_result_type(member->type()); - return true; } - set_error(expr->source(), "struct member not found"); - return false; - } - if (data_type->IsVector()) { + if (ret == nullptr) { + set_error(expr->source(), "struct member " + name + " not found"); + return false; + } + } else if (data_type->IsVector()) { auto* vec = data_type->AsVector(); auto size = expr->member()->name().size(); if (size == 1) { // A single element swizzle is just the type of the vector. - expr->set_result_type(vec->type()); + ret = vec->type(); } else { // The vector will have a number of components equal to the length of the // swizzle. This assumes the validator will check that the swizzle // is correct. - expr->set_result_type(ctx_.type_mgr().Get( - std::make_unique(vec->type(), size))); + ret = ctx_.type_mgr().Get( + std::make_unique(vec->type(), size)); } - return true; + } else { + set_error(expr->source(), + "invalid type " + data_type->type_name() + " in member accessor"); + return false; } - set_error(expr->source(), - "invalid type " + data_type->type_name() + " in member accessor"); - return false; + // If we're extracting from a pointer, we return a pointer. + if (res->IsPointer()) { + ret = ctx_.type_mgr().Get(std::make_unique( + ret, res->AsPointer()->storage_class())); + } + expr->set_result_type(ret); + + return true; } bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) { @@ -432,7 +461,7 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) { if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() || expr->IsShiftRight() || expr->IsShiftRightArith() || expr->IsAdd() || expr->IsSubtract() || expr->IsDivide() || expr->IsModulo()) { - expr->set_result_type(expr->lhs()->result_type()); + expr->set_result_type(expr->lhs()->result_type()->UnwrapPtrIfNeeded()); return true; } // Result type is a scalar or vector of boolean type @@ -441,7 +470,7 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) { expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) { auto* bool_type = ctx_.type_mgr().Get(std::make_unique()); - auto* param_type = expr->lhs()->result_type(); + auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded(); if (param_type->IsVector()) { expr->set_result_type( ctx_.type_mgr().Get(std::make_unique( @@ -452,8 +481,8 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) { return true; } if (expr->IsMultiply()) { - auto* lhs_type = expr->lhs()->result_type(); - auto* rhs_type = expr->rhs()->result_type(); + auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded(); + auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded(); // Note, the ordering here matters. The later checks depend on the prior // checks having been done. @@ -504,7 +533,7 @@ bool TypeDeterminer::DetermineUnaryDerivative( if (!DetermineResultType(expr->param())) { return false; } - expr->set_result_type(expr->param()->result_type()); + expr->set_result_type(expr->param()->result_type()->UnwrapPtrIfNeeded()); return true; } @@ -531,7 +560,7 @@ bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) { auto* bool_type = ctx_.type_mgr().Get(std::make_unique()); - auto* param_type = expr->params()[0]->result_type(); + auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded(); if (param_type->IsVector()) { expr->set_result_type( ctx_.type_mgr().Get(std::make_unique( @@ -552,8 +581,8 @@ bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) { "incorrect number of parameters for outer product"); return false; } - auto* param0_type = expr->params()[0]->result_type(); - auto* param1_type = expr->params()[1]->result_type(); + auto* param0_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded(); + auto* param1_type = expr->params()[1]->result_type()->UnwrapPtrIfNeeded(); if (!param0_type->IsVector() || !param1_type->IsVector()) { set_error(expr->source(), "invalid parameter type for outer product"); return false; @@ -574,7 +603,7 @@ bool TypeDeterminer::DetermineUnaryOp(ast::UnaryOpExpression* expr) { if (!DetermineResultType(expr->expr())) { return false; } - expr->set_result_type(expr->expr()->result_type()); + expr->set_result_type(expr->expr()->result_type()->UnwrapPtrIfNeeded()); return true; } diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 10a2a3bcd2..5e626b4c4d 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -46,6 +46,7 @@ #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/matrix_type.h" +#include "src/ast/type/pointer_type.h" #include "src/ast/type/struct_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" @@ -433,8 +434,33 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Array) { auto idx = std::make_unique( std::make_unique(&i32, 2)); - auto var = - std::make_unique("my_var", ast::StorageClass::kNone, &ary); + auto var = std::make_unique( + "my_var", ast::StorageClass::kFunction, &ary); + mod()->AddGlobalVariable(std::move(var)); + + // Register the global + EXPECT_TRUE(td()->Determine()); + + ast::ArrayAccessorExpression acc( + std::make_unique("my_var"), std::move(idx)); + EXPECT_TRUE(td()->DetermineResultType(&acc)); + ASSERT_NE(acc.result_type(), nullptr); + ASSERT_TRUE(acc.result_type()->IsPointer()); + + auto* ptr = acc.result_type()->AsPointer(); + EXPECT_TRUE(ptr->type()->IsF32()); +} + +TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Array_Constant) { + ast::type::I32Type i32; + ast::type::F32Type f32; + ast::type::ArrayType ary(&f32, 3); + + auto idx = std::make_unique( + std::make_unique(&i32, 2)); + auto var = std::make_unique( + "my_var", ast::StorageClass::kFunction, &ary); + var->set_is_const(true); mod()->AddGlobalVariable(std::move(var)); // Register the global @@ -465,8 +491,11 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix) { std::make_unique("my_var"), std::move(idx)); EXPECT_TRUE(td()->DetermineResultType(&acc)); ASSERT_NE(acc.result_type(), nullptr); - ASSERT_TRUE(acc.result_type()->IsVector()); - EXPECT_EQ(acc.result_type()->AsVector()->size(), 3u); + ASSERT_TRUE(acc.result_type()->IsPointer()); + + auto* ptr = acc.result_type()->AsPointer(); + ASSERT_TRUE(ptr->type()->IsVector()); + EXPECT_EQ(ptr->type()->AsVector()->size(), 3u); } TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix_BothDimensions) { @@ -493,7 +522,10 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix_BothDimensions) { EXPECT_TRUE(td()->DetermineResultType(&acc)); ASSERT_NE(acc.result_type(), nullptr); - EXPECT_TRUE(acc.result_type()->IsF32()); + ASSERT_TRUE(acc.result_type()->IsPointer()); + + auto* ptr = acc.result_type()->AsPointer(); + EXPECT_TRUE(ptr->type()->IsF32()); } TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Vector) { @@ -514,7 +546,10 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Vector) { std::make_unique("my_var"), std::move(idx)); EXPECT_TRUE(td()->DetermineResultType(&acc)); ASSERT_NE(acc.result_type(), nullptr); - EXPECT_TRUE(acc.result_type()->IsF32()); + ASSERT_TRUE(acc.result_type()->IsPointer()); + + auto* ptr = acc.result_type()->AsPointer(); + EXPECT_TRUE(ptr->type()->IsF32()); } TEST_F(TypeDeterminerTest, Expr_As) { @@ -643,12 +678,55 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalVariable) { // Register the global EXPECT_TRUE(td()->Determine()); + ast::IdentifierExpression ident("my_var"); + EXPECT_TRUE(td()->DetermineResultType(&ident)); + ASSERT_NE(ident.result_type(), nullptr); + EXPECT_TRUE(ident.result_type()->IsPointer()); + EXPECT_TRUE(ident.result_type()->AsPointer()->type()->IsF32()); +} + +TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalConstant) { + ast::type::F32Type f32; + auto var = + std::make_unique("my_var", ast::StorageClass::kNone, &f32); + var->set_is_const(true); + mod()->AddGlobalVariable(std::move(var)); + + // Register the global + EXPECT_TRUE(td()->Determine()); + ast::IdentifierExpression ident("my_var"); EXPECT_TRUE(td()->DetermineResultType(&ident)); ASSERT_NE(ident.result_type(), nullptr); EXPECT_TRUE(ident.result_type()->IsF32()); } +TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable_Const) { + ast::type::F32Type f32; + + auto my_var = std::make_unique("my_var"); + auto* my_var_ptr = my_var.get(); + + auto var = + std::make_unique("my_var", ast::StorageClass::kNone, &f32); + var->set_is_const(true); + + ast::StatementList body; + body.push_back(std::make_unique(std::move(var))); + + body.push_back(std::make_unique( + std::move(my_var), + std::make_unique("my_var"))); + + ast::Function f("my_func", {}, &f32); + f.set_body(std::move(body)); + + EXPECT_TRUE(td()->DetermineFunction(&f)); + + ASSERT_NE(my_var_ptr->result_type(), nullptr); + EXPECT_TRUE(my_var_ptr->result_type()->IsF32()); +} + TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) { ast::type::F32Type f32; @@ -670,7 +748,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) { EXPECT_TRUE(td()->DetermineFunction(&f)); ASSERT_NE(my_var_ptr->result_type(), nullptr); - EXPECT_TRUE(my_var_ptr->result_type()->IsF32()); + EXPECT_TRUE(my_var_ptr->result_type()->IsPointer()); + EXPECT_TRUE(my_var_ptr->result_type()->AsPointer()->type()->IsF32()); } TEST_F(TypeDeterminerTest, Expr_Identifier_Function) { @@ -720,7 +799,10 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) { ast::MemberAccessorExpression mem(std::move(ident), std::move(mem_ident)); EXPECT_TRUE(td()->DetermineResultType(&mem)); ASSERT_NE(mem.result_type(), nullptr); - EXPECT_TRUE(mem.result_type()->IsF32()); + ASSERT_TRUE(mem.result_type()->IsPointer()); + + auto* ptr = mem.result_type()->AsPointer(); + EXPECT_TRUE(ptr->type()->IsF32()); } TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) { @@ -740,9 +822,12 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) { ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle)); EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error(); ASSERT_NE(mem.result_type(), nullptr); - ASSERT_TRUE(mem.result_type()->IsVector()); - EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32()); - EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u); + ASSERT_TRUE(mem.result_type()->IsPointer()); + + auto* ptr = mem.result_type()->AsPointer(); + ASSERT_TRUE(ptr->type()->IsVector()); + EXPECT_TRUE(ptr->type()->AsVector()->type()->IsF32()); + EXPECT_EQ(ptr->type()->AsVector()->size(), 2u); } TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) { @@ -762,10 +847,13 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) { ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle)); EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error(); ASSERT_NE(mem.result_type(), nullptr); - ASSERT_TRUE(mem.result_type()->IsF32()); + ASSERT_TRUE(mem.result_type()->IsPointer()); + + auto* ptr = mem.result_type()->AsPointer(); + ASSERT_TRUE(ptr->type()->IsF32()); } -TEST_F(TypeDeterminerTest, Expr_MultiLevel) { +TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) { // struct b { // vec4 foo // } @@ -803,6 +891,7 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) { auto strctB = std::make_unique(ast::StructDecoration::kNone, std::move(b_members)); ast::type::StructType stB(std::move(strctB)); + stB.set_name("B"); ast::type::VectorType vecB(&stB, 3); @@ -814,6 +903,7 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) { std::move(a_members)); ast::type::StructType stA(std::move(strctA)); + stA.set_name("A"); auto var = std::make_unique("c", ast::StorageClass::kNone, &stA); @@ -838,10 +928,14 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) { std::move(foo_ident)), std::move(swizzle)); EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error(); + ASSERT_NE(mem.result_type(), nullptr); - ASSERT_TRUE(mem.result_type()->IsVector()); - EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32()); - EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u); + ASSERT_TRUE(mem.result_type()->IsPointer()); + + auto* ptr = mem.result_type()->AsPointer(); + ASSERT_TRUE(ptr->type()->IsVector()); + EXPECT_TRUE(ptr->type()->AsVector()->type()->IsF32()); + EXPECT_EQ(ptr->type()->AsVector()->size(), 2u); } using Expr_Binary_BitwiseTest = TypeDeterminerTestWithParam; diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index be2c64ee3e..17e5f58be4 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -226,34 +226,6 @@ uint32_t Builder::GenerateExpression(ast::Expression* expr) { return 0; } -uint32_t Builder::GenerateExpressionAndLoad(ast::Expression* expr) { - auto id = GenerateExpression(expr); - if (id == 0) { - return false; - } - - // Only need to load identifiers - if (!expr->IsIdentifier()) { - return id; - } - if (spirv_id_to_variable_.find(id) == spirv_id_to_variable_.end()) { - error_ = "missing generated ID for variable"; - return 0; - } - auto* var = spirv_id_to_variable_[id]; - if (var->is_const()) { - return id; - } - - auto type_id = GenerateTypeIfNeeded(expr->result_type()); - auto result = result_op(); - auto result_id = result.to_i(); - push_function_inst(spv::Op::OpLoad, - {Operand::Int(type_id), result, Operand::Int(id)}); - - return result_id; -} - bool Builder::GenerateFunction(ast::Function* func) { uint32_t func_type_id = GenerateFunctionTypeIfNeeded(func); if (func_type_id == 0) { @@ -468,7 +440,8 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { auto* mem_accessor = source->AsMemberAccessor(); source = mem_accessor->structure(); - auto* data_type = mem_accessor->structure()->result_type(); + auto* data_type = + mem_accessor->structure()->result_type()->UnwrapPtrIfNeeded(); if (data_type->IsStruct()) { auto* strct = data_type->AsStruct()->impl(); auto name = mem_accessor->member()->name(); @@ -476,11 +449,9 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { uint32_t i = 0; for (; i < strct->members().size(); ++i) { const auto& member = strct->members()[i]; - if (member->name() != name) { - continue; + if (member->name() == name) { + break; } - - break; } ast::type::U32Type u32; @@ -507,9 +478,7 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { return 0; } - // The access chain results in a pointer, so wrap the return type. - ast::type::PointerType ptr(expr->result_type(), ast::StorageClass::kFunction); - auto type_id = GenerateTypeIfNeeded(&ptr); + auto type_id = GenerateTypeIfNeeded(expr->result_type()); if (type_id == 0) { return 0; } @@ -535,12 +504,23 @@ uint32_t Builder::GenerateIdentifierExpression( if (val == 0) { error_ = "unable to lookup: " + expr->name() + " in " + expr->path(); } - } else if (!scope_stack_.get(expr->name(), &val)) { - error_ = "unable to find name for identifier: " + expr->name(); - return 0; + return val; + } + if (scope_stack_.get(expr->name(), &val)) { + return val; } - return val; + error_ = "unable to find name for identifier: " + expr->name(); + return 0; +} + +uint32_t Builder::GenerateLoad(ast::type::Type* type, uint32_t id) { + auto type_id = GenerateTypeIfNeeded(type->UnwrapPtrIfNeeded()); + auto result = result_op(); + auto result_id = result.to_i(); + push_function_inst(spv::Op::OpLoad, + {Operand::Int(type_id), result, Operand::Int(id)}); + return result_id; } uint32_t Builder::GenerateUnaryOpExpression(ast::UnaryOpExpression* expr) { @@ -686,14 +666,21 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) { } uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { - auto lhs_id = GenerateExpressionAndLoad(expr->lhs()); + auto lhs_id = GenerateExpression(expr->lhs()); if (lhs_id == 0) { return 0; } - auto rhs_id = GenerateExpressionAndLoad(expr->rhs()); + if (expr->lhs()->result_type()->IsPointer()) { + lhs_id = GenerateLoad(expr->lhs()->result_type(), lhs_id); + } + + auto rhs_id = GenerateExpression(expr->rhs()); if (rhs_id == 0) { return 0; } + if (expr->rhs()->result_type()->IsPointer()) { + rhs_id = GenerateLoad(expr->rhs()->result_type(), rhs_id); + } auto result = result_op(); auto result_id = result.to_i(); @@ -705,8 +692,8 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { // Handle int and float and the vectors of those types. Other types // should have been rejected by validation. - auto* lhs_type = expr->lhs()->result_type(); - auto* rhs_type = expr->rhs()->result_type(); + auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded(); + auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded(); bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector(); bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector(); @@ -819,6 +806,7 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { } else if (expr->IsXor()) { op = spv::Op::OpBitwiseXor; } else { + error_ = "unknown binary expression"; return 0; } diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 6d0fd69b01..c30f4089d9 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -157,10 +157,6 @@ class Builder { /// @param expr the expression to generate /// @returns the resulting ID of the expression or 0 on error uint32_t GenerateExpression(ast::Expression* expr); - /// Generates an expression and emits a load if necessary - /// @param expr the expression - /// @returns the SPIR-V result id - uint32_t GenerateExpressionAndLoad(ast::Expression* expr); /// Generates the instructions for a function /// @param func the function to generate /// @returns true if the instructions were generated @@ -240,6 +236,11 @@ class Builder { /// @param list the statement list to generate /// @returns true on successful generation bool GenerateStatementList(const ast::StatementList& list); + /// Geneates an OpLoad + /// @param type the type to load + /// @param id the variable id to load + /// @returns the ID of the loaded value + uint32_t GenerateLoad(ast::type::Type* type, uint32_t id); /// Geneates an OpStore /// @param to the ID to store too /// @param from the ID to store from diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc index 28c1fde2d4..396028b882 100644 --- a/src/writer/spirv/builder_accessor_expression_test.cc +++ b/src/writer/spirv/builder_accessor_expression_test.cc @@ -132,6 +132,13 @@ TEST_F(BuilderTest, ArrayAccessor_MultiLevel) { TEST_F(BuilderTest, MemberAccessor) { ast::type::F32Type f32; + // my_struct { + // a : f32 + // b : f32 + // } + // var ident : my_struct + // ident.b + ast::StructMemberDecorationList decos; ast::StructMemberList members; members.push_back( @@ -180,6 +187,15 @@ TEST_F(BuilderTest, MemberAccessor) { TEST_F(BuilderTest, MemberAccessor_Nested) { ast::type::F32Type f32; + // inner_struct { + // a : f32 + // } + // my_struct { + // inner : inner_struct + // } + // + // var ident : my_struct + // ident.inner.a ast::StructMemberDecorationList decos; ast::StructMemberList inner_members; inner_members.push_back( diff --git a/src/writer/spirv/builder_assign_test.cc b/src/writer/spirv/builder_assign_test.cc index c567386ba5..99c3af730b 100644 --- a/src/writer/spirv/builder_assign_test.cc +++ b/src/writer/spirv/builder_assign_test.cc @@ -21,6 +21,8 @@ #include "src/ast/scalar_constructor_expression.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/vector_type.h" +#include "src/context.h" +#include "src/type_determiner.h" #include "src/writer/spirv/builder.h" #include "src/writer/spirv/spv_dump.h" @@ -37,18 +39,23 @@ TEST_F(BuilderTest, Assign_Var) { ast::Variable v("var", ast::StorageClass::kOutput, &f32); - ast::Module mod; - Builder b(&mod); - EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); - ASSERT_FALSE(b.has_error()) << b.error(); - auto ident = std::make_unique("var"); auto val = std::make_unique( std::make_unique(&f32, 1.0f)); ast::AssignmentStatement assign(std::move(ident), std::move(val)); + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(&v); + + ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error(); + + Builder b(&mod); b.push_function(Function{}); + EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); + ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error(); EXPECT_FALSE(b.has_error()); diff --git a/src/writer/spirv/builder_ident_expression_test.cc b/src/writer/spirv/builder_ident_expression_test.cc index c69289446f..e69bb73cab 100644 --- a/src/writer/spirv/builder_ident_expression_test.cc +++ b/src/writer/spirv/builder_ident_expression_test.cc @@ -56,7 +56,11 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalConst) { v.set_constructor(std::move(init)); v.set_is_const(true); + Context ctx; ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(&v); + Builder b(&mod); EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); ASSERT_FALSE(b.has_error()) << b.error(); @@ -69,6 +73,8 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalConst) { )"); ast::IdentifierExpression expr("var"); + ASSERT_TRUE(td.DetermineResultType(&expr)); + EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 5u); } @@ -76,8 +82,13 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalVar) { ast::type::F32Type f32; ast::Variable v("var", ast::StorageClass::kOutput, &f32); + Context ctx; ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(&v); + Builder b(&mod); + b.push_function(Function{}); EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "var" )"); @@ -87,6 +98,7 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalVar) { )"); ast::IdentifierExpression expr("var"); + ASSERT_TRUE(td.DetermineResultType(&expr)); EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 1u); } @@ -109,7 +121,11 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionConst) { v.set_constructor(std::move(init)); v.set_is_const(true); + Context ctx; ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(&v); + Builder b(&mod); EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error(); ASSERT_FALSE(b.has_error()) << b.error(); @@ -122,6 +138,7 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionConst) { )"); ast::IdentifierExpression expr("var"); + ASSERT_TRUE(td.DetermineResultType(&expr)); EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 5u); } @@ -129,7 +146,11 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionVar) { ast::type::F32Type f32; ast::Variable v("var", ast::StorageClass::kNone, &f32); + Context ctx; ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(&v); + Builder b(&mod); b.push_function(Function{}); EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error(); @@ -144,6 +165,7 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionVar) { )"); ast::IdentifierExpression expr("var"); + ASSERT_TRUE(td.DetermineResultType(&expr)); EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 1u); } @@ -170,7 +192,7 @@ TEST_F(BuilderTest, IdentifierExpression_Load) { b.push_function(Function{}); ASSERT_TRUE(b.GenerateGlobalVariable(&var)) << b.error(); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 6u) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 6u) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1 %2 = OpTypePointer Private %3 %1 = OpVariable %2 Private @@ -217,31 +239,6 @@ TEST_F(BuilderTest, IdentifierExpression_NoLoadConst) { )"); } -TEST_F(BuilderTest, IdentifierExpression_ImportMethod) { - auto imp = std::make_unique("GLSL.std.450", "std"); - imp->AddMethodId("round", 42u); - - ast::Module mod; - mod.AddImport(std::move(imp)); - Builder b(&mod); - - ast::IdentifierExpression expr(std::vector({"std", "round"})); - EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 42u) << b.error(); -} - -TEST_F(BuilderTest, IdentifierExpression_ImportMethod_NotFound) { - auto imp = std::make_unique("GLSL.std.450", "std"); - - ast::Module mod; - mod.AddImport(std::move(imp)); - Builder b(&mod); - - ast::IdentifierExpression expr(std::vector({"std", "ceil"})); - EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 0u); - ASSERT_TRUE(b.has_error()); - EXPECT_EQ(b.error(), "unable to lookup: ceil in std"); -} - } // namespace } // namespace spirv } // namespace writer