Update type determiner to create pointers.

This CL updates the type determiner such that variable result types
end up wrapped inside pointers, constants do not. The result of Member
and Array accessors are also pointers if the source was a pointer.

Change-Id: I6694367daf6ba1db929e54a975dfea8404fca40c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/20265
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-04-23 22:26:52 +00:00 committed by David Neto
parent 1a4d90667b
commit 8eddb78433
11 changed files with 275 additions and 138 deletions

View File

@ -93,7 +93,7 @@ class Expression : public Node {
/// @returns the expression as a unary op expression /// @returns the expression as a unary op expression
const UnaryOpExpression* AsUnaryOp() const; const UnaryOpExpression* AsUnaryOp() const;
/// @returns the expression as an array accessor /// @returns the expression as an array accessor
ArrayAccessorExpression* AsArrayAccessor(); ArrayAccessorExpression* AsArrayAccessor();
/// @returns the expression as an as /// @returns the expression as an as
AsExpression* AsAs(); AsExpression* AsAs();

View File

@ -36,6 +36,13 @@ Type::Type() = default;
Type::~Type() = default; Type::~Type() = default;
Type* Type::UnwrapPtrIfNeeded() {
if (IsPointer()) {
return AsPointer()->type();
}
return this;
}
bool Type::IsAlias() const { bool Type::IsAlias() const {
return false; return false;
} }

View File

@ -66,6 +66,9 @@ class Type {
/// @returns the name for this type. The |type_name| is unique over all types. /// @returns the name for this type. The |type_name| is unique over all types.
virtual std::string type_name() const = 0; virtual std::string type_name() const = 0;
/// @returns the pointee type if this is a pointer, |this| otherwise
Type* UnwrapPtrIfNeeded();
/// @returns true if this type is a float scalar /// @returns true if this type is a float scalar
bool is_float_scalar(); bool is_float_scalar();
/// @returns true if this type is a float matrix /// @returns true if this type is a float matrix

View File

@ -1159,8 +1159,7 @@ TEST_F(SpvParserTest,
EXPECT_THAT(fe.block_order(), ElementsAre(10, 30, 20, 80, 99)); EXPECT_THAT(fe.block_order(), ElementsAre(10, 30, 20, 80, 99));
} }
TEST_F(SpvParserTest, TEST_F(SpvParserTest, ComputeBlockOrder_Switch_DefaultSameAsACase) {
ComputeBlockOrder_Switch_DefaultSameAsACase) {
auto* p = parser(test::Assemble(CommonTypes() + R"( auto* p = parser(test::Assemble(CommonTypes() + R"(
%100 = OpFunction %void None %voidfn %100 = OpFunction %void None %voidfn
@ -1373,8 +1372,7 @@ TEST_F(SpvParserTest,
<< assembly; << assembly;
} }
TEST_F(SpvParserTest, TEST_F(SpvParserTest, ComputeBlockOrder_Nest_If_Contains_If) {
ComputeBlockOrder_Nest_If_Contains_If) {
auto assembly = CommonTypes() + R"( auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn %100 = OpFunction %void None %voidfn
@ -1424,8 +1422,7 @@ TEST_F(SpvParserTest,
<< assembly; << assembly;
} }
TEST_F(SpvParserTest, TEST_F(SpvParserTest, ComputeBlockOrder_Nest_If_In_SwitchCase) {
ComputeBlockOrder_Nest_If_In_SwitchCase) {
auto assembly = CommonTypes() + R"( auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn %100 = OpFunction %void None %voidfn
@ -1475,8 +1472,7 @@ TEST_F(SpvParserTest,
<< assembly; << assembly;
} }
TEST_F(SpvParserTest, TEST_F(SpvParserTest, ComputeBlockOrder_Nest_IfFallthrough_In_SwitchCase) {
ComputeBlockOrder_Nest_IfFallthrough_In_SwitchCase) {
auto assembly = CommonTypes() + R"( auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn %100 = OpFunction %void None %voidfn
@ -1526,8 +1522,7 @@ TEST_F(SpvParserTest,
<< assembly; << assembly;
} }
TEST_F(SpvParserTest, TEST_F(SpvParserTest, ComputeBlockOrder_Nest_IfBreak_In_SwitchCase) {
ComputeBlockOrder_Nest_IfBreak_In_SwitchCase) {
auto assembly = CommonTypes() + R"( auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn %100 = OpFunction %void None %voidfn

View File

@ -38,6 +38,7 @@
#include "src/ast/type/bool_type.h" #include "src/ast/type/bool_type.h"
#include "src/ast/type/f32_type.h" #include "src/ast/type/f32_type.h"
#include "src/ast/type/matrix_type.h" #include "src/ast/type/matrix_type.h"
#include "src/ast/type/pointer_type.h"
#include "src/ast/type/struct_type.h" #include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h" #include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h" #include "src/ast/type_constructor_expression.h"
@ -274,19 +275,30 @@ bool TypeDeterminer::DetermineArrayAccessor(
if (!DetermineResultType(expr->array())) { if (!DetermineResultType(expr->array())) {
return false; return false;
} }
auto* parent_type = expr->array()->result_type();
auto* res = expr->array()->result_type();
auto* parent_type = res->UnwrapPtrIfNeeded();
ast::type::Type* ret = nullptr;
if (parent_type->IsArray()) { if (parent_type->IsArray()) {
expr->set_result_type(parent_type->AsArray()->type()); ret = parent_type->AsArray()->type();
} else if (parent_type->IsVector()) { } else if (parent_type->IsVector()) {
expr->set_result_type(parent_type->AsVector()->type()); ret = parent_type->AsVector()->type();
} else if (parent_type->IsMatrix()) { } else if (parent_type->IsMatrix()) {
auto* m = parent_type->AsMatrix(); auto* m = parent_type->AsMatrix();
expr->set_result_type(ctx_.type_mgr().Get( ret = ctx_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(m->type(), m->rows()))); std::make_unique<ast::type::VectorType>(m->type(), m->rows()));
} else { } else {
set_error(expr->source(), "invalid parent type in array accessor"); set_error(expr->source(), "invalid parent type in array accessor");
return false; return false;
} }
// If we're extracting from a pointer, we return a pointer.
if (res->IsPointer()) {
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, res->AsPointer()->storage_class()));
}
expr->set_result_type(ret);
return true; return true;
} }
@ -365,7 +377,15 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
auto name = expr->name(); auto name = expr->name();
ast::Variable* var; ast::Variable* var;
if (variable_stack_.get(name, &var)) { if (variable_stack_.get(name, &var)) {
expr->set_result_type(var->type()); // A constant is the type, but a variable is always a pointer so synthesize
// the pointer around the variable type.
if (var->is_const()) {
expr->set_result_type(var->type());
} else {
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
var->type(), var->storage_class())));
}
return true; return true;
} }
@ -384,43 +404,52 @@ bool TypeDeterminer::DetermineMemberAccessor(
return false; return false;
} }
auto* data_type = expr->structure()->result_type(); auto* res = expr->structure()->result_type();
auto* data_type = res->UnwrapPtrIfNeeded();
ast::type::Type* ret = nullptr;
if (data_type->IsStruct()) { if (data_type->IsStruct()) {
auto* strct = data_type->AsStruct()->impl(); auto* strct = data_type->AsStruct()->impl();
auto name = expr->member()->name(); auto name = expr->member()->name();
for (const auto& member : strct->members()) { for (const auto& member : strct->members()) {
if (member->name() != name) { if (member->name() == name) {
continue; ret = member->type();
break;
} }
expr->set_result_type(member->type());
return true;
} }
set_error(expr->source(), "struct member not found"); if (ret == nullptr) {
return false; set_error(expr->source(), "struct member " + name + " not found");
} return false;
if (data_type->IsVector()) { }
} else if (data_type->IsVector()) {
auto* vec = data_type->AsVector(); auto* vec = data_type->AsVector();
auto size = expr->member()->name().size(); auto size = expr->member()->name().size();
if (size == 1) { if (size == 1) {
// A single element swizzle is just the type of the vector. // A single element swizzle is just the type of the vector.
expr->set_result_type(vec->type()); ret = vec->type();
} else { } else {
// The vector will have a number of components equal to the length of the // The vector will have a number of components equal to the length of the
// swizzle. This assumes the validator will check that the swizzle // swizzle. This assumes the validator will check that the swizzle
// is correct. // is correct.
expr->set_result_type(ctx_.type_mgr().Get( ret = ctx_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(vec->type(), size))); std::make_unique<ast::type::VectorType>(vec->type(), size));
} }
return true; } else {
set_error(expr->source(),
"invalid type " + data_type->type_name() + " in member accessor");
return false;
} }
set_error(expr->source(), // If we're extracting from a pointer, we return a pointer.
"invalid type " + data_type->type_name() + " in member accessor"); if (res->IsPointer()) {
return false; ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, res->AsPointer()->storage_class()));
}
expr->set_result_type(ret);
return true;
} }
bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) { bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
@ -432,7 +461,7 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() || if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
expr->IsShiftRight() || expr->IsShiftRightArith() || expr->IsAdd() || expr->IsShiftRight() || expr->IsShiftRightArith() || expr->IsAdd() ||
expr->IsSubtract() || expr->IsDivide() || expr->IsModulo()) { expr->IsSubtract() || expr->IsDivide() || expr->IsModulo()) {
expr->set_result_type(expr->lhs()->result_type()); expr->set_result_type(expr->lhs()->result_type()->UnwrapPtrIfNeeded());
return true; return true;
} }
// Result type is a scalar or vector of boolean type // Result type is a scalar or vector of boolean type
@ -441,7 +470,7 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) { expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
auto* bool_type = auto* bool_type =
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()); ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
auto* param_type = expr->lhs()->result_type(); auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
if (param_type->IsVector()) { if (param_type->IsVector()) {
expr->set_result_type( expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>( ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
@ -452,8 +481,8 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
return true; return true;
} }
if (expr->IsMultiply()) { if (expr->IsMultiply()) {
auto* lhs_type = expr->lhs()->result_type(); auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
auto* rhs_type = expr->rhs()->result_type(); auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded();
// Note, the ordering here matters. The later checks depend on the prior // Note, the ordering here matters. The later checks depend on the prior
// checks having been done. // checks having been done.
@ -504,7 +533,7 @@ bool TypeDeterminer::DetermineUnaryDerivative(
if (!DetermineResultType(expr->param())) { if (!DetermineResultType(expr->param())) {
return false; return false;
} }
expr->set_result_type(expr->param()->result_type()); expr->set_result_type(expr->param()->result_type()->UnwrapPtrIfNeeded());
return true; return true;
} }
@ -531,7 +560,7 @@ bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) {
auto* bool_type = auto* bool_type =
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()); ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
auto* param_type = expr->params()[0]->result_type(); auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
if (param_type->IsVector()) { if (param_type->IsVector()) {
expr->set_result_type( expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>( ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
@ -552,8 +581,8 @@ bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) {
"incorrect number of parameters for outer product"); "incorrect number of parameters for outer product");
return false; return false;
} }
auto* param0_type = expr->params()[0]->result_type(); auto* param0_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
auto* param1_type = expr->params()[1]->result_type(); auto* param1_type = expr->params()[1]->result_type()->UnwrapPtrIfNeeded();
if (!param0_type->IsVector() || !param1_type->IsVector()) { if (!param0_type->IsVector() || !param1_type->IsVector()) {
set_error(expr->source(), "invalid parameter type for outer product"); set_error(expr->source(), "invalid parameter type for outer product");
return false; return false;
@ -574,7 +603,7 @@ bool TypeDeterminer::DetermineUnaryOp(ast::UnaryOpExpression* expr) {
if (!DetermineResultType(expr->expr())) { if (!DetermineResultType(expr->expr())) {
return false; return false;
} }
expr->set_result_type(expr->expr()->result_type()); expr->set_result_type(expr->expr()->result_type()->UnwrapPtrIfNeeded());
return true; return true;
} }

View File

@ -46,6 +46,7 @@
#include "src/ast/type/f32_type.h" #include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h" #include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h" #include "src/ast/type/matrix_type.h"
#include "src/ast/type/pointer_type.h"
#include "src/ast/type/struct_type.h" #include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h" #include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h" #include "src/ast/type_constructor_expression.h"
@ -433,8 +434,33 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Array) {
auto idx = std::make_unique<ast::ScalarConstructorExpression>( auto idx = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 2)); std::make_unique<ast::IntLiteral>(&i32, 2));
auto var = auto var = std::make_unique<ast::Variable>(
std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &ary); "my_var", ast::StorageClass::kFunction, &ary);
mod()->AddGlobalVariable(std::move(var));
// Register the global
EXPECT_TRUE(td()->Determine());
ast::ArrayAccessorExpression acc(
std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx));
EXPECT_TRUE(td()->DetermineResultType(&acc));
ASSERT_NE(acc.result_type(), nullptr);
ASSERT_TRUE(acc.result_type()->IsPointer());
auto* ptr = acc.result_type()->AsPointer();
EXPECT_TRUE(ptr->type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Array_Constant) {
ast::type::I32Type i32;
ast::type::F32Type f32;
ast::type::ArrayType ary(&f32, 3);
auto idx = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 2));
auto var = std::make_unique<ast::Variable>(
"my_var", ast::StorageClass::kFunction, &ary);
var->set_is_const(true);
mod()->AddGlobalVariable(std::move(var)); mod()->AddGlobalVariable(std::move(var));
// Register the global // Register the global
@ -465,8 +491,11 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix) {
std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx)); std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx));
EXPECT_TRUE(td()->DetermineResultType(&acc)); EXPECT_TRUE(td()->DetermineResultType(&acc));
ASSERT_NE(acc.result_type(), nullptr); ASSERT_NE(acc.result_type(), nullptr);
ASSERT_TRUE(acc.result_type()->IsVector()); ASSERT_TRUE(acc.result_type()->IsPointer());
EXPECT_EQ(acc.result_type()->AsVector()->size(), 3u);
auto* ptr = acc.result_type()->AsPointer();
ASSERT_TRUE(ptr->type()->IsVector());
EXPECT_EQ(ptr->type()->AsVector()->size(), 3u);
} }
TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix_BothDimensions) { TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix_BothDimensions) {
@ -493,7 +522,10 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix_BothDimensions) {
EXPECT_TRUE(td()->DetermineResultType(&acc)); EXPECT_TRUE(td()->DetermineResultType(&acc));
ASSERT_NE(acc.result_type(), nullptr); ASSERT_NE(acc.result_type(), nullptr);
EXPECT_TRUE(acc.result_type()->IsF32()); ASSERT_TRUE(acc.result_type()->IsPointer());
auto* ptr = acc.result_type()->AsPointer();
EXPECT_TRUE(ptr->type()->IsF32());
} }
TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Vector) { TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Vector) {
@ -514,7 +546,10 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Vector) {
std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx)); std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx));
EXPECT_TRUE(td()->DetermineResultType(&acc)); EXPECT_TRUE(td()->DetermineResultType(&acc));
ASSERT_NE(acc.result_type(), nullptr); ASSERT_NE(acc.result_type(), nullptr);
EXPECT_TRUE(acc.result_type()->IsF32()); ASSERT_TRUE(acc.result_type()->IsPointer());
auto* ptr = acc.result_type()->AsPointer();
EXPECT_TRUE(ptr->type()->IsF32());
} }
TEST_F(TypeDeterminerTest, Expr_As) { TEST_F(TypeDeterminerTest, Expr_As) {
@ -643,12 +678,55 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalVariable) {
// Register the global // Register the global
EXPECT_TRUE(td()->Determine()); EXPECT_TRUE(td()->Determine());
ast::IdentifierExpression ident("my_var");
EXPECT_TRUE(td()->DetermineResultType(&ident));
ASSERT_NE(ident.result_type(), nullptr);
EXPECT_TRUE(ident.result_type()->IsPointer());
EXPECT_TRUE(ident.result_type()->AsPointer()->type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalConstant) {
ast::type::F32Type f32;
auto var =
std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32);
var->set_is_const(true);
mod()->AddGlobalVariable(std::move(var));
// Register the global
EXPECT_TRUE(td()->Determine());
ast::IdentifierExpression ident("my_var"); ast::IdentifierExpression ident("my_var");
EXPECT_TRUE(td()->DetermineResultType(&ident)); EXPECT_TRUE(td()->DetermineResultType(&ident));
ASSERT_NE(ident.result_type(), nullptr); ASSERT_NE(ident.result_type(), nullptr);
EXPECT_TRUE(ident.result_type()->IsF32()); EXPECT_TRUE(ident.result_type()->IsF32());
} }
TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable_Const) {
ast::type::F32Type f32;
auto my_var = std::make_unique<ast::IdentifierExpression>("my_var");
auto* my_var_ptr = my_var.get();
auto var =
std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32);
var->set_is_const(true);
ast::StatementList body;
body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::move(my_var),
std::make_unique<ast::IdentifierExpression>("my_var")));
ast::Function f("my_func", {}, &f32);
f.set_body(std::move(body));
EXPECT_TRUE(td()->DetermineFunction(&f));
ASSERT_NE(my_var_ptr->result_type(), nullptr);
EXPECT_TRUE(my_var_ptr->result_type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) { TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) {
ast::type::F32Type f32; ast::type::F32Type f32;
@ -670,7 +748,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) {
EXPECT_TRUE(td()->DetermineFunction(&f)); EXPECT_TRUE(td()->DetermineFunction(&f));
ASSERT_NE(my_var_ptr->result_type(), nullptr); ASSERT_NE(my_var_ptr->result_type(), nullptr);
EXPECT_TRUE(my_var_ptr->result_type()->IsF32()); EXPECT_TRUE(my_var_ptr->result_type()->IsPointer());
EXPECT_TRUE(my_var_ptr->result_type()->AsPointer()->type()->IsF32());
} }
TEST_F(TypeDeterminerTest, Expr_Identifier_Function) { TEST_F(TypeDeterminerTest, Expr_Identifier_Function) {
@ -720,7 +799,10 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) {
ast::MemberAccessorExpression mem(std::move(ident), std::move(mem_ident)); ast::MemberAccessorExpression mem(std::move(ident), std::move(mem_ident));
EXPECT_TRUE(td()->DetermineResultType(&mem)); EXPECT_TRUE(td()->DetermineResultType(&mem));
ASSERT_NE(mem.result_type(), nullptr); ASSERT_NE(mem.result_type(), nullptr);
EXPECT_TRUE(mem.result_type()->IsF32()); ASSERT_TRUE(mem.result_type()->IsPointer());
auto* ptr = mem.result_type()->AsPointer();
EXPECT_TRUE(ptr->type()->IsF32());
} }
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) { TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
@ -740,9 +822,12 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle)); ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error(); EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
ASSERT_NE(mem.result_type(), nullptr); ASSERT_NE(mem.result_type(), nullptr);
ASSERT_TRUE(mem.result_type()->IsVector()); ASSERT_TRUE(mem.result_type()->IsPointer());
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u); auto* ptr = mem.result_type()->AsPointer();
ASSERT_TRUE(ptr->type()->IsVector());
EXPECT_TRUE(ptr->type()->AsVector()->type()->IsF32());
EXPECT_EQ(ptr->type()->AsVector()->size(), 2u);
} }
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) { TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
@ -762,10 +847,13 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle)); ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error(); EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
ASSERT_NE(mem.result_type(), nullptr); ASSERT_NE(mem.result_type(), nullptr);
ASSERT_TRUE(mem.result_type()->IsF32()); ASSERT_TRUE(mem.result_type()->IsPointer());
auto* ptr = mem.result_type()->AsPointer();
ASSERT_TRUE(ptr->type()->IsF32());
} }
TEST_F(TypeDeterminerTest, Expr_MultiLevel) { TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) {
// struct b { // struct b {
// vec4<f32> foo // vec4<f32> foo
// } // }
@ -803,6 +891,7 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
auto strctB = std::make_unique<ast::Struct>(ast::StructDecoration::kNone, auto strctB = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
std::move(b_members)); std::move(b_members));
ast::type::StructType stB(std::move(strctB)); ast::type::StructType stB(std::move(strctB));
stB.set_name("B");
ast::type::VectorType vecB(&stB, 3); ast::type::VectorType vecB(&stB, 3);
@ -814,6 +903,7 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
std::move(a_members)); std::move(a_members));
ast::type::StructType stA(std::move(strctA)); ast::type::StructType stA(std::move(strctA));
stA.set_name("A");
auto var = auto var =
std::make_unique<ast::Variable>("c", ast::StorageClass::kNone, &stA); std::make_unique<ast::Variable>("c", ast::StorageClass::kNone, &stA);
@ -838,10 +928,14 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
std::move(foo_ident)), std::move(foo_ident)),
std::move(swizzle)); std::move(swizzle));
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error(); EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
ASSERT_NE(mem.result_type(), nullptr); ASSERT_NE(mem.result_type(), nullptr);
ASSERT_TRUE(mem.result_type()->IsVector()); ASSERT_TRUE(mem.result_type()->IsPointer());
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u); auto* ptr = mem.result_type()->AsPointer();
ASSERT_TRUE(ptr->type()->IsVector());
EXPECT_TRUE(ptr->type()->AsVector()->type()->IsF32());
EXPECT_EQ(ptr->type()->AsVector()->size(), 2u);
} }
using Expr_Binary_BitwiseTest = TypeDeterminerTestWithParam<ast::BinaryOp>; using Expr_Binary_BitwiseTest = TypeDeterminerTestWithParam<ast::BinaryOp>;

