spirv-reader: refactor swizzle creation
Change-Id: I6a09756026b7cbc436d5f232be9331255615e8c3 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/34040 Auto-Submit: David Neto <dneto@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: David Neto <dneto@google.com>
This commit is contained in:
parent
3ec71a63aa
commit
61ec48b99e
|
@ -172,6 +172,8 @@ namespace spirv {
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr uint32_t kMaxVectorLen = 4;
|
||||
|
||||
// Gets the AST unary opcode for the given SPIR-V opcode, if any
|
||||
// @param opcode SPIR-V opcode
|
||||
// @param ast_unary_op return parameter
|
||||
|
@ -2874,6 +2876,16 @@ TypedExpression FunctionEmitter::EmitGlslStd450ExtInst(
|
|||
return {ast_type, call};
|
||||
}
|
||||
|
||||
ast::IdentifierExpression* FunctionEmitter::Swizzle(uint32_t i) {
|
||||
if (i >= kMaxVectorLen) {
|
||||
Fail() << "vector component index is larger than " << kMaxVectorLen - 1
|
||||
<< ": " << i;
|
||||
return nullptr;
|
||||
}
|
||||
const char* names[] = {"x", "y", "z", "w"};
|
||||
return ast_module_.create<ast::IdentifierExpression>(names[i & 3]);
|
||||
}
|
||||
|
||||
TypedExpression FunctionEmitter::MakeAccessChain(
|
||||
const spvtools::opt::Instruction& inst) {
|
||||
if (inst.NumInOperands() < 1) {
|
||||
|
@ -2888,7 +2900,6 @@ TypedExpression FunctionEmitter::MakeAccessChain(
|
|||
// for the base, and then bury that inside nested indexing expressions.
|
||||
TypedExpression current_expr(MakeOperand(inst, 0));
|
||||
const auto constants = constant_mgr_->GetOperandConstants(&inst);
|
||||
static const char* swizzles[] = {"x", "y", "z", "w"};
|
||||
|
||||
const auto base_id = inst.GetSingleWordInOperand(0);
|
||||
auto ptr_ty_id = def_use_mgr_->GetDef(base_id)->type_id();
|
||||
|
@ -2981,16 +2992,12 @@ TypedExpression FunctionEmitter::MakeAccessChain(
|
|||
<< num_elems << " elements";
|
||||
return {};
|
||||
}
|
||||
if (uint64_t(index_const_val) >=
|
||||
sizeof(swizzles) / sizeof(swizzles[0])) {
|
||||
if (uint64_t(index_const_val) >= kMaxVectorLen) {
|
||||
Fail() << "internal error: swizzle index " << index_const_val
|
||||
<< " is too big. Max handled index is "
|
||||
<< ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
|
||||
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
|
||||
}
|
||||
auto* letter_index =
|
||||
create<ast::IdentifierExpression>(swizzles[index_const_val]);
|
||||
next_expr = create<ast::MemberAccessorExpression>(current_expr.expr,
|
||||
letter_index);
|
||||
next_expr = create<ast::MemberAccessorExpression>(
|
||||
current_expr.expr, Swizzle(uint32_t(index_const_val)));
|
||||
} else {
|
||||
// Non-constant index. Use array syntax
|
||||
next_expr = create<ast::ArrayAccessorExpression>(
|
||||
|
@ -3072,7 +3079,6 @@ TypedExpression FunctionEmitter::MakeCompositeExtract(
|
|||
return create<ast::ScalarConstructorExpression>(
|
||||
create<ast::UintLiteral>(&u32, literal));
|
||||
};
|
||||
static const char* swizzles[] = {"x", "y", "z", "w"};
|
||||
|
||||
const auto composite = inst.GetSingleWordInOperand(0);
|
||||
auto current_type_id = def_use_mgr_->GetDef(composite)->type_id();
|
||||
|
@ -3102,15 +3108,12 @@ TypedExpression FunctionEmitter::MakeCompositeExtract(
|
|||
<< " elements";
|
||||
return {};
|
||||
}
|
||||
if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
|
||||
if (index_val >= kMaxVectorLen) {
|
||||
Fail() << "internal error: swizzle index " << index_val
|
||||
<< " is too big. Max handled index is "
|
||||
<< ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
|
||||
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
|
||||
}
|
||||
auto* letter_index =
|
||||
create<ast::IdentifierExpression>(swizzles[index_val]);
|
||||
next_expr = create<ast::MemberAccessorExpression>(current_expr.expr,
|
||||
letter_index);
|
||||
Swizzle(index_val));
|
||||
// All vector components are the same type.
|
||||
current_type_id = current_type_inst->GetSingleWordInOperand(0);
|
||||
break;
|
||||
|
@ -3124,10 +3127,9 @@ TypedExpression FunctionEmitter::MakeCompositeExtract(
|
|||
<< " elements";
|
||||
return {};
|
||||
}
|
||||
if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
|
||||
if (index_val >= kMaxVectorLen) {
|
||||
Fail() << "internal error: swizzle index " << index_val
|
||||
<< " is too big. Max handled index is "
|
||||
<< ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
|
||||
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
|
||||
}
|
||||
// Use array syntax.
|
||||
next_expr = create<ast::ArrayAccessorExpression>(current_expr.expr,
|
||||
|
@ -3197,7 +3199,6 @@ TypedExpression FunctionEmitter::MakeVectorShuffle(
|
|||
type_mgr_->GetType(vec1.type_id())->AsVector()->element_count();
|
||||
|
||||
// Idiomatic vector accessors.
|
||||
const char* swizzles[] = {"x", "y", "z", "w"};
|
||||
|
||||
// Generate an ast::TypeConstructor expression.
|
||||
// Assume the literal indices are valid, and there is a valid number of them.
|
||||
|
@ -3207,16 +3208,13 @@ TypedExpression FunctionEmitter::MakeVectorShuffle(
|
|||
for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
|
||||
const auto index = inst.GetSingleWordInOperand(i);
|
||||
if (index < vec0_len) {
|
||||
assert(index < sizeof(swizzles) / sizeof(swizzles[0]));
|
||||
values.emplace_back(create<ast::MemberAccessorExpression>(
|
||||
MakeExpression(vec0_id).expr,
|
||||
create<ast::IdentifierExpression>(swizzles[index])));
|
||||
MakeExpression(vec0_id).expr, Swizzle(index)));
|
||||
} else if (index < vec0_len + vec1_len) {
|
||||
const auto sub_index = index - vec0_len;
|
||||
assert(sub_index < sizeof(swizzles) / sizeof(swizzles[0]));
|
||||
assert(sub_index < kMaxVectorLen);
|
||||
values.emplace_back(create<ast::MemberAccessorExpression>(
|
||||
MakeExpression(vec1_id).expr,
|
||||
create<ast::IdentifierExpression>(swizzles[sub_index])));
|
||||
MakeExpression(vec1_id).expr, Swizzle(sub_index)));
|
||||
} else if (index == 0xFFFFFFFF) {
|
||||
// By rule, this maps to OpUndef. Instead, make it zero.
|
||||
values.emplace_back(parser_impl_.MakeNullValue(result_type->type()));
|
||||
|
|
|
@ -672,6 +672,13 @@ class FunctionEmitter {
|
|||
/// @returns the associated loop construct, or nullptr
|
||||
const Construct* SiblingLoopConstruct(const Construct* c) const;
|
||||
|
||||
/// Returns an identifier expression for the swizzle name of the given
|
||||
/// index into a vector. Emits an error and returns nullptr if the
|
||||
/// index is out of range, i.e. 4 or higher.
|
||||
/// @param i index of the subcomponent
|
||||
/// @returns the identifier expression for the @p i'th component
|
||||
ast::IdentifierExpression* Swizzle(uint32_t i);
|
||||
|
||||
private:
|
||||
/// @returns the store type for the OpVariable instruction, or
|
||||
/// null on failure.
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "src/ast/identifier_expression.h"
|
||||
#include "src/reader/spirv/function.h"
|
||||
#include "src/reader/spirv/parser_impl.h"
|
||||
#include "src/reader/spirv/parser_impl_test_helper.h"
|
||||
|
@ -295,6 +296,53 @@ TEST_F(SpvParserTestMiscInstruction, OpNop) {
|
|||
)")) << ToString(fe.ast_body());
|
||||
}
|
||||
|
||||
// Test swizzle generation.
|
||||
|
||||
struct SwizzleCase {
|
||||
uint32_t index;
|
||||
std::string expected_expr;
|
||||
std::string expected_error;
|
||||
};
|
||||
using SpvParserSwizzleTest =
|
||||
SpvParserTestBase<::testing::TestWithParam<SwizzleCase>>;
|
||||
|
||||
TEST_P(SpvParserSwizzleTest, Sample) {
|
||||
// We need a function so we can get a FunctionEmitter.
|
||||
const auto assembly = CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
%entry = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
auto p = parser(test::Assemble(assembly));
|
||||
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
|
||||
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
|
||||
|
||||
auto* result = fe.Swizzle(GetParam().index);
|
||||
if (GetParam().expected_error.empty()) {
|
||||
EXPECT_TRUE(fe.success());
|
||||
ASSERT_NE(result, nullptr);
|
||||
std::ostringstream ss;
|
||||
result->to_str(ss, 0);
|
||||
EXPECT_THAT(ss.str(), Eq(GetParam().expected_expr));
|
||||
} else {
|
||||
EXPECT_EQ(result, nullptr);
|
||||
EXPECT_FALSE(fe.success());
|
||||
EXPECT_THAT(p->error(), Eq(GetParam().expected_error));
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ValidIndex,
|
||||
SpvParserSwizzleTest,
|
||||
::testing::ValuesIn(std::vector<SwizzleCase>{
|
||||
{0, "Identifier[not set]{x}\n", ""},
|
||||
{1, "Identifier[not set]{y}\n", ""},
|
||||
{2, "Identifier[not set]{z}\n", ""},
|
||||
{3, "Identifier[not set]{w}\n", ""},
|
||||
{4, "", "vector component index is larger than 3: 4"},
|
||||
{99999, "", "vector component index is larger than 3: 99999"}}));
|
||||
|
||||
// TODO(dneto): OpSizeof : requires Kernel (OpenCL)
|
||||
|
||||
} // namespace
|
||||
|
|
Loading…
Reference in New Issue