[spirv-reader] Support access chain

Bug: tint:3
Change-Id: Ibdb6698c4a97ce66ed533a9bf007bc352a09244e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/21641
Reviewed-by: dan sinclair <dsinclair@google.com>
This commit is contained in:
David Neto 2020-05-20 20:36:18 +00:00
parent 91c5a496d2
commit 7e5e02f805
8 changed files with 645 additions and 26 deletions

View File

@ -23,10 +23,12 @@
#include "source/opt/function.h"
#include "source/opt/instruction.h"
#include "source/opt/module.h"
#include "src/ast/array_accessor_expression.h"
#include "src/ast/as_expression.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/storage_class.h"
#include "src/ast/uint_literal.h"
@ -1492,32 +1494,31 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
return Fail() << "unhandled instruction with opcode " << inst.opcode();
}
TypedExpression FunctionEmitter::MakeOperand(
const spvtools::opt::Instruction& inst,
uint32_t operand_index) {
auto expr = this->MakeExpression(inst.GetSingleWordInOperand(operand_index));
return parser_impl_.RectifyOperandSignedness(inst.opcode(), std::move(expr));
}
TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
const spvtools::opt::Instruction& inst) {
if (inst.result_id() == 0) {
return {};
}
// TODO(dneto): Fill in the following cases.
auto operand = [this, &inst](uint32_t operand_index) {
auto expr =
this->MakeExpression(inst.GetSingleWordInOperand(operand_index));
return parser_impl_.RectifyOperandSignedness(inst.opcode(),
std::move(expr));
};
const auto opcode = inst.opcode();
ast::type::Type* ast_type =
inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr;
auto binary_op = ConvertBinaryOp(inst.opcode());
auto binary_op = ConvertBinaryOp(opcode);
if (binary_op != ast::BinaryOp::kNone) {
auto arg0 = operand(0);
auto arg1 = operand(1);
auto arg0 = MakeOperand(inst, 0);
auto arg1 = MakeOperand(inst, 1);
auto binary_expr = std::make_unique<ast::BinaryExpression>(
binary_op, std::move(arg0.expr), std::move(arg1.expr));
auto* forced_result_ty =
parser_impl_.ForcedResultType(inst.opcode(), arg0.type);
auto* forced_result_ty = parser_impl_.ForcedResultType(opcode, arg0.type);
if (forced_result_ty && forced_result_ty != ast_type) {
return {ast_type, std::make_unique<ast::AsExpression>(
ast_type, std::move(binary_expr))};
@ -1526,12 +1527,11 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
}
auto unary_op = ast::UnaryOp::kNegation;
if (GetUnaryOp(inst.opcode(), &unary_op)) {
auto arg0 = operand(0);
if (GetUnaryOp(opcode, &unary_op)) {
auto arg0 = MakeOperand(inst, 0);
auto unary_expr = std::make_unique<ast::UnaryOpExpression>(
unary_op, std::move(arg0.expr));
auto* forced_result_ty =
parser_impl_.ForcedResultType(inst.opcode(), arg0.type);
auto* forced_result_ty = parser_impl_.ForcedResultType(opcode, arg0.type);
if (forced_result_ty && forced_result_ty != ast_type) {
return {ast_type, std::make_unique<ast::AsExpression>(
ast_type, std::move(unary_expr))};
@ -1539,16 +1539,19 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
return {ast_type, std::move(unary_expr)};
}
if (inst.opcode() == SpvOpBitcast) {
auto* target_ty = parser_impl_.ConvertType(inst.type_id());
return {target_ty,
std::make_unique<ast::AsExpression>(target_ty, operand(0).expr)};
if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) {
return MakeAccessChain(inst);
}
auto negated_op = NegatedFloatCompare(inst.opcode());
if (opcode == SpvOpBitcast) {
return {ast_type, std::make_unique<ast::AsExpression>(
ast_type, MakeOperand(inst, 0).expr)};
}
auto negated_op = NegatedFloatCompare(opcode);
if (negated_op != ast::BinaryOp::kNone) {
auto arg0 = operand(0);
auto arg1 = operand(1);
auto arg0 = MakeOperand(inst, 0);
auto arg1 = MakeOperand(inst, 1);
auto binary_expr = std::make_unique<ast::BinaryExpression>(
negated_op, std::move(arg0.expr), std::move(arg1.expr));
auto negated_expr = std::make_unique<ast::UnaryOpExpression>(
@ -1578,8 +1581,6 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
// OpGenericCastToPtr // Not in Vulkan
// OpGenericCastToPtrExplicit // Not in Vulkan
//
// OpAccessChain
// OpInBoundsAccessChain
// OpArrayLength
// OpVectorExtractDynamic
// OpVectorInsertDynamic
@ -1589,6 +1590,130 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
return {};
}
TypedExpression FunctionEmitter::MakeAccessChain(
const spvtools::opt::Instruction& inst) {
if (inst.NumInOperands() < 1) {
// Binary parsing will fail on this anyway.
Fail() << "invalid access chain: has no input operands";
return {};
}
// A SPIR-V access chain is a single instruction with multiple indices
// walking down into composites. The Tint AST represents this as ever-deeper
// nested indexing expresions.
// Start off with an expression for the base, and then bury that inside
// nested indexing expressions.
TypedExpression current_expr(MakeOperand(inst, 0));
const auto constants = constant_mgr_->GetOperandConstants(&inst);
static const char* swizzles[] = {"x", "y", "z", "w"};
const auto base_id = inst.GetSingleWordInOperand(0);
const auto ptr_ty_id = def_use_mgr_->GetDef(base_id)->type_id();
const auto* ptr_type = type_mgr_->GetType(ptr_ty_id);
if (!ptr_type || !ptr_type->AsPointer()) {
Fail() << "Access chain %" << inst.result_id()
<< " base pointer is not of pointer type";
return {};
}
const auto* pointee_type = ptr_type->AsPointer()->pointee_type();
const auto num_in_operands = inst.NumInOperands();
for (uint32_t index = 1; index < num_in_operands; ++index) {
const auto* index_const =
constants[index] ? constants[index]->AsIntConstant() : nullptr;
const int64_t index_const_val =
index_const ? index_const->GetSignExtendedValue() : 0;
std::unique_ptr<ast::Expression> next_expr;
switch (pointee_type->kind()) {
case spvtools::opt::analysis::Type::kVector:
if (index_const) {
// Try generating a MemberAccessor expression.
if (index_const_val < 0 ||
pointee_type->AsVector()->element_count() <= index_const_val) {
Fail() << "Access chain %" << inst.result_id() << " index %"
<< inst.GetSingleWordInOperand(index) << " value "
<< index_const_val
<< " is out of bounds for vector of "
<< pointee_type->AsVector()->element_count()
<< " elements";
return {};
}
if (uint64_t(index_const_val) >=
sizeof(swizzles) / sizeof(swizzles[0])) {
Fail() << "internal error: swizzle index " << index_const_val
<< " is too big. Max handled index is "
<< ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
}
auto letter_index = std::make_unique<ast::IdentifierExpression>(
swizzles[index_const_val]);
next_expr = std::make_unique<ast::MemberAccessorExpression>(
std::move(current_expr.expr), std::move(letter_index));
} else {
// Non-constant index. Use array syntax
next_expr = std::make_unique<ast::ArrayAccessorExpression>(
std::move(current_expr.expr),
std::move(MakeOperand(inst, index).expr));
}
pointee_type = pointee_type->AsVector()->element_type();
break;
case spvtools::opt::analysis::Type::kMatrix:
// Use array syntax.
next_expr = std::make_unique<ast::ArrayAccessorExpression>(
std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr));
pointee_type = pointee_type->AsMatrix()->element_type();
break;
case spvtools::opt::analysis::Type::kArray:
next_expr = std::make_unique<ast::ArrayAccessorExpression>(
std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr));
pointee_type = pointee_type->AsArray()->element_type();
break;
case spvtools::opt::analysis::Type::kRuntimeArray:
next_expr = std::make_unique<ast::ArrayAccessorExpression>(
std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr));
pointee_type = pointee_type->AsRuntimeArray()->element_type();
break;
case spvtools::opt::analysis::Type::kStruct: {
if (!index_const) {
Fail() << "Access chain %" << inst.result_id() << " index %"
<< inst.GetSingleWordInOperand(index)
<< " is a non-constant index into a structure %"
<< type_mgr_->GetId(pointee_type);
return {};
}
if ((index_const_val < 0) ||
pointee_type->AsStruct()->element_types().size() <=
uint64_t(index_const_val)) {
Fail() << "Access chain %" << inst.result_id()
<< " index value " << index_const_val
<< " is out of bounds for structure %"
<< type_mgr_->GetId(pointee_type) << " having "
<< pointee_type->AsStruct()->element_types().size()
<< " elements";
return {};
}
auto member_access =
std::make_unique<ast::IdentifierExpression>(namer_.GetMemberName(
type_mgr_->GetId(pointee_type), uint32_t(index_const_val)));
next_expr = std::make_unique<ast::MemberAccessorExpression>(
std::move(current_expr.expr), std::move(member_access));
pointee_type =
pointee_type->AsStruct()->element_types()[index_const_val];
break;
}
default:
Fail() << "Access chain with unknown pointee type %"
<< type_mgr_->GetId(pointee_type) << " "
<< pointee_type->str();
return {};
}
current_expr.reset(TypedExpression(
parser_impl_.ConvertType(type_mgr_->GetId(pointee_type)),
std::move(next_expr)));
}
return current_expr;
}
} // namespace spirv
} // namespace reader
} // namespace tint

