spirv-reader: don't dedup composite constants

The SPIR-V optimizer's representation deduplicates constants
by structural equality.  We don't want that for WGSL.

Fixed: tint:1173
Change-Id: I7a3936fcd4803a1cda02e71cbaa7c4be89eba433
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/64701
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: David Neto <dneto@google.com>
This commit is contained in:
David Neto 2021-09-21 18:07:28 +00:00 committed by Tint LUCI CQ
parent 0ed87c8182
commit 92265504fe
3 changed files with 105 additions and 50 deletions

View File

@ -248,6 +248,61 @@ TEST_F(SpvParserTest_Composite_Construct, Struct) {
})")); })"));
} }
TEST_F(SpvParserTest_Composite_Construct,
ConstantComposite_Struct_NoDeduplication) {
const auto assembly = Preamble() + R"(
%200 = OpTypeStruct %uint
%300 = OpTypeStruct %uint ; isomorphic structures
%201 = OpConstantComposite %200 %uint_10
%301 = OpConstantComposite %300 %uint_10 ; isomorphic constants
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%2 = OpCopyObject %200 %201
%3 = OpCopyObject %300 %301
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody()) << p->error();
const auto got = ToString(p->builder(), fe.ast_body());
const auto expected = std::string(
R"(VariableDeclStatement{
VariableConst{
x_2
none
undefined
__type_name_S_1
{
TypeConstructor[not set]{
__type_name_S_1
ScalarConstructor[not set]{10u}
}
}
}
}
VariableDeclStatement{
VariableConst{
x_3
none
undefined
__type_name_S_2
{
TypeConstructor[not set]{
__type_name_S_2
ScalarConstructor[not set]{10u}
}
}
}
}
Return{}
)");
EXPECT_EQ(got, expected) << got;
}
using SpvParserTest_CompositeExtract = SpvParserTest; using SpvParserTest_CompositeExtract = SpvParserTest;
TEST_F(SpvParserTest_CompositeExtract, Vector) { TEST_F(SpvParserTest_CompositeExtract, Vector) {

View File

@ -1890,19 +1890,19 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
Source{}, ast_type->Build(builder_), Source{}, ast_type->Build(builder_),
ast::ExpressionList{x.expr, y.expr, z.expr})}; ast::ExpressionList{x.expr, y.expr, z.expr})};
} else if (id == workgroup_size_builtin_.x_id) { } else if (id == workgroup_size_builtin_.x_id) {
return MakeConstantExpressionForSpirvConstant( return MakeConstantExpressionForScalarSpirvConstant(
Source{}, ConvertType(workgroup_size_builtin_.component_type_id), Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
constant_mgr_->GetConstant( constant_mgr_->GetConstant(
type_mgr_->GetType(workgroup_size_builtin_.component_type_id), type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
{workgroup_size_builtin_.x_value})); {workgroup_size_builtin_.x_value}));
} else if (id == workgroup_size_builtin_.y_id) { } else if (id == workgroup_size_builtin_.y_id) {
return MakeConstantExpressionForSpirvConstant( return MakeConstantExpressionForScalarSpirvConstant(
Source{}, ConvertType(workgroup_size_builtin_.component_type_id), Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
constant_mgr_->GetConstant( constant_mgr_->GetConstant(
type_mgr_->GetType(workgroup_size_builtin_.component_type_id), type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
{workgroup_size_builtin_.y_value})); {workgroup_size_builtin_.y_value}));
} else if (id == workgroup_size_builtin_.z_id) { } else if (id == workgroup_size_builtin_.z_id) {
return MakeConstantExpressionForSpirvConstant( return MakeConstantExpressionForScalarSpirvConstant(
Source{}, ConvertType(workgroup_size_builtin_.component_type_id), Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
constant_mgr_->GetConstant( constant_mgr_->GetConstant(
type_mgr_->GetType(workgroup_size_builtin_.component_type_id), type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
@ -1916,29 +1916,59 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
Fail() << "ID " << id << " is not a registered instruction"; Fail() << "ID " << id << " is not a registered instruction";
return {}; return {};
} }
auto source = GetSourceForInst(inst);
// TODO(dneto): Handle spec constants too?
auto* original_ast_type = ConvertType(inst->type_id()); auto* original_ast_type = ConvertType(inst->type_id());
if (original_ast_type == nullptr) { if (original_ast_type == nullptr) {
return {}; return {};
} }
if (inst->opcode() == SpvOpUndef) { switch (inst->opcode()) {
// Remap undef to null. case SpvOpUndef: // Remap undef to null.
return {original_ast_type, MakeNullValue(original_ast_type)}; case SpvOpConstantNull:
} return {original_ast_type, MakeNullValue(original_ast_type)};
case SpvOpConstantTrue:
case SpvOpConstantFalse:
case SpvOpConstant: {
const auto* spirv_const = constant_mgr_->FindDeclaredConstant(id);
if (spirv_const == nullptr) {
Fail() << "ID " << id << " is not a constant";
return {};
}
return MakeConstantExpressionForScalarSpirvConstant(
source, original_ast_type, spirv_const);
}
case SpvOpConstantComposite: {
// Handle vector, matrix, array, and struct
// TODO(dneto): Handle spec constants too? // Generate a composite from explicit components.
const auto* spirv_const = constant_mgr_->FindDeclaredConstant(id); ast::ExpressionList ast_components;
if (spirv_const == nullptr) { if (!inst->WhileEachInId([&](const uint32_t* id_ref) -> bool {
Fail() << "ID " << id << " is not a constant"; auto component = MakeConstantExpression(*id_ref);
return {}; if (!component) {
this->Fail() << "invalid constant with ID " << *id_ref;
return false;
}
ast_components.emplace_back(component.expr);
return true;
})) {
// We've already emitted a diagnostic.
return {};
}
return {original_ast_type, create<ast::TypeConstructorExpression>(
source, original_ast_type->Build(builder_),
std::move(ast_components))};
}
default:
break;
} }
Fail() << "unhandled constant instruction " << inst->PrettyPrint();
auto source = GetSourceForInst(inst); return {};
return MakeConstantExpressionForSpirvConstant(source, original_ast_type,
spirv_const);
} }
TypedExpression ParserImpl::MakeConstantExpressionForSpirvConstant( TypedExpression ParserImpl::MakeConstantExpressionForScalarSpirvConstant(
Source source, Source source,
const Type* original_ast_type, const Type* original_ast_type,
const spvtools::opt::analysis::Constant* spirv_const) { const spvtools::opt::analysis::Constant* spirv_const) {
@ -1970,37 +2000,7 @@ TypedExpression ParserImpl::MakeConstantExpressionForSpirvConstant(
return {ty_.Bool(), create<ast::ScalarConstructorExpression>( return {ty_.Bool(), create<ast::ScalarConstructorExpression>(
Source{}, create<ast::BoolLiteral>(source, value))}; Source{}, create<ast::BoolLiteral>(source, value))};
} }
auto* spirv_composite_const = spirv_const->AsCompositeConstant(); Fail() << "expected scalar constant";
if (spirv_composite_const != nullptr) {
// Handle vector, matrix, array, and struct
// TODO(dneto): Handle the spirv_composite_const->IsZero() case specially.
// See https://github.com/gpuweb/gpuweb/issues/685
// Generate a composite from explicit components.
ast::ExpressionList ast_components;
for (const auto* component : spirv_composite_const->GetComponents()) {
auto* def = constant_mgr_->GetDefiningInstruction(component);
if (def == nullptr) {
Fail() << "internal error: SPIR-V constant doesn't have defining "
"instruction";
return {};
}
auto ast_component = MakeConstantExpression(def->result_id());
if (!success_) {
// We've already emitted a diagnostic.
return {};
}
ast_components.emplace_back(ast_component.expr);
}
return {original_ast_type, create<ast::TypeConstructorExpression>(
Source{}, original_ast_type->Build(builder_),
std::move(ast_components))};
}
if (spirv_const->AsNullConstant()) {
return {original_ast_type, MakeNullValue(original_ast_type)};
}
Fail() << "Unhandled constant type ";
return {}; return {};
} }

View File

@ -443,12 +443,12 @@ class ParserImpl : Reader {
/// @returns a new expression /// @returns a new expression
TypedExpression MakeConstantExpression(uint32_t id); TypedExpression MakeConstantExpression(uint32_t id);
/// Creates an AST expression node for a SPIR-V constant. /// Creates an AST expression node for a scalar SPIR-V constant.
/// @param source the source location /// @param source the source location
/// @param ast_type the AST type for the value /// @param ast_type the AST type for the value
/// @param spirv_const the internal representation of the SPIR-V constant. /// @param spirv_const the internal representation of the SPIR-V constant.
/// @returns a new expression /// @returns a new expression
TypedExpression MakeConstantExpressionForSpirvConstant( TypedExpression MakeConstantExpressionForScalarSpirvConstant(
Source source, Source source,
const Type* ast_type, const Type* ast_type,
const spvtools::opt::analysis::Constant* spirv_const); const spvtools::opt::analysis::Constant* spirv_const);