Update type determiner to create pointers.

This CL updates the type determiner such that variable result types
end up wrapped inside pointers, constants do not. The result of Member
and Array accessors are also pointers if the source was a pointer.

Change-Id: I6694367daf6ba1db929e54a975dfea8404fca40c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/20265
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-04-23 22:26:52 +00:00 committed by David Neto
parent 1a4d90667b
commit 8eddb78433
11 changed files with 275 additions and 138 deletions

View File

@ -93,7 +93,7 @@ class Expression : public Node {
/// @returns the expression as a unary op expression
const UnaryOpExpression* AsUnaryOp() const;
/// @returns the expression as an array accessor
/// @returns the expression as an array accessor
ArrayAccessorExpression* AsArrayAccessor();
/// @returns the expression as an as
AsExpression* AsAs();

View File

@ -36,6 +36,13 @@ Type::Type() = default;
Type::~Type() = default;
Type* Type::UnwrapPtrIfNeeded() {
if (IsPointer()) {
return AsPointer()->type();
}
return this;
}
bool Type::IsAlias() const {
return false;
}

View File

@ -66,6 +66,9 @@ class Type {
/// @returns the name for this type. The |type_name| is unique over all types.
virtual std::string type_name() const = 0;
/// @returns the pointee type if this is a pointer, |this| otherwise
Type* UnwrapPtrIfNeeded();
/// @returns true if this type is a float scalar
bool is_float_scalar();
/// @returns true if this type is a float matrix

View File

@ -1159,8 +1159,7 @@ TEST_F(SpvParserTest,
EXPECT_THAT(fe.block_order(), ElementsAre(10, 30, 20, 80, 99));
}
TEST_F(SpvParserTest,
ComputeBlockOrder_Switch_DefaultSameAsACase) {
TEST_F(SpvParserTest, ComputeBlockOrder_Switch_DefaultSameAsACase) {
auto* p = parser(test::Assemble(CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
@ -1373,8 +1372,7 @@ TEST_F(SpvParserTest,
<< assembly;
}
TEST_F(SpvParserTest,
ComputeBlockOrder_Nest_If_Contains_If) {
TEST_F(SpvParserTest, ComputeBlockOrder_Nest_If_Contains_If) {
auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
@ -1424,8 +1422,7 @@ TEST_F(SpvParserTest,
<< assembly;
}
TEST_F(SpvParserTest,
ComputeBlockOrder_Nest_If_In_SwitchCase) {
TEST_F(SpvParserTest, ComputeBlockOrder_Nest_If_In_SwitchCase) {
auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
@ -1475,8 +1472,7 @@ TEST_F(SpvParserTest,
<< assembly;
}
TEST_F(SpvParserTest,
ComputeBlockOrder_Nest_IfFallthrough_In_SwitchCase) {
TEST_F(SpvParserTest, ComputeBlockOrder_Nest_IfFallthrough_In_SwitchCase) {
auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
@ -1526,8 +1522,7 @@ TEST_F(SpvParserTest,
<< assembly;
}
TEST_F(SpvParserTest,
ComputeBlockOrder_Nest_IfBreak_In_SwitchCase) {
TEST_F(SpvParserTest, ComputeBlockOrder_Nest_IfBreak_In_SwitchCase) {
auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn

View File

@ -38,6 +38,7 @@
#include "src/ast/type/bool_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/pointer_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
@ -274,19 +275,30 @@ bool TypeDeterminer::DetermineArrayAccessor(
if (!DetermineResultType(expr->array())) {
return false;
}
auto* parent_type = expr->array()->result_type();
auto* res = expr->array()->result_type();
auto* parent_type = res->UnwrapPtrIfNeeded();
ast::type::Type* ret = nullptr;
if (parent_type->IsArray()) {
expr->set_result_type(parent_type->AsArray()->type());
ret = parent_type->AsArray()->type();
} else if (parent_type->IsVector()) {
expr->set_result_type(parent_type->AsVector()->type());
ret = parent_type->AsVector()->type();
} else if (parent_type->IsMatrix()) {
auto* m = parent_type->AsMatrix();
expr->set_result_type(ctx_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(m->type(), m->rows())));
ret = ctx_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(m->type(), m->rows()));
} else {
set_error(expr->source(), "invalid parent type in array accessor");
return false;
}
// If we're extracting from a pointer, we return a pointer.
if (res->IsPointer()) {
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, res->AsPointer()->storage_class()));
}
expr->set_result_type(ret);
return true;
}
@ -365,7 +377,15 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
auto name = expr->name();
ast::Variable* var;
if (variable_stack_.get(name, &var)) {
expr->set_result_type(var->type());
// A constant is the type, but a variable is always a pointer so synthesize
// the pointer around the variable type.
if (var->is_const()) {
expr->set_result_type(var->type());
} else {
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
var->type(), var->storage_class())));
}
return true;
}
@ -384,43 +404,52 @@ bool TypeDeterminer::DetermineMemberAccessor(
return false;
}
auto* data_type = expr->structure()->result_type();
auto* res = expr->structure()->result_type();
auto* data_type = res->UnwrapPtrIfNeeded();
ast::type::Type* ret = nullptr;
if (data_type->IsStruct()) {
auto* strct = data_type->AsStruct()->impl();
auto name = expr->member()->name();
for (const auto& member : strct->members()) {
if (member->name() != name) {
continue;
if (member->name() == name) {
ret = member->type();
break;
}
expr->set_result_type(member->type());
return true;
}
set_error(expr->source(), "struct member not found");
return false;
}
if (data_type->IsVector()) {
if (ret == nullptr) {
set_error(expr->source(), "struct member " + name + " not found");
return false;
}
} else if (data_type->IsVector()) {
auto* vec = data_type->AsVector();
auto size = expr->member()->name().size();
if (size == 1) {
// A single element swizzle is just the type of the vector.
expr->set_result_type(vec->type());
ret = vec->type();
} else {
// The vector will have a number of components equal to the length of the
// swizzle. This assumes the validator will check that the swizzle
// is correct.
expr->set_result_type(ctx_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(vec->type(), size)));
ret = ctx_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(vec->type(), size));
}
return true;
} else {
set_error(expr->source(),
"invalid type " + data_type->type_name() + " in member accessor");
return false;
}
set_error(expr->source(),
"invalid type " + data_type->type_name() + " in member accessor");
return false;
// If we're extracting from a pointer, we return a pointer.
if (res->IsPointer()) {
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, res->AsPointer()->storage_class()));
}
expr->set_result_type(ret);
return true;
}
bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
@ -432,7 +461,7 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
expr->IsShiftRight() || expr->IsShiftRightArith() || expr->IsAdd() ||
expr->IsSubtract() || expr->IsDivide() || expr->IsModulo()) {
expr->set_result_type(expr->lhs()->result_type());
expr->set_result_type(expr->lhs()->result_type()->UnwrapPtrIfNeeded());
return true;
}
// Result type is a scalar or vector of boolean type
@ -441,7 +470,7 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
auto* bool_type =
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
auto* param_type = expr->lhs()->result_type();
auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
if (param_type->IsVector()) {
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
@ -452,8 +481,8 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
return true;
}
if (expr->IsMultiply()) {
auto* lhs_type = expr->lhs()->result_type();
auto* rhs_type = expr->rhs()->result_type();
auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded();
// Note, the ordering here matters. The later checks depend on the prior
// checks having been done.
@ -504,7 +533,7 @@ bool TypeDeterminer::DetermineUnaryDerivative(
if (!DetermineResultType(expr->param())) {
return false;
}
expr->set_result_type(expr->param()->result_type());
expr->set_result_type(expr->param()->result_type()->UnwrapPtrIfNeeded());
return true;
}
@ -531,7 +560,7 @@ bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) {
auto* bool_type =
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
auto* param_type = expr->params()[0]->result_type();
auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
if (param_type->IsVector()) {
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
@ -552,8 +581,8 @@ bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) {
"incorrect number of parameters for outer product");
return false;
}
auto* param0_type = expr->params()[0]->result_type();
auto* param1_type = expr->params()[1]->result_type();
auto* param0_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
auto* param1_type = expr->params()[1]->result_type()->UnwrapPtrIfNeeded();
if (!param0_type->IsVector() || !param1_type->IsVector()) {
set_error(expr->source(), "invalid parameter type for outer product");
return false;
@ -574,7 +603,7 @@ bool TypeDeterminer::DetermineUnaryOp(ast::UnaryOpExpression* expr) {
if (!DetermineResultType(expr->expr())) {
return false;
}
expr->set_result_type(expr->expr()->result_type());
expr->set_result_type(expr->expr()->result_type()->UnwrapPtrIfNeeded());
return true;
}

View File

@ -46,6 +46,7 @@
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/pointer_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
@ -433,8 +434,33 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Array) {
auto idx = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 2));
auto var =
std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &ary);
auto var = std::make_unique<ast::Variable>(
"my_var", ast::StorageClass::kFunction, &ary);
mod()->AddGlobalVariable(std::move(var));
// Register the global
EXPECT_TRUE(td()->Determine());
ast::ArrayAccessorExpression acc(
std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx));
EXPECT_TRUE(td()->DetermineResultType(&acc));
ASSERT_NE(acc.result_type(), nullptr);
ASSERT_TRUE(acc.result_type()->IsPointer());
auto* ptr = acc.result_type()->AsPointer();
EXPECT_TRUE(ptr->type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Array_Constant) {
ast::type::I32Type i32;
ast::type::F32Type f32;
ast::type::ArrayType ary(&f32, 3);
auto idx = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 2));
auto var = std::make_unique<ast::Variable>(
"my_var", ast::StorageClass::kFunction, &ary);
var->set_is_const(true);
mod()->AddGlobalVariable(std::move(var));
// Register the global
@ -465,8 +491,11 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix) {
std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx));
EXPECT_TRUE(td()->DetermineResultType(&acc));
ASSERT_NE(acc.result_type(), nullptr);
ASSERT_TRUE(acc.result_type()->IsVector());
EXPECT_EQ(acc.result_type()->AsVector()->size(), 3u);
ASSERT_TRUE(acc.result_type()->IsPointer());
auto* ptr = acc.result_type()->AsPointer();
ASSERT_TRUE(ptr->type()->IsVector());
EXPECT_EQ(ptr->type()->AsVector()->size(), 3u);
}
TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix_BothDimensions) {
@ -493,7 +522,10 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix_BothDimensions) {
EXPECT_TRUE(td()->DetermineResultType(&acc));
ASSERT_NE(acc.result_type(), nullptr);
EXPECT_TRUE(acc.result_type()->IsF32());
ASSERT_TRUE(acc.result_type()->IsPointer());
auto* ptr = acc.result_type()->AsPointer();
EXPECT_TRUE(ptr->type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Vector) {
@ -514,7 +546,10 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Vector) {
std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx));
EXPECT_TRUE(td()->DetermineResultType(&acc));
ASSERT_NE(acc.result_type(), nullptr);
EXPECT_TRUE(acc.result_type()->IsF32());
ASSERT_TRUE(acc.result_type()->IsPointer());
auto* ptr = acc.result_type()->AsPointer();
EXPECT_TRUE(ptr->type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_As) {
@ -643,12 +678,55 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalVariable) {
// Register the global
EXPECT_TRUE(td()->Determine());
ast::IdentifierExpression ident("my_var");
EXPECT_TRUE(td()->DetermineResultType(&ident));
ASSERT_NE(ident.result_type(), nullptr);
EXPECT_TRUE(ident.result_type()->IsPointer());
EXPECT_TRUE(ident.result_type()->AsPointer()->type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalConstant) {
ast::type::F32Type f32;
auto var =
std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32);
var->set_is_const(true);
mod()->AddGlobalVariable(std::move(var));
// Register the global
EXPECT_TRUE(td()->Determine());
ast::IdentifierExpression ident("my_var");
EXPECT_TRUE(td()->DetermineResultType(&ident));
ASSERT_NE(ident.result_type(), nullptr);
EXPECT_TRUE(ident.result_type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable_Const) {
ast::type::F32Type f32;
auto my_var = std::make_unique<ast::IdentifierExpression>("my_var");
auto* my_var_ptr = my_var.get();
auto var =
std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32);
var->set_is_const(true);
ast::StatementList body;
body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::move(my_var),
std::make_unique<ast::IdentifierExpression>("my_var")));
ast::Function f("my_func", {}, &f32);
f.set_body(std::move(body));
EXPECT_TRUE(td()->DetermineFunction(&f));
ASSERT_NE(my_var_ptr->result_type(), nullptr);
EXPECT_TRUE(my_var_ptr->result_type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) {
ast::type::F32Type f32;
@ -670,7 +748,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) {
EXPECT_TRUE(td()->DetermineFunction(&f));
ASSERT_NE(my_var_ptr->result_type(), nullptr);
EXPECT_TRUE(my_var_ptr->result_type()->IsF32());
EXPECT_TRUE(my_var_ptr->result_type()->IsPointer());
EXPECT_TRUE(my_var_ptr->result_type()->AsPointer()->type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_Function) {
@ -720,7 +799,10 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) {
ast::MemberAccessorExpression mem(std::move(ident), std::move(mem_ident));
EXPECT_TRUE(td()->DetermineResultType(&mem));
ASSERT_NE(mem.result_type(), nullptr);
EXPECT_TRUE(mem.result_type()->IsF32());
ASSERT_TRUE(mem.result_type()->IsPointer());
auto* ptr = mem.result_type()->AsPointer();
EXPECT_TRUE(ptr->type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
@ -740,9 +822,12 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
ASSERT_NE(mem.result_type(), nullptr);
ASSERT_TRUE(mem.result_type()->IsVector());
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u);
ASSERT_TRUE(mem.result_type()->IsPointer());
auto* ptr = mem.result_type()->AsPointer();
ASSERT_TRUE(ptr->type()->IsVector());
EXPECT_TRUE(ptr->type()->AsVector()->type()->IsF32());
EXPECT_EQ(ptr->type()->AsVector()->size(), 2u);
}
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
@ -762,10 +847,13 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
ASSERT_NE(mem.result_type(), nullptr);
ASSERT_TRUE(mem.result_type()->IsF32());
ASSERT_TRUE(mem.result_type()->IsPointer());
auto* ptr = mem.result_type()->AsPointer();
ASSERT_TRUE(ptr->type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) {
// struct b {
// vec4<f32> foo
// }
@ -803,6 +891,7 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
auto strctB = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
std::move(b_members));
ast::type::StructType stB(std::move(strctB));
stB.set_name("B");
ast::type::VectorType vecB(&stB, 3);
@ -814,6 +903,7 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
std::move(a_members));
ast::type::StructType stA(std::move(strctA));
stA.set_name("A");
auto var =
std::make_unique<ast::Variable>("c", ast::StorageClass::kNone, &stA);
@ -838,10 +928,14 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
std::move(foo_ident)),
std::move(swizzle));
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
ASSERT_NE(mem.result_type(), nullptr);
ASSERT_TRUE(mem.result_type()->IsVector());
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u);
ASSERT_TRUE(mem.result_type()->IsPointer());
auto* ptr = mem.result_type()->AsPointer();
ASSERT_TRUE(ptr->type()->IsVector());
EXPECT_TRUE(ptr->type()->AsVector()->type()->IsF32());
EXPECT_EQ(ptr->type()->AsVector()->size(), 2u);
}
using Expr_Binary_BitwiseTest = TypeDeterminerTestWithParam<ast::BinaryOp>;