View File

@ -226,34 +226,6 @@ uint32_t Builder::GenerateExpression(ast::Expression* expr) {
return 0; return 0;
} }
uint32_t Builder::GenerateExpressionAndLoad(ast::Expression* expr) {
auto id = GenerateExpression(expr);
if (id == 0) {
return false;
}
// Only need to load identifiers
if (!expr->IsIdentifier()) {
return id;
}
if (spirv_id_to_variable_.find(id) == spirv_id_to_variable_.end()) {
error_ = "missing generated ID for variable";
return 0;
}
auto* var = spirv_id_to_variable_[id];
if (var->is_const()) {
return id;
}
auto type_id = GenerateTypeIfNeeded(expr->result_type());
auto result = result_op();
auto result_id = result.to_i();
push_function_inst(spv::Op::OpLoad,
{Operand::Int(type_id), result, Operand::Int(id)});
return result_id;
}
bool Builder::GenerateFunction(ast::Function* func) { bool Builder::GenerateFunction(ast::Function* func) {
uint32_t func_type_id = GenerateFunctionTypeIfNeeded(func); uint32_t func_type_id = GenerateFunctionTypeIfNeeded(func);
if (func_type_id == 0) { if (func_type_id == 0) {
@ -468,7 +440,8 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
auto* mem_accessor = source->AsMemberAccessor(); auto* mem_accessor = source->AsMemberAccessor();
source = mem_accessor->structure(); source = mem_accessor->structure();
auto* data_type = mem_accessor->structure()->result_type(); auto* data_type =
mem_accessor->structure()->result_type()->UnwrapPtrIfNeeded();
if (data_type->IsStruct()) { if (data_type->IsStruct()) {
auto* strct = data_type->AsStruct()->impl(); auto* strct = data_type->AsStruct()->impl();
auto name = mem_accessor->member()->name(); auto name = mem_accessor->member()->name();
@ -476,11 +449,9 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
uint32_t i = 0; uint32_t i = 0;
for (; i < strct->members().size(); ++i) { for (; i < strct->members().size(); ++i) {
const auto& member = strct->members()[i]; const auto& member = strct->members()[i];
if (member->name() != name) { if (member->name() == name) {
continue; break;
} }
break;
} }
ast::type::U32Type u32; ast::type::U32Type u32;
@ -507,9 +478,7 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
return 0; return 0;
} }
// The access chain results in a pointer, so wrap the return type. auto type_id = GenerateTypeIfNeeded(expr->result_type());
ast::type::PointerType ptr(expr->result_type(), ast::StorageClass::kFunction);
auto type_id = GenerateTypeIfNeeded(&ptr);
if (type_id == 0) { if (type_id == 0) {
return 0; return 0;
} }
@ -535,12 +504,23 @@ uint32_t Builder::GenerateIdentifierExpression(
if (val == 0) { if (val == 0) {
error_ = "unable to lookup: " + expr->name() + " in " + expr->path(); error_ = "unable to lookup: " + expr->name() + " in " + expr->path();
} }
} else if (!scope_stack_.get(expr->name(), &val)) { return val;
error_ = "unable to find name for identifier: " + expr->name(); }
return 0; if (scope_stack_.get(expr->name(), &val)) {
return val;
} }
return val; error_ = "unable to find name for identifier: " + expr->name();
return 0;
}
uint32_t Builder::GenerateLoad(ast::type::Type* type, uint32_t id) {
auto type_id = GenerateTypeIfNeeded(type->UnwrapPtrIfNeeded());
auto result = result_op();
auto result_id = result.to_i();
push_function_inst(spv::Op::OpLoad,
{Operand::Int(type_id), result, Operand::Int(id)});
return result_id;
} }
uint32_t Builder::GenerateUnaryOpExpression(ast::UnaryOpExpression* expr) { uint32_t Builder::GenerateUnaryOpExpression(ast::UnaryOpExpression* expr) {
@ -686,14 +666,21 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) {
} }
uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
auto lhs_id = GenerateExpressionAndLoad(expr->lhs()); auto lhs_id = GenerateExpression(expr->lhs());
if (lhs_id == 0) { if (lhs_id == 0) {
return 0; return 0;
} }
auto rhs_id = GenerateExpressionAndLoad(expr->rhs()); if (expr->lhs()->result_type()->IsPointer()) {
lhs_id = GenerateLoad(expr->lhs()->result_type(), lhs_id);
}
auto rhs_id = GenerateExpression(expr->rhs());
if (rhs_id == 0) { if (rhs_id == 0) {
return 0; return 0;
} }
if (expr->rhs()->result_type()->IsPointer()) {
rhs_id = GenerateLoad(expr->rhs()->result_type(), rhs_id);
}
auto result = result_op(); auto result = result_op();
auto result_id = result.to_i(); auto result_id = result.to_i();
@ -705,8 +692,8 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
// Handle int and float and the vectors of those types. Other types // Handle int and float and the vectors of those types. Other types
// should have been rejected by validation. // should have been rejected by validation.
auto* lhs_type = expr->lhs()->result_type(); auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
auto* rhs_type = expr->rhs()->result_type(); auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded();
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector(); bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector(); bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
@ -819,6 +806,7 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
} else if (expr->IsXor()) { } else if (expr->IsXor()) {
op = spv::Op::OpBitwiseXor; op = spv::Op::OpBitwiseXor;
} else { } else {
error_ = "unknown binary expression";
return 0; return 0;
} }

