spirv-reader: refactor swizzle creation

Change-Id: I6a09756026b7cbc436d5f232be9331255615e8c3
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/34040
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: David Neto <dneto@google.com>
This commit is contained in:
David Neto 2020-11-26 17:45:52 +00:00 committed by Commit Bot service account
parent 3ec71a63aa
commit 61ec48b99e
3 changed files with 79 additions and 26 deletions

View File

@ -172,6 +172,8 @@ namespace spirv {
namespace {
constexpr uint32_t kMaxVectorLen = 4;
// Gets the AST unary opcode for the given SPIR-V opcode, if any
// @param opcode SPIR-V opcode
// @param ast_unary_op return parameter
@ -2874,6 +2876,16 @@ TypedExpression FunctionEmitter::EmitGlslStd450ExtInst(
return {ast_type, call};
}
ast::IdentifierExpression* FunctionEmitter::Swizzle(uint32_t i) {
if (i >= kMaxVectorLen) {
Fail() << "vector component index is larger than " << kMaxVectorLen - 1
<< ": " << i;
return nullptr;
}
const char* names[] = {"x", "y", "z", "w"};
return ast_module_.create<ast::IdentifierExpression>(names[i & 3]);
}
TypedExpression FunctionEmitter::MakeAccessChain(
const spvtools::opt::Instruction& inst) {
if (inst.NumInOperands() < 1) {
@ -2888,7 +2900,6 @@ TypedExpression FunctionEmitter::MakeAccessChain(
// for the base, and then bury that inside nested indexing expressions.
TypedExpression current_expr(MakeOperand(inst, 0));
const auto constants = constant_mgr_->GetOperandConstants(&inst);
static const char* swizzles[] = {"x", "y", "z", "w"};
const auto base_id = inst.GetSingleWordInOperand(0);
auto ptr_ty_id = def_use_mgr_->GetDef(base_id)->type_id();
@ -2981,16 +2992,12 @@ TypedExpression FunctionEmitter::MakeAccessChain(
<< num_elems << " elements";
return {};
}
if (uint64_t(index_const_val) >=
sizeof(swizzles) / sizeof(swizzles[0])) {
if (uint64_t(index_const_val) >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_const_val
<< " is too big. Max handled index is "
<< ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
}
auto* letter_index =
create<ast::IdentifierExpression>(swizzles[index_const_val]);
next_expr = create<ast::MemberAccessorExpression>(current_expr.expr,
letter_index);
next_expr = create<ast::MemberAccessorExpression>(
current_expr.expr, Swizzle(uint32_t(index_const_val)));
} else {
// Non-constant index. Use array syntax
next_expr = create<ast::ArrayAccessorExpression>(
@ -3072,7 +3079,6 @@ TypedExpression FunctionEmitter::MakeCompositeExtract(
return create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, literal));
};
static const char* swizzles[] = {"x", "y", "z", "w"};
const auto composite = inst.GetSingleWordInOperand(0);
auto current_type_id = def_use_mgr_->GetDef(composite)->type_id();
@ -3102,15 +3108,12 @@ TypedExpression FunctionEmitter::MakeCompositeExtract(
<< " elements";
return {};
}
if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
if (index_val >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_val
<< " is too big. Max handled index is "
<< ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
}
auto* letter_index =
create<ast::IdentifierExpression>(swizzles[index_val]);
next_expr = create<ast::MemberAccessorExpression>(current_expr.expr,
letter_index);
Swizzle(index_val));
// All vector components are the same type.
current_type_id = current_type_inst->GetSingleWordInOperand(0);
break;
@ -3124,10 +3127,9 @@ TypedExpression FunctionEmitter::MakeCompositeExtract(
<< " elements";
return {};
}
if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
if (index_val >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_val
<< " is too big. Max handled index is "
<< ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
}
// Use array syntax.
next_expr = create<ast::ArrayAccessorExpression>(current_expr.expr,
@ -3197,7 +3199,6 @@ TypedExpression FunctionEmitter::MakeVectorShuffle(
type_mgr_->GetType(vec1.type_id())->AsVector()->element_count();
// Idiomatic vector accessors.
const char* swizzles[] = {"x", "y", "z", "w"};
// Generate an ast::TypeConstructor expression.
// Assume the literal indices are valid, and there is a valid number of them.
@ -3207,16 +3208,13 @@ TypedExpression FunctionEmitter::MakeVectorShuffle(
for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
const auto index = inst.GetSingleWordInOperand(i);
if (index < vec0_len) {
assert(index < sizeof(swizzles) / sizeof(swizzles[0]));
values.emplace_back(create<ast::MemberAccessorExpression>(
MakeExpression(vec0_id).expr,
create<ast::IdentifierExpression>(swizzles[index])));
MakeExpression(vec0_id).expr, Swizzle(index)));
} else if (index < vec0_len + vec1_len) {
const auto sub_index = index - vec0_len;
assert(sub_index < sizeof(swizzles) / sizeof(swizzles[0]));
assert(sub_index < kMaxVectorLen);
values.emplace_back(create<ast::MemberAccessorExpression>(
MakeExpression(vec1_id).expr,
create<ast::IdentifierExpression>(swizzles[sub_index])));
MakeExpression(vec1_id).expr, Swizzle(sub_index)));
} else if (index == 0xFFFFFFFF) {
// By rule, this maps to OpUndef. Instead, make it zero.
values.emplace_back(parser_impl_.MakeNullValue(result_type->type()));

View File

@ -672,6 +672,13 @@ class FunctionEmitter {
/// @returns the associated loop construct, or nullptr
const Construct* SiblingLoopConstruct(const Construct* c) const;
/// Returns an identifier expression for the swizzle name of the given
/// index into a vector. Emits an error and returns nullptr if the
/// index is out of range, i.e. 4 or higher.
/// @param i index of the subcomponent
/// @returns the identifier expression for the @p i'th component
ast::IdentifierExpression* Swizzle(uint32_t i);
private:
/// @returns the store type for the OpVariable instruction, or
/// null on failure.

View File

@ -16,6 +16,7 @@
#include <vector>
#include "gmock/gmock.h"
#include "src/ast/identifier_expression.h"
#include "src/reader/spirv/function.h"
#include "src/reader/spirv/parser_impl.h"
#include "src/reader/spirv/parser_impl_test_helper.h"
@ -295,6 +296,53 @@ TEST_F(SpvParserTestMiscInstruction, OpNop) {
)")) << ToString(fe.ast_body());
}
// Test swizzle generation.
struct SwizzleCase {
uint32_t index;
std::string expected_expr;
std::string expected_error;
};
using SpvParserSwizzleTest =
SpvParserTestBase<::testing::TestWithParam<SwizzleCase>>;
TEST_P(SpvParserSwizzleTest, Sample) {
// We need a function so we can get a FunctionEmitter.
const auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
auto* result = fe.Swizzle(GetParam().index);
if (GetParam().expected_error.empty()) {
EXPECT_TRUE(fe.success());
ASSERT_NE(result, nullptr);
std::ostringstream ss;
result->to_str(ss, 0);
EXPECT_THAT(ss.str(), Eq(GetParam().expected_expr));
} else {
EXPECT_EQ(result, nullptr);
EXPECT_FALSE(fe.success());
EXPECT_THAT(p->error(), Eq(GetParam().expected_error));
}
}
INSTANTIATE_TEST_SUITE_P(
ValidIndex,
SpvParserSwizzleTest,
::testing::ValuesIn(std::vector<SwizzleCase>{
{0, "Identifier[not set]{x}\n", ""},
{1, "Identifier[not set]{y}\n", ""},
{2, "Identifier[not set]{z}\n", ""},
{3, "Identifier[not set]{w}\n", ""},
{4, "", "vector component index is larger than 3: 4"},
{99999, "", "vector component index is larger than 3: 99999"}}));
// TODO(dneto): OpSizeof : requires Kernel (OpenCL)
} // namespace