View File

@ -316,6 +316,21 @@ class FunctionEmitter {
ast::type::Type* GetVariableStoreType(
const spvtools::opt::Instruction& var_decl_inst);
/// Returns an expression for an instruction operand. Signedness conversion is
/// performed to match the result type of the SPIR-V instruction.
/// @param inst the SPIR-V instruction
/// @param operand_index the index of the operand, counting 0 as the first
/// input operand
/// @returns a new expression node
TypedExpression MakeOperand(const spvtools::opt::Instruction& inst,
uint32_t operand_index);
/// Returns an expression for a SPIR-V OpAccessChain or OpInBoundsAccessChain
/// instruction.
/// @param inst the SPIR-V instruction
/// @returns an expression
TypedExpression MakeAccessChain(const spvtools::opt::Instruction& inst);
/// Finds the header block for a structured construct that we can "break"
/// out from, from deeply nested control flow, if such a block exists.
/// If the construct is:

View File

@ -26,6 +26,7 @@ namespace reader {
namespace spirv {
namespace {
using ::testing::Eq;
using ::testing::HasSubstr;
TEST_F(SpvParserTest, EmitStatement_StoreBoolConst) {
@ -279,6 +280,434 @@ TEST_F(SpvParserTest, EmitStatement_StoreToModuleScopeVar) {
})"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_NoOperands) {
auto err = test::AssembleFailure(R"(
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%ty = OpTypeInt 32 0
%val = OpConstant %ty 42
%ptr_ty = OpTypePointer Workgroup %ty
%1 = OpVariable %ptr_ty Workgroup
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%2 = OpAccessChain %ptr_ty ; Needs a base operand
OpStore %1 %val
OpReturn
)");
EXPECT_THAT(err,
Eq("11:5: Expected operand, found next instruction instead."));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_BaseIsNotPointer) {
auto* p = parser(test::Assemble(R"(
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%10 = OpTypeInt 32 0
%val = OpConstant %10 42
%ptr_ty = OpTypePointer Workgroup %10
%20 = OpVariable %10 Workgroup ; bad pointer type
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpAccessChain %ptr_ty %20
OpStore %1 %val
OpReturn
)"));
EXPECT_FALSE(p->BuildAndParseInternalModuleExceptFunctions());
EXPECT_THAT(p->error(), Eq("variable with ID 20 has non-pointer type 10"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_VectorSwizzle) {
const std::string assembly = R"(
OpName %1 "myvar"
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%uint = OpTypeInt 32 0
%store_ty = OpTypeVector %uint 4
%uint_2 = OpConstant %uint 2
%uint_42 = OpConstant %uint 42
%elem_ty = OpTypePointer Workgroup %uint
%var_ty = OpTypePointer Workgroup %store_ty
%1 = OpVariable %var_ty Workgroup
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%2 = OpAccessChain %elem_ty %1 %uint_2
OpStore %2 %uint_42
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody());
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
MemberAccessor{
Identifier{myvar}
Identifier{z}
}
ScalarConstructor{42}
})"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_VectorConstOutOfBounds) {
const std::string assembly = R"(
OpName %1 "myvar"
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%uint = OpTypeInt 32 0
%store_ty = OpTypeVector %uint 4
%42 = OpConstant %uint 42
%uint_99 = OpConstant %uint 99
%elem_ty = OpTypePointer Workgroup %uint
%var_ty = OpTypePointer Workgroup %store_ty
%1 = OpVariable %var_ty Workgroup
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%2 = OpAccessChain %elem_ty %1 %42
OpStore %2 %uint_99
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_FALSE(fe.EmitBody());
EXPECT_THAT(p->error(), Eq("Access chain %2 index %42 value 42 is out of "
"bounds for vector of 4 elements"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_VectorNonConstIndex) {
const std::string assembly = R"(
OpName %1 "myvar"
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%uint = OpTypeInt 32 0
%store_ty = OpTypeVector %uint 4
%uint_2 = OpConstant %uint 2
%uint_42 = OpConstant %uint 42
%elem_ty = OpTypePointer Workgroup %uint
%var_ty = OpTypePointer Workgroup %store_ty
%1 = OpVariable %var_ty Workgroup
%10 = OpVariable %var_ty Workgroup
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%11 = OpLoad %uint %10
%2 = OpAccessChain %elem_ty %1 %11
OpStore %2 %uint_42
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody());
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
ArrayAccessor{
Identifier{myvar}
Identifier{x_11}
}
ScalarConstructor{42}
})"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_Matrix) {
const std::string assembly = R"(
OpName %1 "myvar"
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%m3v4float = OpTypeMatrix %v4float 3
%elem_ty = OpTypePointer Workgroup %v4float
%var_ty = OpTypePointer Workgroup %m3v4float
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%float_42 = OpConstant %float 42
%v4float_42 = OpConstantComposite %v4float %float_42 %float_42 %float_42 %float_42
%1 = OpVariable %var_ty Workgroup
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%2 = OpAccessChain %elem_ty %1 %uint_2
OpStore %2 %v4float_42
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody());
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
ArrayAccessor{
Identifier{myvar}
ScalarConstructor{2}
}
TypeConstructor{
__vec_4__f32
ScalarConstructor{42.000000}
ScalarConstructor{42.000000}
ScalarConstructor{42.000000}
ScalarConstructor{42.000000}
}
})"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_Array) {
const std::string assembly = R"(
OpName %1 "myvar"
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%m3v4float = OpTypeMatrix %v4float 3
%elem_ty = OpTypePointer Workgroup %v4float
%var_ty = OpTypePointer Workgroup %m3v4float
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%float_42 = OpConstant %float 42
%v4float_42 = OpConstantComposite %v4float %float_42 %float_42 %float_42 %float_42
%1 = OpVariable %var_ty Workgroup
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%2 = OpAccessChain %elem_ty %1 %uint_2
OpStore %2 %v4float_42
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody());
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
ArrayAccessor{
Identifier{myvar}
ScalarConstructor{2}
}
TypeConstructor{
__vec_4__f32
ScalarConstructor{42.000000}
ScalarConstructor{42.000000}
ScalarConstructor{42.000000}
ScalarConstructor{42.000000}
}
})"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_Struct) {
const std::string assembly = R"(
OpName %1 "myvar"
OpMemberName %strct 1 "age"
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%float = OpTypeFloat 32
%float_42 = OpConstant %float 42
%strct = OpTypeStruct %float %float
%elem_ty = OpTypePointer Workgroup %float
%var_ty = OpTypePointer Workgroup %strct
%uint = OpTypeInt 32 0
%uint_1 = OpConstant %uint 1
%1 = OpVariable %var_ty Workgroup
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%2 = OpAccessChain %elem_ty %1 %uint_1
OpStore %2 %float_42
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody());
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
MemberAccessor{
Identifier{myvar}
Identifier{age}
}
ScalarConstructor{42.000000}
})"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_StructNonConstIndex) {
const std::string assembly = R"(
OpName %1 "myvar"
OpMemberName %55 1 "age"
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%float = OpTypeFloat 32
%float_42 = OpConstant %float 42
%55 = OpTypeStruct %float %float
%elem_ty = OpTypePointer Workgroup %float
%var_ty = OpTypePointer Workgroup %55
%uint = OpTypeInt 32 0
%uint_1 = OpConstant %uint 1
%uint_ptr = OpTypePointer Workgroup %uint
%uintvar = OpVariable %uint_ptr Workgroup
%1 = OpVariable %var_ty Workgroup
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%10 = OpLoad %uint %uintvar
%2 = OpAccessChain %elem_ty %1 %10
OpStore %2 %float_42
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_FALSE(fe.EmitBody());
EXPECT_THAT(p->error(), Eq("Access chain %2 index %10 is a non-constant "
"index into a structure %55"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_StructConstOutOfBounds) {
const std::string assembly = R"(
OpName %1 "myvar"
OpMemberName %55 1 "age"
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%float = OpTypeFloat 32
%float_42 = OpConstant %float 42
%55 = OpTypeStruct %float %float
%elem_ty = OpTypePointer Workgroup %float
%var_ty = OpTypePointer Workgroup %55
%uint = OpTypeInt 32 0
%uint_99 = OpConstant %uint 99
%1 = OpVariable %var_ty Workgroup
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%2 = OpAccessChain %elem_ty %1 %uint_99
OpStore %2 %float_42
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_FALSE(fe.EmitBody());
EXPECT_THAT(p->error(), Eq("Access chain %2 index value 99 is out of bounds "
"for structure %55 having 2 elements"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_Struct_RuntimeArray) {
const std::string assembly = R"(
OpName %1 "myvar"
OpMemberName %strct 1 "age"
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%float = OpTypeFloat 32
%float_42 = OpConstant %float 42
%rtarr = OpTypeRuntimeArray %float
%strct = OpTypeStruct %float %rtarr
%elem_ty = OpTypePointer Workgroup %float
%var_ty = OpTypePointer Workgroup %strct
%uint = OpTypeInt 32 0
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%1 = OpVariable %var_ty Workgroup
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%2 = OpAccessChain %elem_ty %1 %uint_1 %uint_2
OpStore %2 %float_42
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody());
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
ArrayAccessor{
MemberAccessor{
Identifier{myvar}
Identifier{age}
}
ScalarConstructor{2}
}
ScalarConstructor{42.000000}
})"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_Compound_Matrix_Vector) {
const std::string assembly = R"(
OpName %1 "myvar"
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%m3v4float = OpTypeMatrix %v4float 3
%elem_ty = OpTypePointer Workgroup %float
%var_ty = OpTypePointer Workgroup %m3v4float
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%uint_3 = OpConstant %uint 3
%float_42 = OpConstant %float 42
%1 = OpVariable %var_ty Workgroup
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%2 = OpAccessChain %elem_ty %1 %uint_2 %uint_3
OpStore %2 %float_42
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody());
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
MemberAccessor{
ArrayAccessor{
Identifier{myvar}
ScalarConstructor{2}
}
Identifier{w}
}
ScalarConstructor{42.000000}
})"));
}
TEST_F(SpvParserTest, EmitStatement_AccessChain_InvalidPointeeType) {
const std::string assembly = R"(
OpName %1 "myvar"
%55 = OpTypeVoid
%voidfn = OpTypeFunction %55
%float = OpTypeFloat 32
%60 = OpTypePointer Workgroup %55
%var_ty = OpTypePointer Workgroup %60
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%1 = OpVariable %var_ty Workgroup
%100 = OpFunction %55 None %voidfn
%entry = OpLabel
%2 = OpAccessChain %60 %1 %uint_2
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_FALSE(fe.EmitBody());
EXPECT_THAT(p->error(),
HasSubstr("Access chain with unknown pointee type %60 void"));
}
} // namespace
} // namespace spirv
} // namespace reader

