From a2580d6720cfcfe2f9e46ec828ed95c1ce46b8e4 Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Thu, 6 May 2021 21:23:13 +0000 Subject: [PATCH] spirv reader: replace typ::Type with ast::Type Bug: tint:724 Change-Id: Idaf807dd1ff75af8e0044731e7362c0915ae7e54 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50200 Reviewed-by: Ben Clayton Reviewed-by: David Neto Commit-Queue: Antonio Maiorano --- src/reader/spirv/function.cc | 223 ++++++------ src/reader/spirv/function.h | 11 +- src/reader/spirv/parser_impl.cc | 190 +++++----- src/reader/spirv/parser_impl.h | 69 ++-- .../spirv/parser_impl_convert_type_test.cc | 332 +++++++++--------- src/reader/spirv/parser_impl_test_helper.h | 2 +- src/typepair.h | 21 -- 7 files changed, 408 insertions(+), 440 deletions(-) diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index dd3974df84..523c370ddf 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -685,13 +685,6 @@ struct LoopStatementBuilder ast::BlockStatement* continuing = nullptr; }; -// Forwards UnwrapAll to both the ast and sem types of the TypePair -// @param tp the type pair -// @returns the unwrapped type pair -typ::Type UnwrapAll(typ::Type tp) { - return typ::Type{tp.ast->UnwrapAll(), tp.sem->UnwrapAll()}; -} - } // namespace BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb) @@ -907,7 +900,7 @@ bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) { // Surprisingly, the "type id" on an OpFunction is the result type of the // function, not the type of the function. This is the one exceptional case // in SPIR-V where the type ID is not the type of the result ID. - auto ret_ty = parser_impl_.ConvertType(function_.type_id()); + auto* ret_ty = parser_impl_.ConvertType(function_.type_id()); if (failed()) { return false; } @@ -920,7 +913,7 @@ bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) { ast::VariableList ast_params; function_.ForEachParam( [this, &ast_params](const spvtools::opt::Instruction* param) { - auto ast_type = parser_impl_.ConvertType(param->type_id()); + auto* ast_type = parser_impl_.ConvertType(param->type_id()); if (ast_type != nullptr) { auto* ast_param = parser_impl_.MakeVariable( param->result_id(), ast::StorageClass::kNone, ast_type, true, @@ -950,7 +943,7 @@ bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) { return success(); } -typ::Type FunctionEmitter::GetVariableStoreType( +ast::Type* FunctionEmitter::GetVariableStoreType( const spvtools::opt::Instruction& var_decl_inst) { const auto type_id = var_decl_inst.type_id(); auto* var_ref_type = type_mgr_->GetType(type_id); @@ -2013,7 +2006,7 @@ bool FunctionEmitter::EmitFunctionVariables() { if (inst.opcode() != SpvOpVariable) { continue; } - auto var_store_type = GetVariableStoreType(inst); + auto* var_store_type = GetVariableStoreType(inst); if (failed()) { return false; } @@ -2049,7 +2042,7 @@ TypedExpression FunctionEmitter::MakeExpression(uint32_t id) { << id; return {}; case SkipReason::kPointSizeBuiltinValue: { - return {create(), + return {create(), create( Source{}, create(Source{}, 1.0f))}; } @@ -2660,7 +2653,7 @@ bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) { if (result_type->AsVoid() != nullptr) { AddStatement(create(Source{})); } else { - auto ast_type = parser_impl_.ConvertType(function_.type_id()); + auto* ast_type = parser_impl_.ConvertType(function_.type_id()); AddStatement(create( Source{}, parser_impl_.MakeNullValue(ast_type))); } @@ -2905,7 +2898,7 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info, for (auto id : sorted_by_index(block_info.hoisted_ids)) { const auto* def_inst = def_use_mgr_->GetDef(id); TINT_ASSERT(def_inst); - auto ast_type = + auto* ast_type = RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id); AddStatement(create( Source{}, @@ -3109,7 +3102,7 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { case SkipReason::kSampleMaskOutBuiltinPointer: ptr_id = sample_mask_out_id; - if (rhs.type != builder_.ty.u32()) { + if (!rhs.type->Is()) { // WGSL requires sample_mask_out to be signed. rhs = TypedExpression{builder_.ty.u32(), create( @@ -3164,12 +3157,12 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { auto name = namer_.Name(sample_mask_in_id); ast::Expression* id_expr = create( Source{}, builder_.Symbols().Register(name)); - auto load_result_type = parser_impl_.ConvertType(inst.type_id()); + auto* load_result_type = parser_impl_.ConvertType(inst.type_id()); ast::Expression* ast_expr = nullptr; - if (load_result_type == builder_.ty.i32()) { + if (load_result_type->Is()) { ast_expr = create( Source{}, builder_.ty.i32(), ast::ExpressionList{id_expr}); - } else if (load_result_type == builder_.ty.u32()) { + } else if (load_result_type->Is()) { ast_expr = id_expr; } else { return Fail() << "loading the whole SampleMask input array is not " @@ -3188,8 +3181,8 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { } // The load result type is the pointee type of its operand. - TINT_ASSERT(expr.type.ast->Is()); - expr.type = typ::Call_type(typ::As(expr.type)); + TINT_ASSERT(expr.type->Is()); + expr.type = expr.type->As()->type(); return EmitConstDefOrWriteToHoistedVar(inst, expr); } @@ -3204,7 +3197,7 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { return true; } auto expr = MakeExpression(value_id); - if (!expr.type.ast || !expr.expr) { + if (!expr.type || !expr.expr) { return false; } expr.type = RemapStorageClass(expr.type, result_id); @@ -3291,7 +3284,7 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( const auto opcode = inst.opcode(); - typ::Type ast_type = + ast::Type* ast_type = inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr; auto binary_op = ConvertBinaryOp(opcode); @@ -3464,7 +3457,7 @@ TypedExpression FunctionEmitter::EmitGlslStd450ExtInst( auto* func = create( Source{}, builder_.Symbols().Register(name)); ast::ExpressionList operands; - typ::Type first_operand_type = nullptr; + ast::Type* first_operand_type = nullptr; // All parameters to GLSL.std.450 extended instructions are IDs. for (uint32_t iarg = 2; iarg < inst.NumInOperands(); ++iarg) { TypedExpression operand = MakeOperand(inst, iarg); @@ -3473,7 +3466,7 @@ TypedExpression FunctionEmitter::EmitGlslStd450ExtInst( } operands.emplace_back(operand.expr); } - auto ast_type = parser_impl_.ConvertType(inst.type_id()); + auto* ast_type = parser_impl_.ConvertType(inst.type_id()); auto* call = create(Source{}, func, std::move(operands)); TypedExpression call_expr{ast_type, call}; return parser_impl_.RectifyForcedResultType(call_expr, inst, @@ -3708,9 +3701,9 @@ TypedExpression FunctionEmitter::MakeAccessChain( } const auto pointer_type_id = type_mgr_->FindPointerToType(pointee_type_id, storage_class); - auto ast_pointer_type = parser_impl_.ConvertType(pointer_type_id); - TINT_ASSERT(ast_pointer_type.ast); - TINT_ASSERT(ast_pointer_type.ast->Is()); + auto* ast_pointer_type = parser_impl_.ConvertType(pointer_type_id); + TINT_ASSERT(ast_pointer_type); + TINT_ASSERT(ast_pointer_type->Is()); current_expr = TypedExpression{ast_pointer_type, next_expr}; } return current_expr; @@ -3894,8 +3887,8 @@ TypedExpression FunctionEmitter::MakeVectorShuffle( // Generate an ast::TypeConstructor expression. // Assume the literal indices are valid, and there is a valid number of them. auto source = GetSourceForInst(inst); - typ::Vector result_type = - typ::As(parser_impl_.ConvertType(inst.type_id())); + ast::Vector* result_type = + parser_impl_.ConvertType(inst.type_id())->As(); ast::ExpressionList values; for (uint32_t i = 2; i < inst.NumInOperands(); ++i) { const auto index = inst.GetSingleWordInOperand(i); @@ -3917,8 +3910,7 @@ TypedExpression FunctionEmitter::MakeVectorShuffle( source, expr.expr, Swizzle(sub_index))); } else if (index == 0xFFFFFFFF) { // By rule, this maps to OpUndef. Instead, make it zero. - values.emplace_back( - parser_impl_.MakeNullValue(typ::Call_type(result_type))); + values.emplace_back(parser_impl_.MakeNullValue(result_type->type())); } else { Fail() << "invalid vectorshuffle ID %" << inst.result_id() << ": index too large: " << index; @@ -3995,8 +3987,8 @@ bool FunctionEmitter::RegisterLocallyDefinedValues() { const auto* type = type_mgr_->GetType(inst.type_id()); if (type) { if (type->AsPointer()) { - if (auto ast_type = parser_impl_.ConvertType(inst.type_id())) { - if (auto* ptr = ast_type.ast->As()) { + if (auto* ast_type = parser_impl_.ConvertType(inst.type_id())) { + if (auto* ptr = ast_type->As()) { info->storage_class = ptr->storage_class(); } } @@ -4040,22 +4032,22 @@ ast::StorageClass FunctionEmitter::GetStorageClassForPointerValue(uint32_t id) { } const auto type_id = def_use_mgr_->GetDef(id)->type_id(); if (type_id) { - auto ast_type = parser_impl_.ConvertType(type_id); - if (auto ptr = typ::As(ast_type)) { - return ptr.ast->storage_class(); + auto* ast_type = parser_impl_.ConvertType(type_id); + if (auto* ptr = ast_type->As()) { + return ptr->storage_class(); } } return ast::StorageClass::kNone; } -typ::Type FunctionEmitter::RemapStorageClass(typ::Type type, - uint32_t result_id) { - if (auto ast_ptr_type = typ::As(type)) { +ast::Type* FunctionEmitter::RemapStorageClass(ast::Type* type, + uint32_t result_id) { + if (auto* ast_ptr_type = type->As()) { // Remap an old-style storage buffer pointer to a new-style storage // buffer pointer. const auto sc = GetStorageClassForPointerValue(result_id); - if (ast_ptr_type.ast->storage_class() != sc) { - return builder_.ty.pointer(typ::Call_type(ast_ptr_type), sc); + if (ast_ptr_type->storage_class() != sc) { + return builder_.ty.pointer(ast_ptr_type->type(), sc); } } return type; @@ -4232,13 +4224,13 @@ const Construct* FunctionEmitter::GetEnclosingScope(uint32_t first_pos, TypedExpression FunctionEmitter::MakeNumericConversion( const spvtools::opt::Instruction& inst) { const auto opcode = inst.opcode(); - auto requested_type = parser_impl_.ConvertType(inst.type_id()); + auto* requested_type = parser_impl_.ConvertType(inst.type_id()); auto arg_expr = MakeOperand(inst, 0); if (!arg_expr.expr || !arg_expr.type) { return {}; } - typ::Type expr_type = nullptr; + ast::Type* expr_type = nullptr; if ((opcode == SpvOpConvertSToF) || (opcode == SpvOpConvertUToF)) { if (arg_expr.type->is_integer_scalar_or_vector()) { expr_type = requested_type; @@ -4276,7 +4268,7 @@ TypedExpression FunctionEmitter::MakeNumericConversion( Source{}, builder_.ty.MaybeCreateTypename(expr_type), std::move(params))}; - if (requested_type == expr_type) { + if (AstTypesEquivalent(requested_type, expr_type)) { return result; } return {requested_type, create( @@ -4298,13 +4290,13 @@ bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) { } auto* call_expr = create(Source{}, function, std::move(params)); - auto result_type = parser_impl_.ConvertType(inst.type_id()); - if (!result_type.ast) { + auto* result_type = parser_impl_.ConvertType(inst.type_id()); + if (!result_type) { return Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint(); } - if (result_type.ast->Is()) { + if (result_type->Is()) { return nullptr != AddStatement(create(Source{}, call_expr)); } @@ -4367,7 +4359,7 @@ TypedExpression FunctionEmitter::MakeIntrinsicCall( Source{}, builder_.Symbols().Register(name)); ast::ExpressionList params; - typ::Type first_operand_type = nullptr; + ast::Type* first_operand_type = nullptr; for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) { TypedExpression operand = MakeOperand(inst, iarg); if (first_operand_type == nullptr) { @@ -4377,8 +4369,8 @@ TypedExpression FunctionEmitter::MakeIntrinsicCall( } auto* call_expr = create(Source{}, ident, std::move(params)); - auto result_type = parser_impl_.ConvertType(inst.type_id()); - if (!result_type.ast) { + auto* result_type = parser_impl_.ConvertType(inst.type_id()); + if (!result_type) { Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint(); return {}; @@ -4398,7 +4390,7 @@ TypedExpression FunctionEmitter::MakeSimpleSelect( // - operand1, operand2, and result type to match. // - you can't select over pointers or pointer vectors, unless you also have // a VariablePointers* capability, which is not allowed in by WebGPU. - auto* op_ty = operand1.type.ast; + auto* op_ty = operand1.type; if (op_ty->Is() || op_ty->is_float_scalar() || op_ty->is_integer_scalar() || op_ty->Is()) { ast::ExpressionList params; @@ -4438,9 +4430,9 @@ const spvtools::opt::Instruction* FunctionEmitter::GetImage( return image; } -typ::Texture FunctionEmitter::GetImageType( +ast::Texture* FunctionEmitter::GetImageType( const spvtools::opt::Instruction& image) { - typ::Pointer ptr_type = parser_impl_.GetTypeForHandleVar(image); + ast::Pointer* ptr_type = parser_impl_.GetTypeForHandleVar(image); if (!parser_impl_.success()) { Fail(); return {}; @@ -4449,7 +4441,7 @@ typ::Texture FunctionEmitter::GetImageType( Fail() << "invalid texture type for " << image.PrettyPrint(); return {}; } - auto result = typ::As(UnwrapAll(typ::Call_type(ptr_type))); + auto* result = ptr_type->type()->UnwrapAll()->As(); if (!result) { Fail() << "invalid texture type for " << image.PrettyPrint(); return {}; @@ -4504,14 +4496,14 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) { } } - typ::Pointer texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image); - if (!texture_ptr_type.ast) { + ast::Pointer* texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image); + if (!texture_ptr_type) { return Fail(); } - typ::Texture texture_type = - typ::As(UnwrapAll(typ::Call_type(texture_ptr_type))); + ast::Texture* texture_type = + texture_ptr_type->type()->UnwrapAll()->As(); - if (!texture_type.ast) { + if (!texture_type) { return Fail(); } @@ -4612,7 +4604,7 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) { } TypedExpression lod = MakeOperand(inst, arg_index); // When sampling from a depth texture, the Lod operand must be an I32. - if (texture_type.ast->Is()) { + if (texture_type->Is()) { // Convert it to a signed integer type. lod = ToI32(lod); } @@ -4620,8 +4612,8 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) { image_operands_mask ^= SpvImageOperandsLodMask; arg_index++; } else if ((opcode == SpvOpImageFetch) && - (texture_type.ast->Is() || - texture_type.ast->Is())) { + (texture_type->Is() || + texture_type->Is())) { // textureLoad on sampled texture and depth texture requires an explicit // level-of-detail parameter. params.push_back(parser_impl_.MakeNullValue(builder_.ty.i32())); @@ -4682,10 +4674,10 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) { ast::Expression* value = call_expr; // The result type, derived from the SPIR-V instruction. - auto result_type = parser_impl_.ConvertType(inst.type_id()); - auto result_component_type = result_type; - if (auto result_vector_type = typ::As(result_type)) { - result_component_type = typ::Call_type(result_vector_type); + auto* result_type = parser_impl_.ConvertType(inst.type_id()); + auto* result_component_type = result_type; + if (auto* result_vector_type = result_type->As()) { + result_component_type = result_vector_type->type(); } // For depth textures, the arity might mot match WGSL: @@ -4699,7 +4691,7 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) { // dref gather vec4 ImageFetch vec4 TODO(dneto) // Construct a 4-element vector with the result from the builtin in the // first component. - if (texture_type.ast->Is()) { + if (texture_type->Is()) { if (is_non_dref_sample || (opcode == SpvOpImageFetch)) { value = create( Source{}, @@ -4720,14 +4712,14 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) { return Fail() << "invalid image type for image memory object declaration " << image->PrettyPrint(); } - auto expected_component_type = + auto* expected_component_type = parser_impl_.ConvertType(spirv_image_type->GetSingleWordInOperand(0)); - if (expected_component_type != result_component_type) { + if (!AstTypesEquivalent(expected_component_type, result_component_type)) { // This occurs if one is signed integer and the other is unsigned integer, // or vice versa. Perform a bitcast. value = create(Source{}, result_type, call_expr); } - if (!expected_component_type.ast->Is() && + if (!expected_component_type->Is() && IsSampledImageAccess(opcode)) { // WGSL permits sampled image access only on float textures. // Reject this case in the SPIR-V reader, at least until SPIR-V validation @@ -4750,7 +4742,7 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) { if (!image) { return false; } - auto texture_type = GetImageType(*image); + auto* texture_type = GetImageType(*image); if (!texture_type) { return false; } @@ -4778,7 +4770,7 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) { Source{}, layers_ident, ast::ExpressionList{GetImageExpression(inst)})); } - auto result_type = parser_impl_.ConvertType(inst.type_id()); + auto* result_type = parser_impl_.ConvertType(inst.type_id()); TypedExpression expr = { result_type, create( @@ -4799,10 +4791,10 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) { ast::Expression* ast_expr = create( Source{}, levels_ident, ast::ExpressionList{GetImageExpression(inst)}); - auto result_type = parser_impl_.ConvertType(inst.type_id()); + auto* result_type = parser_impl_.ConvertType(inst.type_id()); // The SPIR-V result type must be integer scalar. The WGSL bulitin // returns i32. If they aren't the same then convert the result. - if (result_type != builder_.ty.i32()) { + if (!result_type->Is()) { ast_expr = create( Source{}, builder_.ty.MaybeCreateTypename(result_type), ast::ExpressionList{ast_expr}); @@ -4848,7 +4840,7 @@ ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess( if (!raw_coords.type) { return {}; } - typ::Texture texture_type = GetImageType(*image); + ast::Texture* texture_type = GetImageType(*image); if (!texture_type) { return {}; } @@ -4863,12 +4855,12 @@ ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess( } const auto num_coords_required = num_axes + (is_arrayed ? 1 : 0); uint32_t num_coords_supplied = 0; - auto component_type = raw_coords.type; + auto* component_type = raw_coords.type; if (component_type->is_float_scalar() || component_type->is_integer_scalar()) { num_coords_supplied = 1; - } else if (auto vec_type = typ::As(raw_coords.type)) { - component_type = typ::Call_type(vec_type); + } else if (auto* vec_type = raw_coords.type->As()) { + component_type = vec_type->type(); num_coords_supplied = vec_type->size(); } if (num_coords_supplied == 0) { @@ -4892,9 +4884,10 @@ ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess( // will actually use them. auto prefix_swizzle_expr = [this, num_axes, component_type, raw_coords]() -> ast::Expression* { - auto swizzle_type = - (num_axes == 1) ? component_type - : typ::Type{builder_.ty.vec(component_type, num_axes)}; + auto* swizzle_type = (num_axes == 1) + ? component_type + : static_cast( + builder_.ty.vec(component_type, num_axes)); auto* swizzle = create( Source{}, raw_coords.expr, PrefixSwizzle(num_axes)); return ToSignedIfUnsigned({swizzle_type, swizzle}).expr; @@ -4928,32 +4921,32 @@ ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess( ast::Expression* FunctionEmitter::ConvertTexelForStorage( const spvtools::opt::Instruction& inst, TypedExpression texel, - typ::Texture texture_type) { - auto storage_texture_type = typ::As(texture_type); - auto src_type = texel.type; - if (!storage_texture_type.ast) { + ast::Texture* texture_type) { + auto* storage_texture_type = texture_type->As(); + auto* src_type = texel.type; + if (!storage_texture_type) { Fail() << "writing to other than storage texture: " << inst.PrettyPrint(); return nullptr; } - const auto format = storage_texture_type.ast->image_format(); - auto dest_type = parser_impl_.GetTexelTypeForFormat(format); - if (!dest_type.ast) { + const auto format = storage_texture_type->image_format(); + auto* dest_type = parser_impl_.GetTexelTypeForFormat(format); + if (!dest_type) { Fail(); return nullptr; } - if (src_type == dest_type) { + if (AstTypesEquivalent(src_type, dest_type)) { return texel.expr; } const uint32_t dest_count = - dest_type.ast->is_scalar() ? 1 : dest_type.ast->As()->size(); + dest_type->is_scalar() ? 1 : dest_type->As()->size(); if (dest_count == 3) { Fail() << "3-channel storage textures are not supported: " << inst.PrettyPrint(); return nullptr; } const uint32_t src_count = - src_type.ast->is_scalar() ? 1 : src_type.ast->As()->size(); + src_type->is_scalar() ? 1 : src_type->As()->size(); if (src_count < dest_count) { Fail() << "texel has too few components for storage texture: " << src_count << " provided but " << dest_count @@ -4968,29 +4961,29 @@ ast::Expression* FunctionEmitter::ConvertTexelForStorage( : create(Source{}, texel.expr, PrefixSwizzle(dest_count)); - if (!(dest_type.ast->is_float_scalar_or_vector() || - dest_type.ast->is_unsigned_scalar_or_vector() || - dest_type.ast->is_signed_scalar_or_vector())) { + if (!(dest_type->is_float_scalar_or_vector() || + dest_type->is_unsigned_scalar_or_vector() || + dest_type->is_signed_scalar_or_vector())) { Fail() << "invalid destination type for storage texture write: " - << dest_type.ast->type_name(); + << dest_type->type_name(); return nullptr; } - if (!(src_type.ast->is_float_scalar_or_vector() || - src_type.ast->is_unsigned_scalar_or_vector() || - src_type.ast->is_signed_scalar_or_vector())) { + if (!(src_type->is_float_scalar_or_vector() || + src_type->is_unsigned_scalar_or_vector() || + src_type->is_signed_scalar_or_vector())) { Fail() << "invalid texel type for storage texture write: " << inst.PrettyPrint(); return nullptr; } - if (dest_type.ast->is_float_scalar_or_vector() && - !src_type.ast->is_float_scalar_or_vector()) { + if (dest_type->is_float_scalar_or_vector() && + !src_type->is_float_scalar_or_vector()) { Fail() << "can only write float or float vector to a storage image with " "floating texel format: " << inst.PrettyPrint(); return nullptr; } - if (!dest_type.ast->is_float_scalar_or_vector() && - src_type.ast->is_float_scalar_or_vector()) { + if (!dest_type->is_float_scalar_or_vector() && + src_type->is_float_scalar_or_vector()) { Fail() << "float or float vector can only be written to a storage image with " "floating texel format: " @@ -4998,13 +4991,13 @@ ast::Expression* FunctionEmitter::ConvertTexelForStorage( return nullptr; } - if (dest_type.ast->is_float_scalar_or_vector()) { + if (dest_type->is_float_scalar_or_vector()) { return texel_prefix; } // The only remaining cases are signed/unsigned source, and signed/unsigned // destination. - if (dest_type.ast->is_unsigned_scalar_or_vector() == - src_type.ast->is_unsigned_scalar_or_vector()) { + if (dest_type->is_unsigned_scalar_or_vector() == + src_type->is_unsigned_scalar_or_vector()) { return texel_prefix; } // We must do a bitcast conversion. @@ -5012,7 +5005,7 @@ ast::Expression* FunctionEmitter::ConvertTexelForStorage( } TypedExpression FunctionEmitter::ToI32(TypedExpression value) { - if (!value.type || value.type == builder_.ty.i32()) { + if (!value.type || value.type->Is()) { return value; } return {builder_.ty.i32(), @@ -5024,7 +5017,7 @@ TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) { if (!value.type || !value.type->is_unsigned_scalar_or_vector()) { return value; } - if (auto* vec_type = value.type.ast->As()) { + if (auto* vec_type = value.type->As()) { auto new_type = builder_.ty.vec(builder_.ty.i32(), vec_type->size()); return {new_type, builder_.Construct(new_type, ast::ExpressionList{value.expr})}; @@ -5080,12 +5073,12 @@ TypedExpression FunctionEmitter::MakeOuterProduct( // Synthesize the result. auto col = MakeOperand(inst, 0); auto row = MakeOperand(inst, 1); - auto col_ty = typ::As(col.type); - auto row_ty = typ::As(row.type); - auto result_ty = - typ::As(parser_impl_.ConvertType(inst.type_id())); - if (!col_ty || !col_ty || !result_ty || result_ty->type() != col_ty->type() || - result_ty->type() != row_ty->type() || + auto* col_ty = col.type->As(); + auto* row_ty = row.type->As(); + auto* result_ty = parser_impl_.ConvertType(inst.type_id())->As(); + if (!col_ty || !col_ty || !result_ty || + !AstTypesEquivalent(result_ty->type(), col_ty->type()) || + !AstTypesEquivalent(result_ty->type(), row_ty->type()) || result_ty->columns() != row_ty->size() || result_ty->rows() != col_ty->size()) { Fail() << "invalid outer product instruction: bad types " @@ -5135,7 +5128,7 @@ bool FunctionEmitter::MakeVectorInsertDynamic( // Then use result everywhere the original SPIR-V id is used. Using a const // like this avoids constantly reloading the value many times. - auto ast_type = parser_impl_.ConvertType(inst.type_id()); + auto* ast_type = parser_impl_.ConvertType(inst.type_id()); auto src_vector = MakeOperand(inst, 0); auto component = MakeOperand(inst, 1); auto index = MakeOperand(inst, 2); @@ -5183,7 +5176,7 @@ bool FunctionEmitter::MakeCompositeInsert( // - building up an access-chain like access like for CompositeExtract, but // on the left-hand side of the assignment. - auto ast_type = parser_impl_.ConvertType(inst.type_id()); + auto* ast_type = parser_impl_.ConvertType(inst.type_id()); auto component = MakeOperand(inst, 0); auto src_composite = MakeOperand(inst, 1); diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index bb3a138201..1f41bd1d69 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -25,7 +25,6 @@ #include "src/program_builder.h" #include "src/reader/spirv/construct.h" #include "src/reader/spirv/parser_impl.h" -#include "src/typepair.h" namespace tint { namespace reader { @@ -516,7 +515,7 @@ class FunctionEmitter { /// @param type the AST type /// @param result_id the SPIR-V ID for the locally defined value /// @returns an possibly updated type - typ::Type RemapStorageClass(typ::Type type, uint32_t result_id); + ast::Type* RemapStorageClass(ast::Type* type, uint32_t result_id); /// Marks locally defined values when they should get a 'const' /// definition in WGSL, or a 'var' definition at an outer scope. @@ -857,7 +856,7 @@ class FunctionEmitter { /// Function parameters ast::VariableList params; /// Function return type - typ::Type return_type; + ast::Type* return_type; /// Function decorations ast::DecorationList decorations; }; @@ -870,7 +869,7 @@ class FunctionEmitter { /// @returns the store type for the OpVariable instruction, or /// null on failure. - typ::Type GetVariableStoreType( + ast::Type* GetVariableStoreType( const spvtools::opt::Instruction& var_decl_inst); /// Returns an expression for an instruction operand. Signedness conversion is @@ -938,7 +937,7 @@ class FunctionEmitter { /// Get the AST texture the SPIR-V image memory object declaration. /// @param inst the SPIR-V memory object declaration for the image. /// @returns a texture type, or null on error - typ::Texture GetImageType(const spvtools::opt::Instruction& inst); + ast::Texture* GetImageType(const spvtools::opt::Instruction& inst); /// Get the expression for the image operand from the first operand to the /// given instruction. @@ -975,7 +974,7 @@ class FunctionEmitter { ast::Expression* ConvertTexelForStorage( const spvtools::opt::Instruction& inst, TypedExpression texel, - typ::Texture texture_type); + ast::Texture* texture_type); /// Returns an expression for an OpSelect, if its operands are scalars /// or vectors. These translate directly to WGSL select. Otherwise, return diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 78eb256b6b..e1f6484c0c 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -232,14 +232,6 @@ bool AssumesResultSignednessMatchesFirstOperand(GLSLstd450 extended_opcode) { return false; } -// Forwards UnwrapIfNeeded to both the ast and sem types of the TypePair -// @param tp the type pair -// @returns the unwrapped type pair -typ::Type UnwrapIfNeeded(typ::Type tp) { - return typ::Type{tp.ast ? tp.ast->UnwrapIfNeeded() : nullptr, - tp.sem ? tp.sem->UnwrapIfNeeded() : nullptr}; -} - } // namespace TypedExpression::TypedExpression() = default; @@ -248,7 +240,7 @@ TypedExpression::TypedExpression(const TypedExpression&) = default; TypedExpression& TypedExpression::operator=(const TypedExpression&) = default; -TypedExpression::TypedExpression(typ::Type type_in, ast::Expression* expr_in) +TypedExpression::TypedExpression(ast::Type* type_in, ast::Expression* expr_in) : type(type_in), expr(expr_in) {} ParserImpl::ParserImpl(const std::vector& spv_binary) @@ -313,7 +305,7 @@ Program ParserImpl::program() { return tint::Program(std::move(builder_)); } -typ::Type ParserImpl::ConvertType(uint32_t type_id) { +ast::Type* ParserImpl::ConvertType(uint32_t type_id) { if (!success_) { return nullptr; } @@ -330,8 +322,8 @@ typ::Type ParserImpl::ConvertType(uint32_t type_id) { } auto maybe_generate_alias = [this, type_id, - spirv_type](typ::Type type) -> typ::Type { - if (type.ast != nullptr) { + spirv_type](ast::Type* type) -> ast::Type* { + if (type != nullptr) { return MaybeGenerateAlias(type_id, spirv_type, type); } return {}; @@ -782,17 +774,17 @@ bool ParserImpl::RegisterEntryPoints() { return success_; } -typ::Type ParserImpl::ConvertType( +ast::Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Integer* int_ty) { if (int_ty->width() == 32) { - return int_ty->IsSigned() ? typ::Type{builder_.ty.i32()} - : typ::Type{builder_.ty.u32()}; + return int_ty->IsSigned() ? static_cast(builder_.ty.i32()) + : static_cast(builder_.ty.u32()); } Fail() << "unhandled integer width: " << int_ty->width(); return nullptr; } -typ::Type ParserImpl::ConvertType( +ast::Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Float* float_ty) { if (float_ty->width() == 32) { return builder_.ty.f32(); @@ -801,33 +793,33 @@ typ::Type ParserImpl::ConvertType( return nullptr; } -typ::Type ParserImpl::ConvertType( +ast::Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Vector* vec_ty) { const auto num_elem = vec_ty->element_count(); - auto ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type())); - if (ast_elem_ty.ast == nullptr) { + auto* ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type())); + if (ast_elem_ty == nullptr) { return nullptr; } return builder_.ty.vec(ast_elem_ty, num_elem); } -typ::Type ParserImpl::ConvertType( +ast::Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Matrix* mat_ty) { const auto* vec_ty = mat_ty->element_type()->AsVector(); const auto* scalar_ty = vec_ty->element_type(); const auto num_rows = vec_ty->element_count(); const auto num_columns = mat_ty->element_count(); - auto ast_scalar_ty = ConvertType(type_mgr_->GetId(scalar_ty)); - if (ast_scalar_ty.ast == nullptr) { + auto* ast_scalar_ty = ConvertType(type_mgr_->GetId(scalar_ty)); + if (ast_scalar_ty == nullptr) { return nullptr; } return builder_.ty.mat(ast_scalar_ty, num_columns, num_rows); } -typ::Type ParserImpl::ConvertType( +ast::Type* ParserImpl::ConvertType( const spvtools::opt::analysis::RuntimeArray* rtarr_ty) { - auto ast_elem_ty = ConvertType(type_mgr_->GetId(rtarr_ty->element_type())); - if (ast_elem_ty.ast == nullptr) { + auto* ast_elem_ty = ConvertType(type_mgr_->GetId(rtarr_ty->element_type())); + if (ast_elem_ty == nullptr) { return nullptr; } ast::DecorationList decorations; @@ -837,11 +829,11 @@ typ::Type ParserImpl::ConvertType( return builder_.ty.array(ast_elem_ty, 0, std::move(decorations)); } -typ::Type ParserImpl::ConvertType( +ast::Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Array* arr_ty) { const auto elem_type_id = type_mgr_->GetId(arr_ty->element_type()); - auto ast_elem_ty = ConvertType(elem_type_id); - if (ast_elem_ty.ast == nullptr) { + auto* ast_elem_ty = ConvertType(elem_type_id); + if (ast_elem_ty == nullptr) { return nullptr; } const auto& length_info = arr_ty->length_info(); @@ -912,7 +904,7 @@ bool ParserImpl::ParseArrayDecorations( return true; } -typ::Type ParserImpl::ConvertType( +ast::Type* ParserImpl::ConvertType( uint32_t type_id, const spvtools::opt::analysis::Struct* struct_ty) { // Compute the struct decoration. @@ -942,8 +934,8 @@ typ::Type ParserImpl::ConvertType( for (uint32_t member_index = 0; member_index < members.size(); ++member_index) { const auto member_type_id = type_mgr_->GetId(members[member_index]); - auto ast_member_ty = ConvertType(member_type_id); - if (ast_member_ty.ast == nullptr) { + auto* ast_member_ty = ConvertType(member_type_id); + if (ast_member_ty == nullptr) { // Already emitted diagnostics. return nullptr; } @@ -1034,12 +1026,11 @@ typ::Type ParserImpl::ConvertType( } auto* ast_struct = create(Source{}, sym, std::move(ast_members), std::move(ast_struct_decorations)); - auto result = builder_.ty.struct_(ast_struct); if (num_non_writable_members == members.size()) { - read_only_struct_types_.insert(result.ast->name()); + read_only_struct_types_.insert(ast_struct->name()); } - AddConstructedType(sym, result); - return result; + AddConstructedType(sym, ast_struct); + return ast_struct; } void ParserImpl::AddConstructedType(Symbol name, ast::NamedType* type) { @@ -1049,8 +1040,8 @@ void ParserImpl::AddConstructedType(Symbol name, ast::NamedType* type) { } } -typ::Type ParserImpl::ConvertType(uint32_t type_id, - const spvtools::opt::analysis::Pointer*) { +ast::Type* ParserImpl::ConvertType(uint32_t type_id, + const spvtools::opt::analysis::Pointer*) { const auto* inst = def_use_mgr_->GetDef(type_id); const auto pointee_type_id = inst->GetSingleWordInOperand(1); const auto storage_class = SpvStorageClass(inst->GetSingleWordInOperand(0)); @@ -1060,8 +1051,8 @@ typ::Type ParserImpl::ConvertType(uint32_t type_id, builtin_position_.storage_class = storage_class; return nullptr; } - auto ast_elem_ty = ConvertType(pointee_type_id); - if (ast_elem_ty.ast == nullptr) { + auto* ast_elem_ty = ConvertType(pointee_type_id); + if (ast_elem_ty == nullptr) { Fail() << "SPIR-V pointer type with ID " << type_id << " has invalid pointee type " << pointee_type_id; return nullptr; @@ -1123,7 +1114,7 @@ bool ParserImpl::EmitScalarSpecConstants() { // that is OpSpecConstantTrue, OpSpecConstantFalse, or OpSpecConstant. for (auto& inst : module_->types_values()) { // These will be populated for a valid scalar spec constant. - typ::Type ast_type; + ast::Type* ast_type = nullptr; ast::ScalarConstructorExpression* ast_expr = nullptr; switch (inst.opcode()) { @@ -1138,15 +1129,15 @@ bool ParserImpl::EmitScalarSpecConstants() { case SpvOpSpecConstant: { ast_type = ConvertType(inst.type_id()); const uint32_t literal_value = inst.GetSingleWordInOperand(0); - if (ast_type.ast->Is()) { + if (ast_type->Is()) { ast_expr = create( Source{}, create( Source{}, static_cast(literal_value))); - } else if (ast_type.ast->Is()) { + } else if (ast_type->Is()) { ast_expr = create( Source{}, create( Source{}, static_cast(literal_value))); - } else if (ast_type.ast->Is()) { + } else if (ast_type->Is()) { float float_value; // Copy the bits so we can read them as a float. std::memcpy(&float_value, &literal_value, sizeof(float_value)); @@ -1182,10 +1173,10 @@ bool ParserImpl::EmitScalarSpecConstants() { return success_; } -typ::Type ParserImpl::MaybeGenerateAlias( +ast::Type* ParserImpl::MaybeGenerateAlias( uint32_t type_id, const spvtools::opt::analysis::Type* type, - typ::Type ast_type) { + ast::Type* ast_type) { if (!success_) { return {}; } @@ -1208,8 +1199,8 @@ typ::Type ParserImpl::MaybeGenerateAlias( // Ignore constants, and any other types. return ast_type; } - auto ast_underlying_type = ast_type; - if (ast_underlying_type.ast == nullptr) { + auto* ast_underlying_type = ast_type; + if (ast_underlying_type == nullptr) { Fail() << "internal error: no type registered for SPIR-V ID: " << type_id; return {}; } @@ -1261,7 +1252,7 @@ bool ParserImpl::EmitModuleScopeVariables() { if (!success_) { return false; } - typ::Type ast_type; + ast::Type* ast_type; if (spirv_storage_class == SpvStorageClassUniformConstant) { // These are opaque handles: samplers or textures ast_type = GetTypeForHandleVar(var); @@ -1270,19 +1261,19 @@ bool ParserImpl::EmitModuleScopeVariables() { } } else { ast_type = ConvertType(type_id); - if (ast_type.ast == nullptr) { + if (ast_type == nullptr) { return Fail() << "internal error: failed to register Tint AST type for " "SPIR-V type with ID: " << var.type_id(); } - if (!ast_type.ast->Is()) { + if (!ast_type->Is()) { return Fail() << "variable with ID " << var.result_id() << " has non-pointer type " << var.type_id(); } } - auto ast_store_type = typ::Call_type(typ::As(ast_type)); - auto ast_storage_class = ast_type.ast->As()->storage_class(); + auto* ast_store_type = ast_type->As()->type(); + auto ast_storage_class = ast_type->As()->storage_class(); ast::Expression* ast_constructor = nullptr; if (var.NumInOperands() > 1) { // SPIR-V initializers are always constants. @@ -1345,18 +1336,18 @@ const spvtools::opt::analysis::IntConstant* ParserImpl::GetArraySize( ast::Variable* ParserImpl::MakeVariable(uint32_t id, ast::StorageClass sc, - typ::Type type, + ast::Type* type, bool is_const, ast::Expression* constructor, ast::DecorationList decorations) { - if (type.ast == nullptr) { + if (type == nullptr) { Fail() << "internal error: can't make ast::Variable for null type"; return nullptr; } if (sc == ast::StorageClass::kStorage) { bool read_only = false; - if (auto* tn = type.ast->As()) { + if (auto* tn = type->As()) { read_only = read_only_struct_types_.count(tn->name()) > 0; } @@ -1388,7 +1379,7 @@ ast::Variable* ParserImpl::MakeVariable(uint32_t id, // The SPIR-V variable is likely to be signed (because GLSL // requires signed), but WGSL requires unsigned. Handle specially // so we always perform the conversion at load and store. - if (auto forced_type = UnsignedTypeFor(type)) { + if (auto* forced_type = UnsignedTypeFor(type)) { // Requires conversion and special handling in code generation. special_builtins_[id] = spv_builtin; type = forced_type; @@ -1461,8 +1452,8 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { Fail() << "ID " << id << " is not a registered instruction"; return {}; } - auto original_ast_type = ConvertType(inst->type_id()); - if (original_ast_type.ast == nullptr) { + auto* original_ast_type = ConvertType(inst->type_id()); + if (original_ast_type == nullptr) { return {}; } @@ -1479,28 +1470,28 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { } auto source = GetSourceForInst(inst); - auto ast_type = UnwrapIfNeeded(original_ast_type); + auto* ast_type = original_ast_type->UnwrapIfNeeded(); // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0. // So canonicalization should map that way too. // Currently "null" is missing from the WGSL parser. // See https://bugs.chromium.org/p/tint/issues/detail?id=34 - if (ast_type.ast->Is()) { + if (ast_type->Is()) { return {ast_type, create( Source{}, create( source, spirv_const->GetU32()))}; } - if (ast_type.ast->Is()) { + if (ast_type->Is()) { return {ast_type, create( Source{}, create( source, spirv_const->GetS32()))}; } - if (ast_type.ast->Is()) { + if (ast_type->Is()) { return {ast_type, create( Source{}, create( source, spirv_const->GetFloat()))}; } - if (ast_type.ast->Is()) { + if (ast_type->Is()) { const bool value = spirv_const->AsNullConstant() ? false : spirv_const->AsBoolConstant()->value(); @@ -1556,7 +1547,7 @@ ast::Expression* ParserImpl::MakeNullValue(ast::Type* type) { } auto* original_type = type; - type = UnwrapIfNeeded(type); + type = type->UnwrapIfNeeded(); if (type->Is()) { return create( @@ -1622,15 +1613,15 @@ ast::Expression* ParserImpl::MakeNullValue(ast::Type* type) { return nullptr; } -TypedExpression ParserImpl::MakeNullExpression(typ::Type type) { +TypedExpression ParserImpl::MakeNullExpression(ast::Type* type) { return {type, MakeNullValue(type)}; } -typ::Type ParserImpl::UnsignedTypeFor(typ::Type type) { - if (type.ast->Is()) { +ast::Type* ParserImpl::UnsignedTypeFor(ast::Type* type) { + if (type->Is()) { return builder_.ty.u32(); } - if (auto* v = type.ast->As()) { + if (auto* v = type->As()) { if (v->type()->Is()) { return builder_.ty.vec(builder_.ty.u32(), v->size()); } @@ -1638,11 +1629,11 @@ typ::Type ParserImpl::UnsignedTypeFor(typ::Type type) { return {}; } -typ::Type ParserImpl::SignedTypeFor(typ::Type type) { - if (type.ast->Is()) { +ast::Type* ParserImpl::SignedTypeFor(ast::Type* type) { + if (type->Is()) { return builder_.ty.i32(); } - if (auto* v = type.ast->As()) { + if (auto* v = type->As()) { if (v->type()->Is()) { return builder_.ty.vec(builder_.ty.i32(), v->size()); } @@ -1673,20 +1664,20 @@ TypedExpression ParserImpl::RectifyOperandSignedness( Fail() << "internal error: RectifyOperandSignedness given a null expr\n"; return {}; } - auto type = expr.type; - if (!type.ast) { + auto* type = expr.type; + if (!type) { Fail() << "internal error: unmapped type for: " << builder_.str(expr.expr) << "\n"; return {}; } if (requires_unsigned) { - if (auto unsigned_ty = UnsignedTypeFor(type)) { + if (auto* unsigned_ty = UnsignedTypeFor(type)) { // Conversion is required. return {unsigned_ty, create(Source{}, unsigned_ty, expr.expr)}; } } else if (requires_signed) { - if (auto signed_ty = SignedTypeFor(type)) { + if (auto* signed_ty = SignedTypeFor(type)) { // Conversion is required. return {signed_ty, create(Source{}, signed_ty, expr.expr)}; @@ -1698,9 +1689,9 @@ TypedExpression ParserImpl::RectifyOperandSignedness( TypedExpression ParserImpl::RectifySecondOperandSignedness( const spvtools::opt::Instruction& inst, - typ::Type first_operand_type, + ast::Type* first_operand_type, TypedExpression&& second_operand_expr) { - if ((first_operand_type != second_operand_expr.type) && + if (!AstTypesEquivalent(first_operand_type, second_operand_expr.type) && AssumesSecondOperandSignednessMatchesFirstOperand(inst.opcode())) { // Conversion is required. return {first_operand_type, @@ -1711,8 +1702,8 @@ TypedExpression ParserImpl::RectifySecondOperandSignedness( return std::move(second_operand_expr); } -typ::Type ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst, - typ::Type first_operand_type) { +ast::Type* ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst, + ast::Type* first_operand_type) { const auto opcode = inst.opcode(); if (AssumesResultSignednessMatchesFirstOperand(opcode)) { return first_operand_type; @@ -1727,16 +1718,15 @@ typ::Type ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst, return nullptr; } -typ::Type ParserImpl::GetSignedIntMatchingShape(typ::Type other) { - if (other.ast == nullptr) { +ast::Type* ParserImpl::GetSignedIntMatchingShape(ast::Type* other) { + if (other == nullptr) { Fail() << "no type provided"; } auto i32 = builder_.ty.i32(); - if (other.ast->Is() || other.ast->Is() || - other.ast->Is()) { + if (other->Is() || other->Is() || other->Is()) { return i32; } - auto* vec_ty = other.ast->As(); + auto* vec_ty = other->As(); if (vec_ty) { return builder_.ty.vec(i32, vec_ty->size()); } @@ -1744,17 +1734,16 @@ typ::Type ParserImpl::GetSignedIntMatchingShape(typ::Type other) { return nullptr; } -typ::Type ParserImpl::GetUnsignedIntMatchingShape(typ::Type other) { - if (other.ast == nullptr) { +ast::Type* ParserImpl::GetUnsignedIntMatchingShape(ast::Type* other) { + if (other == nullptr) { Fail() << "no type provided"; return nullptr; } auto u32 = builder_.ty.u32(); - if (other.ast->Is() || other.ast->Is() || - other.ast->Is()) { + if (other->Is() || other->Is() || other->Is()) { return u32; } - auto* vec_ty = other.ast->As(); + auto* vec_ty = other->As(); if (vec_ty) { return builder_.ty.vec(u32, vec_ty->size()); } @@ -1765,9 +1754,10 @@ typ::Type ParserImpl::GetUnsignedIntMatchingShape(typ::Type other) { TypedExpression ParserImpl::RectifyForcedResultType( TypedExpression expr, const spvtools::opt::Instruction& inst, - typ::Type first_operand_type) { - auto forced_result_ty = ForcedResultType(inst, first_operand_type); - if ((forced_result_ty.ast == nullptr) || (forced_result_ty == expr.type)) { + ast::Type* first_operand_type) { + auto* forced_result_ty = ForcedResultType(inst, first_operand_type); + if ((forced_result_ty == nullptr) || + AstTypesEquivalent(forced_result_ty, expr.type)) { return expr; } return {expr.type, @@ -1776,7 +1766,7 @@ TypedExpression ParserImpl::RectifyForcedResultType( TypedExpression ParserImpl::AsUnsigned(TypedExpression expr) { if (expr.type && expr.type->is_signed_scalar_or_vector()) { - auto new_type = GetUnsignedIntMatchingShape(expr.type); + auto* new_type = GetUnsignedIntMatchingShape(expr.type); return {new_type, create(Source{}, new_type, expr.expr)}; } @@ -1785,7 +1775,7 @@ TypedExpression ParserImpl::AsUnsigned(TypedExpression expr) { TypedExpression ParserImpl::AsSigned(TypedExpression expr) { if (expr.type && expr.type->is_unsigned_scalar_or_vector()) { - auto new_type = GetSignedIntMatchingShape(expr.type); + auto* new_type = GetSignedIntMatchingShape(expr.type); return {new_type, create(Source{}, new_type, expr.expr)}; } @@ -1962,7 +1952,7 @@ ParserImpl::GetSpirvTypeForHandleMemoryObjectDeclaration( return raw_handle_type; } -typ::Pointer ParserImpl::GetTypeForHandleVar( +ast::Pointer* ParserImpl::GetTypeForHandleVar( const spvtools::opt::Instruction& var) { auto where = handle_type_.find(&var); if (where != handle_type_.end()) { @@ -2046,7 +2036,7 @@ typ::Pointer ParserImpl::GetTypeForHandleVar( } // Construct the Tint handle type. - typ::Type ast_store_type; + ast::Type* ast_store_type; if (usage.IsSampler()) { ast_store_type = builder_.ty.sampler( usage.IsComparisonSampler() ? ast::SamplerKind::kComparisonSampler @@ -2071,7 +2061,7 @@ typ::Pointer ParserImpl::GetTypeForHandleVar( if (usage.IsSampledTexture() || (image_type->format() == SpvImageFormatUnknown)) { // Make a sampled texture type. - auto ast_sampled_component_type = + auto* ast_sampled_component_type = ConvertType(raw_handle_type->GetSingleWordInOperand(0)); // Vulkan ignores the depth parameter on OpImage, so pay attention to the @@ -2114,7 +2104,7 @@ typ::Pointer ParserImpl::GetTypeForHandleVar( return result; } -typ::Type ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) { +ast::Type* ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) { switch (format) { case ast::ImageFormat::kR8Uint: case ast::ImageFormat::kR16Uint: @@ -2163,8 +2153,8 @@ typ::Type ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) { return nullptr; } -typ::Type ParserImpl::GetTexelTypeForFormat(ast::ImageFormat format) { - auto component_type = GetComponentTypeForFormat(format); +ast::Type* ParserImpl::GetTexelTypeForFormat(ast::ImageFormat format) { + auto* component_type = GetComponentTypeForFormat(format); if (!component_type) { return nullptr; } diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index 1a18a0c833..492d7be531 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -29,7 +29,6 @@ #include "src/reader/spirv/enum_converter.h" #include "src/reader/spirv/namer.h" #include "src/reader/spirv/usage.h" -#include "src/typepair.h" /// This is the implementation of the SPIR-V parser for Tint. @@ -52,6 +51,14 @@ namespace tint { namespace reader { namespace spirv { +/// Returns true of the two input ast types are semantically equivalent +/// @param lhs first type to compare +/// @param rhs other type to compare +/// @returns true if both types are semantically equivalent +inline bool AstTypesEquivalent(ast::Type* lhs, ast::Type* rhs) { + return lhs->type_name() == rhs->type_name(); +} + /// The binary representation of a SPIR-V decoration enum followed by its /// operands, if any. /// Example: { SpvDecorationBlock } @@ -74,10 +81,10 @@ struct TypedExpression { /// Constructor /// @param type_in the type of the expression /// @param expr_in the expression - TypedExpression(typ::Type type_in, ast::Expression* expr_in); + TypedExpression(ast::Type* type_in, ast::Expression* expr_in); /// The type - typ::Type type; + ast::Type* type; /// The expression ast::Expression* expr = nullptr; }; @@ -156,7 +163,7 @@ class ParserImpl : Reader { /// after the internal representation of the module has been built. /// @param type_id the SPIR-V ID of a type. /// @returns a Tint type, or nullptr - typ::Type ConvertType(uint32_t type_id); + ast::Type* ConvertType(uint32_t type_id); /// Emits an alias type declaration for the given type, if necessary, and /// also updates the mapping of the SPIR-V type ID to the alias type. @@ -169,9 +176,9 @@ class ParserImpl : Reader { /// @param type the type that might get an alias /// @param ast_type the ast type that might get an alias /// @returns an alias type or `ast_type` if no alias was created - typ::Type MaybeGenerateAlias(uint32_t type_id, - const spvtools::opt::analysis::Type* type, - typ::Type ast_type); + ast::Type* MaybeGenerateAlias(uint32_t type_id, + const spvtools::opt::analysis::Type* type, + ast::Type* ast_type); /// @returns the fail stream object FailStream& fail_stream() { return fail_stream_; } @@ -321,7 +328,7 @@ class ParserImpl : Reader { /// in the error case ast::Variable* MakeVariable(uint32_t id, ast::StorageClass sc, - typ::Type type, + ast::Type* type, bool is_const, ast::Expression* constructor, ast::DecorationList decorations); @@ -339,7 +346,7 @@ class ParserImpl : Reader { /// Make a typed expression for the null value for the given type. /// @param type the AST type /// @returns a new typed expression - TypedExpression MakeNullExpression(typ::Type type); + TypedExpression MakeNullExpression(ast::Type* type); /// Converts a given expression to the signedness demanded for an operand /// of the given SPIR-V instruction, if required. If the instruction assumes @@ -364,7 +371,7 @@ class ParserImpl : Reader { /// @returns second_operand_expr, or a cast of it TypedExpression RectifySecondOperandSignedness( const spvtools::opt::Instruction& inst, - typ::Type first_operand_type, + ast::Type* first_operand_type, TypedExpression&& second_operand_expr); /// Returns the "forced" result type for the given SPIR-V instruction. @@ -375,8 +382,8 @@ class ParserImpl : Reader { /// @param inst the SPIR-V instruction /// @param first_operand_type the AST type for the first operand. /// @returns the forced AST result type, or nullptr if no forcing is required. - typ::Type ForcedResultType(const spvtools::opt::Instruction& inst, - typ::Type first_operand_type); + ast::Type* ForcedResultType(const spvtools::opt::Instruction& inst, + ast::Type* first_operand_type); /// Returns a signed integer scalar or vector type matching the shape (scalar, /// vector, and component bit width) of another type, which itself is a @@ -384,7 +391,7 @@ class ParserImpl : Reader { /// requirement. /// @param other the type whose shape must be matched /// @returns the signed scalar or vector type - typ::Type GetSignedIntMatchingShape(typ::Type other); + ast::Type* GetSignedIntMatchingShape(ast::Type* other); /// Returns a signed integer scalar or vector type matching the shape (scalar, /// vector, and component bit width) of another type, which itself is a @@ -392,7 +399,7 @@ class ParserImpl : Reader { /// requirement. /// @param other the type whose shape must be matched /// @returns the unsigned scalar or vector type - typ::Type GetUnsignedIntMatchingShape(typ::Type other); + ast::Type* GetUnsignedIntMatchingShape(ast::Type* other); /// Wraps the given expression in an as-cast to the given expression's type, /// when the underlying operation produces a forced result type different @@ -405,7 +412,7 @@ class ParserImpl : Reader { TypedExpression RectifyForcedResultType( TypedExpression expr, const spvtools::opt::Instruction& inst, - typ::Type first_operand_type); + ast::Type* first_operand_type); /// Returns the given expression, but ensuring it's an unsigned type of the /// same shape as the operand. Wraps the expresison with a bitcast if needed. @@ -505,18 +512,18 @@ class ParserImpl : Reader { /// @param var the OpVariable instruction /// @returns the Tint AST type for the poiner-to-{sampler|texture} or null on /// error - typ::Pointer GetTypeForHandleVar(const spvtools::opt::Instruction& var); + ast::Pointer* GetTypeForHandleVar(const spvtools::opt::Instruction& var); /// Returns the channel component type corresponding to the given image /// format. /// @param format image texel format /// @returns the component type, one of f32, i32, u32 - typ::Type GetComponentTypeForFormat(ast::ImageFormat format); + ast::Type* GetComponentTypeForFormat(ast::ImageFormat format); /// Returns texel type corresponding to the given image format. /// @param format image texel format /// @returns the texel format - typ::Type GetTexelTypeForFormat(ast::ImageFormat format); + ast::Type* GetTexelTypeForFormat(ast::ImageFormat format); /// Returns the SPIR-V instruction with the given ID, or nullptr. /// @param id the SPIR-V result ID @@ -554,19 +561,19 @@ class ParserImpl : Reader { private: /// Converts a specific SPIR-V type to a Tint type. Integer case - typ::Type ConvertType(const spvtools::opt::analysis::Integer* int_ty); + ast::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty); /// Converts a specific SPIR-V type to a Tint type. Float case - typ::Type ConvertType(const spvtools::opt::analysis::Float* float_ty); + ast::Type* ConvertType(const spvtools::opt::analysis::Float* float_ty); /// Converts a specific SPIR-V type to a Tint type. Vector case - typ::Type ConvertType(const spvtools::opt::analysis::Vector* vec_ty); + ast::Type* ConvertType(const spvtools::opt::analysis::Vector* vec_ty); /// Converts a specific SPIR-V type to a Tint type. Matrix case - typ::Type ConvertType(const spvtools::opt::analysis::Matrix* mat_ty); + ast::Type* ConvertType(const spvtools::opt::analysis::Matrix* mat_ty); /// Converts a specific SPIR-V type to a Tint type. RuntimeArray case /// @param rtarr_ty the Tint type - typ::Type ConvertType(const spvtools::opt::analysis::RuntimeArray* rtarr_ty); + ast::Type* ConvertType(const spvtools::opt::analysis::RuntimeArray* rtarr_ty); /// Converts a specific SPIR-V type to a Tint type. Array case /// @param arr_ty the Tint type - typ::Type ConvertType(const spvtools::opt::analysis::Array* arr_ty); + ast::Type* ConvertType(const spvtools::opt::analysis::Array* arr_ty); /// Converts a specific SPIR-V type to a Tint type. Struct case. /// SPIR-V allows distinct struct type definitions for two OpTypeStruct /// that otherwise have the same set of members (and struct and member @@ -578,27 +585,27 @@ class ParserImpl : Reader { /// not significant to the optimizer's module representation. /// @param type_id the SPIR-V ID for the type. /// @param struct_ty the Tint type - typ::Type ConvertType(uint32_t type_id, - const spvtools::opt::analysis::Struct* struct_ty); + ast::Type* ConvertType(uint32_t type_id, + const spvtools::opt::analysis::Struct* struct_ty); /// Converts a specific SPIR-V type to a Tint type. Pointer case /// The pointer to gl_PerVertex maps to nullptr, and instead is recorded /// in member #builtin_position_. /// @param type_id the SPIR-V ID for the type. /// @param ptr_ty the Tint type - typ::Type ConvertType(uint32_t type_id, - const spvtools::opt::analysis::Pointer* ptr_ty); + ast::Type* ConvertType(uint32_t type_id, + const spvtools::opt::analysis::Pointer* ptr_ty); /// If `type` is a signed integral, or vector of signed integral, /// returns the unsigned type, otherwise returns `type`. /// @param type the possibly signed type /// @returns the unsigned type - typ::Type UnsignedTypeFor(typ::Type type); + ast::Type* UnsignedTypeFor(ast::Type* type); /// If `type` is a unsigned integral, or vector of unsigned integral, /// returns the signed type, otherwise returns `type`. /// @param type the possibly unsigned type /// @returns the signed type - typ::Type SignedTypeFor(typ::Type type); + ast::Type* SignedTypeFor(ast::Type* type); /// Parses the array or runtime-array decorations. /// @param spv_type the SPIR-V array or runtime-array type. @@ -709,7 +716,7 @@ class ParserImpl : Reader { // usages implied by usages of the memory-object-declaration. std::unordered_map handle_usage_; // The inferred pointer type for the given handle variable. - std::unordered_map + std::unordered_map handle_type_; // Set of symbols of constructed types that have been added, used to avoid diff --git a/src/reader/spirv/parser_impl_convert_type_test.cc b/src/reader/spirv/parser_impl_convert_type_test.cc index 319e8a61dd..180f119088 100644 --- a/src/reader/spirv/parser_impl_convert_type_test.cc +++ b/src/reader/spirv/parser_impl_convert_type_test.cc @@ -26,14 +26,14 @@ using ::testing::Eq; TEST_F(SpvParserTest, ConvertType_PreservesExistingFailure) { auto p = parser(std::vector{}); p->Fail() << "boing"; - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); EXPECT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("boing")); } TEST_F(SpvParserTest, ConvertType_RequiresInternalRepresntation) { auto p = parser(std::vector{}); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); EXPECT_EQ(type, nullptr); EXPECT_THAT( p->error(), @@ -44,7 +44,7 @@ TEST_F(SpvParserTest, ConvertType_NotAnId) { auto p = parser(test::Assemble("%1 = OpExtInstImport \"GLSL.std.450\"")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); EXPECT_EQ(type, nullptr); EXPECT_EQ(nullptr, type); EXPECT_THAT(p->error(), Eq("ID is not a SPIR-V type: 10")); @@ -54,7 +54,7 @@ TEST_F(SpvParserTest, ConvertType_IdExistsButIsNotAType) { auto p = parser(test::Assemble("%1 = OpExtInstImport \"GLSL.std.450\"")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(1); + auto* type = p->ConvertType(1); EXPECT_EQ(nullptr, type); EXPECT_THAT(p->error(), Eq("ID is not a SPIR-V type: 1")); } @@ -64,7 +64,7 @@ TEST_F(SpvParserTest, ConvertType_UnhandledType) { auto p = parser(test::Assemble("%70 = OpTypePipe WriteOnly")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(70); + auto* type = p->ConvertType(70); EXPECT_EQ(nullptr, type); EXPECT_THAT(p->error(), Eq("unknown SPIR-V type with ID 70: %70 = OpTypePipe WriteOnly")); @@ -74,8 +74,8 @@ TEST_F(SpvParserTest, ConvertType_Void) { auto p = parser(test::Assemble("%1 = OpTypeVoid")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(1); - EXPECT_TRUE(type.ast->Is()); + auto* type = p->ConvertType(1); + EXPECT_TRUE(type->Is()); EXPECT_TRUE(p->error().empty()); } @@ -83,8 +83,8 @@ TEST_F(SpvParserTest, ConvertType_Bool) { auto p = parser(test::Assemble("%100 = OpTypeBool")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(100); - EXPECT_TRUE(type.ast->Is()); + auto* type = p->ConvertType(100); + EXPECT_TRUE(type->Is()); EXPECT_TRUE(p->error().empty()); } @@ -92,8 +92,8 @@ TEST_F(SpvParserTest, ConvertType_I32) { auto p = parser(test::Assemble("%2 = OpTypeInt 32 1")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(2); - EXPECT_TRUE(type.ast->Is()); + auto* type = p->ConvertType(2); + EXPECT_TRUE(type->Is()); EXPECT_TRUE(p->error().empty()); } @@ -101,8 +101,8 @@ TEST_F(SpvParserTest, ConvertType_U32) { auto p = parser(test::Assemble("%3 = OpTypeInt 32 0")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(3); - EXPECT_TRUE(type.ast->Is()); + auto* type = p->ConvertType(3); + EXPECT_TRUE(type->Is()); EXPECT_TRUE(p->error().empty()); } @@ -110,8 +110,8 @@ TEST_F(SpvParserTest, ConvertType_F32) { auto p = parser(test::Assemble("%4 = OpTypeFloat 32")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(4); - EXPECT_TRUE(type.ast->Is()); + auto* type = p->ConvertType(4); + EXPECT_TRUE(type->Is()); EXPECT_TRUE(p->error().empty()); } @@ -119,7 +119,7 @@ TEST_F(SpvParserTest, ConvertType_BadIntWidth) { auto p = parser(test::Assemble("%5 = OpTypeInt 17 1")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(5); + auto* type = p->ConvertType(5); EXPECT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("unhandled integer width: 17")); } @@ -128,7 +128,7 @@ TEST_F(SpvParserTest, ConvertType_BadFloatWidth) { auto p = parser(test::Assemble("%6 = OpTypeFloat 19")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(6); + auto* type = p->ConvertType(6); EXPECT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("unhandled float width: 19")); } @@ -140,7 +140,7 @@ TEST_F(SpvParserTest, DISABLED_ConvertType_InvalidVectorElement) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(20); + auto* type = p->ConvertType(20); EXPECT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("unknown SPIR-V type: 5")); } @@ -154,20 +154,20 @@ TEST_F(SpvParserTest, ConvertType_VecOverF32) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto v2xf32 = p->ConvertType(20); - EXPECT_TRUE(v2xf32.ast->Is()); - EXPECT_TRUE(v2xf32.ast->As()->type()->Is()); - EXPECT_EQ(v2xf32.ast->As()->size(), 2u); + auto* v2xf32 = p->ConvertType(20); + EXPECT_TRUE(v2xf32->Is()); + EXPECT_TRUE(v2xf32->As()->type()->Is()); + EXPECT_EQ(v2xf32->As()->size(), 2u); - auto v3xf32 = p->ConvertType(30); - EXPECT_TRUE(v3xf32.ast->Is()); - EXPECT_TRUE(v3xf32.ast->As()->type()->Is()); - EXPECT_EQ(v3xf32.ast->As()->size(), 3u); + auto* v3xf32 = p->ConvertType(30); + EXPECT_TRUE(v3xf32->Is()); + EXPECT_TRUE(v3xf32->As()->type()->Is()); + EXPECT_EQ(v3xf32->As()->size(), 3u); - auto v4xf32 = p->ConvertType(40); - EXPECT_TRUE(v4xf32.ast->Is()); - EXPECT_TRUE(v4xf32.ast->As()->type()->Is()); - EXPECT_EQ(v4xf32.ast->As()->size(), 4u); + auto* v4xf32 = p->ConvertType(40); + EXPECT_TRUE(v4xf32->Is()); + EXPECT_TRUE(v4xf32->As()->type()->Is()); + EXPECT_EQ(v4xf32->As()->size(), 4u); EXPECT_TRUE(p->error().empty()); } @@ -181,20 +181,20 @@ TEST_F(SpvParserTest, ConvertType_VecOverI32) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto v2xi32 = p->ConvertType(20); - EXPECT_TRUE(v2xi32.ast->Is()); - EXPECT_TRUE(v2xi32.ast->As()->type()->Is()); - EXPECT_EQ(v2xi32.ast->As()->size(), 2u); + auto* v2xi32 = p->ConvertType(20); + EXPECT_TRUE(v2xi32->Is()); + EXPECT_TRUE(v2xi32->As()->type()->Is()); + EXPECT_EQ(v2xi32->As()->size(), 2u); - auto v3xi32 = p->ConvertType(30); - EXPECT_TRUE(v3xi32.ast->Is()); - EXPECT_TRUE(v3xi32.ast->As()->type()->Is()); - EXPECT_EQ(v3xi32.ast->As()->size(), 3u); + auto* v3xi32 = p->ConvertType(30); + EXPECT_TRUE(v3xi32->Is()); + EXPECT_TRUE(v3xi32->As()->type()->Is()); + EXPECT_EQ(v3xi32->As()->size(), 3u); - auto v4xi32 = p->ConvertType(40); - EXPECT_TRUE(v4xi32.ast->Is()); - EXPECT_TRUE(v4xi32.ast->As()->type()->Is()); - EXPECT_EQ(v4xi32.ast->As()->size(), 4u); + auto* v4xi32 = p->ConvertType(40); + EXPECT_TRUE(v4xi32->Is()); + EXPECT_TRUE(v4xi32->As()->type()->Is()); + EXPECT_EQ(v4xi32->As()->size(), 4u); EXPECT_TRUE(p->error().empty()); } @@ -208,20 +208,20 @@ TEST_F(SpvParserTest, ConvertType_VecOverU32) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto v2xu32 = p->ConvertType(20); - EXPECT_TRUE(v2xu32.ast->Is()); - EXPECT_TRUE(v2xu32.ast->As()->type()->Is()); - EXPECT_EQ(v2xu32.ast->As()->size(), 2u); + auto* v2xu32 = p->ConvertType(20); + EXPECT_TRUE(v2xu32->Is()); + EXPECT_TRUE(v2xu32->As()->type()->Is()); + EXPECT_EQ(v2xu32->As()->size(), 2u); - auto v3xu32 = p->ConvertType(30); - EXPECT_TRUE(v3xu32.ast->Is()); - EXPECT_TRUE(v3xu32.ast->As()->type()->Is()); - EXPECT_EQ(v3xu32.ast->As()->size(), 3u); + auto* v3xu32 = p->ConvertType(30); + EXPECT_TRUE(v3xu32->Is()); + EXPECT_TRUE(v3xu32->As()->type()->Is()); + EXPECT_EQ(v3xu32->As()->size(), 3u); - auto v4xu32 = p->ConvertType(40); - EXPECT_TRUE(v4xu32.ast->Is()); - EXPECT_TRUE(v4xu32.ast->As()->type()->Is()); - EXPECT_EQ(v4xu32.ast->As()->size(), 4u); + auto* v4xu32 = p->ConvertType(40); + EXPECT_TRUE(v4xu32->Is()); + EXPECT_TRUE(v4xu32->As()->type()->Is()); + EXPECT_EQ(v4xu32->As()->size(), 4u); EXPECT_TRUE(p->error().empty()); } @@ -234,7 +234,7 @@ TEST_F(SpvParserTest, DISABLED_ConvertType_InvalidMatrixElement) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(20); + auto* type = p->ConvertType(20); EXPECT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("unknown SPIR-V type: 5")); } @@ -260,59 +260,59 @@ TEST_F(SpvParserTest, ConvertType_MatrixOverF32) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto m22 = p->ConvertType(22); - EXPECT_TRUE(m22.ast->Is()); - EXPECT_TRUE(m22.ast->As()->type()->Is()); - EXPECT_EQ(m22.ast->As()->rows(), 2u); - EXPECT_EQ(m22.ast->As()->columns(), 2u); + auto* m22 = p->ConvertType(22); + EXPECT_TRUE(m22->Is()); + EXPECT_TRUE(m22->As()->type()->Is()); + EXPECT_EQ(m22->As()->rows(), 2u); + EXPECT_EQ(m22->As()->columns(), 2u); - auto m23 = p->ConvertType(23); - EXPECT_TRUE(m23.ast->Is()); - EXPECT_TRUE(m23.ast->As()->type()->Is()); - EXPECT_EQ(m23.ast->As()->rows(), 2u); - EXPECT_EQ(m23.ast->As()->columns(), 3u); + auto* m23 = p->ConvertType(23); + EXPECT_TRUE(m23->Is()); + EXPECT_TRUE(m23->As()->type()->Is()); + EXPECT_EQ(m23->As()->rows(), 2u); + EXPECT_EQ(m23->As()->columns(), 3u); - auto m24 = p->ConvertType(24); - EXPECT_TRUE(m24.ast->Is()); - EXPECT_TRUE(m24.ast->As()->type()->Is()); - EXPECT_EQ(m24.ast->As()->rows(), 2u); - EXPECT_EQ(m24.ast->As()->columns(), 4u); + auto* m24 = p->ConvertType(24); + EXPECT_TRUE(m24->Is()); + EXPECT_TRUE(m24->As()->type()->Is()); + EXPECT_EQ(m24->As()->rows(), 2u); + EXPECT_EQ(m24->As()->columns(), 4u); - auto m32 = p->ConvertType(32); - EXPECT_TRUE(m32.ast->Is()); - EXPECT_TRUE(m32.ast->As()->type()->Is()); - EXPECT_EQ(m32.ast->As()->rows(), 3u); - EXPECT_EQ(m32.ast->As()->columns(), 2u); + auto* m32 = p->ConvertType(32); + EXPECT_TRUE(m32->Is()); + EXPECT_TRUE(m32->As()->type()->Is()); + EXPECT_EQ(m32->As()->rows(), 3u); + EXPECT_EQ(m32->As()->columns(), 2u); - auto m33 = p->ConvertType(33); - EXPECT_TRUE(m33.ast->Is()); - EXPECT_TRUE(m33.ast->As()->type()->Is()); - EXPECT_EQ(m33.ast->As()->rows(), 3u); - EXPECT_EQ(m33.ast->As()->columns(), 3u); + auto* m33 = p->ConvertType(33); + EXPECT_TRUE(m33->Is()); + EXPECT_TRUE(m33->As()->type()->Is()); + EXPECT_EQ(m33->As()->rows(), 3u); + EXPECT_EQ(m33->As()->columns(), 3u); - auto m34 = p->ConvertType(34); - EXPECT_TRUE(m34.ast->Is()); - EXPECT_TRUE(m34.ast->As()->type()->Is()); - EXPECT_EQ(m34.ast->As()->rows(), 3u); - EXPECT_EQ(m34.ast->As()->columns(), 4u); + auto* m34 = p->ConvertType(34); + EXPECT_TRUE(m34->Is()); + EXPECT_TRUE(m34->As()->type()->Is()); + EXPECT_EQ(m34->As()->rows(), 3u); + EXPECT_EQ(m34->As()->columns(), 4u); - auto m42 = p->ConvertType(42); - EXPECT_TRUE(m42.ast->Is()); - EXPECT_TRUE(m42.ast->As()->type()->Is()); - EXPECT_EQ(m42.ast->As()->rows(), 4u); - EXPECT_EQ(m42.ast->As()->columns(), 2u); + auto* m42 = p->ConvertType(42); + EXPECT_TRUE(m42->Is()); + EXPECT_TRUE(m42->As()->type()->Is()); + EXPECT_EQ(m42->As()->rows(), 4u); + EXPECT_EQ(m42->As()->columns(), 2u); - auto m43 = p->ConvertType(43); - EXPECT_TRUE(m43.ast->Is()); - EXPECT_TRUE(m43.ast->As()->type()->Is()); - EXPECT_EQ(m43.ast->As()->rows(), 4u); - EXPECT_EQ(m43.ast->As()->columns(), 3u); + auto* m43 = p->ConvertType(43); + EXPECT_TRUE(m43->Is()); + EXPECT_TRUE(m43->As()->type()->Is()); + EXPECT_EQ(m43->As()->rows(), 4u); + EXPECT_EQ(m43->As()->columns(), 3u); - auto m44 = p->ConvertType(44); - EXPECT_TRUE(m44.ast->Is()); - EXPECT_TRUE(m44.ast->As()->type()->Is()); - EXPECT_EQ(m44.ast->As()->rows(), 4u); - EXPECT_EQ(m44.ast->As()->columns(), 4u); + auto* m44 = p->ConvertType(44); + EXPECT_TRUE(m44->Is()); + EXPECT_TRUE(m44->As()->type()->Is()); + EXPECT_EQ(m44->As()->rows(), 4u); + EXPECT_EQ(m44->As()->columns(), 4u); EXPECT_TRUE(p->error().empty()); } @@ -324,10 +324,10 @@ TEST_F(SpvParserTest, ConvertType_RuntimeArray) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type.ast->UnwrapAliasIfNeeded()->Is()); - auto* arr_type = type.ast->UnwrapAliasIfNeeded()->As(); + EXPECT_TRUE(type->UnwrapAliasIfNeeded()->Is()); + auto* arr_type = type->UnwrapAliasIfNeeded()->As(); EXPECT_TRUE(arr_type->IsRuntimeArray()); ASSERT_NE(arr_type, nullptr); EXPECT_EQ(arr_type->size(), 0u); @@ -345,7 +345,7 @@ TEST_F(SpvParserTest, ConvertType_RuntimeArray_InvalidDecoration) { %10 = OpTypeRuntimeArray %uint )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); EXPECT_EQ(type, nullptr); EXPECT_THAT( p->error(), @@ -359,9 +359,9 @@ TEST_F(SpvParserTest, ConvertType_RuntimeArray_ArrayStride_Valid) { %10 = OpTypeRuntimeArray %uint )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - auto* arr_type = type.ast->UnwrapAliasIfNeeded()->As(); + auto* arr_type = type->UnwrapAliasIfNeeded()->As(); EXPECT_TRUE(arr_type->IsRuntimeArray()); ASSERT_NE(arr_type, nullptr); ASSERT_EQ(arr_type->decorations().size(), 1u); @@ -378,7 +378,7 @@ TEST_F(SpvParserTest, ConvertType_RuntimeArray_ArrayStride_ZeroIsError) { %10 = OpTypeRuntimeArray %uint )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); EXPECT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("invalid array type ID 10: ArrayStride can't be 0")); @@ -393,7 +393,7 @@ TEST_F(SpvParserTest, %10 = OpTypeRuntimeArray %uint )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); EXPECT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("invalid array type ID 10: multiple ArrayStride decorations")); @@ -407,10 +407,10 @@ TEST_F(SpvParserTest, ConvertType_Array) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type.ast->Is()); - auto* arr_type = type.ast->As(); + EXPECT_TRUE(type->Is()); + auto* arr_type = type->As(); EXPECT_FALSE(arr_type->IsRuntimeArray()); ASSERT_NE(arr_type, nullptr); EXPECT_EQ(arr_type->size(), 42u); @@ -430,7 +430,7 @@ TEST_F(SpvParserTest, ConvertType_ArrayBadLengthIsSpecConstantValue) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("Array type 10 length is a specialization constant")); @@ -445,7 +445,7 @@ TEST_F(SpvParserTest, ConvertType_ArrayBadLengthIsSpecConstantExpr) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("Array type 10 length is a specialization constant")); @@ -463,7 +463,7 @@ TEST_F(SpvParserTest, ConvertType_ArrayBadTooBig) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_EQ(type, nullptr); // TODO(dneto): Right now it's rejected earlier in the flow because // we can't even utter the uint64 type. @@ -478,7 +478,7 @@ TEST_F(SpvParserTest, ConvertType_Array_InvalidDecoration) { %10 = OpTypeArray %uint %uint_5 )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); EXPECT_EQ(type, nullptr); EXPECT_THAT( p->error(), @@ -494,10 +494,10 @@ TEST_F(SpvParserTest, ConvertType_ArrayStride_Valid) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type.ast->UnwrapAliasIfNeeded()->Is()); - auto* arr_type = type.ast->UnwrapAliasIfNeeded()->As(); + EXPECT_TRUE(type->UnwrapAliasIfNeeded()->Is()); + auto* arr_type = type->UnwrapAliasIfNeeded()->As(); ASSERT_NE(arr_type, nullptr); ASSERT_EQ(arr_type->decorations().size(), 1u); @@ -517,7 +517,7 @@ TEST_F(SpvParserTest, ConvertType_ArrayStride_ZeroIsError) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("invalid array type ID 10: ArrayStride can't be 0")); @@ -533,7 +533,7 @@ TEST_F(SpvParserTest, ConvertType_ArrayStride_SpecifiedTwiceIsError) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("invalid array type ID 10: multiple ArrayStride decorations")); @@ -548,12 +548,12 @@ TEST_F(SpvParserTest, ConvertType_StructTwoMembers) { EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->RegisterUserAndStructMemberNames()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type.ast->Is()); + EXPECT_TRUE(type->Is()); Program program = p->program(); - EXPECT_THAT(program.str(type.ast->As()), Eq(R"(Struct S { + EXPECT_THAT(program.str(type->As()), Eq(R"(Struct S { StructMember{field0: __u32} StructMember{field1: __f32} } @@ -569,12 +569,12 @@ TEST_F(SpvParserTest, ConvertType_StructWithBlockDecoration) { EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->RegisterUserAndStructMemberNames()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type.ast->Is()); + EXPECT_TRUE(type->Is()); Program program = p->program(); - EXPECT_THAT(program.str(type.ast->As()), Eq(R"(Struct S { + EXPECT_THAT(program.str(type->As()), Eq(R"(Struct S { [[block]] StructMember{field0: __u32} } @@ -594,12 +594,12 @@ TEST_F(SpvParserTest, ConvertType_StructWithMemberDecorations) { EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->RegisterUserAndStructMemberNames()); - auto type = p->ConvertType(10); + auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type.ast->Is()); + EXPECT_TRUE(type->Is()); Program program = p->program(); - EXPECT_THAT(program.str(type.ast->As()), Eq(R"(Struct S { + EXPECT_THAT(program.str(type->As()), Eq(R"(Struct S { StructMember{[[ offset 0 ]] field0: __f32} StructMember{[[ offset 8 ]] field1: __vec_2__f32} StructMember{[[ offset 16 ]] field2: __mat_2_2__f32} @@ -621,7 +621,7 @@ TEST_F(SpvParserTest, ConvertType_InvalidPointeetype) { )")); EXPECT_TRUE(p->BuildInternalModule()) << p->error(); - auto type = p->ConvertType(3); + auto* type = p->ConvertType(3); EXPECT_EQ(type, nullptr); EXPECT_THAT(p->error(), Eq("SPIR-V pointer type with ID 3 has invalid pointee type 42")); @@ -644,9 +644,9 @@ TEST_F(SpvParserTest, ConvertType_PointerInput) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(3); - EXPECT_TRUE(type.ast->Is()); - auto* ptr_ty = type.ast->As(); + auto* type = p->ConvertType(3); + EXPECT_TRUE(type->Is()); + auto* ptr_ty = type->As(); EXPECT_NE(ptr_ty, nullptr); EXPECT_TRUE(ptr_ty->type()->Is()); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kInput); @@ -660,9 +660,9 @@ TEST_F(SpvParserTest, ConvertType_PointerOutput) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(3); - EXPECT_TRUE(type.ast->Is()); - auto* ptr_ty = type.ast->As(); + auto* type = p->ConvertType(3); + EXPECT_TRUE(type->Is()); + auto* ptr_ty = type->As(); EXPECT_NE(ptr_ty, nullptr); EXPECT_TRUE(ptr_ty->type()->Is()); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kOutput); @@ -676,9 +676,9 @@ TEST_F(SpvParserTest, ConvertType_PointerUniform) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(3); - EXPECT_TRUE(type.ast->Is()); - auto* ptr_ty = type.ast->As(); + auto* type = p->ConvertType(3); + EXPECT_TRUE(type->Is()); + auto* ptr_ty = type->As(); EXPECT_NE(ptr_ty, nullptr); EXPECT_TRUE(ptr_ty->type()->Is()); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kUniform); @@ -692,9 +692,9 @@ TEST_F(SpvParserTest, ConvertType_PointerWorkgroup) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(3); - EXPECT_TRUE(type.ast->Is()); - auto* ptr_ty = type.ast->As(); + auto* type = p->ConvertType(3); + EXPECT_TRUE(type->Is()); + auto* ptr_ty = type->As(); EXPECT_NE(ptr_ty, nullptr); EXPECT_TRUE(ptr_ty->type()->Is()); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kWorkgroup); @@ -708,9 +708,9 @@ TEST_F(SpvParserTest, ConvertType_PointerUniformConstant) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(3); - EXPECT_TRUE(type.ast->Is()); - auto* ptr_ty = type.ast->As(); + auto* type = p->ConvertType(3); + EXPECT_TRUE(type->Is()); + auto* ptr_ty = type->As(); EXPECT_NE(ptr_ty, nullptr); EXPECT_TRUE(ptr_ty->type()->Is()); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kUniformConstant); @@ -724,9 +724,9 @@ TEST_F(SpvParserTest, ConvertType_PointerStorageBuffer) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(3); - EXPECT_TRUE(type.ast->Is()); - auto* ptr_ty = type.ast->As(); + auto* type = p->ConvertType(3); + EXPECT_TRUE(type->Is()); + auto* ptr_ty = type->As(); EXPECT_NE(ptr_ty, nullptr); EXPECT_TRUE(ptr_ty->type()->Is()); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kStorage); @@ -740,9 +740,9 @@ TEST_F(SpvParserTest, ConvertType_PointerImage) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(3); - EXPECT_TRUE(type.ast->Is()); - auto* ptr_ty = type.ast->As(); + auto* type = p->ConvertType(3); + EXPECT_TRUE(type->Is()); + auto* ptr_ty = type->As(); EXPECT_NE(ptr_ty, nullptr); EXPECT_TRUE(ptr_ty->type()->Is()); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kImage); @@ -756,9 +756,9 @@ TEST_F(SpvParserTest, ConvertType_PointerPrivate) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(3); - EXPECT_TRUE(type.ast->Is()); - auto* ptr_ty = type.ast->As(); + auto* type = p->ConvertType(3); + EXPECT_TRUE(type->Is()); + auto* ptr_ty = type->As(); EXPECT_NE(ptr_ty, nullptr); EXPECT_TRUE(ptr_ty->type()->Is()); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kPrivate); @@ -772,9 +772,9 @@ TEST_F(SpvParserTest, ConvertType_PointerFunction) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(3); - EXPECT_TRUE(type.ast->Is()); - auto* ptr_ty = type.ast->As(); + auto* type = p->ConvertType(3); + EXPECT_TRUE(type->Is()); + auto* ptr_ty = type->As(); EXPECT_NE(ptr_ty, nullptr); EXPECT_TRUE(ptr_ty->type()->Is()); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kFunction); @@ -790,11 +790,11 @@ TEST_F(SpvParserTest, ConvertType_PointerToPointer) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(3); + auto* type = p->ConvertType(3); EXPECT_NE(type, nullptr); - EXPECT_TRUE(type.ast->Is()); + EXPECT_TRUE(type->Is()); - auto* ptr_ty = type.ast->As(); + auto* ptr_ty = type->As(); EXPECT_NE(ptr_ty, nullptr); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kInput); EXPECT_TRUE(ptr_ty->type()->Is()); @@ -814,8 +814,8 @@ TEST_F(SpvParserTest, ConvertType_Sampler_PretendVoid) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(1); - EXPECT_TRUE(type.ast->Is()); + auto* type = p->ConvertType(1); + EXPECT_TRUE(type->Is()); EXPECT_TRUE(p->error().empty()); } @@ -827,8 +827,8 @@ TEST_F(SpvParserTest, ConvertType_Image_PretendVoid) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(1); - EXPECT_TRUE(type.ast->Is()); + auto* type = p->ConvertType(1); + EXPECT_TRUE(type->Is()); EXPECT_TRUE(p->error().empty()); } @@ -840,8 +840,8 @@ TEST_F(SpvParserTest, ConvertType_SampledImage_PretendVoid) { )")); EXPECT_TRUE(p->BuildInternalModule()); - auto type = p->ConvertType(1); - EXPECT_TRUE(type.ast->Is()); + auto* type = p->ConvertType(1); + EXPECT_TRUE(type->Is()); EXPECT_TRUE(p->error().empty()); } diff --git a/src/reader/spirv/parser_impl_test_helper.h b/src/reader/spirv/parser_impl_test_helper.h index 1794f7fd96..bfa5ca6369 100644 --- a/src/reader/spirv/parser_impl_test_helper.h +++ b/src/reader/spirv/parser_impl_test_helper.h @@ -158,7 +158,7 @@ class ParserImplWrapperForTest { /// after the internal representation of the module has been built. /// @param id the SPIR-V ID of a type. /// @returns a Tint type, or nullptr - typ::Type ConvertType(uint32_t id) { return impl_.ConvertType(id); } + ast::Type* ConvertType(uint32_t id) { return impl_.ConvertType(id); } /// Gets the list of decorations for a SPIR-V result ID. Returns an empty /// vector if the ID is not a result ID, or if no decorations target that ID. diff --git a/src/typepair.h b/src/typepair.h index 4fe005a699..422e72f344 100644 --- a/src/typepair.h +++ b/src/typepair.h @@ -271,27 +271,6 @@ inline auto MakeTypePair(AST* ast, SEM* sem) { return TypePair{ast, sem}; } -/// Performs an As operation on the `ast` and `sem` members of the input type -/// pair, deducing the mapped type from typ::* to ast::* and sem::* -/// respectively. -/// @param tp the type pair to call As on -/// @returns a new type pair after As has been called on each of `sem` and `ast` -template -auto As(TypePair tp) - -> TypePair { - return MakeTypePair( - tp.ast ? tp.ast->template As() : nullptr, - tp.sem ? tp.sem->template As() : nullptr); -} - -/// Invokes the `type()` member function on each of `ast` and `sem` of the input -/// type pair -/// @param tp the type pair -/// @returns a type pair with the result of calling `type()` on `ast` and `sem` -template -TypePair Call_type(TypePair tp) { - return MakeTypePair(tp.ast->type(), tp.sem->type()); -} } // namespace typ