View File

@ -157,10 +157,6 @@ class Builder {
/// @param expr the expression to generate /// @param expr the expression to generate
/// @returns the resulting ID of the expression or 0 on error /// @returns the resulting ID of the expression or 0 on error
uint32_t GenerateExpression(ast::Expression* expr); uint32_t GenerateExpression(ast::Expression* expr);
/// Generates an expression and emits a load if necessary
/// @param expr the expression
/// @returns the SPIR-V result id
uint32_t GenerateExpressionAndLoad(ast::Expression* expr);
/// Generates the instructions for a function /// Generates the instructions for a function
/// @param func the function to generate /// @param func the function to generate
/// @returns true if the instructions were generated /// @returns true if the instructions were generated
@ -240,6 +236,11 @@ class Builder {
/// @param list the statement list to generate /// @param list the statement list to generate
/// @returns true on successful generation /// @returns true on successful generation
bool GenerateStatementList(const ast::StatementList& list); bool GenerateStatementList(const ast::StatementList& list);
/// Geneates an OpLoad
/// @param type the type to load
/// @param id the variable id to load
/// @returns the ID of the loaded value
uint32_t GenerateLoad(ast::type::Type* type, uint32_t id);
/// Geneates an OpStore /// Geneates an OpStore
/// @param to the ID to store too /// @param to the ID to store too
/// @param from the ID to store from /// @param from the ID to store from

View File

@ -132,6 +132,13 @@ TEST_F(BuilderTest, ArrayAccessor_MultiLevel) {
TEST_F(BuilderTest, MemberAccessor) { TEST_F(BuilderTest, MemberAccessor) {
ast::type::F32Type f32; ast::type::F32Type f32;
// my_struct {
// a : f32
// b : f32
// }
// var ident : my_struct
// ident.b
ast::StructMemberDecorationList decos; ast::StructMemberDecorationList decos;
ast::StructMemberList members; ast::StructMemberList members;
members.push_back( members.push_back(
@ -180,6 +187,15 @@ TEST_F(BuilderTest, MemberAccessor) {
TEST_F(BuilderTest, MemberAccessor_Nested) { TEST_F(BuilderTest, MemberAccessor_Nested) {
ast::type::F32Type f32; ast::type::F32Type f32;
// inner_struct {
// a : f32
// }
// my_struct {
// inner : inner_struct
// }
//
// var ident : my_struct
// ident.inner.a
ast::StructMemberDecorationList decos; ast::StructMemberDecorationList decos;
ast::StructMemberList inner_members; ast::StructMemberList inner_members;
inner_members.push_back( inner_members.push_back(

View File

@ -21,6 +21,8 @@
#include "src/ast/scalar_constructor_expression.h" #include "src/ast/scalar_constructor_expression.h"
#include "src/ast/type/f32_type.h" #include "src/ast/type/f32_type.h"
#include "src/ast/type/vector_type.h" #include "src/ast/type/vector_type.h"
#include "src/context.h"
#include "src/type_determiner.h"
#include "src/writer/spirv/builder.h" #include "src/writer/spirv/builder.h"
#include "src/writer/spirv/spv_dump.h" #include "src/writer/spirv/spv_dump.h"
@ -37,18 +39,23 @@ TEST_F(BuilderTest, Assign_Var) {
ast::Variable v("var", ast::StorageClass::kOutput, &f32); ast::Variable v("var", ast::StorageClass::kOutput, &f32);
ast::Module mod;
Builder b(&mod);
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
ASSERT_FALSE(b.has_error()) << b.error();
auto ident = std::make_unique<ast::IdentifierExpression>("var"); auto ident = std::make_unique<ast::IdentifierExpression>("var");
auto val = std::make_unique<ast::ScalarConstructorExpression>( auto val = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)); std::make_unique<ast::FloatLiteral>(&f32, 1.0f));
ast::AssignmentStatement assign(std::move(ident), std::move(val)); ast::AssignmentStatement assign(std::move(ident), std::move(val));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&v);
ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error();
Builder b(&mod);
b.push_function(Function{}); b.push_function(Function{});
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error(); EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error();
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());

View File

@ -56,7 +56,11 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalConst) {
v.set_constructor(std::move(init)); v.set_constructor(std::move(init));
v.set_is_const(true); v.set_is_const(true);
Context ctx;
ast::Module mod; ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&v);
Builder b(&mod); Builder b(&mod);
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
ASSERT_FALSE(b.has_error()) << b.error(); ASSERT_FALSE(b.has_error()) << b.error();
@ -69,6 +73,8 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalConst) {
)"); )");
ast::IdentifierExpression expr("var"); ast::IdentifierExpression expr("var");
ASSERT_TRUE(td.DetermineResultType(&expr));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 5u); EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 5u);
} }
@ -76,8 +82,13 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalVar) {
ast::type::F32Type f32; ast::type::F32Type f32;
ast::Variable v("var", ast::StorageClass::kOutput, &f32); ast::Variable v("var", ast::StorageClass::kOutput, &f32);
Context ctx;
ast::Module mod; ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&v);
Builder b(&mod); Builder b(&mod);
b.push_function(Function{});
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "var" EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "var"
)"); )");
@ -87,6 +98,7 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalVar) {
)"); )");
ast::IdentifierExpression expr("var"); ast::IdentifierExpression expr("var");
ASSERT_TRUE(td.DetermineResultType(&expr));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 1u); EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 1u);
} }
@ -109,7 +121,11 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionConst) {
v.set_constructor(std::move(init)); v.set_constructor(std::move(init));
v.set_is_const(true); v.set_is_const(true);
Context ctx;
ast::Module mod; ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&v);
Builder b(&mod); Builder b(&mod);
EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error(); EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error();
ASSERT_FALSE(b.has_error()) << b.error(); ASSERT_FALSE(b.has_error()) << b.error();
@ -122,6 +138,7 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionConst) {
)"); )");
ast::IdentifierExpression expr("var"); ast::IdentifierExpression expr("var");
ASSERT_TRUE(td.DetermineResultType(&expr));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 5u); EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 5u);
} }
@ -129,7 +146,11 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionVar) {
ast::type::F32Type f32; ast::type::F32Type f32;
ast::Variable v("var", ast::StorageClass::kNone, &f32); ast::Variable v("var", ast::StorageClass::kNone, &f32);
Context ctx;
ast::Module mod; ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&v);
Builder b(&mod); Builder b(&mod);
b.push_function(Function{}); b.push_function(Function{});
EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error(); EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error();
@ -144,6 +165,7 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionVar) {
)"); )");
ast::IdentifierExpression expr("var"); ast::IdentifierExpression expr("var");
ASSERT_TRUE(td.DetermineResultType(&expr));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 1u); EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 1u);
} }
@ -170,7 +192,7 @@ TEST_F(BuilderTest, IdentifierExpression_Load) {
b.push_function(Function{}); b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(&var)) << b.error(); ASSERT_TRUE(b.GenerateGlobalVariable(&var)) << b.error();
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 6u) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 6u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1 EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
%2 = OpTypePointer Private %3 %2 = OpTypePointer Private %3
%1 = OpVariable %2 Private %1 = OpVariable %2 Private
@ -217,31 +239,6 @@ TEST_F(BuilderTest, IdentifierExpression_NoLoadConst) {
)"); )");
} }
TEST_F(BuilderTest, IdentifierExpression_ImportMethod) {
auto imp = std::make_unique<ast::Import>("GLSL.std.450", "std");
imp->AddMethodId("round", 42u);
ast::Module mod;
mod.AddImport(std::move(imp));
Builder b(&mod);
ast::IdentifierExpression expr(std::vector<std::string>({"std", "round"}));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 42u) << b.error();
}
TEST_F(BuilderTest, IdentifierExpression_ImportMethod_NotFound) {
auto imp = std::make_unique<ast::Import>("GLSL.std.450", "std");
ast::Module mod;
mod.AddImport(std::move(imp));
Builder b(&mod);
ast::IdentifierExpression expr(std::vector<std::string>({"std", "ceil"}));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 0u);
ASSERT_TRUE(b.has_error());
EXPECT_EQ(b.error(), "unable to lookup: ceil in std");
}
} // namespace } // namespace
} // namespace spirv } // namespace spirv
} // namespace writer } // namespace writer