View File

@ -183,6 +183,11 @@ TypedExpression::TypedExpression(TypedExpression&& other)
TypedExpression::~TypedExpression() {}
void TypedExpression::reset(TypedExpression&& other) {
type = other.type;
expr = std::move(other.expr);
}
ParserImpl::ParserImpl(Context* ctx, const std::vector<uint32_t>& spv_binary)
: Reader(ctx),
spv_binary_(spv_binary),
@ -786,6 +791,10 @@ bool ParserImpl::EmitModuleScopeVariables() {
"SPIR-V type with ID: "
<< var.type_id();
}
if (!ast_type->IsPointer()) {
return Fail() << "variable with ID " << var.result_id()
<< " has non-pointer type " << var.type_id();
}
auto* ast_store_type = ast_type->AsPointer()->type();
auto ast_var =
MakeVariable(var.result_id(), ast_storage_class, ast_store_type);

View File

@ -65,6 +65,9 @@ struct TypedExpression {
TypedExpression(TypedExpression&& other);
/// Destructor
~TypedExpression();
/// Takes values from another typed expression.
/// @param other the other typed expression
void reset(TypedExpression&& other);
/// The type
ast::type::Type* type;
/// The expression

View File

@ -99,6 +99,20 @@ TEST_F(SpvParserTest, ModuleScopeVar_BadPointerType) {
"AST type for SPIR-V type with ID: 3"));
}
TEST_F(SpvParserTest, ModuleScopeVar_NonPointerType) {
auto* p = parser(test::Assemble(R"(
%float = OpTypeFloat 32
%5 = OpTypeFunction %float
%3 = OpTypePointer Private %5
%52 = OpVariable %float Private
)"));
EXPECT_TRUE(p->BuildInternalModule());
EXPECT_FALSE(p->RegisterTypes());
EXPECT_THAT(
p->error(),
HasSubstr("SPIR-V pointer type with ID 3 has invalid pointee type 5"));
}
TEST_F(SpvParserTest, ModuleScopeVar_AnonWorkgroupVar) {
auto* p = parser(test::Assemble(R"(
%float = OpTypeFloat 32

View File

@ -49,6 +49,26 @@ std::vector<uint32_t> Assemble(const std::string& spirv_assembly) {
return result;
}
std::string AssembleFailure(const std::string& spirv_assembly) {
// TODO(dneto): Use ScopedTrace?
// (The target environment doesn't affect assembly.
spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
std::stringstream errors;
std::vector<uint32_t> result;
tools.SetMessageConsumer([&errors](spv_message_level_t, const char*,
const spv_position_t& position,
const char* message) {
errors << position.line << ":" << position.column << ": " << message;
});
const auto success = tools.Assemble(
spirv_assembly, &result, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
EXPECT_FALSE(success);
return errors.str();
}
} // namespace test
} // namespace spirv
} // namespace reader

View File

@ -28,6 +28,10 @@ namespace test {
/// are preserved.
std::vector<uint32_t> Assemble(const std::string& spirv_assembly);
/// Attempts to assemble given SPIR-V assembly text. Expect it to fail.
/// @returns the failure message.
std::string AssembleFailure(const std::string& spirv_assembly);
} // namespace test
} // namespace spirv
} // namespace reader