View File

@ -226,34 +226,6 @@ uint32_t Builder::GenerateExpression(ast::Expression* expr) {
return 0;
}
uint32_t Builder::GenerateExpressionAndLoad(ast::Expression* expr) {
auto id = GenerateExpression(expr);
if (id == 0) {
return false;
}
// Only need to load identifiers
if (!expr->IsIdentifier()) {
return id;
}
if (spirv_id_to_variable_.find(id) == spirv_id_to_variable_.end()) {
error_ = "missing generated ID for variable";
return 0;
}
auto* var = spirv_id_to_variable_[id];
if (var->is_const()) {
return id;
}
auto type_id = GenerateTypeIfNeeded(expr->result_type());
auto result = result_op();
auto result_id = result.to_i();
push_function_inst(spv::Op::OpLoad,
{Operand::Int(type_id), result, Operand::Int(id)});
return result_id;
}
bool Builder::GenerateFunction(ast::Function* func) {
uint32_t func_type_id = GenerateFunctionTypeIfNeeded(func);
if (func_type_id == 0) {
@ -468,7 +440,8 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
auto* mem_accessor = source->AsMemberAccessor();
source = mem_accessor->structure();
auto* data_type = mem_accessor->structure()->result_type();
auto* data_type =
mem_accessor->structure()->result_type()->UnwrapPtrIfNeeded();
if (data_type->IsStruct()) {
auto* strct = data_type->AsStruct()->impl();
auto name = mem_accessor->member()->name();
@ -476,11 +449,9 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
uint32_t i = 0;
for (; i < strct->members().size(); ++i) {
const auto& member = strct->members()[i];
if (member->name() != name) {
continue;
if (member->name() == name) {
break;
}
break;
}
ast::type::U32Type u32;
@ -507,9 +478,7 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
return 0;
}
// The access chain results in a pointer, so wrap the return type.
ast::type::PointerType ptr(expr->result_type(), ast::StorageClass::kFunction);
auto type_id = GenerateTypeIfNeeded(&ptr);
auto type_id = GenerateTypeIfNeeded(expr->result_type());
if (type_id == 0) {
return 0;
}
@ -535,12 +504,23 @@ uint32_t Builder::GenerateIdentifierExpression(
if (val == 0) {
error_ = "unable to lookup: " + expr->name() + " in " + expr->path();
}
} else if (!scope_stack_.get(expr->name(), &val)) {
error_ = "unable to find name for identifier: " + expr->name();
return 0;
return val;
}
if (scope_stack_.get(expr->name(), &val)) {
return val;
}
return val;
error_ = "unable to find name for identifier: " + expr->name();
return 0;
}
uint32_t Builder::GenerateLoad(ast::type::Type* type, uint32_t id) {
auto type_id = GenerateTypeIfNeeded(type->UnwrapPtrIfNeeded());
auto result = result_op();
auto result_id = result.to_i();
push_function_inst(spv::Op::OpLoad,
{Operand::Int(type_id), result, Operand::Int(id)});
return result_id;
}
uint32_t Builder::GenerateUnaryOpExpression(ast::UnaryOpExpression* expr) {
@ -686,14 +666,21 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) {
}
uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
auto lhs_id = GenerateExpressionAndLoad(expr->lhs());
auto lhs_id = GenerateExpression(expr->lhs());
if (lhs_id == 0) {
return 0;
}
auto rhs_id = GenerateExpressionAndLoad(expr->rhs());
if (expr->lhs()->result_type()->IsPointer()) {
lhs_id = GenerateLoad(expr->lhs()->result_type(), lhs_id);
}
auto rhs_id = GenerateExpression(expr->rhs());
if (rhs_id == 0) {
return 0;
}
if (expr->rhs()->result_type()->IsPointer()) {
rhs_id = GenerateLoad(expr->rhs()->result_type(), rhs_id);
}
auto result = result_op();
auto result_id = result.to_i();
@ -705,8 +692,8 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
// Handle int and float and the vectors of those types. Other types
// should have been rejected by validation.
auto* lhs_type = expr->lhs()->result_type();
auto* rhs_type = expr->rhs()->result_type();
auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded();
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
@ -819,6 +806,7 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
} else if (expr->IsXor()) {
op = spv::Op::OpBitwiseXor;
} else {
error_ = "unknown binary expression";
return 0;
}

View File

@ -157,10 +157,6 @@ class Builder {
/// @param expr the expression to generate
/// @returns the resulting ID of the expression or 0 on error
uint32_t GenerateExpression(ast::Expression* expr);
/// Generates an expression and emits a load if necessary
/// @param expr the expression
/// @returns the SPIR-V result id
uint32_t GenerateExpressionAndLoad(ast::Expression* expr);
/// Generates the instructions for a function
/// @param func the function to generate
/// @returns true if the instructions were generated
@ -240,6 +236,11 @@ class Builder {
/// @param list the statement list to generate
/// @returns true on successful generation
bool GenerateStatementList(const ast::StatementList& list);
/// Geneates an OpLoad
/// @param type the type to load
/// @param id the variable id to load
/// @returns the ID of the loaded value
uint32_t GenerateLoad(ast::type::Type* type, uint32_t id);
/// Geneates an OpStore
/// @param to the ID to store too
/// @param from the ID to store from

View File

@ -132,6 +132,13 @@ TEST_F(BuilderTest, ArrayAccessor_MultiLevel) {
TEST_F(BuilderTest, MemberAccessor) {
ast::type::F32Type f32;
// my_struct {
// a : f32
// b : f32
// }
// var ident : my_struct
// ident.b
ast::StructMemberDecorationList decos;
ast::StructMemberList members;
members.push_back(
@ -180,6 +187,15 @@ TEST_F(BuilderTest, MemberAccessor) {
TEST_F(BuilderTest, MemberAccessor_Nested) {
ast::type::F32Type f32;
// inner_struct {
// a : f32
// }
// my_struct {
// inner : inner_struct
// }
//
// var ident : my_struct
// ident.inner.a
ast::StructMemberDecorationList decos;
ast::StructMemberList inner_members;
inner_members.push_back(

View File

@ -21,6 +21,8 @@
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/context.h"
#include "src/type_determiner.h"
#include "src/writer/spirv/builder.h"
#include "src/writer/spirv/spv_dump.h"
@ -37,18 +39,23 @@ TEST_F(BuilderTest, Assign_Var) {
ast::Variable v("var", ast::StorageClass::kOutput, &f32);
ast::Module mod;
Builder b(&mod);
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
ASSERT_FALSE(b.has_error()) << b.error();
auto ident = std::make_unique<ast::IdentifierExpression>("var");
auto val = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f));
ast::AssignmentStatement assign(std::move(ident), std::move(val));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&v);
ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error();
Builder b(&mod);
b.push_function(Function{});
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error();
EXPECT_FALSE(b.has_error());

View File

@ -56,7 +56,11 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalConst) {
v.set_constructor(std::move(init));
v.set_is_const(true);
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&v);
Builder b(&mod);
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
ASSERT_FALSE(b.has_error()) << b.error();
@ -69,6 +73,8 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalConst) {
)");
ast::IdentifierExpression expr("var");
ASSERT_TRUE(td.DetermineResultType(&expr));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 5u);
}
@ -76,8 +82,13 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalVar) {
ast::type::F32Type f32;
ast::Variable v("var", ast::StorageClass::kOutput, &f32);
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&v);
Builder b(&mod);
b.push_function(Function{});
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "var"
)");
@ -87,6 +98,7 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalVar) {
)");
ast::IdentifierExpression expr("var");
ASSERT_TRUE(td.DetermineResultType(&expr));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 1u);
}
@ -109,7 +121,11 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionConst) {
v.set_constructor(std::move(init));
v.set_is_const(true);
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&v);
Builder b(&mod);
EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error();
ASSERT_FALSE(b.has_error()) << b.error();
@ -122,6 +138,7 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionConst) {
)");
ast::IdentifierExpression expr("var");
ASSERT_TRUE(td.DetermineResultType(&expr));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 5u);
}
@ -129,7 +146,11 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionVar) {
ast::type::F32Type f32;
ast::Variable v("var", ast::StorageClass::kNone, &f32);
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&v);
Builder b(&mod);
b.push_function(Function{});
EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error();
@ -144,6 +165,7 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionVar) {
)");
ast::IdentifierExpression expr("var");
ASSERT_TRUE(td.DetermineResultType(&expr));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 1u);
}
@ -170,7 +192,7 @@ TEST_F(BuilderTest, IdentifierExpression_Load) {
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(&var)) << b.error();
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 6u) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 6u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
%2 = OpTypePointer Private %3
%1 = OpVariable %2 Private
@ -217,31 +239,6 @@ TEST_F(BuilderTest, IdentifierExpression_NoLoadConst) {
)");
}
TEST_F(BuilderTest, IdentifierExpression_ImportMethod) {
auto imp = std::make_unique<ast::Import>("GLSL.std.450", "std");
imp->AddMethodId("round", 42u);
ast::Module mod;
mod.AddImport(std::move(imp));
Builder b(&mod);
ast::IdentifierExpression expr(std::vector<std::string>({"std", "round"}));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 42u) << b.error();
}
TEST_F(BuilderTest, IdentifierExpression_ImportMethod_NotFound) {
auto imp = std::make_unique<ast::Import>("GLSL.std.450", "std");
ast::Module mod;
mod.AddImport(std::move(imp));
Builder b(&mod);
ast::IdentifierExpression expr(std::vector<std::string>({"std", "ceil"}));
EXPECT_EQ(b.GenerateIdentifierExpression(&expr), 0u);
ASSERT_TRUE(b.has_error());
EXPECT_EQ(b.error(), "unable to lookup: ceil in std");
}
} // namespace
} // namespace spirv
} // namespace writer