Update type determiner to create pointers.
This CL updates the type determiner such that variable result types end up wrapped inside pointers, constants do not. The result of Member and Array accessors are also pointers if the source was a pointer. Change-Id: I6694367daf6ba1db929e54a975dfea8404fca40c Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/20265 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
1a4d90667b
commit
8eddb78433
|
@ -93,7 +93,7 @@ class Expression : public Node {
|
|||
/// @returns the expression as a unary op expression
|
||||
const UnaryOpExpression* AsUnaryOp() const;
|
||||
|
||||
/// @returns the expression as an array accessor
|
||||
/// @returns the expression as an array accessor
|
||||
ArrayAccessorExpression* AsArrayAccessor();
|
||||
/// @returns the expression as an as
|
||||
AsExpression* AsAs();
|
||||
|
|
|
@ -36,6 +36,13 @@ Type::Type() = default;
|
|||
|
||||
Type::~Type() = default;
|
||||
|
||||
Type* Type::UnwrapPtrIfNeeded() {
|
||||
if (IsPointer()) {
|
||||
return AsPointer()->type();
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
bool Type::IsAlias() const {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -66,6 +66,9 @@ class Type {
|
|||
/// @returns the name for this type. The |type_name| is unique over all types.
|
||||
virtual std::string type_name() const = 0;
|
||||
|
||||
/// @returns the pointee type if this is a pointer, |this| otherwise
|
||||
Type* UnwrapPtrIfNeeded();
|
||||
|
||||
/// @returns true if this type is a float scalar
|
||||
bool is_float_scalar();
|
||||
/// @returns true if this type is a float matrix
|
||||
|
|
|
@ -1159,8 +1159,7 @@ TEST_F(SpvParserTest,
|
|||
EXPECT_THAT(fe.block_order(), ElementsAre(10, 30, 20, 80, 99));
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest,
|
||||
ComputeBlockOrder_Switch_DefaultSameAsACase) {
|
||||
TEST_F(SpvParserTest, ComputeBlockOrder_Switch_DefaultSameAsACase) {
|
||||
auto* p = parser(test::Assemble(CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
|
||||
|
@ -1373,8 +1372,7 @@ TEST_F(SpvParserTest,
|
|||
<< assembly;
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest,
|
||||
ComputeBlockOrder_Nest_If_Contains_If) {
|
||||
TEST_F(SpvParserTest, ComputeBlockOrder_Nest_If_Contains_If) {
|
||||
auto assembly = CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
|
||||
|
@ -1424,8 +1422,7 @@ TEST_F(SpvParserTest,
|
|||
<< assembly;
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest,
|
||||
ComputeBlockOrder_Nest_If_In_SwitchCase) {
|
||||
TEST_F(SpvParserTest, ComputeBlockOrder_Nest_If_In_SwitchCase) {
|
||||
auto assembly = CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
|
||||
|
@ -1475,8 +1472,7 @@ TEST_F(SpvParserTest,
|
|||
<< assembly;
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest,
|
||||
ComputeBlockOrder_Nest_IfFallthrough_In_SwitchCase) {
|
||||
TEST_F(SpvParserTest, ComputeBlockOrder_Nest_IfFallthrough_In_SwitchCase) {
|
||||
auto assembly = CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
|
||||
|
@ -1526,8 +1522,7 @@ TEST_F(SpvParserTest,
|
|||
<< assembly;
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest,
|
||||
ComputeBlockOrder_Nest_IfBreak_In_SwitchCase) {
|
||||
TEST_F(SpvParserTest, ComputeBlockOrder_Nest_IfBreak_In_SwitchCase) {
|
||||
auto assembly = CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "src/ast/type/bool_type.h"
|
||||
#include "src/ast/type/f32_type.h"
|
||||
#include "src/ast/type/matrix_type.h"
|
||||
#include "src/ast/type/pointer_type.h"
|
||||
#include "src/ast/type/struct_type.h"
|
||||
#include "src/ast/type/vector_type.h"
|
||||
#include "src/ast/type_constructor_expression.h"
|
||||
|
@ -274,19 +275,30 @@ bool TypeDeterminer::DetermineArrayAccessor(
|
|||
if (!DetermineResultType(expr->array())) {
|
||||
return false;
|
||||
}
|
||||
auto* parent_type = expr->array()->result_type();
|
||||
|
||||
auto* res = expr->array()->result_type();
|
||||
auto* parent_type = res->UnwrapPtrIfNeeded();
|
||||
ast::type::Type* ret = nullptr;
|
||||
if (parent_type->IsArray()) {
|
||||
expr->set_result_type(parent_type->AsArray()->type());
|
||||
ret = parent_type->AsArray()->type();
|
||||
} else if (parent_type->IsVector()) {
|
||||
expr->set_result_type(parent_type->AsVector()->type());
|
||||
ret = parent_type->AsVector()->type();
|
||||
} else if (parent_type->IsMatrix()) {
|
||||
auto* m = parent_type->AsMatrix();
|
||||
expr->set_result_type(ctx_.type_mgr().Get(
|
||||
std::make_unique<ast::type::VectorType>(m->type(), m->rows())));
|
||||
ret = ctx_.type_mgr().Get(
|
||||
std::make_unique<ast::type::VectorType>(m->type(), m->rows()));
|
||||
} else {
|
||||
set_error(expr->source(), "invalid parent type in array accessor");
|
||||
return false;
|
||||
}
|
||||
|
||||
// If we're extracting from a pointer, we return a pointer.
|
||||
if (res->IsPointer()) {
|
||||
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
|
||||
ret, res->AsPointer()->storage_class()));
|
||||
}
|
||||
expr->set_result_type(ret);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -365,7 +377,15 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
|
|||
auto name = expr->name();
|
||||
ast::Variable* var;
|
||||
if (variable_stack_.get(name, &var)) {
|
||||
expr->set_result_type(var->type());
|
||||
// A constant is the type, but a variable is always a pointer so synthesize
|
||||
// the pointer around the variable type.
|
||||
if (var->is_const()) {
|
||||
expr->set_result_type(var->type());
|
||||
} else {
|
||||
expr->set_result_type(
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
|
||||
var->type(), var->storage_class())));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -384,43 +404,52 @@ bool TypeDeterminer::DetermineMemberAccessor(
|
|||
return false;
|
||||
}
|
||||
|
||||
auto* data_type = expr->structure()->result_type();
|
||||
auto* res = expr->structure()->result_type();
|
||||
auto* data_type = res->UnwrapPtrIfNeeded();
|
||||
ast::type::Type* ret = nullptr;
|
||||
if (data_type->IsStruct()) {
|
||||
auto* strct = data_type->AsStruct()->impl();
|
||||
auto name = expr->member()->name();
|
||||
|
||||
for (const auto& member : strct->members()) {
|
||||
if (member->name() != name) {
|
||||
continue;
|
||||
if (member->name() == name) {
|
||||
ret = member->type();
|
||||
break;
|
||||
}
|
||||
|
||||
expr->set_result_type(member->type());
|
||||
return true;
|
||||
}
|
||||
|
||||
set_error(expr->source(), "struct member not found");
|
||||
return false;
|
||||
}
|
||||
if (data_type->IsVector()) {
|
||||
if (ret == nullptr) {
|
||||
set_error(expr->source(), "struct member " + name + " not found");
|
||||
return false;
|
||||
}
|
||||
} else if (data_type->IsVector()) {
|
||||
auto* vec = data_type->AsVector();
|
||||
|
||||
auto size = expr->member()->name().size();
|
||||
if (size == 1) {
|
||||
// A single element swizzle is just the type of the vector.
|
||||
expr->set_result_type(vec->type());
|
||||
ret = vec->type();
|
||||
} else {
|
||||
// The vector will have a number of components equal to the length of the
|
||||
// swizzle. This assumes the validator will check that the swizzle
|
||||
// is correct.
|
||||
expr->set_result_type(ctx_.type_mgr().Get(
|
||||
std::make_unique<ast::type::VectorType>(vec->type(), size)));
|
||||
ret = ctx_.type_mgr().Get(
|
||||
std::make_unique<ast::type::VectorType>(vec->type(), size));
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
set_error(expr->source(),
|
||||
"invalid type " + data_type->type_name() + " in member accessor");
|
||||
return false;
|
||||
}
|
||||
|
||||
set_error(expr->source(),
|
||||
"invalid type " + data_type->type_name() + " in member accessor");
|
||||
return false;
|
||||
// If we're extracting from a pointer, we return a pointer.
|
||||
if (res->IsPointer()) {
|
||||
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
|
||||
ret, res->AsPointer()->storage_class()));
|
||||
}
|
||||
expr->set_result_type(ret);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
|
||||
|
@ -432,7 +461,7 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
|
|||
if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
|
||||
expr->IsShiftRight() || expr->IsShiftRightArith() || expr->IsAdd() ||
|
||||
expr->IsSubtract() || expr->IsDivide() || expr->IsModulo()) {
|
||||
expr->set_result_type(expr->lhs()->result_type());
|
||||
expr->set_result_type(expr->lhs()->result_type()->UnwrapPtrIfNeeded());
|
||||
return true;
|
||||
}
|
||||
// Result type is a scalar or vector of boolean type
|
||||
|
@ -441,7 +470,7 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
|
|||
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
|
||||
auto* bool_type =
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
|
||||
auto* param_type = expr->lhs()->result_type();
|
||||
auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
|
||||
if (param_type->IsVector()) {
|
||||
expr->set_result_type(
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
|
||||
|
@ -452,8 +481,8 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
|
|||
return true;
|
||||
}
|
||||
if (expr->IsMultiply()) {
|
||||
auto* lhs_type = expr->lhs()->result_type();
|
||||
auto* rhs_type = expr->rhs()->result_type();
|
||||
auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
|
||||
auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded();
|
||||
|
||||
// Note, the ordering here matters. The later checks depend on the prior
|
||||
// checks having been done.
|
||||
|
@ -504,7 +533,7 @@ bool TypeDeterminer::DetermineUnaryDerivative(
|
|||
if (!DetermineResultType(expr->param())) {
|
||||
return false;
|
||||
}
|
||||
expr->set_result_type(expr->param()->result_type());
|
||||
expr->set_result_type(expr->param()->result_type()->UnwrapPtrIfNeeded());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -531,7 +560,7 @@ bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) {
|
|||
|
||||
auto* bool_type =
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
|
||||
auto* param_type = expr->params()[0]->result_type();
|
||||
auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
|
||||
if (param_type->IsVector()) {
|
||||
expr->set_result_type(
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
|
||||
|
@ -552,8 +581,8 @@ bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) {
|
|||
"incorrect number of parameters for outer product");
|
||||
return false;
|
||||
}
|
||||
auto* param0_type = expr->params()[0]->result_type();
|
||||
auto* param1_type = expr->params()[1]->result_type();
|
||||
auto* param0_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
|
||||
auto* param1_type = expr->params()[1]->result_type()->UnwrapPtrIfNeeded();
|
||||
if (!param0_type->IsVector() || !param1_type->IsVector()) {
|
||||
set_error(expr->source(), "invalid parameter type for outer product");
|
||||
return false;
|
||||
|
@ -574,7 +603,7 @@ bool TypeDeterminer::DetermineUnaryOp(ast::UnaryOpExpression* expr) {
|
|||
if (!DetermineResultType(expr->expr())) {
|
||||
return false;
|
||||
}
|
||||
expr->set_result_type(expr->expr()->result_type());
|
||||
expr->set_result_type(expr->expr()->result_type()->UnwrapPtrIfNeeded());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@
|
|||
#include "src/ast/type/f32_type.h"
|
||||
#include "src/ast/type/i32_type.h"
|
||||
#include "src/ast/type/matrix_type.h"
|
||||
#include "src/ast/type/pointer_type.h"
|
||||
#include "src/ast/type/struct_type.h"
|
||||
#include "src/ast/type/vector_type.h"
|
||||
#include "src/ast/type_constructor_expression.h"
|
||||
|
@ -433,8 +434,33 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Array) {
|
|||
|
||||
auto idx = std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 2));
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &ary);
|
||||
auto var = std::make_unique<ast::Variable>(
|
||||
"my_var", ast::StorageClass::kFunction, &ary);
|
||||
mod()->AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
EXPECT_TRUE(td()->Determine());
|
||||
|
||||
ast::ArrayAccessorExpression acc(
|
||||
std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx));
|
||||
EXPECT_TRUE(td()->DetermineResultType(&acc));
|
||||
ASSERT_NE(acc.result_type(), nullptr);
|
||||
ASSERT_TRUE(acc.result_type()->IsPointer());
|
||||
|
||||
auto* ptr = acc.result_type()->AsPointer();
|
||||
EXPECT_TRUE(ptr->type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Array_Constant) {
|
||||
ast::type::I32Type i32;
|
||||
ast::type::F32Type f32;
|
||||
ast::type::ArrayType ary(&f32, 3);
|
||||
|
||||
auto idx = std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 2));
|
||||
auto var = std::make_unique<ast::Variable>(
|
||||
"my_var", ast::StorageClass::kFunction, &ary);
|
||||
var->set_is_const(true);
|
||||
mod()->AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
|
@ -465,8 +491,11 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix) {
|
|||
std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx));
|
||||
EXPECT_TRUE(td()->DetermineResultType(&acc));
|
||||
ASSERT_NE(acc.result_type(), nullptr);
|
||||
ASSERT_TRUE(acc.result_type()->IsVector());
|
||||
EXPECT_EQ(acc.result_type()->AsVector()->size(), 3u);
|
||||
ASSERT_TRUE(acc.result_type()->IsPointer());
|
||||
|
||||
auto* ptr = acc.result_type()->AsPointer();
|
||||
ASSERT_TRUE(ptr->type()->IsVector());
|
||||
EXPECT_EQ(ptr->type()->AsVector()->size(), 3u);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix_BothDimensions) {
|
||||
|
@ -493,7 +522,10 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix_BothDimensions) {
|
|||
|
||||
EXPECT_TRUE(td()->DetermineResultType(&acc));
|
||||
ASSERT_NE(acc.result_type(), nullptr);
|
||||
EXPECT_TRUE(acc.result_type()->IsF32());
|
||||
ASSERT_TRUE(acc.result_type()->IsPointer());
|
||||
|
||||
auto* ptr = acc.result_type()->AsPointer();
|
||||
EXPECT_TRUE(ptr->type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Vector) {
|
||||
|
@ -514,7 +546,10 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Vector) {
|
|||
std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx));
|
||||
EXPECT_TRUE(td()->DetermineResultType(&acc));
|
||||
ASSERT_NE(acc.result_type(), nullptr);
|
||||
EXPECT_TRUE(acc.result_type()->IsF32());
|
||||
ASSERT_TRUE(acc.result_type()->IsPointer());
|
||||
|
||||
auto* ptr = acc.result_type()->AsPointer();
|
||||
EXPECT_TRUE(ptr->type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_As) {
|
||||
|
@ -643,12 +678,55 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalVariable) {
|
|||
// Register the global
|
||||
EXPECT_TRUE(td()->Determine());
|
||||
|
||||
ast::IdentifierExpression ident("my_var");
|
||||
EXPECT_TRUE(td()->DetermineResultType(&ident));
|
||||
ASSERT_NE(ident.result_type(), nullptr);
|
||||
EXPECT_TRUE(ident.result_type()->IsPointer());
|
||||
EXPECT_TRUE(ident.result_type()->AsPointer()->type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalConstant) {
|
||||
ast::type::F32Type f32;
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32);
|
||||
var->set_is_const(true);
|
||||
mod()->AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
EXPECT_TRUE(td()->Determine());
|
||||
|
||||
ast::IdentifierExpression ident("my_var");
|
||||
EXPECT_TRUE(td()->DetermineResultType(&ident));
|
||||
ASSERT_NE(ident.result_type(), nullptr);
|
||||
EXPECT_TRUE(ident.result_type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable_Const) {
|
||||
ast::type::F32Type f32;
|
||||
|
||||
auto my_var = std::make_unique<ast::IdentifierExpression>("my_var");
|
||||
auto* my_var_ptr = my_var.get();
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32);
|
||||
var->set_is_const(true);
|
||||
|
||||
ast::StatementList body;
|
||||
body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
|
||||
|
||||
body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||
std::move(my_var),
|
||||
std::make_unique<ast::IdentifierExpression>("my_var")));
|
||||
|
||||
ast::Function f("my_func", {}, &f32);
|
||||
f.set_body(std::move(body));
|
||||
|
||||
EXPECT_TRUE(td()->DetermineFunction(&f));
|
||||
|
||||
ASSERT_NE(my_var_ptr->result_type(), nullptr);
|
||||
EXPECT_TRUE(my_var_ptr->result_type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) {
|
||||
ast::type::F32Type f32;
|
||||
|
||||
|
@ -670,7 +748,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) {
|
|||
EXPECT_TRUE(td()->DetermineFunction(&f));
|
||||
|
||||
ASSERT_NE(my_var_ptr->result_type(), nullptr);
|
||||
EXPECT_TRUE(my_var_ptr->result_type()->IsF32());
|
||||
EXPECT_TRUE(my_var_ptr->result_type()->IsPointer());
|
||||
EXPECT_TRUE(my_var_ptr->result_type()->AsPointer()->type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Identifier_Function) {
|
||||
|
@ -720,7 +799,10 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) {
|
|||
ast::MemberAccessorExpression mem(std::move(ident), std::move(mem_ident));
|
||||
EXPECT_TRUE(td()->DetermineResultType(&mem));
|
||||
ASSERT_NE(mem.result_type(), nullptr);
|
||||
EXPECT_TRUE(mem.result_type()->IsF32());
|
||||
ASSERT_TRUE(mem.result_type()->IsPointer());
|
||||
|
||||
auto* ptr = mem.result_type()->AsPointer();
|
||||
EXPECT_TRUE(ptr->type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
|
||||
|
@ -740,9 +822,12 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
|
|||
ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
|
||||
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
|
||||
ASSERT_NE(mem.result_type(), nullptr);
|
||||
ASSERT_TRUE(mem.result_type()->IsVector());
|
||||
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
|
||||
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u);
|
||||
ASSERT_TRUE(mem.result_type()->IsPointer());
|
||||
|
||||
auto* ptr = mem.result_type()->AsPointer();
|
||||
ASSERT_TRUE(ptr->type()->IsVector());
|
||||
EXPECT_TRUE(ptr->type()->AsVector()->type()->IsF32());
|
||||
EXPECT_EQ(ptr->type()->AsVector()->size(), 2u);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
|
||||
|
@ -762,10 +847,13 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
|
|||
ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
|
||||
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
|
||||
ASSERT_NE(mem.result_type(), nullptr);
|
||||
ASSERT_TRUE(mem.result_type()->IsF32());
|
||||
ASSERT_TRUE(mem.result_type()->IsPointer());
|
||||
|
||||
auto* ptr = mem.result_type()->AsPointer();
|
||||
ASSERT_TRUE(ptr->type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
|
||||
TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) {
|
||||
// struct b {
|
||||
// vec4<f32> foo
|
||||
// }
|
||||
|
@ -803,6 +891,7 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
|
|||
auto strctB = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
|
||||
std::move(b_members));
|
||||
ast::type::StructType stB(std::move(strctB));
|
||||
stB.set_name("B");
|
||||
|
||||
ast::type::VectorType vecB(&stB, 3);
|
||||
|
||||
|
@ -814,6 +903,7 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
|
|||
std::move(a_members));
|
||||
|
||||
ast::type::StructType stA(std::move(strctA));
|
||||
stA.set_name("A");
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("c", ast::StorageClass::kNone, &stA);
|
||||
|
@ -838,10 +928,14 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
|
|||
std::move(foo_ident)),
|
||||
std::move(swizzle));
|
||||
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
|
||||
|
||||
ASSERT_NE(mem.result_type(), nullptr);
|
||||
ASSERT_TRUE(mem.result_type()->IsVector());
|
||||
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
|
||||
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u);
|
||||
ASSERT_TRUE(mem.result_type()->IsPointer());
|
||||
|
||||
auto* ptr = mem.result_type()->AsPointer();
|
||||
ASSERT_TRUE(ptr->type()->IsVector());
|
||||
EXPECT_TRUE(ptr->type()->AsVector()->type()->IsF32());
|
||||
EXPECT_EQ(ptr->type()->AsVector()->size(), 2u);
|
||||
}
|
||||
|
||||
using Expr_Binary_BitwiseTest = TypeDeterminerTestWithParam<ast::BinaryOp>;
|
||||
|
|
|
@ -226,34 +226,6 @@ uint32_t Builder::GenerateExpression(ast::Expression* expr) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
uint32_t Builder::GenerateExpressionAndLoad(ast::Expression* expr) {
|
||||
auto id = GenerateExpression(expr);
|
||||
if (id == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only need to load identifiers
|
||||
if (!expr->IsIdentifier()) {
|
||||
return id;
|
||||
}
|
||||
if (spirv_id_to_variable_.find(id) == spirv_id_to_variable_.end()) {
|
||||
error_ = "missing generated ID for variable";
|
||||
return 0;
|
||||
}
|
||||
auto* var = spirv_id_to_variable_[id];
|
||||
if (var->is_const()) {
|
||||
return id;
|
||||
}
|
||||
|
||||
auto type_id = GenerateTypeIfNeeded(expr->result_type());
|
||||
auto result = result_op();
|
||||
auto result_id = result.to_i();
|
||||
push_function_inst(spv::Op::OpLoad,
|
||||
{Operand::Int(type_id), result, Operand::Int(id)});
|
||||
|
||||
return result_id;
|
||||
}
|
||||
|
||||
bool Builder::GenerateFunction(ast::Function* func) {
|
||||
uint32_t func_type_id = GenerateFunctionTypeIfNeeded(func);
|
||||
if (func_type_id == 0) {
|
||||
|
@ -468,7 +440,8 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
|
|||
auto* mem_accessor = source->AsMemberAccessor();
|
||||
source = mem_accessor->structure();
|
||||
|
||||
auto* data_type = mem_accessor->structure()->result_type();
|
||||
auto* data_type =
|
||||
mem_accessor->structure()->result_type()->UnwrapPtrIfNeeded();
|
||||
if (data_type->IsStruct()) {
|
||||
auto* strct = data_type->AsStruct()->impl();
|
||||
auto name = mem_accessor->member()->name();
|
||||
|
@ -476,11 +449,9 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
|
|||
uint32_t i = 0;
|
||||
for (; i < strct->members().size(); ++i) {
|
||||
const auto& member = strct->members()[i];
|
||||
if (member->name() != name) {
|
||||
continue;
|
||||
if (member->name() == name) {
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
ast::type::U32Type u32;
|
||||
|
@ -507,9 +478,7 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
// The access chain results in a pointer, so wrap the return type.
|
||||
ast::type::PointerType ptr(expr->result_type(), ast::StorageClass::kFunction);
|
||||
auto type_id = GenerateTypeIfNeeded(&ptr);
|
||||
auto type_id = GenerateTypeIfNeeded(expr->result_type());
|
||||
if (type_id == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
@ -535,12 +504,23 @@ uint32_t Builder::GenerateIdentifierExpression(
|
|||
if (val == 0) {
|
||||
error_ = "unable to lookup: " + expr->name() + " in " + expr->path();
|
||||
}
|
||||
} else if (!scope_stack_.get(expr->name(), &val)) {
|
||||
error_ = "unable to find name for identifier: " + expr->name();
|
||||
return 0;
|
||||
return val;
|
||||
}
|
||||
if (scope_stack_.get(expr->name(), &val)) {
|
||||
return val;
|
||||
}
|
||||
|
||||
return val;
|
||||
error_ = "unable to find name for identifier: " + expr->name();
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t Builder::GenerateLoad(ast::type::Type* type, uint32_t id) {
|
||||
auto type_id = GenerateTypeIfNeeded(type->UnwrapPtrIfNeeded());
|
||||
auto result = result_op();
|
||||
auto result_id = result.to_i();
|
||||
push_function_inst(spv::Op::OpLoad,
|
||||
{Operand::Int(type_id), result, Operand::Int(id)});
|
||||
return result_id;
|
||||
}
|
||||
|
||||
uint32_t Builder::GenerateUnaryOpExpression(ast::UnaryOpExpression* expr) {
|
||||
|
@ -686,14 +666,21 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) {
|
|||
}
|
||||
|
||||
uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
||||
auto lhs_id = GenerateExpressionAndLoad(expr->lhs());
|
||||
auto lhs_id = GenerateExpression(expr->lhs());
|
||||
if (lhs_id == 0) {
|
||||
return 0;
|
||||
}
|
||||
auto rhs_id = GenerateExpressionAndLoad(expr->rhs());
|
||||
if (expr->lhs()->result_type()->IsPointer()) {
|
||||
lhs_id = GenerateLoad(expr->lhs()->result_type(), lhs_id);
|
||||
}
|
||||
|
||||
auto rhs_id = GenerateExpression(expr->rhs());
|
||||
if (rhs_id == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (expr->rhs()->result_type()->IsPointer()) {
|
||||
rhs_id = GenerateLoad(expr->rhs()->result_type(), rhs_id);
|
||||
}
|
||||
|
||||
auto result = result_op();
|
||||
auto result_id = result.to_i();
|
||||
|
@ -705,8 +692,8 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
|||
|
||||
// Handle int and float and the vectors of those types. Other types
|
||||
// should have been rejected by validation.
|
||||
auto* lhs_type = expr->lhs()->result_type();
|
||||
auto* rhs_type = expr->rhs()->result_type();
|
||||
auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
|
||||
auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded();
|
||||
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
|
||||
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
|
||||
|
||||
|
@ -819,6 +806,7 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
|||
} else if (expr->IsXor()) {
|
||||
op = spv::Op::OpBitwiseXor;
|
||||
} else {
|
||||
error_ = "unknown binary expression";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -157,10 +157,6 @@ class Builder {
|
|||
/// @param expr the expression to generate
|
||||
/// @returns the resulting ID of the expression or 0 on error
|
||||
uint32_t GenerateExpression(ast::Expression* expr);
|
||||
/// Generates an expression and emits a load if necessary
|
||||
/// @param expr the expression
|
||||
/// @returns the SPIR-V result id
|
||||
uint32_t GenerateExpressionAndLoad(ast::Expression* expr);
|
||||
/// Generates the instructions for a function
|
||||
/// @param func the function to generate
|
||||
/// @returns true if the instructions were generated
|
||||
|
@ -240,6 +236,11 @@ class Builder {
|
|||
/// @param list the statement list to generate
|
||||
/// @returns true on successful generation
|
||||
bool GenerateStatementList(const ast::StatementList& list);
|
||||
/// Geneates an OpLoad
|
||||
/// @param type the type to load
|
||||
/// @param id the variable id to load
|
||||
/// @returns the ID of the loaded value
|
||||
uint32_t GenerateLoad(ast::type::Type* type, uint32_t id);
|
||||
/// Geneates an OpStore
|
||||
/// @param to the ID to store too
|
||||
/// @param from the ID to store from
|
||||
|
|
|
@ -132,6 +132,13 @@ TEST_F(BuilderTest, ArrayAccessor_MultiLevel) {
|
|||
TEST_F(BuilderTest, MemberAccessor) {
|
||||
ast::type::F32Type f32;
|
||||
|
||||
// my_struct {
|
||||
// a : f32
|
||||
// b : f32
|
||||
// }
|
||||
// var ident : my_struct
|
||||
// ident.b
|
||||
|
||||
ast::StructMemberDecorationList decos;
|
||||
ast::StructMemberList members;
|
||||
members.push_back(
|
||||
|
@ -180,6 +187,15 @@ TEST_F(BuilderTest, MemberAccessor) {
|
|||
TEST_F(BuilderTest, MemberAccessor_Nested) {
|
||||
ast::type::F32Type f32;
|
||||
|
||||
// inner_struct {
|
||||
// a : f32
|
||||
// }
|
||||
// my_struct {
|
||||
// inner : inner_struct
|
||||
// }
|
||||
//
|
||||
// var ident : my_struct
|
||||
// ident.inner.a
|
||||
ast::StructMemberDecorationList decos;
|
||||
ast::StructMemberList inner_members;
|
||||
inner_members.push_back(
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include "src/ast/scalar_constructor_expression.h"
|
||||
#include "src/ast/type/f32_type.h"
|
||||
#include "src/ast/type/vector_type.h"
|
||||
#include "src/context.h"
|
||||
#include "src/type_determiner.h"
|
||||
#include "src/writer/spirv/builder.h"
|
||||
#include "src/writer/spirv/spv_dump.h"
|
||||
|
||||
|
@ -37,18 +39,23 @@ TEST_F(BuilderTest, Assign_Var) {
|
|||
|
||||
ast::Variable v("var", ast::StorageClass::kOutput, &f32);
|
||||
|
||||
ast::Module mod;
|
||||
Builder b(&mod);
|
||||
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
|
||||
auto ident = std::make_unique<ast::IdentifierExpression>("var");
|
||||
auto val = std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.0f));
|
||||
|
||||
ast::AssignmentStatement assign(std::move(ident), std::move(val));
|
||||
|
||||
Context ctx;
|
||||
ast::Module mod;
|
||||
TypeDeterminer td(&ctx, &mod);
|
||||
td.RegisterVariableForTesting(&v);
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error();
|
||||
|
||||
Builder b(&mod);
|
||||
b.push_function(Function{});
|
||||
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
|
||||
EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error();
|
||||
EXPECT_FALSE(b.has_error());
|
||||
|
|
|
@ -56,7 +56,11 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalConst) {
|
|||
v.set_constructor(std::move(init));
|
||||
v.set_is_const(true);
|
||||
|
||||
Context ctx;
|
||||
ast::Module mod;
|
||||
TypeDeterminer td(&ctx, &mod);
|
||||
td.RegisterVariableForTesting(&v);
|
||||
|
||||
Builder b(&mod);
|
||||
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
|
@ -69,6 +73,8 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalConst) {
|
|||
)");
|
||||
|
||||
ast::IdentifierExpression expr("var");
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr));
|
||||
|
||||
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 5u);
|
||||
}
|
||||
|
||||
|
@ -76,8 +82,13 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalVar) {
|
|||
ast::type::F32Type f32;
|
||||
ast::Variable v("var", ast::StorageClass::kOutput, &f32);
|
||||
|
||||
Context ctx;
|
||||
ast::Module mod;
|
||||
TypeDeterminer td(&ctx, &mod);
|
||||
td.RegisterVariableForTesting(&v);
|
||||
|
||||
Builder b(&mod);
|
||||
b.push_function(Function{});
|
||||
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "var"
|
||||
)");
|
||||
|
@ -87,6 +98,7 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalVar) {
|
|||
)");
|
||||
|
||||
ast::IdentifierExpression expr("var");
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr));
|
||||
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 1u);
|
||||
}
|
||||
|
||||
|
@ -109,7 +121,11 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionConst) {
|
|||
v.set_constructor(std::move(init));
|
||||
v.set_is_const(true);
|
||||
|
||||
Context ctx;
|
||||
ast::Module mod;
|
||||
TypeDeterminer td(&ctx, &mod);
|
||||
td.RegisterVariableForTesting(&v);
|
||||
|
||||
Builder b(&mod);
|
||||
EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error();
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
|
@ -122,6 +138,7 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionConst) {
|
|||
)");
|
||||
|
||||
ast::IdentifierExpression expr("var");
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr));
|
||||
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 5u);
|
||||
}
|
||||
|
||||
|
@ -129,7 +146,11 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionVar) {
|
|||
ast::type::F32Type f32;
|
||||
ast::Variable v("var", ast::StorageClass::kNone, &f32);
|
||||
|
||||
Context ctx;
|
||||
ast::Module mod;
|
||||
TypeDeterminer td(&ctx, &mod);
|
||||
td.RegisterVariableForTesting(&v);
|
||||
|
||||
Builder b(&mod);
|
||||
b.push_function(Function{});
|
||||
EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error();
|
||||
|
@ -144,6 +165,7 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionVar) {
|
|||
)");
|
||||
|
||||
ast::IdentifierExpression expr("var");
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr));
|
||||
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 1u);
|
||||
}
|
||||
|
||||
|
@ -170,7 +192,7 @@ TEST_F(BuilderTest, IdentifierExpression_Load) {
|
|||
b.push_function(Function{});
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(&var)) << b.error();
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 6u) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 6u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
|
||||
%2 = OpTypePointer Private %3
|
||||
%1 = OpVariable %2 Private
|
||||
|
@ -217,31 +239,6 @@ TEST_F(BuilderTest, IdentifierExpression_NoLoadConst) {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, IdentifierExpression_ImportMethod) {
|
||||
auto imp = std::make_unique<ast::Import>("GLSL.std.450", "std");
|
||||
imp->AddMethodId("round", 42u);
|
||||
|
||||
ast::Module mod;
|
||||
mod.AddImport(std::move(imp));
|
||||
Builder b(&mod);
|
||||
|
||||
ast::IdentifierExpression expr(std::vector<std::string>({"std", "round"}));
|
||||
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 42u) << b.error();
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, IdentifierExpression_ImportMethod_NotFound) {
|
||||
auto imp = std::make_unique<ast::Import>("GLSL.std.450", "std");
|
||||
|
||||
ast::Module mod;
|
||||
mod.AddImport(std::move(imp));
|
||||
Builder b(&mod);
|
||||
|
||||
ast::IdentifierExpression expr(std::vector<std::string>({"std", "ceil"}));
|
||||
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 0u);
|
||||
ASSERT_TRUE(b.has_error());
|
||||
EXPECT_EQ(b.error(), "unable to lookup: ceil in std");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace spirv
|
||||
} // namespace writer
|
||||
|
|
Loading…
Reference in New Issue