diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index c1be6f18f5..ec9781e59a 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -2289,6 +2289,19 @@ bool FunctionEmitter::EmitFunctionVariables() { return success(); } +TypedExpression FunctionEmitter::AddressOfIfNeeded( + TypedExpression expr, + const spvtools::opt::Instruction* inst) { + if (inst && expr) { + if (auto* spirv_type = type_mgr_->GetType(inst->type_id())) { + if (expr.type->Is() && spirv_type->AsPointer()) { + return AddressOf(expr); + } + } + } + return expr; +} + TypedExpression FunctionEmitter::MakeExpression(uint32_t id) { if (failed()) { return {}; @@ -3227,11 +3240,7 @@ bool FunctionEmitter::EmitConstDefinition( if (!expr) { return false; } - if (expr.type->Is()) { - // `let` declarations cannot hold references, so we need to take the address - // of the RHS, and make the `let` be a pointer. - expr = AddressOf(expr); - } + expr = AddressOfIfNeeded(expr, &inst); auto* ast_const = parser_impl_.MakeVariable( inst.result_id(), ast::StorageClass::kNone, expr.type, true, expr.expr, ast::DecorationList{}); @@ -3246,6 +3255,11 @@ bool FunctionEmitter::EmitConstDefinition( bool FunctionEmitter::EmitConstDefOrWriteToHoistedVar( const spvtools::opt::Instruction& inst, TypedExpression expr) { + return WriteIfHoistedVar(inst, expr) || EmitConstDefinition(inst, expr); +} + +bool FunctionEmitter::WriteIfHoistedVar(const spvtools::opt::Instruction& inst, + TypedExpression expr) { const auto result_id = inst.result_id(); const auto* def_info = GetDefInfo(result_id); if (def_info && def_info->requires_hoisted_def) { @@ -3258,7 +3272,7 @@ bool FunctionEmitter::EmitConstDefOrWriteToHoistedVar( expr.expr)); return true; } - return EmitConstDefinition(inst, expr); + return false; } bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { @@ -3490,15 +3504,10 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { GetDefInfo(inst.result_id())->skip = skip; return true; } - auto expr = MakeExpression(value_id); + auto expr = AddressOfIfNeeded(MakeExpression(value_id), &inst); if (!expr) { return false; } - if (expr.type->Is()) { - // If the source is a reference, then we need to take the address of the - // expression. - expr = AddressOf(expr); - } expr.type = RemapStorageClass(expr.type, result_id); return EmitConstDefOrWriteToHoistedVar(inst, expr); } @@ -3546,7 +3555,7 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { TypedExpression FunctionEmitter::MakeOperand( const spvtools::opt::Instruction& inst, uint32_t operand_index) { - auto expr = this->MakeExpression(inst.GetSingleWordInOperand(operand_index)); + auto expr = MakeExpression(inst.GetSingleWordInOperand(operand_index)); if (!expr) { return {}; } @@ -4583,11 +4592,10 @@ bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) { if (!expr) { return false; } - if (expr.type->Is()) { - // Functions cannot use references as parameters, so we need to pass by - // pointer. - expr = AddressOf(expr); - } + // Functions cannot use references as parameters, so we need to pass by + // pointer if the operand is of pointer type. + expr = AddressOfIfNeeded( + expr, def_use_mgr_->GetDef(inst.GetSingleWordInOperand(iarg))); args.emplace_back(expr.expr); } if (failed()) { @@ -5435,6 +5443,9 @@ bool FunctionEmitter::MakeVectorInsertDynamic( // Then use result everywhere the original SPIR-V id is used. Using a const // like this avoids constantly reloading the value many times. + // TODO(dneto): crbug.com/tint/804: handle the case where %src_vector is + // has been hoisted into a variable. + auto* ast_type = parser_impl_.ConvertType(inst.type_id()); auto src_vector = MakeOperand(inst, 0); auto component = MakeOperand(inst, 1); @@ -5468,40 +5479,57 @@ bool FunctionEmitter::MakeCompositeInsert( const spvtools::opt::Instruction& inst) { // For // %result = OpCompositeInsert %type %object %composite 1 2 3 ... - // generate statements like this: + // there are two cases. + // + // Case 1: + // The %composite value has already been hoisted into a variable. + // In this case, assign %composite to that variable, then write the + // component into the right spot: + // + // hoisted = composite; + // hoisted[index].x = object; + // + // Case 2: + // The %composite value is not hoisted. In this case, make a temporary + // variable with the %composite contents, then write the component, + // and then make a let-declaration that reads the value out: // // var temp : type = composite; // temp[index].x = object; // let result : type = temp; // - // Then use result everywhere the original SPIR-V id is used. Using a const - // like this avoids constantly reloading the value many times. + // Then use result everywhere the original SPIR-V id is used. Using a const + // like this avoids constantly reloading the value many times. // - // This technique is a combination of: - // - making a temporary variable and constant declaration, like what we do - // for VectorInsertDynamic, and - // - building up an access-chain like access like for CompositeExtract, but - // on the left-hand side of the assignment. + // This technique is a combination of: + // - making a temporary variable and constant declaration, like what we do + // for VectorInsertDynamic, and + // - building up an access-chain like access like for CompositeExtract, but + // on the left-hand side of the assignment. - auto* ast_type = parser_impl_.ConvertType(inst.type_id()); + auto* type = parser_impl_.ConvertType(inst.type_id()); auto component = MakeOperand(inst, 0); auto src_composite = MakeOperand(inst, 1); - // Synthesize the temporary variable. - // It doesn't correspond to a SPIR-V ID, so we don't use the ordinary - // API in parser_impl_. - auto result_name = namer_.Name(inst.result_id()); - auto temp_name = namer_.MakeDerivedName(result_name); - auto registered_temp_name = builder_.Symbols().Register(temp_name); + std::string var_name; + auto original_value_name = namer_.Name(inst.result_id()); + const bool hoisted = WriteIfHoistedVar(inst, src_composite); + if (hoisted) { + // The variable was already declared in an earlier block. + var_name = original_value_name; + // Assign the source composite value to it. + builder_.Assign({}, builder_.Expr(var_name), src_composite.expr); + } else { + // Synthesize a temporary variable. + // It doesn't correspond to a SPIR-V ID, so we don't use the ordinary + // API in parser_impl_. + var_name = namer_.MakeDerivedName(original_value_name); + auto* temp_var = builder_.Var(var_name, type->Build(builder_), + ast::StorageClass::kNone, src_composite.expr); + AddStatement(builder_.Decl({}, temp_var)); + } - auto* temp_var = create( - Source{}, registered_temp_name, ast::StorageClass::kNone, - ast::Access::kUndefined, ast_type->Build(builder_), false, - src_composite.expr, ast::DecorationList{}); - AddStatement(create(Source{}, temp_var)); - - TypedExpression seed_expr{ast_type, create( - Source{}, registered_temp_name)}; + TypedExpression seed_expr{type, builder_.Expr(var_name)}; // The left-hand side of the assignment *looks* like a decomposition. TypedExpression lhs = @@ -5513,9 +5541,13 @@ bool FunctionEmitter::MakeCompositeInsert( AddStatement( create(Source{}, lhs.expr, component.expr)); - return EmitConstDefinition( - inst, - {ast_type, create(registered_temp_name)}); + if (hoisted) { + // The hoisted variable itself stands for this result ID. + return success(); + } + // Create a new let-declaration that is initialized by the contents + // of the temporary variable. + return EmitConstDefinition(inst, {type, builder_.Expr(var_name)}); } TypedExpression FunctionEmitter::AddressOf(TypedExpression expr) { diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index 72a357f5a2..ef920ba3ee 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -281,7 +281,7 @@ struct DefInfo { /// to that variable, and each SPIR-V use becomes a WGSL read from the /// variable. /// TODO(dneto): This works for constants of storable type, but not, for - /// example, pointers. + /// example, pointers. crbug.com/tint/98 bool requires_hoisted_def = false; /// If the definition is an OpPhi, then `phi_var` is the name of the @@ -682,7 +682,17 @@ class FunctionEmitter { bool EmitConstDefOrWriteToHoistedVar(const spvtools::opt::Instruction& inst, TypedExpression ast_expr); - /// Makes an expression + /// If the result ID of the given instruction is hoisted, then emits + /// a statement to write the expression to the hoisted variable, and + /// returns true. Otherwise return false. + /// @param inst the SPIR-V instruction defining a value. + /// @param ast_expr the expression to assign. + /// @returns true if the instruction has an associated hoisted variable. + bool WriteIfHoistedVar(const spvtools::opt::Instruction& inst, + TypedExpression ast_expr); + + /// Makes an expression from a SPIR-V ID. + /// if the SPIR-V result type is a pointer. /// @param id the SPIR-V ID of the value /// @returns true if emission has not yet failed. TypedExpression MakeExpression(uint32_t id); @@ -1122,6 +1132,14 @@ class FunctionEmitter { /// @note `expr` must be a reference type TypedExpression AddressOf(TypedExpression expr); + /// Returns AddressOf(expr) if expr is has reference type and + /// the instruction has a pointer result type. Otherwise returns expr. + /// @param expr the expression to take the address of + /// @returns a TypedExpression that is the address-of `expr` (`&expr`) + /// @note `expr` must be a reference type + TypedExpression AddressOfIfNeeded(TypedExpression expr, + const spvtools::opt::Instruction* inst); + /// @param expr the expression to dereference /// @returns a TypedExpression that is the dereference-of `expr` (`*expr`) /// @note `expr` must be a pointer type diff --git a/src/reader/spirv/function_call_test.cc b/src/reader/spirv/function_call_test.cc index 33e77cd30f..007707a144 100644 --- a/src/reader/spirv/function_call_test.cc +++ b/src/reader/spirv/function_call_test.cc @@ -162,8 +162,9 @@ TEST_F(SpvParserTest, EmitStatement_ScalarCallNoParamsUsedTwice) { { auto fe = p->function_emitter(100); EXPECT_TRUE(fe.EmitBody()) << p->error(); - EXPECT_THAT(ToString(p->builder(), fe.ast_body()), - HasSubstr(R"(VariableDeclStatement{ + const auto got = ToString(p->builder(), fe.ast_body()); + const std::string expected = + R"(VariableDeclStatement{ Variable{ x_10 none @@ -194,7 +195,9 @@ Assignment{ Identifier[not set]{x_10} Identifier[not set]{x_1} } -Return{})")); +Return{} +)"; + EXPECT_EQ(got, expected); } { auto fe = p->function_emitter(50); diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc index 7621aa50a2..1f4acf3f00 100644 --- a/src/reader/spirv/function_composite_test.cc +++ b/src/reader/spirv/function_composite_test.cc @@ -607,8 +607,9 @@ TEST_F(SpvParserTest_CompositeInsert, Vector) { ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; auto fe = p->function_emitter(100); EXPECT_TRUE(fe.EmitBody()) << p->error(); - auto body_str = ToString(p->builder(), fe.ast_body()); - EXPECT_THAT(body_str, HasSubstr(R"(VariableDeclStatement{ + auto got = ToString(p->builder(), fe.ast_body()); + const auto* expected = + R"(VariableDeclStatement{ Variable{ x_1_1 none @@ -640,7 +641,10 @@ VariableDeclStatement{ Identifier[not set]{x_1_1} } } -})")) << body_str; +} +Return{} +)"; + EXPECT_EQ(got, expected); } TEST_F(SpvParserTest_CompositeInsert, Vector_IndexTooBigError) { diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc index ca6f5b3124..36ba2d09b2 100644 --- a/src/reader/spirv/function_var_test.cc +++ b/src/reader/spirv/function_var_test.cc @@ -59,14 +59,18 @@ std::string CommonTypes() { %uint_0 = OpConstant %uint 0 %uint_1 = OpConstant %uint 1 %int_m1 = OpConstant %int -1 + %int_0 = OpConstant %int 0 %uint_2 = OpConstant %uint 2 %uint_3 = OpConstant %uint 3 %uint_4 = OpConstant %uint 4 %uint_5 = OpConstant %uint 5 + %v2int = OpTypeVector %int 2 %v2float = OpTypeVector %float 2 %m3v2float = OpTypeMatrix %v2float 3 + %v2int_null = OpConstantNull %v2int + %arr2uint = OpTypeArray %uint %uint_2 %strct = OpTypeStruct %uint %float %arr2uint )"; @@ -2282,7 +2286,7 @@ TEST_F(SpvParserFunctionVarTest, auto got = ToString(p->builder(), fe.ast_body()); auto* expect = R"(VariableDeclStatement{ Variable{ - x_35_phi + x_38_phi none undefined __u32 @@ -2308,14 +2312,14 @@ Switch{ Else{ { Assignment{ - Identifier[not set]{x_35_phi} + Identifier[not set]{x_38_phi} ScalarConstructor[not set]{0u} } Break{} } } Assignment{ - Identifier[not set]{x_35_phi} + Identifier[not set]{x_38_phi} ScalarConstructor[not set]{1u} } } @@ -2323,12 +2327,12 @@ Switch{ } VariableDeclStatement{ VariableConst{ - x_35 + x_38 none undefined __u32 { - Identifier[not set]{x_35_phi} + Identifier[not set]{x_38_phi} } } } @@ -2520,6 +2524,84 @@ Return{} EXPECT_EQ(got, expected); } +TEST_F(SpvParserFunctionVarTest, EmitStatement_Hoist_CompositeInsert) { + // From crbug.com/tint/804 + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %50 None + OpBranchConditional %true %20 %30 + + %20 = OpLabel + %200 = OpCompositeInsert %v2int %int_0 %v2int_null 0 + OpBranch %50 + + %30 = OpLabel + OpReturn + + %50 = OpLabel ; dominated by %20, but %200 needs to be hoisted + %201 = OpCopyObject %v2int %200 + OpReturn + OpFunctionEnd +)"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly; + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + const auto* expected = R"(VariableDeclStatement{ + Variable{ + x_200 + none + undefined + __vec_2__i32 + } +} +If{ + ( + ScalarConstructor[not set]{true} + ) + { + Assignment{ + Identifier[not set]{x_200} + TypeConstructor[not set]{ + __vec_2__i32 + ScalarConstructor[not set]{0} + ScalarConstructor[not set]{0} + } + } + Assignment{ + MemberAccessor[not set]{ + Identifier[not set]{x_200} + Identifier[not set]{x} + } + ScalarConstructor[not set]{0} + } + } +} +Else{ + { + Return{} + } +} +VariableDeclStatement{ + VariableConst{ + x_201 + none + undefined + __vec_2__i32 + { + Identifier[not set]{x_200} + } + } +} +Return{} +)"; + const auto got = ToString(p->builder(), fe.ast_body()); + EXPECT_EQ(got, expected); +} + } // namespace } // namespace spirv } // namespace reader