spirv reader: replace typ::Type with ast::Type

Bug: tint:724
Change-Id: Idaf807dd1ff75af8e0044731e7362c0915ae7e54
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50200
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2021-05-06 21:23:13 +00:00 committed by Commit Bot service account
parent 467184fb06
commit a2580d6720
7 changed files with 408 additions and 440 deletions

View File

@ -685,13 +685,6 @@ struct LoopStatementBuilder
ast::BlockStatement* continuing = nullptr; ast::BlockStatement* continuing = nullptr;
}; };
// Forwards UnwrapAll to both the ast and sem types of the TypePair
// @param tp the type pair
// @returns the unwrapped type pair
typ::Type UnwrapAll(typ::Type tp) {
return typ::Type{tp.ast->UnwrapAll(), tp.sem->UnwrapAll()};
}
} // namespace } // namespace
BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb) BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb)
@ -907,7 +900,7 @@ bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) {
// Surprisingly, the "type id" on an OpFunction is the result type of the // Surprisingly, the "type id" on an OpFunction is the result type of the
// function, not the type of the function. This is the one exceptional case // function, not the type of the function. This is the one exceptional case
// in SPIR-V where the type ID is not the type of the result ID. // in SPIR-V where the type ID is not the type of the result ID.
auto ret_ty = parser_impl_.ConvertType(function_.type_id()); auto* ret_ty = parser_impl_.ConvertType(function_.type_id());
if (failed()) { if (failed()) {
return false; return false;
} }
@ -920,7 +913,7 @@ bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) {
ast::VariableList ast_params; ast::VariableList ast_params;
function_.ForEachParam( function_.ForEachParam(
[this, &ast_params](const spvtools::opt::Instruction* param) { [this, &ast_params](const spvtools::opt::Instruction* param) {
auto ast_type = parser_impl_.ConvertType(param->type_id()); auto* ast_type = parser_impl_.ConvertType(param->type_id());
if (ast_type != nullptr) { if (ast_type != nullptr) {
auto* ast_param = parser_impl_.MakeVariable( auto* ast_param = parser_impl_.MakeVariable(
param->result_id(), ast::StorageClass::kNone, ast_type, true, param->result_id(), ast::StorageClass::kNone, ast_type, true,
@ -950,7 +943,7 @@ bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) {
return success(); return success();
} }
typ::Type FunctionEmitter::GetVariableStoreType( ast::Type* FunctionEmitter::GetVariableStoreType(
const spvtools::opt::Instruction& var_decl_inst) { const spvtools::opt::Instruction& var_decl_inst) {
const auto type_id = var_decl_inst.type_id(); const auto type_id = var_decl_inst.type_id();
auto* var_ref_type = type_mgr_->GetType(type_id); auto* var_ref_type = type_mgr_->GetType(type_id);
@ -2013,7 +2006,7 @@ bool FunctionEmitter::EmitFunctionVariables() {
if (inst.opcode() != SpvOpVariable) { if (inst.opcode() != SpvOpVariable) {
continue; continue;
} }
auto var_store_type = GetVariableStoreType(inst); auto* var_store_type = GetVariableStoreType(inst);
if (failed()) { if (failed()) {
return false; return false;
} }
@ -2049,7 +2042,7 @@ TypedExpression FunctionEmitter::MakeExpression(uint32_t id) {
<< id; << id;
return {}; return {};
case SkipReason::kPointSizeBuiltinValue: { case SkipReason::kPointSizeBuiltinValue: {
return {create<sem::F32>(), return {create<ast::F32>(),
create<ast::ScalarConstructorExpression>( create<ast::ScalarConstructorExpression>(
Source{}, create<ast::FloatLiteral>(Source{}, 1.0f))}; Source{}, create<ast::FloatLiteral>(Source{}, 1.0f))};
} }
@ -2660,7 +2653,7 @@ bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) {
if (result_type->AsVoid() != nullptr) { if (result_type->AsVoid() != nullptr) {
AddStatement(create<ast::ReturnStatement>(Source{})); AddStatement(create<ast::ReturnStatement>(Source{}));
} else { } else {
auto ast_type = parser_impl_.ConvertType(function_.type_id()); auto* ast_type = parser_impl_.ConvertType(function_.type_id());
AddStatement(create<ast::ReturnStatement>( AddStatement(create<ast::ReturnStatement>(
Source{}, parser_impl_.MakeNullValue(ast_type))); Source{}, parser_impl_.MakeNullValue(ast_type)));
} }
@ -2905,7 +2898,7 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
for (auto id : sorted_by_index(block_info.hoisted_ids)) { for (auto id : sorted_by_index(block_info.hoisted_ids)) {
const auto* def_inst = def_use_mgr_->GetDef(id); const auto* def_inst = def_use_mgr_->GetDef(id);
TINT_ASSERT(def_inst); TINT_ASSERT(def_inst);
auto ast_type = auto* ast_type =
RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id); RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id);
AddStatement(create<ast::VariableDeclStatement>( AddStatement(create<ast::VariableDeclStatement>(
Source{}, Source{},
@ -3109,7 +3102,7 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
case SkipReason::kSampleMaskOutBuiltinPointer: case SkipReason::kSampleMaskOutBuiltinPointer:
ptr_id = sample_mask_out_id; ptr_id = sample_mask_out_id;
if (rhs.type != builder_.ty.u32()) { if (!rhs.type->Is<ast::U32>()) {
// WGSL requires sample_mask_out to be signed. // WGSL requires sample_mask_out to be signed.
rhs = TypedExpression{builder_.ty.u32(), rhs = TypedExpression{builder_.ty.u32(),
create<ast::TypeConstructorExpression>( create<ast::TypeConstructorExpression>(
@ -3164,12 +3157,12 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
auto name = namer_.Name(sample_mask_in_id); auto name = namer_.Name(sample_mask_in_id);
ast::Expression* id_expr = create<ast::IdentifierExpression>( ast::Expression* id_expr = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(name)); Source{}, builder_.Symbols().Register(name));
auto load_result_type = parser_impl_.ConvertType(inst.type_id()); auto* load_result_type = parser_impl_.ConvertType(inst.type_id());
ast::Expression* ast_expr = nullptr; ast::Expression* ast_expr = nullptr;
if (load_result_type == builder_.ty.i32()) { if (load_result_type->Is<ast::I32>()) {
ast_expr = create<ast::TypeConstructorExpression>( ast_expr = create<ast::TypeConstructorExpression>(
Source{}, builder_.ty.i32(), ast::ExpressionList{id_expr}); Source{}, builder_.ty.i32(), ast::ExpressionList{id_expr});
} else if (load_result_type == builder_.ty.u32()) { } else if (load_result_type->Is<ast::U32>()) {
ast_expr = id_expr; ast_expr = id_expr;
} else { } else {
return Fail() << "loading the whole SampleMask input array is not " return Fail() << "loading the whole SampleMask input array is not "
@ -3188,8 +3181,8 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
} }
// The load result type is the pointee type of its operand. // The load result type is the pointee type of its operand.
TINT_ASSERT(expr.type.ast->Is<ast::Pointer>()); TINT_ASSERT(expr.type->Is<ast::Pointer>());
expr.type = typ::Call_type(typ::As<typ::Pointer>(expr.type)); expr.type = expr.type->As<ast::Pointer>()->type();
return EmitConstDefOrWriteToHoistedVar(inst, expr); return EmitConstDefOrWriteToHoistedVar(inst, expr);
} }
@ -3204,7 +3197,7 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
return true; return true;
} }
auto expr = MakeExpression(value_id); auto expr = MakeExpression(value_id);
if (!expr.type.ast || !expr.expr) { if (!expr.type || !expr.expr) {
return false; return false;
} }
expr.type = RemapStorageClass(expr.type, result_id); expr.type = RemapStorageClass(expr.type, result_id);
@ -3291,7 +3284,7 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
const auto opcode = inst.opcode(); const auto opcode = inst.opcode();
typ::Type ast_type = ast::Type* ast_type =
inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr; inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr;
auto binary_op = ConvertBinaryOp(opcode); auto binary_op = ConvertBinaryOp(opcode);
@ -3464,7 +3457,7 @@ TypedExpression FunctionEmitter::EmitGlslStd450ExtInst(
auto* func = create<ast::IdentifierExpression>( auto* func = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(name)); Source{}, builder_.Symbols().Register(name));
ast::ExpressionList operands; ast::ExpressionList operands;
typ::Type first_operand_type = nullptr; ast::Type* first_operand_type = nullptr;
// All parameters to GLSL.std.450 extended instructions are IDs. // All parameters to GLSL.std.450 extended instructions are IDs.
for (uint32_t iarg = 2; iarg < inst.NumInOperands(); ++iarg) { for (uint32_t iarg = 2; iarg < inst.NumInOperands(); ++iarg) {
TypedExpression operand = MakeOperand(inst, iarg); TypedExpression operand = MakeOperand(inst, iarg);
@ -3473,7 +3466,7 @@ TypedExpression FunctionEmitter::EmitGlslStd450ExtInst(
} }
operands.emplace_back(operand.expr); operands.emplace_back(operand.expr);
} }
auto ast_type = parser_impl_.ConvertType(inst.type_id()); auto* ast_type = parser_impl_.ConvertType(inst.type_id());
auto* call = create<ast::CallExpression>(Source{}, func, std::move(operands)); auto* call = create<ast::CallExpression>(Source{}, func, std::move(operands));
TypedExpression call_expr{ast_type, call}; TypedExpression call_expr{ast_type, call};
return parser_impl_.RectifyForcedResultType(call_expr, inst, return parser_impl_.RectifyForcedResultType(call_expr, inst,
@ -3708,9 +3701,9 @@ TypedExpression FunctionEmitter::MakeAccessChain(
} }
const auto pointer_type_id = const auto pointer_type_id =
type_mgr_->FindPointerToType(pointee_type_id, storage_class); type_mgr_->FindPointerToType(pointee_type_id, storage_class);
auto ast_pointer_type = parser_impl_.ConvertType(pointer_type_id); auto* ast_pointer_type = parser_impl_.ConvertType(pointer_type_id);
TINT_ASSERT(ast_pointer_type.ast); TINT_ASSERT(ast_pointer_type);
TINT_ASSERT(ast_pointer_type.ast->Is<ast::Pointer>()); TINT_ASSERT(ast_pointer_type->Is<ast::Pointer>());
current_expr = TypedExpression{ast_pointer_type, next_expr}; current_expr = TypedExpression{ast_pointer_type, next_expr};
} }
return current_expr; return current_expr;
@ -3894,8 +3887,8 @@ TypedExpression FunctionEmitter::MakeVectorShuffle(
// Generate an ast::TypeConstructor expression. // Generate an ast::TypeConstructor expression.
// Assume the literal indices are valid, and there is a valid number of them. // Assume the literal indices are valid, and there is a valid number of them.
auto source = GetSourceForInst(inst); auto source = GetSourceForInst(inst);
typ::Vector result_type = ast::Vector* result_type =
typ::As<typ::Vector>(parser_impl_.ConvertType(inst.type_id())); parser_impl_.ConvertType(inst.type_id())->As<ast::Vector>();
ast::ExpressionList values; ast::ExpressionList values;
for (uint32_t i = 2; i < inst.NumInOperands(); ++i) { for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
const auto index = inst.GetSingleWordInOperand(i); const auto index = inst.GetSingleWordInOperand(i);
@ -3917,8 +3910,7 @@ TypedExpression FunctionEmitter::MakeVectorShuffle(
source, expr.expr, Swizzle(sub_index))); source, expr.expr, Swizzle(sub_index)));
} else if (index == 0xFFFFFFFF) { } else if (index == 0xFFFFFFFF) {
// By rule, this maps to OpUndef. Instead, make it zero. // By rule, this maps to OpUndef. Instead, make it zero.
values.emplace_back( values.emplace_back(parser_impl_.MakeNullValue(result_type->type()));
parser_impl_.MakeNullValue(typ::Call_type(result_type)));
} else { } else {
Fail() << "invalid vectorshuffle ID %" << inst.result_id() Fail() << "invalid vectorshuffle ID %" << inst.result_id()
<< ": index too large: " << index; << ": index too large: " << index;
@ -3995,8 +3987,8 @@ bool FunctionEmitter::RegisterLocallyDefinedValues() {
const auto* type = type_mgr_->GetType(inst.type_id()); const auto* type = type_mgr_->GetType(inst.type_id());
if (type) { if (type) {
if (type->AsPointer()) { if (type->AsPointer()) {
if (auto ast_type = parser_impl_.ConvertType(inst.type_id())) { if (auto* ast_type = parser_impl_.ConvertType(inst.type_id())) {
if (auto* ptr = ast_type.ast->As<ast::Pointer>()) { if (auto* ptr = ast_type->As<ast::Pointer>()) {
info->storage_class = ptr->storage_class(); info->storage_class = ptr->storage_class();
} }
} }
@ -4040,22 +4032,22 @@ ast::StorageClass FunctionEmitter::GetStorageClassForPointerValue(uint32_t id) {
} }
const auto type_id = def_use_mgr_->GetDef(id)->type_id(); const auto type_id = def_use_mgr_->GetDef(id)->type_id();
if (type_id) { if (type_id) {
auto ast_type = parser_impl_.ConvertType(type_id); auto* ast_type = parser_impl_.ConvertType(type_id);
if (auto ptr = typ::As<typ::Pointer>(ast_type)) { if (auto* ptr = ast_type->As<ast::Pointer>()) {
return ptr.ast->storage_class(); return ptr->storage_class();
} }
} }
return ast::StorageClass::kNone; return ast::StorageClass::kNone;
} }
typ::Type FunctionEmitter::RemapStorageClass(typ::Type type, ast::Type* FunctionEmitter::RemapStorageClass(ast::Type* type,
uint32_t result_id) { uint32_t result_id) {
if (auto ast_ptr_type = typ::As<typ::Pointer>(type)) { if (auto* ast_ptr_type = type->As<ast::Pointer>()) {
// Remap an old-style storage buffer pointer to a new-style storage // Remap an old-style storage buffer pointer to a new-style storage
// buffer pointer. // buffer pointer.
const auto sc = GetStorageClassForPointerValue(result_id); const auto sc = GetStorageClassForPointerValue(result_id);
if (ast_ptr_type.ast->storage_class() != sc) { if (ast_ptr_type->storage_class() != sc) {
return builder_.ty.pointer(typ::Call_type(ast_ptr_type), sc); return builder_.ty.pointer(ast_ptr_type->type(), sc);
} }
} }
return type; return type;
@ -4232,13 +4224,13 @@ const Construct* FunctionEmitter::GetEnclosingScope(uint32_t first_pos,
TypedExpression FunctionEmitter::MakeNumericConversion( TypedExpression FunctionEmitter::MakeNumericConversion(
const spvtools::opt::Instruction& inst) { const spvtools::opt::Instruction& inst) {
const auto opcode = inst.opcode(); const auto opcode = inst.opcode();
auto requested_type = parser_impl_.ConvertType(inst.type_id()); auto* requested_type = parser_impl_.ConvertType(inst.type_id());
auto arg_expr = MakeOperand(inst, 0); auto arg_expr = MakeOperand(inst, 0);
if (!arg_expr.expr || !arg_expr.type) { if (!arg_expr.expr || !arg_expr.type) {
return {}; return {};
} }
typ::Type expr_type = nullptr; ast::Type* expr_type = nullptr;
if ((opcode == SpvOpConvertSToF) || (opcode == SpvOpConvertUToF)) { if ((opcode == SpvOpConvertSToF) || (opcode == SpvOpConvertUToF)) {
if (arg_expr.type->is_integer_scalar_or_vector()) { if (arg_expr.type->is_integer_scalar_or_vector()) {
expr_type = requested_type; expr_type = requested_type;
@ -4276,7 +4268,7 @@ TypedExpression FunctionEmitter::MakeNumericConversion(
Source{}, builder_.ty.MaybeCreateTypename(expr_type), Source{}, builder_.ty.MaybeCreateTypename(expr_type),
std::move(params))}; std::move(params))};
if (requested_type == expr_type) { if (AstTypesEquivalent(requested_type, expr_type)) {
return result; return result;
} }
return {requested_type, create<ast::BitcastExpression>( return {requested_type, create<ast::BitcastExpression>(
@ -4298,13 +4290,13 @@ bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) {
} }
auto* call_expr = auto* call_expr =
create<ast::CallExpression>(Source{}, function, std::move(params)); create<ast::CallExpression>(Source{}, function, std::move(params));
auto result_type = parser_impl_.ConvertType(inst.type_id()); auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (!result_type.ast) { if (!result_type) {
return Fail() << "internal error: no mapped type result of call: " return Fail() << "internal error: no mapped type result of call: "
<< inst.PrettyPrint(); << inst.PrettyPrint();
} }
if (result_type.ast->Is<ast::Void>()) { if (result_type->Is<ast::Void>()) {
return nullptr != return nullptr !=
AddStatement(create<ast::CallStatement>(Source{}, call_expr)); AddStatement(create<ast::CallStatement>(Source{}, call_expr));
} }
@ -4367,7 +4359,7 @@ TypedExpression FunctionEmitter::MakeIntrinsicCall(
Source{}, builder_.Symbols().Register(name)); Source{}, builder_.Symbols().Register(name));
ast::ExpressionList params; ast::ExpressionList params;
typ::Type first_operand_type = nullptr; ast::Type* first_operand_type = nullptr;
for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) { for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
TypedExpression operand = MakeOperand(inst, iarg); TypedExpression operand = MakeOperand(inst, iarg);
if (first_operand_type == nullptr) { if (first_operand_type == nullptr) {
@ -4377,8 +4369,8 @@ TypedExpression FunctionEmitter::MakeIntrinsicCall(
} }
auto* call_expr = auto* call_expr =
create<ast::CallExpression>(Source{}, ident, std::move(params)); create<ast::CallExpression>(Source{}, ident, std::move(params));
auto result_type = parser_impl_.ConvertType(inst.type_id()); auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (!result_type.ast) { if (!result_type) {
Fail() << "internal error: no mapped type result of call: " Fail() << "internal error: no mapped type result of call: "
<< inst.PrettyPrint(); << inst.PrettyPrint();
return {}; return {};
@ -4398,7 +4390,7 @@ TypedExpression FunctionEmitter::MakeSimpleSelect(
// - operand1, operand2, and result type to match. // - operand1, operand2, and result type to match.
// - you can't select over pointers or pointer vectors, unless you also have // - you can't select over pointers or pointer vectors, unless you also have
// a VariablePointers* capability, which is not allowed in by WebGPU. // a VariablePointers* capability, which is not allowed in by WebGPU.
auto* op_ty = operand1.type.ast; auto* op_ty = operand1.type;
if (op_ty->Is<ast::Vector>() || op_ty->is_float_scalar() || if (op_ty->Is<ast::Vector>() || op_ty->is_float_scalar() ||
op_ty->is_integer_scalar() || op_ty->Is<ast::Bool>()) { op_ty->is_integer_scalar() || op_ty->Is<ast::Bool>()) {
ast::ExpressionList params; ast::ExpressionList params;
@ -4438,9 +4430,9 @@ const spvtools::opt::Instruction* FunctionEmitter::GetImage(
return image; return image;
} }
typ::Texture FunctionEmitter::GetImageType( ast::Texture* FunctionEmitter::GetImageType(
const spvtools::opt::Instruction& image) { const spvtools::opt::Instruction& image) {
typ::Pointer ptr_type = parser_impl_.GetTypeForHandleVar(image); ast::Pointer* ptr_type = parser_impl_.GetTypeForHandleVar(image);
if (!parser_impl_.success()) { if (!parser_impl_.success()) {
Fail(); Fail();
return {}; return {};
@ -4449,7 +4441,7 @@ typ::Texture FunctionEmitter::GetImageType(
Fail() << "invalid texture type for " << image.PrettyPrint(); Fail() << "invalid texture type for " << image.PrettyPrint();
return {}; return {};
} }
auto result = typ::As<typ::Texture>(UnwrapAll(typ::Call_type(ptr_type))); auto* result = ptr_type->type()->UnwrapAll()->As<ast::Texture>();
if (!result) { if (!result) {
Fail() << "invalid texture type for " << image.PrettyPrint(); Fail() << "invalid texture type for " << image.PrettyPrint();
return {}; return {};
@ -4504,14 +4496,14 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
} }
} }
typ::Pointer texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image); ast::Pointer* texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image);
if (!texture_ptr_type.ast) { if (!texture_ptr_type) {
return Fail(); return Fail();
} }
typ::Texture texture_type = ast::Texture* texture_type =
typ::As<typ::Texture>(UnwrapAll(typ::Call_type(texture_ptr_type))); texture_ptr_type->type()->UnwrapAll()->As<ast::Texture>();
if (!texture_type.ast) { if (!texture_type) {
return Fail(); return Fail();
} }
@ -4612,7 +4604,7 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
} }
TypedExpression lod = MakeOperand(inst, arg_index); TypedExpression lod = MakeOperand(inst, arg_index);
// When sampling from a depth texture, the Lod operand must be an I32. // When sampling from a depth texture, the Lod operand must be an I32.
if (texture_type.ast->Is<ast::DepthTexture>()) { if (texture_type->Is<ast::DepthTexture>()) {
// Convert it to a signed integer type. // Convert it to a signed integer type.
lod = ToI32(lod); lod = ToI32(lod);
} }
@ -4620,8 +4612,8 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
image_operands_mask ^= SpvImageOperandsLodMask; image_operands_mask ^= SpvImageOperandsLodMask;
arg_index++; arg_index++;
} else if ((opcode == SpvOpImageFetch) && } else if ((opcode == SpvOpImageFetch) &&
(texture_type.ast->Is<ast::SampledTexture>() || (texture_type->Is<ast::SampledTexture>() ||
texture_type.ast->Is<ast::DepthTexture>())) { texture_type->Is<ast::DepthTexture>())) {
// textureLoad on sampled texture and depth texture requires an explicit // textureLoad on sampled texture and depth texture requires an explicit
// level-of-detail parameter. // level-of-detail parameter.
params.push_back(parser_impl_.MakeNullValue(builder_.ty.i32())); params.push_back(parser_impl_.MakeNullValue(builder_.ty.i32()));
@ -4682,10 +4674,10 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
ast::Expression* value = call_expr; ast::Expression* value = call_expr;
// The result type, derived from the SPIR-V instruction. // The result type, derived from the SPIR-V instruction.
auto result_type = parser_impl_.ConvertType(inst.type_id()); auto* result_type = parser_impl_.ConvertType(inst.type_id());
auto result_component_type = result_type; auto* result_component_type = result_type;
if (auto result_vector_type = typ::As<typ::Vector>(result_type)) { if (auto* result_vector_type = result_type->As<ast::Vector>()) {
result_component_type = typ::Call_type(result_vector_type); result_component_type = result_vector_type->type();
} }
// For depth textures, the arity might mot match WGSL: // For depth textures, the arity might mot match WGSL:
@ -4699,7 +4691,7 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
// dref gather vec4 ImageFetch vec4 TODO(dneto) // dref gather vec4 ImageFetch vec4 TODO(dneto)
// Construct a 4-element vector with the result from the builtin in the // Construct a 4-element vector with the result from the builtin in the
// first component. // first component.
if (texture_type.ast->Is<ast::DepthTexture>()) { if (texture_type->Is<ast::DepthTexture>()) {
if (is_non_dref_sample || (opcode == SpvOpImageFetch)) { if (is_non_dref_sample || (opcode == SpvOpImageFetch)) {
value = create<ast::TypeConstructorExpression>( value = create<ast::TypeConstructorExpression>(
Source{}, Source{},
@ -4720,14 +4712,14 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
return Fail() << "invalid image type for image memory object declaration " return Fail() << "invalid image type for image memory object declaration "
<< image->PrettyPrint(); << image->PrettyPrint();
} }
auto expected_component_type = auto* expected_component_type =
parser_impl_.ConvertType(spirv_image_type->GetSingleWordInOperand(0)); parser_impl_.ConvertType(spirv_image_type->GetSingleWordInOperand(0));
if (expected_component_type != result_component_type) { if (!AstTypesEquivalent(expected_component_type, result_component_type)) {
// This occurs if one is signed integer and the other is unsigned integer, // This occurs if one is signed integer and the other is unsigned integer,
// or vice versa. Perform a bitcast. // or vice versa. Perform a bitcast.
value = create<ast::BitcastExpression>(Source{}, result_type, call_expr); value = create<ast::BitcastExpression>(Source{}, result_type, call_expr);
} }
if (!expected_component_type.ast->Is<ast::F32>() && if (!expected_component_type->Is<ast::F32>() &&
IsSampledImageAccess(opcode)) { IsSampledImageAccess(opcode)) {
// WGSL permits sampled image access only on float textures. // WGSL permits sampled image access only on float textures.
// Reject this case in the SPIR-V reader, at least until SPIR-V validation // Reject this case in the SPIR-V reader, at least until SPIR-V validation
@ -4750,7 +4742,7 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
if (!image) { if (!image) {
return false; return false;
} }
auto texture_type = GetImageType(*image); auto* texture_type = GetImageType(*image);
if (!texture_type) { if (!texture_type) {
return false; return false;
} }
@ -4778,7 +4770,7 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
Source{}, layers_ident, Source{}, layers_ident,
ast::ExpressionList{GetImageExpression(inst)})); ast::ExpressionList{GetImageExpression(inst)}));
} }
auto result_type = parser_impl_.ConvertType(inst.type_id()); auto* result_type = parser_impl_.ConvertType(inst.type_id());
TypedExpression expr = { TypedExpression expr = {
result_type, result_type,
create<ast::TypeConstructorExpression>( create<ast::TypeConstructorExpression>(
@ -4799,10 +4791,10 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
ast::Expression* ast_expr = create<ast::CallExpression>( ast::Expression* ast_expr = create<ast::CallExpression>(
Source{}, levels_ident, Source{}, levels_ident,
ast::ExpressionList{GetImageExpression(inst)}); ast::ExpressionList{GetImageExpression(inst)});
auto result_type = parser_impl_.ConvertType(inst.type_id()); auto* result_type = parser_impl_.ConvertType(inst.type_id());
// The SPIR-V result type must be integer scalar. The WGSL bulitin // The SPIR-V result type must be integer scalar. The WGSL bulitin
// returns i32. If they aren't the same then convert the result. // returns i32. If they aren't the same then convert the result.
if (result_type != builder_.ty.i32()) { if (!result_type->Is<ast::I32>()) {
ast_expr = create<ast::TypeConstructorExpression>( ast_expr = create<ast::TypeConstructorExpression>(
Source{}, builder_.ty.MaybeCreateTypename(result_type), Source{}, builder_.ty.MaybeCreateTypename(result_type),
ast::ExpressionList{ast_expr}); ast::ExpressionList{ast_expr});
@ -4848,7 +4840,7 @@ ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess(
if (!raw_coords.type) { if (!raw_coords.type) {
return {}; return {};
} }
typ::Texture texture_type = GetImageType(*image); ast::Texture* texture_type = GetImageType(*image);
if (!texture_type) { if (!texture_type) {
return {}; return {};
} }
@ -4863,12 +4855,12 @@ ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess(
} }
const auto num_coords_required = num_axes + (is_arrayed ? 1 : 0); const auto num_coords_required = num_axes + (is_arrayed ? 1 : 0);
uint32_t num_coords_supplied = 0; uint32_t num_coords_supplied = 0;
auto component_type = raw_coords.type; auto* component_type = raw_coords.type;
if (component_type->is_float_scalar() || if (component_type->is_float_scalar() ||
component_type->is_integer_scalar()) { component_type->is_integer_scalar()) {
num_coords_supplied = 1; num_coords_supplied = 1;
} else if (auto vec_type = typ::As<typ::Vector>(raw_coords.type)) { } else if (auto* vec_type = raw_coords.type->As<ast::Vector>()) {
component_type = typ::Call_type(vec_type); component_type = vec_type->type();
num_coords_supplied = vec_type->size(); num_coords_supplied = vec_type->size();
} }
if (num_coords_supplied == 0) { if (num_coords_supplied == 0) {
@ -4892,9 +4884,10 @@ ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess(
// will actually use them. // will actually use them.
auto prefix_swizzle_expr = [this, num_axes, component_type, auto prefix_swizzle_expr = [this, num_axes, component_type,
raw_coords]() -> ast::Expression* { raw_coords]() -> ast::Expression* {
auto swizzle_type = auto* swizzle_type = (num_axes == 1)
(num_axes == 1) ? component_type ? component_type
: typ::Type{builder_.ty.vec(component_type, num_axes)}; : static_cast<ast::Type*>(
builder_.ty.vec(component_type, num_axes));
auto* swizzle = create<ast::MemberAccessorExpression>( auto* swizzle = create<ast::MemberAccessorExpression>(
Source{}, raw_coords.expr, PrefixSwizzle(num_axes)); Source{}, raw_coords.expr, PrefixSwizzle(num_axes));
return ToSignedIfUnsigned({swizzle_type, swizzle}).expr; return ToSignedIfUnsigned({swizzle_type, swizzle}).expr;
@ -4928,32 +4921,32 @@ ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess(
ast::Expression* FunctionEmitter::ConvertTexelForStorage( ast::Expression* FunctionEmitter::ConvertTexelForStorage(
const spvtools::opt::Instruction& inst, const spvtools::opt::Instruction& inst,
TypedExpression texel, TypedExpression texel,
typ::Texture texture_type) { ast::Texture* texture_type) {
auto storage_texture_type = typ::As<typ::StorageTexture>(texture_type); auto* storage_texture_type = texture_type->As<ast::StorageTexture>();
auto src_type = texel.type; auto* src_type = texel.type;
if (!storage_texture_type.ast) { if (!storage_texture_type) {
Fail() << "writing to other than storage texture: " << inst.PrettyPrint(); Fail() << "writing to other than storage texture: " << inst.PrettyPrint();
return nullptr; return nullptr;
} }
const auto format = storage_texture_type.ast->image_format(); const auto format = storage_texture_type->image_format();
auto dest_type = parser_impl_.GetTexelTypeForFormat(format); auto* dest_type = parser_impl_.GetTexelTypeForFormat(format);
if (!dest_type.ast) { if (!dest_type) {
Fail(); Fail();
return nullptr; return nullptr;
} }
if (src_type == dest_type) { if (AstTypesEquivalent(src_type, dest_type)) {
return texel.expr; return texel.expr;
} }
const uint32_t dest_count = const uint32_t dest_count =
dest_type.ast->is_scalar() ? 1 : dest_type.ast->As<ast::Vector>()->size(); dest_type->is_scalar() ? 1 : dest_type->As<ast::Vector>()->size();
if (dest_count == 3) { if (dest_count == 3) {
Fail() << "3-channel storage textures are not supported: " Fail() << "3-channel storage textures are not supported: "
<< inst.PrettyPrint(); << inst.PrettyPrint();
return nullptr; return nullptr;
} }
const uint32_t src_count = const uint32_t src_count =
src_type.ast->is_scalar() ? 1 : src_type.ast->As<ast::Vector>()->size(); src_type->is_scalar() ? 1 : src_type->As<ast::Vector>()->size();
if (src_count < dest_count) { if (src_count < dest_count) {
Fail() << "texel has too few components for storage texture: " << src_count Fail() << "texel has too few components for storage texture: " << src_count
<< " provided but " << dest_count << " provided but " << dest_count
@ -4968,29 +4961,29 @@ ast::Expression* FunctionEmitter::ConvertTexelForStorage(
: create<ast::MemberAccessorExpression>(Source{}, texel.expr, : create<ast::MemberAccessorExpression>(Source{}, texel.expr,
PrefixSwizzle(dest_count)); PrefixSwizzle(dest_count));
if (!(dest_type.ast->is_float_scalar_or_vector() || if (!(dest_type->is_float_scalar_or_vector() ||
dest_type.ast->is_unsigned_scalar_or_vector() || dest_type->is_unsigned_scalar_or_vector() ||
dest_type.ast->is_signed_scalar_or_vector())) { dest_type->is_signed_scalar_or_vector())) {
Fail() << "invalid destination type for storage texture write: " Fail() << "invalid destination type for storage texture write: "
<< dest_type.ast->type_name(); << dest_type->type_name();
return nullptr; return nullptr;
} }
if (!(src_type.ast->is_float_scalar_or_vector() || if (!(src_type->is_float_scalar_or_vector() ||
src_type.ast->is_unsigned_scalar_or_vector() || src_type->is_unsigned_scalar_or_vector() ||
src_type.ast->is_signed_scalar_or_vector())) { src_type->is_signed_scalar_or_vector())) {
Fail() << "invalid texel type for storage texture write: " Fail() << "invalid texel type for storage texture write: "
<< inst.PrettyPrint(); << inst.PrettyPrint();
return nullptr; return nullptr;
} }
if (dest_type.ast->is_float_scalar_or_vector() && if (dest_type->is_float_scalar_or_vector() &&
!src_type.ast->is_float_scalar_or_vector()) { !src_type->is_float_scalar_or_vector()) {
Fail() << "can only write float or float vector to a storage image with " Fail() << "can only write float or float vector to a storage image with "
"floating texel format: " "floating texel format: "
<< inst.PrettyPrint(); << inst.PrettyPrint();
return nullptr; return nullptr;
} }
if (!dest_type.ast->is_float_scalar_or_vector() && if (!dest_type->is_float_scalar_or_vector() &&
src_type.ast->is_float_scalar_or_vector()) { src_type->is_float_scalar_or_vector()) {
Fail() Fail()
<< "float or float vector can only be written to a storage image with " << "float or float vector can only be written to a storage image with "
"floating texel format: " "floating texel format: "
@ -4998,13 +4991,13 @@ ast::Expression* FunctionEmitter::ConvertTexelForStorage(
return nullptr; return nullptr;
} }
if (dest_type.ast->is_float_scalar_or_vector()) { if (dest_type->is_float_scalar_or_vector()) {
return texel_prefix; return texel_prefix;
} }
// The only remaining cases are signed/unsigned source, and signed/unsigned // The only remaining cases are signed/unsigned source, and signed/unsigned
// destination. // destination.
if (dest_type.ast->is_unsigned_scalar_or_vector() == if (dest_type->is_unsigned_scalar_or_vector() ==
src_type.ast->is_unsigned_scalar_or_vector()) { src_type->is_unsigned_scalar_or_vector()) {
return texel_prefix; return texel_prefix;
} }
// We must do a bitcast conversion. // We must do a bitcast conversion.
@ -5012,7 +5005,7 @@ ast::Expression* FunctionEmitter::ConvertTexelForStorage(
} }
TypedExpression FunctionEmitter::ToI32(TypedExpression value) { TypedExpression FunctionEmitter::ToI32(TypedExpression value) {
if (!value.type || value.type == builder_.ty.i32()) { if (!value.type || value.type->Is<ast::I32>()) {
return value; return value;
} }
return {builder_.ty.i32(), return {builder_.ty.i32(),
@ -5024,7 +5017,7 @@ TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) {
if (!value.type || !value.type->is_unsigned_scalar_or_vector()) { if (!value.type || !value.type->is_unsigned_scalar_or_vector()) {
return value; return value;
} }
if (auto* vec_type = value.type.ast->As<ast::Vector>()) { if (auto* vec_type = value.type->As<ast::Vector>()) {
auto new_type = builder_.ty.vec(builder_.ty.i32(), vec_type->size()); auto new_type = builder_.ty.vec(builder_.ty.i32(), vec_type->size());
return {new_type, return {new_type,
builder_.Construct(new_type, ast::ExpressionList{value.expr})}; builder_.Construct(new_type, ast::ExpressionList{value.expr})};
@ -5080,12 +5073,12 @@ TypedExpression FunctionEmitter::MakeOuterProduct(
// Synthesize the result. // Synthesize the result.
auto col = MakeOperand(inst, 0); auto col = MakeOperand(inst, 0);
auto row = MakeOperand(inst, 1); auto row = MakeOperand(inst, 1);
auto col_ty = typ::As<typ::Vector>(col.type); auto* col_ty = col.type->As<ast::Vector>();
auto row_ty = typ::As<typ::Vector>(row.type); auto* row_ty = row.type->As<ast::Vector>();
auto result_ty = auto* result_ty = parser_impl_.ConvertType(inst.type_id())->As<ast::Matrix>();
typ::As<typ::Matrix>(parser_impl_.ConvertType(inst.type_id())); if (!col_ty || !col_ty || !result_ty ||
if (!col_ty || !col_ty || !result_ty || result_ty->type() != col_ty->type() || !AstTypesEquivalent(result_ty->type(), col_ty->type()) ||
result_ty->type() != row_ty->type() || !AstTypesEquivalent(result_ty->type(), row_ty->type()) ||
result_ty->columns() != row_ty->size() || result_ty->columns() != row_ty->size() ||
result_ty->rows() != col_ty->size()) { result_ty->rows() != col_ty->size()) {
Fail() << "invalid outer product instruction: bad types " Fail() << "invalid outer product instruction: bad types "
@ -5135,7 +5128,7 @@ bool FunctionEmitter::MakeVectorInsertDynamic(
// Then use result everywhere the original SPIR-V id is used. Using a const // Then use result everywhere the original SPIR-V id is used. Using a const
// like this avoids constantly reloading the value many times. // like this avoids constantly reloading the value many times.
auto ast_type = parser_impl_.ConvertType(inst.type_id()); auto* ast_type = parser_impl_.ConvertType(inst.type_id());
auto src_vector = MakeOperand(inst, 0); auto src_vector = MakeOperand(inst, 0);
auto component = MakeOperand(inst, 1); auto component = MakeOperand(inst, 1);
auto index = MakeOperand(inst, 2); auto index = MakeOperand(inst, 2);
@ -5183,7 +5176,7 @@ bool FunctionEmitter::MakeCompositeInsert(
// - building up an access-chain like access like for CompositeExtract, but // - building up an access-chain like access like for CompositeExtract, but
// on the left-hand side of the assignment. // on the left-hand side of the assignment.
auto ast_type = parser_impl_.ConvertType(inst.type_id()); auto* ast_type = parser_impl_.ConvertType(inst.type_id());
auto component = MakeOperand(inst, 0); auto component = MakeOperand(inst, 0);
auto src_composite = MakeOperand(inst, 1); auto src_composite = MakeOperand(inst, 1);

View File

@ -25,7 +25,6 @@
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/reader/spirv/construct.h" #include "src/reader/spirv/construct.h"
#include "src/reader/spirv/parser_impl.h" #include "src/reader/spirv/parser_impl.h"
#include "src/typepair.h"
namespace tint { namespace tint {
namespace reader { namespace reader {
@ -516,7 +515,7 @@ class FunctionEmitter {
/// @param type the AST type /// @param type the AST type
/// @param result_id the SPIR-V ID for the locally defined value /// @param result_id the SPIR-V ID for the locally defined value
/// @returns an possibly updated type /// @returns an possibly updated type
typ::Type RemapStorageClass(typ::Type type, uint32_t result_id); ast::Type* RemapStorageClass(ast::Type* type, uint32_t result_id);
/// Marks locally defined values when they should get a 'const' /// Marks locally defined values when they should get a 'const'
/// definition in WGSL, or a 'var' definition at an outer scope. /// definition in WGSL, or a 'var' definition at an outer scope.
@ -857,7 +856,7 @@ class FunctionEmitter {
/// Function parameters /// Function parameters
ast::VariableList params; ast::VariableList params;
/// Function return type /// Function return type
typ::Type return_type; ast::Type* return_type;
/// Function decorations /// Function decorations
ast::DecorationList decorations; ast::DecorationList decorations;
}; };
@ -870,7 +869,7 @@ class FunctionEmitter {
/// @returns the store type for the OpVariable instruction, or /// @returns the store type for the OpVariable instruction, or
/// null on failure. /// null on failure.
typ::Type GetVariableStoreType( ast::Type* GetVariableStoreType(
const spvtools::opt::Instruction& var_decl_inst); const spvtools::opt::Instruction& var_decl_inst);
/// Returns an expression for an instruction operand. Signedness conversion is /// Returns an expression for an instruction operand. Signedness conversion is
@ -938,7 +937,7 @@ class FunctionEmitter {
/// Get the AST texture the SPIR-V image memory object declaration. /// Get the AST texture the SPIR-V image memory object declaration.
/// @param inst the SPIR-V memory object declaration for the image. /// @param inst the SPIR-V memory object declaration for the image.
/// @returns a texture type, or null on error /// @returns a texture type, or null on error
typ::Texture GetImageType(const spvtools::opt::Instruction& inst); ast::Texture* GetImageType(const spvtools::opt::Instruction& inst);
/// Get the expression for the image operand from the first operand to the /// Get the expression for the image operand from the first operand to the
/// given instruction. /// given instruction.
@ -975,7 +974,7 @@ class FunctionEmitter {
ast::Expression* ConvertTexelForStorage( ast::Expression* ConvertTexelForStorage(
const spvtools::opt::Instruction& inst, const spvtools::opt::Instruction& inst,
TypedExpression texel, TypedExpression texel,
typ::Texture texture_type); ast::Texture* texture_type);
/// Returns an expression for an OpSelect, if its operands are scalars /// Returns an expression for an OpSelect, if its operands are scalars
/// or vectors. These translate directly to WGSL select. Otherwise, return /// or vectors. These translate directly to WGSL select. Otherwise, return

View File

@ -232,14 +232,6 @@ bool AssumesResultSignednessMatchesFirstOperand(GLSLstd450 extended_opcode) {
return false; return false;
} }
// Forwards UnwrapIfNeeded to both the ast and sem types of the TypePair
// @param tp the type pair
// @returns the unwrapped type pair
typ::Type UnwrapIfNeeded(typ::Type tp) {
return typ::Type{tp.ast ? tp.ast->UnwrapIfNeeded() : nullptr,
tp.sem ? tp.sem->UnwrapIfNeeded() : nullptr};
}
} // namespace } // namespace
TypedExpression::TypedExpression() = default; TypedExpression::TypedExpression() = default;
@ -248,7 +240,7 @@ TypedExpression::TypedExpression(const TypedExpression&) = default;
TypedExpression& TypedExpression::operator=(const TypedExpression&) = default; TypedExpression& TypedExpression::operator=(const TypedExpression&) = default;
TypedExpression::TypedExpression(typ::Type type_in, ast::Expression* expr_in) TypedExpression::TypedExpression(ast::Type* type_in, ast::Expression* expr_in)
: type(type_in), expr(expr_in) {} : type(type_in), expr(expr_in) {}
ParserImpl::ParserImpl(const std::vector<uint32_t>& spv_binary) ParserImpl::ParserImpl(const std::vector<uint32_t>& spv_binary)
@ -313,7 +305,7 @@ Program ParserImpl::program() {
return tint::Program(std::move(builder_)); return tint::Program(std::move(builder_));
} }
typ::Type ParserImpl::ConvertType(uint32_t type_id) { ast::Type* ParserImpl::ConvertType(uint32_t type_id) {
if (!success_) { if (!success_) {
return nullptr; return nullptr;
} }
@ -330,8 +322,8 @@ typ::Type ParserImpl::ConvertType(uint32_t type_id) {
} }
auto maybe_generate_alias = [this, type_id, auto maybe_generate_alias = [this, type_id,
spirv_type](typ::Type type) -> typ::Type { spirv_type](ast::Type* type) -> ast::Type* {
if (type.ast != nullptr) { if (type != nullptr) {
return MaybeGenerateAlias(type_id, spirv_type, type); return MaybeGenerateAlias(type_id, spirv_type, type);
} }
return {}; return {};
@ -782,17 +774,17 @@ bool ParserImpl::RegisterEntryPoints() {
return success_; return success_;
} }
typ::Type ParserImpl::ConvertType( ast::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Integer* int_ty) { const spvtools::opt::analysis::Integer* int_ty) {
if (int_ty->width() == 32) { if (int_ty->width() == 32) {
return int_ty->IsSigned() ? typ::Type{builder_.ty.i32()} return int_ty->IsSigned() ? static_cast<ast::Type*>(builder_.ty.i32())
: typ::Type{builder_.ty.u32()}; : static_cast<ast::Type*>(builder_.ty.u32());
} }
Fail() << "unhandled integer width: " << int_ty->width(); Fail() << "unhandled integer width: " << int_ty->width();
return nullptr; return nullptr;
} }
typ::Type ParserImpl::ConvertType( ast::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Float* float_ty) { const spvtools::opt::analysis::Float* float_ty) {
if (float_ty->width() == 32) { if (float_ty->width() == 32) {
return builder_.ty.f32(); return builder_.ty.f32();
@ -801,33 +793,33 @@ typ::Type ParserImpl::ConvertType(
return nullptr; return nullptr;
} }
typ::Type ParserImpl::ConvertType( ast::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Vector* vec_ty) { const spvtools::opt::analysis::Vector* vec_ty) {
const auto num_elem = vec_ty->element_count(); const auto num_elem = vec_ty->element_count();
auto ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type())); auto* ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type()));
if (ast_elem_ty.ast == nullptr) { if (ast_elem_ty == nullptr) {
return nullptr; return nullptr;
} }
return builder_.ty.vec(ast_elem_ty, num_elem); return builder_.ty.vec(ast_elem_ty, num_elem);
} }
typ::Type ParserImpl::ConvertType( ast::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Matrix* mat_ty) { const spvtools::opt::analysis::Matrix* mat_ty) {
const auto* vec_ty = mat_ty->element_type()->AsVector(); const auto* vec_ty = mat_ty->element_type()->AsVector();
const auto* scalar_ty = vec_ty->element_type(); const auto* scalar_ty = vec_ty->element_type();
const auto num_rows = vec_ty->element_count(); const auto num_rows = vec_ty->element_count();
const auto num_columns = mat_ty->element_count(); const auto num_columns = mat_ty->element_count();
auto ast_scalar_ty = ConvertType(type_mgr_->GetId(scalar_ty)); auto* ast_scalar_ty = ConvertType(type_mgr_->GetId(scalar_ty));
if (ast_scalar_ty.ast == nullptr) { if (ast_scalar_ty == nullptr) {
return nullptr; return nullptr;
} }
return builder_.ty.mat(ast_scalar_ty, num_columns, num_rows); return builder_.ty.mat(ast_scalar_ty, num_columns, num_rows);
} }
typ::Type ParserImpl::ConvertType( ast::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::RuntimeArray* rtarr_ty) { const spvtools::opt::analysis::RuntimeArray* rtarr_ty) {
auto ast_elem_ty = ConvertType(type_mgr_->GetId(rtarr_ty->element_type())); auto* ast_elem_ty = ConvertType(type_mgr_->GetId(rtarr_ty->element_type()));
if (ast_elem_ty.ast == nullptr) { if (ast_elem_ty == nullptr) {
return nullptr; return nullptr;
} }
ast::DecorationList decorations; ast::DecorationList decorations;
@ -837,11 +829,11 @@ typ::Type ParserImpl::ConvertType(
return builder_.ty.array(ast_elem_ty, 0, std::move(decorations)); return builder_.ty.array(ast_elem_ty, 0, std::move(decorations));
} }
typ::Type ParserImpl::ConvertType( ast::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Array* arr_ty) { const spvtools::opt::analysis::Array* arr_ty) {
const auto elem_type_id = type_mgr_->GetId(arr_ty->element_type()); const auto elem_type_id = type_mgr_->GetId(arr_ty->element_type());
auto ast_elem_ty = ConvertType(elem_type_id); auto* ast_elem_ty = ConvertType(elem_type_id);
if (ast_elem_ty.ast == nullptr) { if (ast_elem_ty == nullptr) {
return nullptr; return nullptr;
} }
const auto& length_info = arr_ty->length_info(); const auto& length_info = arr_ty->length_info();
@ -912,7 +904,7 @@ bool ParserImpl::ParseArrayDecorations(
return true; return true;
} }
typ::Type ParserImpl::ConvertType( ast::Type* ParserImpl::ConvertType(
uint32_t type_id, uint32_t type_id,
const spvtools::opt::analysis::Struct* struct_ty) { const spvtools::opt::analysis::Struct* struct_ty) {
// Compute the struct decoration. // Compute the struct decoration.
@ -942,8 +934,8 @@ typ::Type ParserImpl::ConvertType(
for (uint32_t member_index = 0; member_index < members.size(); for (uint32_t member_index = 0; member_index < members.size();
++member_index) { ++member_index) {
const auto member_type_id = type_mgr_->GetId(members[member_index]); const auto member_type_id = type_mgr_->GetId(members[member_index]);
auto ast_member_ty = ConvertType(member_type_id); auto* ast_member_ty = ConvertType(member_type_id);
if (ast_member_ty.ast == nullptr) { if (ast_member_ty == nullptr) {
// Already emitted diagnostics. // Already emitted diagnostics.
return nullptr; return nullptr;
} }
@ -1034,12 +1026,11 @@ typ::Type ParserImpl::ConvertType(
} }
auto* ast_struct = create<ast::Struct>(Source{}, sym, std::move(ast_members), auto* ast_struct = create<ast::Struct>(Source{}, sym, std::move(ast_members),
std::move(ast_struct_decorations)); std::move(ast_struct_decorations));
auto result = builder_.ty.struct_(ast_struct);
if (num_non_writable_members == members.size()) { if (num_non_writable_members == members.size()) {
read_only_struct_types_.insert(result.ast->name()); read_only_struct_types_.insert(ast_struct->name());
} }
AddConstructedType(sym, result); AddConstructedType(sym, ast_struct);
return result; return ast_struct;
} }
void ParserImpl::AddConstructedType(Symbol name, ast::NamedType* type) { void ParserImpl::AddConstructedType(Symbol name, ast::NamedType* type) {
@ -1049,7 +1040,7 @@ void ParserImpl::AddConstructedType(Symbol name, ast::NamedType* type) {
} }
} }
typ::Type ParserImpl::ConvertType(uint32_t type_id, ast::Type* ParserImpl::ConvertType(uint32_t type_id,
const spvtools::opt::analysis::Pointer*) { const spvtools::opt::analysis::Pointer*) {
const auto* inst = def_use_mgr_->GetDef(type_id); const auto* inst = def_use_mgr_->GetDef(type_id);
const auto pointee_type_id = inst->GetSingleWordInOperand(1); const auto pointee_type_id = inst->GetSingleWordInOperand(1);
@ -1060,8 +1051,8 @@ typ::Type ParserImpl::ConvertType(uint32_t type_id,
builtin_position_.storage_class = storage_class; builtin_position_.storage_class = storage_class;
return nullptr; return nullptr;
} }
auto ast_elem_ty = ConvertType(pointee_type_id); auto* ast_elem_ty = ConvertType(pointee_type_id);
if (ast_elem_ty.ast == nullptr) { if (ast_elem_ty == nullptr) {
Fail() << "SPIR-V pointer type with ID " << type_id Fail() << "SPIR-V pointer type with ID " << type_id
<< " has invalid pointee type " << pointee_type_id; << " has invalid pointee type " << pointee_type_id;
return nullptr; return nullptr;
@ -1123,7 +1114,7 @@ bool ParserImpl::EmitScalarSpecConstants() {
// that is OpSpecConstantTrue, OpSpecConstantFalse, or OpSpecConstant. // that is OpSpecConstantTrue, OpSpecConstantFalse, or OpSpecConstant.
for (auto& inst : module_->types_values()) { for (auto& inst : module_->types_values()) {
// These will be populated for a valid scalar spec constant. // These will be populated for a valid scalar spec constant.
typ::Type ast_type; ast::Type* ast_type = nullptr;
ast::ScalarConstructorExpression* ast_expr = nullptr; ast::ScalarConstructorExpression* ast_expr = nullptr;
switch (inst.opcode()) { switch (inst.opcode()) {
@ -1138,15 +1129,15 @@ bool ParserImpl::EmitScalarSpecConstants() {
case SpvOpSpecConstant: { case SpvOpSpecConstant: {
ast_type = ConvertType(inst.type_id()); ast_type = ConvertType(inst.type_id());
const uint32_t literal_value = inst.GetSingleWordInOperand(0); const uint32_t literal_value = inst.GetSingleWordInOperand(0);
if (ast_type.ast->Is<ast::I32>()) { if (ast_type->Is<ast::I32>()) {
ast_expr = create<ast::ScalarConstructorExpression>( ast_expr = create<ast::ScalarConstructorExpression>(
Source{}, create<ast::SintLiteral>( Source{}, create<ast::SintLiteral>(
Source{}, static_cast<int32_t>(literal_value))); Source{}, static_cast<int32_t>(literal_value)));
} else if (ast_type.ast->Is<ast::U32>()) { } else if (ast_type->Is<ast::U32>()) {
ast_expr = create<ast::ScalarConstructorExpression>( ast_expr = create<ast::ScalarConstructorExpression>(
Source{}, create<ast::UintLiteral>( Source{}, create<ast::UintLiteral>(
Source{}, static_cast<uint32_t>(literal_value))); Source{}, static_cast<uint32_t>(literal_value)));
} else if (ast_type.ast->Is<ast::F32>()) { } else if (ast_type->Is<ast::F32>()) {
float float_value; float float_value;
// Copy the bits so we can read them as a float. // Copy the bits so we can read them as a float.
std::memcpy(&float_value, &literal_value, sizeof(float_value)); std::memcpy(&float_value, &literal_value, sizeof(float_value));
@ -1182,10 +1173,10 @@ bool ParserImpl::EmitScalarSpecConstants() {
return success_; return success_;
} }
typ::Type ParserImpl::MaybeGenerateAlias( ast::Type* ParserImpl::MaybeGenerateAlias(
uint32_t type_id, uint32_t type_id,
const spvtools::opt::analysis::Type* type, const spvtools::opt::analysis::Type* type,
typ::Type ast_type) { ast::Type* ast_type) {
if (!success_) { if (!success_) {
return {}; return {};
} }
@ -1208,8 +1199,8 @@ typ::Type ParserImpl::MaybeGenerateAlias(
// Ignore constants, and any other types. // Ignore constants, and any other types.
return ast_type; return ast_type;
} }
auto ast_underlying_type = ast_type; auto* ast_underlying_type = ast_type;
if (ast_underlying_type.ast == nullptr) { if (ast_underlying_type == nullptr) {
Fail() << "internal error: no type registered for SPIR-V ID: " << type_id; Fail() << "internal error: no type registered for SPIR-V ID: " << type_id;
return {}; return {};
} }
@ -1261,7 +1252,7 @@ bool ParserImpl::EmitModuleScopeVariables() {
if (!success_) { if (!success_) {
return false; return false;
} }
typ::Type ast_type; ast::Type* ast_type;
if (spirv_storage_class == SpvStorageClassUniformConstant) { if (spirv_storage_class == SpvStorageClassUniformConstant) {
// These are opaque handles: samplers or textures // These are opaque handles: samplers or textures
ast_type = GetTypeForHandleVar(var); ast_type = GetTypeForHandleVar(var);
@ -1270,19 +1261,19 @@ bool ParserImpl::EmitModuleScopeVariables() {
} }
} else { } else {
ast_type = ConvertType(type_id); ast_type = ConvertType(type_id);
if (ast_type.ast == nullptr) { if (ast_type == nullptr) {
return Fail() << "internal error: failed to register Tint AST type for " return Fail() << "internal error: failed to register Tint AST type for "
"SPIR-V type with ID: " "SPIR-V type with ID: "
<< var.type_id(); << var.type_id();
} }
if (!ast_type.ast->Is<ast::Pointer>()) { if (!ast_type->Is<ast::Pointer>()) {
return Fail() << "variable with ID " << var.result_id() return Fail() << "variable with ID " << var.result_id()
<< " has non-pointer type " << var.type_id(); << " has non-pointer type " << var.type_id();
} }
} }
auto ast_store_type = typ::Call_type(typ::As<typ::Pointer>(ast_type)); auto* ast_store_type = ast_type->As<ast::Pointer>()->type();
auto ast_storage_class = ast_type.ast->As<ast::Pointer>()->storage_class(); auto ast_storage_class = ast_type->As<ast::Pointer>()->storage_class();
ast::Expression* ast_constructor = nullptr; ast::Expression* ast_constructor = nullptr;
if (var.NumInOperands() > 1) { if (var.NumInOperands() > 1) {
// SPIR-V initializers are always constants. // SPIR-V initializers are always constants.
@ -1345,18 +1336,18 @@ const spvtools::opt::analysis::IntConstant* ParserImpl::GetArraySize(
ast::Variable* ParserImpl::MakeVariable(uint32_t id, ast::Variable* ParserImpl::MakeVariable(uint32_t id,
ast::StorageClass sc, ast::StorageClass sc,
typ::Type type, ast::Type* type,
bool is_const, bool is_const,
ast::Expression* constructor, ast::Expression* constructor,
ast::DecorationList decorations) { ast::DecorationList decorations) {
if (type.ast == nullptr) { if (type == nullptr) {
Fail() << "internal error: can't make ast::Variable for null type"; Fail() << "internal error: can't make ast::Variable for null type";
return nullptr; return nullptr;
} }
if (sc == ast::StorageClass::kStorage) { if (sc == ast::StorageClass::kStorage) {
bool read_only = false; bool read_only = false;
if (auto* tn = type.ast->As<ast::TypeName>()) { if (auto* tn = type->As<ast::TypeName>()) {
read_only = read_only_struct_types_.count(tn->name()) > 0; read_only = read_only_struct_types_.count(tn->name()) > 0;
} }
@ -1388,7 +1379,7 @@ ast::Variable* ParserImpl::MakeVariable(uint32_t id,
// The SPIR-V variable is likely to be signed (because GLSL // The SPIR-V variable is likely to be signed (because GLSL
// requires signed), but WGSL requires unsigned. Handle specially // requires signed), but WGSL requires unsigned. Handle specially
// so we always perform the conversion at load and store. // so we always perform the conversion at load and store.
if (auto forced_type = UnsignedTypeFor(type)) { if (auto* forced_type = UnsignedTypeFor(type)) {
// Requires conversion and special handling in code generation. // Requires conversion and special handling in code generation.
special_builtins_[id] = spv_builtin; special_builtins_[id] = spv_builtin;
type = forced_type; type = forced_type;
@ -1461,8 +1452,8 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
Fail() << "ID " << id << " is not a registered instruction"; Fail() << "ID " << id << " is not a registered instruction";
return {}; return {};
} }
auto original_ast_type = ConvertType(inst->type_id()); auto* original_ast_type = ConvertType(inst->type_id());
if (original_ast_type.ast == nullptr) { if (original_ast_type == nullptr) {
return {}; return {};
} }
@ -1479,28 +1470,28 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
} }
auto source = GetSourceForInst(inst); auto source = GetSourceForInst(inst);
auto ast_type = UnwrapIfNeeded(original_ast_type); auto* ast_type = original_ast_type->UnwrapIfNeeded();
// TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0. // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0.
// So canonicalization should map that way too. // So canonicalization should map that way too.
// Currently "null<type>" is missing from the WGSL parser. // Currently "null<type>" is missing from the WGSL parser.
// See https://bugs.chromium.org/p/tint/issues/detail?id=34 // See https://bugs.chromium.org/p/tint/issues/detail?id=34
if (ast_type.ast->Is<ast::U32>()) { if (ast_type->Is<ast::U32>()) {
return {ast_type, create<ast::ScalarConstructorExpression>( return {ast_type, create<ast::ScalarConstructorExpression>(
Source{}, create<ast::UintLiteral>( Source{}, create<ast::UintLiteral>(
source, spirv_const->GetU32()))}; source, spirv_const->GetU32()))};
} }
if (ast_type.ast->Is<ast::I32>()) { if (ast_type->Is<ast::I32>()) {
return {ast_type, create<ast::ScalarConstructorExpression>( return {ast_type, create<ast::ScalarConstructorExpression>(
Source{}, create<ast::SintLiteral>( Source{}, create<ast::SintLiteral>(
source, spirv_const->GetS32()))}; source, spirv_const->GetS32()))};
} }
if (ast_type.ast->Is<ast::F32>()) { if (ast_type->Is<ast::F32>()) {
return {ast_type, create<ast::ScalarConstructorExpression>( return {ast_type, create<ast::ScalarConstructorExpression>(
Source{}, create<ast::FloatLiteral>( Source{}, create<ast::FloatLiteral>(
source, spirv_const->GetFloat()))}; source, spirv_const->GetFloat()))};
} }
if (ast_type.ast->Is<ast::Bool>()) { if (ast_type->Is<ast::Bool>()) {
const bool value = spirv_const->AsNullConstant() const bool value = spirv_const->AsNullConstant()
? false ? false
: spirv_const->AsBoolConstant()->value(); : spirv_const->AsBoolConstant()->value();
@ -1556,7 +1547,7 @@ ast::Expression* ParserImpl::MakeNullValue(ast::Type* type) {
} }
auto* original_type = type; auto* original_type = type;
type = UnwrapIfNeeded(type); type = type->UnwrapIfNeeded();
if (type->Is<ast::Bool>()) { if (type->Is<ast::Bool>()) {
return create<ast::ScalarConstructorExpression>( return create<ast::ScalarConstructorExpression>(
@ -1622,15 +1613,15 @@ ast::Expression* ParserImpl::MakeNullValue(ast::Type* type) {
return nullptr; return nullptr;
} }
TypedExpression ParserImpl::MakeNullExpression(typ::Type type) { TypedExpression ParserImpl::MakeNullExpression(ast::Type* type) {
return {type, MakeNullValue(type)}; return {type, MakeNullValue(type)};
} }
typ::Type ParserImpl::UnsignedTypeFor(typ::Type type) { ast::Type* ParserImpl::UnsignedTypeFor(ast::Type* type) {
if (type.ast->Is<ast::I32>()) { if (type->Is<ast::I32>()) {
return builder_.ty.u32(); return builder_.ty.u32();
} }
if (auto* v = type.ast->As<ast::Vector>()) { if (auto* v = type->As<ast::Vector>()) {
if (v->type()->Is<ast::I32>()) { if (v->type()->Is<ast::I32>()) {
return builder_.ty.vec(builder_.ty.u32(), v->size()); return builder_.ty.vec(builder_.ty.u32(), v->size());
} }
@ -1638,11 +1629,11 @@ typ::Type ParserImpl::UnsignedTypeFor(typ::Type type) {
return {}; return {};
} }
typ::Type ParserImpl::SignedTypeFor(typ::Type type) { ast::Type* ParserImpl::SignedTypeFor(ast::Type* type) {
if (type.ast->Is<ast::U32>()) { if (type->Is<ast::U32>()) {
return builder_.ty.i32(); return builder_.ty.i32();
} }
if (auto* v = type.ast->As<ast::Vector>()) { if (auto* v = type->As<ast::Vector>()) {
if (v->type()->Is<ast::U32>()) { if (v->type()->Is<ast::U32>()) {
return builder_.ty.vec(builder_.ty.i32(), v->size()); return builder_.ty.vec(builder_.ty.i32(), v->size());
} }
@ -1673,20 +1664,20 @@ TypedExpression ParserImpl::RectifyOperandSignedness(
Fail() << "internal error: RectifyOperandSignedness given a null expr\n"; Fail() << "internal error: RectifyOperandSignedness given a null expr\n";
return {}; return {};
} }
auto type = expr.type; auto* type = expr.type;
if (!type.ast) { if (!type) {
Fail() << "internal error: unmapped type for: " << builder_.str(expr.expr) Fail() << "internal error: unmapped type for: " << builder_.str(expr.expr)
<< "\n"; << "\n";
return {}; return {};
} }
if (requires_unsigned) { if (requires_unsigned) {
if (auto unsigned_ty = UnsignedTypeFor(type)) { if (auto* unsigned_ty = UnsignedTypeFor(type)) {
// Conversion is required. // Conversion is required.
return {unsigned_ty, return {unsigned_ty,
create<ast::BitcastExpression>(Source{}, unsigned_ty, expr.expr)}; create<ast::BitcastExpression>(Source{}, unsigned_ty, expr.expr)};
} }
} else if (requires_signed) { } else if (requires_signed) {
if (auto signed_ty = SignedTypeFor(type)) { if (auto* signed_ty = SignedTypeFor(type)) {
// Conversion is required. // Conversion is required.
return {signed_ty, return {signed_ty,
create<ast::BitcastExpression>(Source{}, signed_ty, expr.expr)}; create<ast::BitcastExpression>(Source{}, signed_ty, expr.expr)};
@ -1698,9 +1689,9 @@ TypedExpression ParserImpl::RectifyOperandSignedness(
TypedExpression ParserImpl::RectifySecondOperandSignedness( TypedExpression ParserImpl::RectifySecondOperandSignedness(
const spvtools::opt::Instruction& inst, const spvtools::opt::Instruction& inst,
typ::Type first_operand_type, ast::Type* first_operand_type,
TypedExpression&& second_operand_expr) { TypedExpression&& second_operand_expr) {
if ((first_operand_type != second_operand_expr.type) && if (!AstTypesEquivalent(first_operand_type, second_operand_expr.type) &&
AssumesSecondOperandSignednessMatchesFirstOperand(inst.opcode())) { AssumesSecondOperandSignednessMatchesFirstOperand(inst.opcode())) {
// Conversion is required. // Conversion is required.
return {first_operand_type, return {first_operand_type,
@ -1711,8 +1702,8 @@ TypedExpression ParserImpl::RectifySecondOperandSignedness(
return std::move(second_operand_expr); return std::move(second_operand_expr);
} }
typ::Type ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst, ast::Type* ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst,
typ::Type first_operand_type) { ast::Type* first_operand_type) {
const auto opcode = inst.opcode(); const auto opcode = inst.opcode();
if (AssumesResultSignednessMatchesFirstOperand(opcode)) { if (AssumesResultSignednessMatchesFirstOperand(opcode)) {
return first_operand_type; return first_operand_type;
@ -1727,16 +1718,15 @@ typ::Type ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst,
return nullptr; return nullptr;
} }
typ::Type ParserImpl::GetSignedIntMatchingShape(typ::Type other) { ast::Type* ParserImpl::GetSignedIntMatchingShape(ast::Type* other) {
if (other.ast == nullptr) { if (other == nullptr) {
Fail() << "no type provided"; Fail() << "no type provided";
} }
auto i32 = builder_.ty.i32(); auto i32 = builder_.ty.i32();
if (other.ast->Is<ast::F32>() || other.ast->Is<ast::U32>() || if (other->Is<ast::F32>() || other->Is<ast::U32>() || other->Is<ast::I32>()) {
other.ast->Is<ast::I32>()) {
return i32; return i32;
} }
auto* vec_ty = other.ast->As<ast::Vector>(); auto* vec_ty = other->As<ast::Vector>();
if (vec_ty) { if (vec_ty) {
return builder_.ty.vec(i32, vec_ty->size()); return builder_.ty.vec(i32, vec_ty->size());
} }
@ -1744,17 +1734,16 @@ typ::Type ParserImpl::GetSignedIntMatchingShape(typ::Type other) {
return nullptr; return nullptr;
} }
typ::Type ParserImpl::GetUnsignedIntMatchingShape(typ::Type other) { ast::Type* ParserImpl::GetUnsignedIntMatchingShape(ast::Type* other) {
if (other.ast == nullptr) { if (other == nullptr) {
Fail() << "no type provided"; Fail() << "no type provided";
return nullptr; return nullptr;
} }
auto u32 = builder_.ty.u32(); auto u32 = builder_.ty.u32();
if (other.ast->Is<ast::F32>() || other.ast->Is<ast::U32>() || if (other->Is<ast::F32>() || other->Is<ast::U32>() || other->Is<ast::I32>()) {
other.ast->Is<ast::I32>()) {
return u32; return u32;
} }
auto* vec_ty = other.ast->As<ast::Vector>(); auto* vec_ty = other->As<ast::Vector>();
if (vec_ty) { if (vec_ty) {
return builder_.ty.vec(u32, vec_ty->size()); return builder_.ty.vec(u32, vec_ty->size());
} }
@ -1765,9 +1754,10 @@ typ::Type ParserImpl::GetUnsignedIntMatchingShape(typ::Type other) {
TypedExpression ParserImpl::RectifyForcedResultType( TypedExpression ParserImpl::RectifyForcedResultType(
TypedExpression expr, TypedExpression expr,
const spvtools::opt::Instruction& inst, const spvtools::opt::Instruction& inst,
typ::Type first_operand_type) { ast::Type* first_operand_type) {
auto forced_result_ty = ForcedResultType(inst, first_operand_type); auto* forced_result_ty = ForcedResultType(inst, first_operand_type);
if ((forced_result_ty.ast == nullptr) || (forced_result_ty == expr.type)) { if ((forced_result_ty == nullptr) ||
AstTypesEquivalent(forced_result_ty, expr.type)) {
return expr; return expr;
} }
return {expr.type, return {expr.type,
@ -1776,7 +1766,7 @@ TypedExpression ParserImpl::RectifyForcedResultType(
TypedExpression ParserImpl::AsUnsigned(TypedExpression expr) { TypedExpression ParserImpl::AsUnsigned(TypedExpression expr) {
if (expr.type && expr.type->is_signed_scalar_or_vector()) { if (expr.type && expr.type->is_signed_scalar_or_vector()) {
auto new_type = GetUnsignedIntMatchingShape(expr.type); auto* new_type = GetUnsignedIntMatchingShape(expr.type);
return {new_type, return {new_type,
create<ast::BitcastExpression>(Source{}, new_type, expr.expr)}; create<ast::BitcastExpression>(Source{}, new_type, expr.expr)};
} }
@ -1785,7 +1775,7 @@ TypedExpression ParserImpl::AsUnsigned(TypedExpression expr) {
TypedExpression ParserImpl::AsSigned(TypedExpression expr) { TypedExpression ParserImpl::AsSigned(TypedExpression expr) {
if (expr.type && expr.type->is_unsigned_scalar_or_vector()) { if (expr.type && expr.type->is_unsigned_scalar_or_vector()) {
auto new_type = GetSignedIntMatchingShape(expr.type); auto* new_type = GetSignedIntMatchingShape(expr.type);
return {new_type, return {new_type,
create<ast::BitcastExpression>(Source{}, new_type, expr.expr)}; create<ast::BitcastExpression>(Source{}, new_type, expr.expr)};
} }
@ -1962,7 +1952,7 @@ ParserImpl::GetSpirvTypeForHandleMemoryObjectDeclaration(
return raw_handle_type; return raw_handle_type;
} }
typ::Pointer ParserImpl::GetTypeForHandleVar( ast::Pointer* ParserImpl::GetTypeForHandleVar(
const spvtools::opt::Instruction& var) { const spvtools::opt::Instruction& var) {
auto where = handle_type_.find(&var); auto where = handle_type_.find(&var);
if (where != handle_type_.end()) { if (where != handle_type_.end()) {
@ -2046,7 +2036,7 @@ typ::Pointer ParserImpl::GetTypeForHandleVar(
} }
// Construct the Tint handle type. // Construct the Tint handle type.
typ::Type ast_store_type; ast::Type* ast_store_type;
if (usage.IsSampler()) { if (usage.IsSampler()) {
ast_store_type = builder_.ty.sampler( ast_store_type = builder_.ty.sampler(
usage.IsComparisonSampler() ? ast::SamplerKind::kComparisonSampler usage.IsComparisonSampler() ? ast::SamplerKind::kComparisonSampler
@ -2071,7 +2061,7 @@ typ::Pointer ParserImpl::GetTypeForHandleVar(
if (usage.IsSampledTexture() || if (usage.IsSampledTexture() ||
(image_type->format() == SpvImageFormatUnknown)) { (image_type->format() == SpvImageFormatUnknown)) {
// Make a sampled texture type. // Make a sampled texture type.
auto ast_sampled_component_type = auto* ast_sampled_component_type =
ConvertType(raw_handle_type->GetSingleWordInOperand(0)); ConvertType(raw_handle_type->GetSingleWordInOperand(0));
// Vulkan ignores the depth parameter on OpImage, so pay attention to the // Vulkan ignores the depth parameter on OpImage, so pay attention to the
@ -2114,7 +2104,7 @@ typ::Pointer ParserImpl::GetTypeForHandleVar(
return result; return result;
} }
typ::Type ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) { ast::Type* ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) {
switch (format) { switch (format) {
case ast::ImageFormat::kR8Uint: case ast::ImageFormat::kR8Uint:
case ast::ImageFormat::kR16Uint: case ast::ImageFormat::kR16Uint:
@ -2163,8 +2153,8 @@ typ::Type ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) {
return nullptr; return nullptr;
} }
typ::Type ParserImpl::GetTexelTypeForFormat(ast::ImageFormat format) { ast::Type* ParserImpl::GetTexelTypeForFormat(ast::ImageFormat format) {
auto component_type = GetComponentTypeForFormat(format); auto* component_type = GetComponentTypeForFormat(format);
if (!component_type) { if (!component_type) {
return nullptr; return nullptr;
} }

View File

@ -29,7 +29,6 @@
#include "src/reader/spirv/enum_converter.h" #include "src/reader/spirv/enum_converter.h"
#include "src/reader/spirv/namer.h" #include "src/reader/spirv/namer.h"
#include "src/reader/spirv/usage.h" #include "src/reader/spirv/usage.h"
#include "src/typepair.h"
/// This is the implementation of the SPIR-V parser for Tint. /// This is the implementation of the SPIR-V parser for Tint.
@ -52,6 +51,14 @@ namespace tint {
namespace reader { namespace reader {
namespace spirv { namespace spirv {
/// Returns true of the two input ast types are semantically equivalent
/// @param lhs first type to compare
/// @param rhs other type to compare
/// @returns true if both types are semantically equivalent
inline bool AstTypesEquivalent(ast::Type* lhs, ast::Type* rhs) {
return lhs->type_name() == rhs->type_name();
}
/// The binary representation of a SPIR-V decoration enum followed by its /// The binary representation of a SPIR-V decoration enum followed by its
/// operands, if any. /// operands, if any.
/// Example: { SpvDecorationBlock } /// Example: { SpvDecorationBlock }
@ -74,10 +81,10 @@ struct TypedExpression {
/// Constructor /// Constructor
/// @param type_in the type of the expression /// @param type_in the type of the expression
/// @param expr_in the expression /// @param expr_in the expression
TypedExpression(typ::Type type_in, ast::Expression* expr_in); TypedExpression(ast::Type* type_in, ast::Expression* expr_in);
/// The type /// The type
typ::Type type; ast::Type* type;
/// The expression /// The expression
ast::Expression* expr = nullptr; ast::Expression* expr = nullptr;
}; };
@ -156,7 +163,7 @@ class ParserImpl : Reader {
/// after the internal representation of the module has been built. /// after the internal representation of the module has been built.
/// @param type_id the SPIR-V ID of a type. /// @param type_id the SPIR-V ID of a type.
/// @returns a Tint type, or nullptr /// @returns a Tint type, or nullptr
typ::Type ConvertType(uint32_t type_id); ast::Type* ConvertType(uint32_t type_id);
/// Emits an alias type declaration for the given type, if necessary, and /// Emits an alias type declaration for the given type, if necessary, and
/// also updates the mapping of the SPIR-V type ID to the alias type. /// also updates the mapping of the SPIR-V type ID to the alias type.
@ -169,9 +176,9 @@ class ParserImpl : Reader {
/// @param type the type that might get an alias /// @param type the type that might get an alias
/// @param ast_type the ast type that might get an alias /// @param ast_type the ast type that might get an alias
/// @returns an alias type or `ast_type` if no alias was created /// @returns an alias type or `ast_type` if no alias was created
typ::Type MaybeGenerateAlias(uint32_t type_id, ast::Type* MaybeGenerateAlias(uint32_t type_id,
const spvtools::opt::analysis::Type* type, const spvtools::opt::analysis::Type* type,
typ::Type ast_type); ast::Type* ast_type);
/// @returns the fail stream object /// @returns the fail stream object
FailStream& fail_stream() { return fail_stream_; } FailStream& fail_stream() { return fail_stream_; }
@ -321,7 +328,7 @@ class ParserImpl : Reader {
/// in the error case /// in the error case
ast::Variable* MakeVariable(uint32_t id, ast::Variable* MakeVariable(uint32_t id,
ast::StorageClass sc, ast::StorageClass sc,
typ::Type type, ast::Type* type,
bool is_const, bool is_const,
ast::Expression* constructor, ast::Expression* constructor,
ast::DecorationList decorations); ast::DecorationList decorations);
@ -339,7 +346,7 @@ class ParserImpl : Reader {
/// Make a typed expression for the null value for the given type. /// Make a typed expression for the null value for the given type.
/// @param type the AST type /// @param type the AST type
/// @returns a new typed expression /// @returns a new typed expression
TypedExpression MakeNullExpression(typ::Type type); TypedExpression MakeNullExpression(ast::Type* type);
/// Converts a given expression to the signedness demanded for an operand /// Converts a given expression to the signedness demanded for an operand
/// of the given SPIR-V instruction, if required. If the instruction assumes /// of the given SPIR-V instruction, if required. If the instruction assumes
@ -364,7 +371,7 @@ class ParserImpl : Reader {
/// @returns second_operand_expr, or a cast of it /// @returns second_operand_expr, or a cast of it
TypedExpression RectifySecondOperandSignedness( TypedExpression RectifySecondOperandSignedness(
const spvtools::opt::Instruction& inst, const spvtools::opt::Instruction& inst,
typ::Type first_operand_type, ast::Type* first_operand_type,
TypedExpression&& second_operand_expr); TypedExpression&& second_operand_expr);
/// Returns the "forced" result type for the given SPIR-V instruction. /// Returns the "forced" result type for the given SPIR-V instruction.
@ -375,8 +382,8 @@ class ParserImpl : Reader {
/// @param inst the SPIR-V instruction /// @param inst the SPIR-V instruction
/// @param first_operand_type the AST type for the first operand. /// @param first_operand_type the AST type for the first operand.
/// @returns the forced AST result type, or nullptr if no forcing is required. /// @returns the forced AST result type, or nullptr if no forcing is required.
typ::Type ForcedResultType(const spvtools::opt::Instruction& inst, ast::Type* ForcedResultType(const spvtools::opt::Instruction& inst,
typ::Type first_operand_type); ast::Type* first_operand_type);
/// Returns a signed integer scalar or vector type matching the shape (scalar, /// Returns a signed integer scalar or vector type matching the shape (scalar,
/// vector, and component bit width) of another type, which itself is a /// vector, and component bit width) of another type, which itself is a
@ -384,7 +391,7 @@ class ParserImpl : Reader {
/// requirement. /// requirement.
/// @param other the type whose shape must be matched /// @param other the type whose shape must be matched
/// @returns the signed scalar or vector type /// @returns the signed scalar or vector type
typ::Type GetSignedIntMatchingShape(typ::Type other); ast::Type* GetSignedIntMatchingShape(ast::Type* other);
/// Returns a signed integer scalar or vector type matching the shape (scalar, /// Returns a signed integer scalar or vector type matching the shape (scalar,
/// vector, and component bit width) of another type, which itself is a /// vector, and component bit width) of another type, which itself is a
@ -392,7 +399,7 @@ class ParserImpl : Reader {
/// requirement. /// requirement.
/// @param other the type whose shape must be matched /// @param other the type whose shape must be matched
/// @returns the unsigned scalar or vector type /// @returns the unsigned scalar or vector type
typ::Type GetUnsignedIntMatchingShape(typ::Type other); ast::Type* GetUnsignedIntMatchingShape(ast::Type* other);
/// Wraps the given expression in an as-cast to the given expression's type, /// Wraps the given expression in an as-cast to the given expression's type,
/// when the underlying operation produces a forced result type different /// when the underlying operation produces a forced result type different
@ -405,7 +412,7 @@ class ParserImpl : Reader {
TypedExpression RectifyForcedResultType( TypedExpression RectifyForcedResultType(
TypedExpression expr, TypedExpression expr,
const spvtools::opt::Instruction& inst, const spvtools::opt::Instruction& inst,
typ::Type first_operand_type); ast::Type* first_operand_type);
/// Returns the given expression, but ensuring it's an unsigned type of the /// Returns the given expression, but ensuring it's an unsigned type of the
/// same shape as the operand. Wraps the expresison with a bitcast if needed. /// same shape as the operand. Wraps the expresison with a bitcast if needed.
@ -505,18 +512,18 @@ class ParserImpl : Reader {
/// @param var the OpVariable instruction /// @param var the OpVariable instruction
/// @returns the Tint AST type for the poiner-to-{sampler|texture} or null on /// @returns the Tint AST type for the poiner-to-{sampler|texture} or null on
/// error /// error
typ::Pointer GetTypeForHandleVar(const spvtools::opt::Instruction& var); ast::Pointer* GetTypeForHandleVar(const spvtools::opt::Instruction& var);
/// Returns the channel component type corresponding to the given image /// Returns the channel component type corresponding to the given image
/// format. /// format.
/// @param format image texel format /// @param format image texel format
/// @returns the component type, one of f32, i32, u32 /// @returns the component type, one of f32, i32, u32
typ::Type GetComponentTypeForFormat(ast::ImageFormat format); ast::Type* GetComponentTypeForFormat(ast::ImageFormat format);
/// Returns texel type corresponding to the given image format. /// Returns texel type corresponding to the given image format.
/// @param format image texel format /// @param format image texel format
/// @returns the texel format /// @returns the texel format
typ::Type GetTexelTypeForFormat(ast::ImageFormat format); ast::Type* GetTexelTypeForFormat(ast::ImageFormat format);
/// Returns the SPIR-V instruction with the given ID, or nullptr. /// Returns the SPIR-V instruction with the given ID, or nullptr.
/// @param id the SPIR-V result ID /// @param id the SPIR-V result ID
@ -554,19 +561,19 @@ class ParserImpl : Reader {
private: private:
/// Converts a specific SPIR-V type to a Tint type. Integer case /// Converts a specific SPIR-V type to a Tint type. Integer case
typ::Type ConvertType(const spvtools::opt::analysis::Integer* int_ty); ast::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty);
/// Converts a specific SPIR-V type to a Tint type. Float case /// Converts a specific SPIR-V type to a Tint type. Float case
typ::Type ConvertType(const spvtools::opt::analysis::Float* float_ty); ast::Type* ConvertType(const spvtools::opt::analysis::Float* float_ty);
/// Converts a specific SPIR-V type to a Tint type. Vector case /// Converts a specific SPIR-V type to a Tint type. Vector case
typ::Type ConvertType(const spvtools::opt::analysis::Vector* vec_ty); ast::Type* ConvertType(const spvtools::opt::analysis::Vector* vec_ty);
/// Converts a specific SPIR-V type to a Tint type. Matrix case /// Converts a specific SPIR-V type to a Tint type. Matrix case
typ::Type ConvertType(const spvtools::opt::analysis::Matrix* mat_ty); ast::Type* ConvertType(const spvtools::opt::analysis::Matrix* mat_ty);
/// Converts a specific SPIR-V type to a Tint type. RuntimeArray case /// Converts a specific SPIR-V type to a Tint type. RuntimeArray case
/// @param rtarr_ty the Tint type /// @param rtarr_ty the Tint type
typ::Type ConvertType(const spvtools::opt::analysis::RuntimeArray* rtarr_ty); ast::Type* ConvertType(const spvtools::opt::analysis::RuntimeArray* rtarr_ty);
/// Converts a specific SPIR-V type to a Tint type. Array case /// Converts a specific SPIR-V type to a Tint type. Array case
/// @param arr_ty the Tint type /// @param arr_ty the Tint type
typ::Type ConvertType(const spvtools::opt::analysis::Array* arr_ty); ast::Type* ConvertType(const spvtools::opt::analysis::Array* arr_ty);
/// Converts a specific SPIR-V type to a Tint type. Struct case. /// Converts a specific SPIR-V type to a Tint type. Struct case.
/// SPIR-V allows distinct struct type definitions for two OpTypeStruct /// SPIR-V allows distinct struct type definitions for two OpTypeStruct
/// that otherwise have the same set of members (and struct and member /// that otherwise have the same set of members (and struct and member
@ -578,27 +585,27 @@ class ParserImpl : Reader {
/// not significant to the optimizer's module representation. /// not significant to the optimizer's module representation.
/// @param type_id the SPIR-V ID for the type. /// @param type_id the SPIR-V ID for the type.
/// @param struct_ty the Tint type /// @param struct_ty the Tint type
typ::Type ConvertType(uint32_t type_id, ast::Type* ConvertType(uint32_t type_id,
const spvtools::opt::analysis::Struct* struct_ty); const spvtools::opt::analysis::Struct* struct_ty);
/// Converts a specific SPIR-V type to a Tint type. Pointer case /// Converts a specific SPIR-V type to a Tint type. Pointer case
/// The pointer to gl_PerVertex maps to nullptr, and instead is recorded /// The pointer to gl_PerVertex maps to nullptr, and instead is recorded
/// in member #builtin_position_. /// in member #builtin_position_.
/// @param type_id the SPIR-V ID for the type. /// @param type_id the SPIR-V ID for the type.
/// @param ptr_ty the Tint type /// @param ptr_ty the Tint type
typ::Type ConvertType(uint32_t type_id, ast::Type* ConvertType(uint32_t type_id,
const spvtools::opt::analysis::Pointer* ptr_ty); const spvtools::opt::analysis::Pointer* ptr_ty);
/// If `type` is a signed integral, or vector of signed integral, /// If `type` is a signed integral, or vector of signed integral,
/// returns the unsigned type, otherwise returns `type`. /// returns the unsigned type, otherwise returns `type`.
/// @param type the possibly signed type /// @param type the possibly signed type
/// @returns the unsigned type /// @returns the unsigned type
typ::Type UnsignedTypeFor(typ::Type type); ast::Type* UnsignedTypeFor(ast::Type* type);
/// If `type` is a unsigned integral, or vector of unsigned integral, /// If `type` is a unsigned integral, or vector of unsigned integral,
/// returns the signed type, otherwise returns `type`. /// returns the signed type, otherwise returns `type`.
/// @param type the possibly unsigned type /// @param type the possibly unsigned type
/// @returns the signed type /// @returns the signed type
typ::Type SignedTypeFor(typ::Type type); ast::Type* SignedTypeFor(ast::Type* type);
/// Parses the array or runtime-array decorations. /// Parses the array or runtime-array decorations.
/// @param spv_type the SPIR-V array or runtime-array type. /// @param spv_type the SPIR-V array or runtime-array type.
@ -709,7 +716,7 @@ class ParserImpl : Reader {
// usages implied by usages of the memory-object-declaration. // usages implied by usages of the memory-object-declaration.
std::unordered_map<const spvtools::opt::Instruction*, Usage> handle_usage_; std::unordered_map<const spvtools::opt::Instruction*, Usage> handle_usage_;
// The inferred pointer type for the given handle variable. // The inferred pointer type for the given handle variable.
std::unordered_map<const spvtools::opt::Instruction*, typ::Pointer> std::unordered_map<const spvtools::opt::Instruction*, ast::Pointer*>
handle_type_; handle_type_;
// Set of symbols of constructed types that have been added, used to avoid // Set of symbols of constructed types that have been added, used to avoid

View File

@ -26,14 +26,14 @@ using ::testing::Eq;
TEST_F(SpvParserTest, ConvertType_PreservesExistingFailure) { TEST_F(SpvParserTest, ConvertType_PreservesExistingFailure) {
auto p = parser(std::vector<uint32_t>{}); auto p = parser(std::vector<uint32_t>{});
p->Fail() << "boing"; p->Fail() << "boing";
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_THAT(p->error(), Eq("boing")); EXPECT_THAT(p->error(), Eq("boing"));
} }
TEST_F(SpvParserTest, ConvertType_RequiresInternalRepresntation) { TEST_F(SpvParserTest, ConvertType_RequiresInternalRepresntation) {
auto p = parser(std::vector<uint32_t>{}); auto p = parser(std::vector<uint32_t>{});
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_THAT( EXPECT_THAT(
p->error(), p->error(),
@ -44,7 +44,7 @@ TEST_F(SpvParserTest, ConvertType_NotAnId) {
auto p = parser(test::Assemble("%1 = OpExtInstImport \"GLSL.std.450\"")); auto p = parser(test::Assemble("%1 = OpExtInstImport \"GLSL.std.450\""));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_EQ(nullptr, type); EXPECT_EQ(nullptr, type);
EXPECT_THAT(p->error(), Eq("ID is not a SPIR-V type: 10")); EXPECT_THAT(p->error(), Eq("ID is not a SPIR-V type: 10"));
@ -54,7 +54,7 @@ TEST_F(SpvParserTest, ConvertType_IdExistsButIsNotAType) {
auto p = parser(test::Assemble("%1 = OpExtInstImport \"GLSL.std.450\"")); auto p = parser(test::Assemble("%1 = OpExtInstImport \"GLSL.std.450\""));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(1); auto* type = p->ConvertType(1);
EXPECT_EQ(nullptr, type); EXPECT_EQ(nullptr, type);
EXPECT_THAT(p->error(), Eq("ID is not a SPIR-V type: 1")); EXPECT_THAT(p->error(), Eq("ID is not a SPIR-V type: 1"));
} }
@ -64,7 +64,7 @@ TEST_F(SpvParserTest, ConvertType_UnhandledType) {
auto p = parser(test::Assemble("%70 = OpTypePipe WriteOnly")); auto p = parser(test::Assemble("%70 = OpTypePipe WriteOnly"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(70); auto* type = p->ConvertType(70);
EXPECT_EQ(nullptr, type); EXPECT_EQ(nullptr, type);
EXPECT_THAT(p->error(), EXPECT_THAT(p->error(),
Eq("unknown SPIR-V type with ID 70: %70 = OpTypePipe WriteOnly")); Eq("unknown SPIR-V type with ID 70: %70 = OpTypePipe WriteOnly"));
@ -74,8 +74,8 @@ TEST_F(SpvParserTest, ConvertType_Void) {
auto p = parser(test::Assemble("%1 = OpTypeVoid")); auto p = parser(test::Assemble("%1 = OpTypeVoid"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(1); auto* type = p->ConvertType(1);
EXPECT_TRUE(type.ast->Is<ast::Void>()); EXPECT_TRUE(type->Is<ast::Void>());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }
@ -83,8 +83,8 @@ TEST_F(SpvParserTest, ConvertType_Bool) {
auto p = parser(test::Assemble("%100 = OpTypeBool")); auto p = parser(test::Assemble("%100 = OpTypeBool"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(100); auto* type = p->ConvertType(100);
EXPECT_TRUE(type.ast->Is<ast::Bool>()); EXPECT_TRUE(type->Is<ast::Bool>());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }
@ -92,8 +92,8 @@ TEST_F(SpvParserTest, ConvertType_I32) {
auto p = parser(test::Assemble("%2 = OpTypeInt 32 1")); auto p = parser(test::Assemble("%2 = OpTypeInt 32 1"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(2); auto* type = p->ConvertType(2);
EXPECT_TRUE(type.ast->Is<ast::I32>()); EXPECT_TRUE(type->Is<ast::I32>());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }
@ -101,8 +101,8 @@ TEST_F(SpvParserTest, ConvertType_U32) {
auto p = parser(test::Assemble("%3 = OpTypeInt 32 0")); auto p = parser(test::Assemble("%3 = OpTypeInt 32 0"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_TRUE(type.ast->Is<ast::U32>()); EXPECT_TRUE(type->Is<ast::U32>());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }
@ -110,8 +110,8 @@ TEST_F(SpvParserTest, ConvertType_F32) {
auto p = parser(test::Assemble("%4 = OpTypeFloat 32")); auto p = parser(test::Assemble("%4 = OpTypeFloat 32"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(4); auto* type = p->ConvertType(4);
EXPECT_TRUE(type.ast->Is<ast::F32>()); EXPECT_TRUE(type->Is<ast::F32>());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }
@ -119,7 +119,7 @@ TEST_F(SpvParserTest, ConvertType_BadIntWidth) {
auto p = parser(test::Assemble("%5 = OpTypeInt 17 1")); auto p = parser(test::Assemble("%5 = OpTypeInt 17 1"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(5); auto* type = p->ConvertType(5);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_THAT(p->error(), Eq("unhandled integer width: 17")); EXPECT_THAT(p->error(), Eq("unhandled integer width: 17"));
} }
@ -128,7 +128,7 @@ TEST_F(SpvParserTest, ConvertType_BadFloatWidth) {
auto p = parser(test::Assemble("%6 = OpTypeFloat 19")); auto p = parser(test::Assemble("%6 = OpTypeFloat 19"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(6); auto* type = p->ConvertType(6);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_THAT(p->error(), Eq("unhandled float width: 19")); EXPECT_THAT(p->error(), Eq("unhandled float width: 19"));
} }
@ -140,7 +140,7 @@ TEST_F(SpvParserTest, DISABLED_ConvertType_InvalidVectorElement) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(20); auto* type = p->ConvertType(20);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_THAT(p->error(), Eq("unknown SPIR-V type: 5")); EXPECT_THAT(p->error(), Eq("unknown SPIR-V type: 5"));
} }
@ -154,20 +154,20 @@ TEST_F(SpvParserTest, ConvertType_VecOverF32) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto v2xf32 = p->ConvertType(20); auto* v2xf32 = p->ConvertType(20);
EXPECT_TRUE(v2xf32.ast->Is<ast::Vector>()); EXPECT_TRUE(v2xf32->Is<ast::Vector>());
EXPECT_TRUE(v2xf32.ast->As<ast::Vector>()->type()->Is<ast::F32>()); EXPECT_TRUE(v2xf32->As<ast::Vector>()->type()->Is<ast::F32>());
EXPECT_EQ(v2xf32.ast->As<ast::Vector>()->size(), 2u); EXPECT_EQ(v2xf32->As<ast::Vector>()->size(), 2u);
auto v3xf32 = p->ConvertType(30); auto* v3xf32 = p->ConvertType(30);
EXPECT_TRUE(v3xf32.ast->Is<ast::Vector>()); EXPECT_TRUE(v3xf32->Is<ast::Vector>());
EXPECT_TRUE(v3xf32.ast->As<ast::Vector>()->type()->Is<ast::F32>()); EXPECT_TRUE(v3xf32->As<ast::Vector>()->type()->Is<ast::F32>());
EXPECT_EQ(v3xf32.ast->As<ast::Vector>()->size(), 3u); EXPECT_EQ(v3xf32->As<ast::Vector>()->size(), 3u);
auto v4xf32 = p->ConvertType(40); auto* v4xf32 = p->ConvertType(40);
EXPECT_TRUE(v4xf32.ast->Is<ast::Vector>()); EXPECT_TRUE(v4xf32->Is<ast::Vector>());
EXPECT_TRUE(v4xf32.ast->As<ast::Vector>()->type()->Is<ast::F32>()); EXPECT_TRUE(v4xf32->As<ast::Vector>()->type()->Is<ast::F32>());
EXPECT_EQ(v4xf32.ast->As<ast::Vector>()->size(), 4u); EXPECT_EQ(v4xf32->As<ast::Vector>()->size(), 4u);
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }
@ -181,20 +181,20 @@ TEST_F(SpvParserTest, ConvertType_VecOverI32) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto v2xi32 = p->ConvertType(20); auto* v2xi32 = p->ConvertType(20);
EXPECT_TRUE(v2xi32.ast->Is<ast::Vector>()); EXPECT_TRUE(v2xi32->Is<ast::Vector>());
EXPECT_TRUE(v2xi32.ast->As<ast::Vector>()->type()->Is<ast::I32>()); EXPECT_TRUE(v2xi32->As<ast::Vector>()->type()->Is<ast::I32>());
EXPECT_EQ(v2xi32.ast->As<ast::Vector>()->size(), 2u); EXPECT_EQ(v2xi32->As<ast::Vector>()->size(), 2u);
auto v3xi32 = p->ConvertType(30); auto* v3xi32 = p->ConvertType(30);
EXPECT_TRUE(v3xi32.ast->Is<ast::Vector>()); EXPECT_TRUE(v3xi32->Is<ast::Vector>());
EXPECT_TRUE(v3xi32.ast->As<ast::Vector>()->type()->Is<ast::I32>()); EXPECT_TRUE(v3xi32->As<ast::Vector>()->type()->Is<ast::I32>());
EXPECT_EQ(v3xi32.ast->As<ast::Vector>()->size(), 3u); EXPECT_EQ(v3xi32->As<ast::Vector>()->size(), 3u);
auto v4xi32 = p->ConvertType(40); auto* v4xi32 = p->ConvertType(40);
EXPECT_TRUE(v4xi32.ast->Is<ast::Vector>()); EXPECT_TRUE(v4xi32->Is<ast::Vector>());
EXPECT_TRUE(v4xi32.ast->As<ast::Vector>()->type()->Is<ast::I32>()); EXPECT_TRUE(v4xi32->As<ast::Vector>()->type()->Is<ast::I32>());
EXPECT_EQ(v4xi32.ast->As<ast::Vector>()->size(), 4u); EXPECT_EQ(v4xi32->As<ast::Vector>()->size(), 4u);
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }
@ -208,20 +208,20 @@ TEST_F(SpvParserTest, ConvertType_VecOverU32) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto v2xu32 = p->ConvertType(20); auto* v2xu32 = p->ConvertType(20);
EXPECT_TRUE(v2xu32.ast->Is<ast::Vector>()); EXPECT_TRUE(v2xu32->Is<ast::Vector>());
EXPECT_TRUE(v2xu32.ast->As<ast::Vector>()->type()->Is<ast::U32>()); EXPECT_TRUE(v2xu32->As<ast::Vector>()->type()->Is<ast::U32>());
EXPECT_EQ(v2xu32.ast->As<ast::Vector>()->size(), 2u); EXPECT_EQ(v2xu32->As<ast::Vector>()->size(), 2u);
auto v3xu32 = p->ConvertType(30); auto* v3xu32 = p->ConvertType(30);
EXPECT_TRUE(v3xu32.ast->Is<ast::Vector>()); EXPECT_TRUE(v3xu32->Is<ast::Vector>());
EXPECT_TRUE(v3xu32.ast->As<ast::Vector>()->type()->Is<ast::U32>()); EXPECT_TRUE(v3xu32->As<ast::Vector>()->type()->Is<ast::U32>());
EXPECT_EQ(v3xu32.ast->As<ast::Vector>()->size(), 3u); EXPECT_EQ(v3xu32->As<ast::Vector>()->size(), 3u);
auto v4xu32 = p->ConvertType(40); auto* v4xu32 = p->ConvertType(40);
EXPECT_TRUE(v4xu32.ast->Is<ast::Vector>()); EXPECT_TRUE(v4xu32->Is<ast::Vector>());
EXPECT_TRUE(v4xu32.ast->As<ast::Vector>()->type()->Is<ast::U32>()); EXPECT_TRUE(v4xu32->As<ast::Vector>()->type()->Is<ast::U32>());
EXPECT_EQ(v4xu32.ast->As<ast::Vector>()->size(), 4u); EXPECT_EQ(v4xu32->As<ast::Vector>()->size(), 4u);
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }
@ -234,7 +234,7 @@ TEST_F(SpvParserTest, DISABLED_ConvertType_InvalidMatrixElement) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(20); auto* type = p->ConvertType(20);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_THAT(p->error(), Eq("unknown SPIR-V type: 5")); EXPECT_THAT(p->error(), Eq("unknown SPIR-V type: 5"));
} }
@ -260,59 +260,59 @@ TEST_F(SpvParserTest, ConvertType_MatrixOverF32) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto m22 = p->ConvertType(22); auto* m22 = p->ConvertType(22);
EXPECT_TRUE(m22.ast->Is<ast::Matrix>()); EXPECT_TRUE(m22->Is<ast::Matrix>());
EXPECT_TRUE(m22.ast->As<ast::Matrix>()->type()->Is<ast::F32>()); EXPECT_TRUE(m22->As<ast::Matrix>()->type()->Is<ast::F32>());
EXPECT_EQ(m22.ast->As<ast::Matrix>()->rows(), 2u); EXPECT_EQ(m22->As<ast::Matrix>()->rows(), 2u);
EXPECT_EQ(m22.ast->As<ast::Matrix>()->columns(), 2u); EXPECT_EQ(m22->As<ast::Matrix>()->columns(), 2u);
auto m23 = p->ConvertType(23); auto* m23 = p->ConvertType(23);
EXPECT_TRUE(m23.ast->Is<ast::Matrix>()); EXPECT_TRUE(m23->Is<ast::Matrix>());
EXPECT_TRUE(m23.ast->As<ast::Matrix>()->type()->Is<ast::F32>()); EXPECT_TRUE(m23->As<ast::Matrix>()->type()->Is<ast::F32>());
EXPECT_EQ(m23.ast->As<ast::Matrix>()->rows(), 2u); EXPECT_EQ(m23->As<ast::Matrix>()->rows(), 2u);
EXPECT_EQ(m23.ast->As<ast::Matrix>()->columns(), 3u); EXPECT_EQ(m23->As<ast::Matrix>()->columns(), 3u);
auto m24 = p->ConvertType(24); auto* m24 = p->ConvertType(24);
EXPECT_TRUE(m24.ast->Is<ast::Matrix>()); EXPECT_TRUE(m24->Is<ast::Matrix>());
EXPECT_TRUE(m24.ast->As<ast::Matrix>()->type()->Is<ast::F32>()); EXPECT_TRUE(m24->As<ast::Matrix>()->type()->Is<ast::F32>());
EXPECT_EQ(m24.ast->As<ast::Matrix>()->rows(), 2u); EXPECT_EQ(m24->As<ast::Matrix>()->rows(), 2u);
EXPECT_EQ(m24.ast->As<ast::Matrix>()->columns(), 4u); EXPECT_EQ(m24->As<ast::Matrix>()->columns(), 4u);
auto m32 = p->ConvertType(32); auto* m32 = p->ConvertType(32);
EXPECT_TRUE(m32.ast->Is<ast::Matrix>()); EXPECT_TRUE(m32->Is<ast::Matrix>());
EXPECT_TRUE(m32.ast->As<ast::Matrix>()->type()->Is<ast::F32>()); EXPECT_TRUE(m32->As<ast::Matrix>()->type()->Is<ast::F32>());
EXPECT_EQ(m32.ast->As<ast::Matrix>()->rows(), 3u); EXPECT_EQ(m32->As<ast::Matrix>()->rows(), 3u);
EXPECT_EQ(m32.ast->As<ast::Matrix>()->columns(), 2u); EXPECT_EQ(m32->As<ast::Matrix>()->columns(), 2u);
auto m33 = p->ConvertType(33); auto* m33 = p->ConvertType(33);
EXPECT_TRUE(m33.ast->Is<ast::Matrix>()); EXPECT_TRUE(m33->Is<ast::Matrix>());
EXPECT_TRUE(m33.ast->As<ast::Matrix>()->type()->Is<ast::F32>()); EXPECT_TRUE(m33->As<ast::Matrix>()->type()->Is<ast::F32>());
EXPECT_EQ(m33.ast->As<ast::Matrix>()->rows(), 3u); EXPECT_EQ(m33->As<ast::Matrix>()->rows(), 3u);
EXPECT_EQ(m33.ast->As<ast::Matrix>()->columns(), 3u); EXPECT_EQ(m33->As<ast::Matrix>()->columns(), 3u);
auto m34 = p->ConvertType(34); auto* m34 = p->ConvertType(34);
EXPECT_TRUE(m34.ast->Is<ast::Matrix>()); EXPECT_TRUE(m34->Is<ast::Matrix>());
EXPECT_TRUE(m34.ast->As<ast::Matrix>()->type()->Is<ast::F32>()); EXPECT_TRUE(m34->As<ast::Matrix>()->type()->Is<ast::F32>());
EXPECT_EQ(m34.ast->As<ast::Matrix>()->rows(), 3u); EXPECT_EQ(m34->As<ast::Matrix>()->rows(), 3u);
EXPECT_EQ(m34.ast->As<ast::Matrix>()->columns(), 4u); EXPECT_EQ(m34->As<ast::Matrix>()->columns(), 4u);
auto m42 = p->ConvertType(42); auto* m42 = p->ConvertType(42);
EXPECT_TRUE(m42.ast->Is<ast::Matrix>()); EXPECT_TRUE(m42->Is<ast::Matrix>());
EXPECT_TRUE(m42.ast->As<ast::Matrix>()->type()->Is<ast::F32>()); EXPECT_TRUE(m42->As<ast::Matrix>()->type()->Is<ast::F32>());
EXPECT_EQ(m42.ast->As<ast::Matrix>()->rows(), 4u); EXPECT_EQ(m42->As<ast::Matrix>()->rows(), 4u);
EXPECT_EQ(m42.ast->As<ast::Matrix>()->columns(), 2u); EXPECT_EQ(m42->As<ast::Matrix>()->columns(), 2u);
auto m43 = p->ConvertType(43); auto* m43 = p->ConvertType(43);
EXPECT_TRUE(m43.ast->Is<ast::Matrix>()); EXPECT_TRUE(m43->Is<ast::Matrix>());
EXPECT_TRUE(m43.ast->As<ast::Matrix>()->type()->Is<ast::F32>()); EXPECT_TRUE(m43->As<ast::Matrix>()->type()->Is<ast::F32>());
EXPECT_EQ(m43.ast->As<ast::Matrix>()->rows(), 4u); EXPECT_EQ(m43->As<ast::Matrix>()->rows(), 4u);
EXPECT_EQ(m43.ast->As<ast::Matrix>()->columns(), 3u); EXPECT_EQ(m43->As<ast::Matrix>()->columns(), 3u);
auto m44 = p->ConvertType(44); auto* m44 = p->ConvertType(44);
EXPECT_TRUE(m44.ast->Is<ast::Matrix>()); EXPECT_TRUE(m44->Is<ast::Matrix>());
EXPECT_TRUE(m44.ast->As<ast::Matrix>()->type()->Is<ast::F32>()); EXPECT_TRUE(m44->As<ast::Matrix>()->type()->Is<ast::F32>());
EXPECT_EQ(m44.ast->As<ast::Matrix>()->rows(), 4u); EXPECT_EQ(m44->As<ast::Matrix>()->rows(), 4u);
EXPECT_EQ(m44.ast->As<ast::Matrix>()->columns(), 4u); EXPECT_EQ(m44->As<ast::Matrix>()->columns(), 4u);
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }
@ -324,10 +324,10 @@ TEST_F(SpvParserTest, ConvertType_RuntimeArray) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr); ASSERT_NE(type, nullptr);
EXPECT_TRUE(type.ast->UnwrapAliasIfNeeded()->Is<ast::Array>()); EXPECT_TRUE(type->UnwrapAliasIfNeeded()->Is<ast::Array>());
auto* arr_type = type.ast->UnwrapAliasIfNeeded()->As<ast::Array>(); auto* arr_type = type->UnwrapAliasIfNeeded()->As<ast::Array>();
EXPECT_TRUE(arr_type->IsRuntimeArray()); EXPECT_TRUE(arr_type->IsRuntimeArray());
ASSERT_NE(arr_type, nullptr); ASSERT_NE(arr_type, nullptr);
EXPECT_EQ(arr_type->size(), 0u); EXPECT_EQ(arr_type->size(), 0u);
@ -345,7 +345,7 @@ TEST_F(SpvParserTest, ConvertType_RuntimeArray_InvalidDecoration) {
%10 = OpTypeRuntimeArray %uint %10 = OpTypeRuntimeArray %uint
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_THAT( EXPECT_THAT(
p->error(), p->error(),
@ -359,9 +359,9 @@ TEST_F(SpvParserTest, ConvertType_RuntimeArray_ArrayStride_Valid) {
%10 = OpTypeRuntimeArray %uint %10 = OpTypeRuntimeArray %uint
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr); ASSERT_NE(type, nullptr);
auto* arr_type = type.ast->UnwrapAliasIfNeeded()->As<ast::Array>(); auto* arr_type = type->UnwrapAliasIfNeeded()->As<ast::Array>();
EXPECT_TRUE(arr_type->IsRuntimeArray()); EXPECT_TRUE(arr_type->IsRuntimeArray());
ASSERT_NE(arr_type, nullptr); ASSERT_NE(arr_type, nullptr);
ASSERT_EQ(arr_type->decorations().size(), 1u); ASSERT_EQ(arr_type->decorations().size(), 1u);
@ -378,7 +378,7 @@ TEST_F(SpvParserTest, ConvertType_RuntimeArray_ArrayStride_ZeroIsError) {
%10 = OpTypeRuntimeArray %uint %10 = OpTypeRuntimeArray %uint
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_THAT(p->error(), EXPECT_THAT(p->error(),
Eq("invalid array type ID 10: ArrayStride can't be 0")); Eq("invalid array type ID 10: ArrayStride can't be 0"));
@ -393,7 +393,7 @@ TEST_F(SpvParserTest,
%10 = OpTypeRuntimeArray %uint %10 = OpTypeRuntimeArray %uint
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_THAT(p->error(), EXPECT_THAT(p->error(),
Eq("invalid array type ID 10: multiple ArrayStride decorations")); Eq("invalid array type ID 10: multiple ArrayStride decorations"));
@ -407,10 +407,10 @@ TEST_F(SpvParserTest, ConvertType_Array) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr); ASSERT_NE(type, nullptr);
EXPECT_TRUE(type.ast->Is<ast::Array>()); EXPECT_TRUE(type->Is<ast::Array>());
auto* arr_type = type.ast->As<ast::Array>(); auto* arr_type = type->As<ast::Array>();
EXPECT_FALSE(arr_type->IsRuntimeArray()); EXPECT_FALSE(arr_type->IsRuntimeArray());
ASSERT_NE(arr_type, nullptr); ASSERT_NE(arr_type, nullptr);
EXPECT_EQ(arr_type->size(), 42u); EXPECT_EQ(arr_type->size(), 42u);
@ -430,7 +430,7 @@ TEST_F(SpvParserTest, ConvertType_ArrayBadLengthIsSpecConstantValue) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_EQ(type, nullptr); ASSERT_EQ(type, nullptr);
EXPECT_THAT(p->error(), EXPECT_THAT(p->error(),
Eq("Array type 10 length is a specialization constant")); Eq("Array type 10 length is a specialization constant"));
@ -445,7 +445,7 @@ TEST_F(SpvParserTest, ConvertType_ArrayBadLengthIsSpecConstantExpr) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_EQ(type, nullptr); ASSERT_EQ(type, nullptr);
EXPECT_THAT(p->error(), EXPECT_THAT(p->error(),
Eq("Array type 10 length is a specialization constant")); Eq("Array type 10 length is a specialization constant"));
@ -463,7 +463,7 @@ TEST_F(SpvParserTest, ConvertType_ArrayBadTooBig) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_EQ(type, nullptr); ASSERT_EQ(type, nullptr);
// TODO(dneto): Right now it's rejected earlier in the flow because // TODO(dneto): Right now it's rejected earlier in the flow because
// we can't even utter the uint64 type. // we can't even utter the uint64 type.
@ -478,7 +478,7 @@ TEST_F(SpvParserTest, ConvertType_Array_InvalidDecoration) {
%10 = OpTypeArray %uint %uint_5 %10 = OpTypeArray %uint %uint_5
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_THAT( EXPECT_THAT(
p->error(), p->error(),
@ -494,10 +494,10 @@ TEST_F(SpvParserTest, ConvertType_ArrayStride_Valid) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr); ASSERT_NE(type, nullptr);
EXPECT_TRUE(type.ast->UnwrapAliasIfNeeded()->Is<ast::Array>()); EXPECT_TRUE(type->UnwrapAliasIfNeeded()->Is<ast::Array>());
auto* arr_type = type.ast->UnwrapAliasIfNeeded()->As<ast::Array>(); auto* arr_type = type->UnwrapAliasIfNeeded()->As<ast::Array>();
ASSERT_NE(arr_type, nullptr); ASSERT_NE(arr_type, nullptr);
ASSERT_EQ(arr_type->decorations().size(), 1u); ASSERT_EQ(arr_type->decorations().size(), 1u);
@ -517,7 +517,7 @@ TEST_F(SpvParserTest, ConvertType_ArrayStride_ZeroIsError) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_EQ(type, nullptr); ASSERT_EQ(type, nullptr);
EXPECT_THAT(p->error(), EXPECT_THAT(p->error(),
Eq("invalid array type ID 10: ArrayStride can't be 0")); Eq("invalid array type ID 10: ArrayStride can't be 0"));
@ -533,7 +533,7 @@ TEST_F(SpvParserTest, ConvertType_ArrayStride_SpecifiedTwiceIsError) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_EQ(type, nullptr); ASSERT_EQ(type, nullptr);
EXPECT_THAT(p->error(), EXPECT_THAT(p->error(),
Eq("invalid array type ID 10: multiple ArrayStride decorations")); Eq("invalid array type ID 10: multiple ArrayStride decorations"));
@ -548,12 +548,12 @@ TEST_F(SpvParserTest, ConvertType_StructTwoMembers) {
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
EXPECT_TRUE(p->RegisterUserAndStructMemberNames()); EXPECT_TRUE(p->RegisterUserAndStructMemberNames());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr); ASSERT_NE(type, nullptr);
EXPECT_TRUE(type.ast->Is<ast::Struct>()); EXPECT_TRUE(type->Is<ast::Struct>());
Program program = p->program(); Program program = p->program();
EXPECT_THAT(program.str(type.ast->As<ast::Struct>()), Eq(R"(Struct S { EXPECT_THAT(program.str(type->As<ast::Struct>()), Eq(R"(Struct S {
StructMember{field0: __u32} StructMember{field0: __u32}
StructMember{field1: __f32} StructMember{field1: __f32}
} }
@ -569,12 +569,12 @@ TEST_F(SpvParserTest, ConvertType_StructWithBlockDecoration) {
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
EXPECT_TRUE(p->RegisterUserAndStructMemberNames()); EXPECT_TRUE(p->RegisterUserAndStructMemberNames());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr); ASSERT_NE(type, nullptr);
EXPECT_TRUE(type.ast->Is<ast::Struct>()); EXPECT_TRUE(type->Is<ast::Struct>());
Program program = p->program(); Program program = p->program();
EXPECT_THAT(program.str(type.ast->As<ast::Struct>()), Eq(R"(Struct S { EXPECT_THAT(program.str(type->As<ast::Struct>()), Eq(R"(Struct S {
[[block]] [[block]]
StructMember{field0: __u32} StructMember{field0: __u32}
} }
@ -594,12 +594,12 @@ TEST_F(SpvParserTest, ConvertType_StructWithMemberDecorations) {
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
EXPECT_TRUE(p->RegisterUserAndStructMemberNames()); EXPECT_TRUE(p->RegisterUserAndStructMemberNames());
auto type = p->ConvertType(10); auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr); ASSERT_NE(type, nullptr);
EXPECT_TRUE(type.ast->Is<ast::Struct>()); EXPECT_TRUE(type->Is<ast::Struct>());
Program program = p->program(); Program program = p->program();
EXPECT_THAT(program.str(type.ast->As<ast::Struct>()), Eq(R"(Struct S { EXPECT_THAT(program.str(type->As<ast::Struct>()), Eq(R"(Struct S {
StructMember{[[ offset 0 ]] field0: __f32} StructMember{[[ offset 0 ]] field0: __f32}
StructMember{[[ offset 8 ]] field1: __vec_2__f32} StructMember{[[ offset 8 ]] field1: __vec_2__f32}
StructMember{[[ offset 16 ]] field2: __mat_2_2__f32} StructMember{[[ offset 16 ]] field2: __mat_2_2__f32}
@ -621,7 +621,7 @@ TEST_F(SpvParserTest, ConvertType_InvalidPointeetype) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()) << p->error(); EXPECT_TRUE(p->BuildInternalModule()) << p->error();
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_EQ(type, nullptr); EXPECT_EQ(type, nullptr);
EXPECT_THAT(p->error(), EXPECT_THAT(p->error(),
Eq("SPIR-V pointer type with ID 3 has invalid pointee type 42")); Eq("SPIR-V pointer type with ID 3 has invalid pointee type 42"));
@ -644,9 +644,9 @@ TEST_F(SpvParserTest, ConvertType_PointerInput) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_TRUE(type.ast->Is<ast::Pointer>()); EXPECT_TRUE(type->Is<ast::Pointer>());
auto* ptr_ty = type.ast->As<ast::Pointer>(); auto* ptr_ty = type->As<ast::Pointer>();
EXPECT_NE(ptr_ty, nullptr); EXPECT_NE(ptr_ty, nullptr);
EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kInput); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kInput);
@ -660,9 +660,9 @@ TEST_F(SpvParserTest, ConvertType_PointerOutput) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_TRUE(type.ast->Is<ast::Pointer>()); EXPECT_TRUE(type->Is<ast::Pointer>());
auto* ptr_ty = type.ast->As<ast::Pointer>(); auto* ptr_ty = type->As<ast::Pointer>();
EXPECT_NE(ptr_ty, nullptr); EXPECT_NE(ptr_ty, nullptr);
EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kOutput); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kOutput);
@ -676,9 +676,9 @@ TEST_F(SpvParserTest, ConvertType_PointerUniform) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_TRUE(type.ast->Is<ast::Pointer>()); EXPECT_TRUE(type->Is<ast::Pointer>());
auto* ptr_ty = type.ast->As<ast::Pointer>(); auto* ptr_ty = type->As<ast::Pointer>();
EXPECT_NE(ptr_ty, nullptr); EXPECT_NE(ptr_ty, nullptr);
EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kUniform); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kUniform);
@ -692,9 +692,9 @@ TEST_F(SpvParserTest, ConvertType_PointerWorkgroup) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_TRUE(type.ast->Is<ast::Pointer>()); EXPECT_TRUE(type->Is<ast::Pointer>());
auto* ptr_ty = type.ast->As<ast::Pointer>(); auto* ptr_ty = type->As<ast::Pointer>();
EXPECT_NE(ptr_ty, nullptr); EXPECT_NE(ptr_ty, nullptr);
EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kWorkgroup); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kWorkgroup);
@ -708,9 +708,9 @@ TEST_F(SpvParserTest, ConvertType_PointerUniformConstant) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_TRUE(type.ast->Is<ast::Pointer>()); EXPECT_TRUE(type->Is<ast::Pointer>());
auto* ptr_ty = type.ast->As<ast::Pointer>(); auto* ptr_ty = type->As<ast::Pointer>();
EXPECT_NE(ptr_ty, nullptr); EXPECT_NE(ptr_ty, nullptr);
EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kUniformConstant); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kUniformConstant);
@ -724,9 +724,9 @@ TEST_F(SpvParserTest, ConvertType_PointerStorageBuffer) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_TRUE(type.ast->Is<ast::Pointer>()); EXPECT_TRUE(type->Is<ast::Pointer>());
auto* ptr_ty = type.ast->As<ast::Pointer>(); auto* ptr_ty = type->As<ast::Pointer>();
EXPECT_NE(ptr_ty, nullptr); EXPECT_NE(ptr_ty, nullptr);
EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kStorage); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kStorage);
@ -740,9 +740,9 @@ TEST_F(SpvParserTest, ConvertType_PointerImage) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_TRUE(type.ast->Is<ast::Pointer>()); EXPECT_TRUE(type->Is<ast::Pointer>());
auto* ptr_ty = type.ast->As<ast::Pointer>(); auto* ptr_ty = type->As<ast::Pointer>();
EXPECT_NE(ptr_ty, nullptr); EXPECT_NE(ptr_ty, nullptr);
EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kImage); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kImage);
@ -756,9 +756,9 @@ TEST_F(SpvParserTest, ConvertType_PointerPrivate) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_TRUE(type.ast->Is<ast::Pointer>()); EXPECT_TRUE(type->Is<ast::Pointer>());
auto* ptr_ty = type.ast->As<ast::Pointer>(); auto* ptr_ty = type->As<ast::Pointer>();
EXPECT_NE(ptr_ty, nullptr); EXPECT_NE(ptr_ty, nullptr);
EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kPrivate); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kPrivate);
@ -772,9 +772,9 @@ TEST_F(SpvParserTest, ConvertType_PointerFunction) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_TRUE(type.ast->Is<ast::Pointer>()); EXPECT_TRUE(type->Is<ast::Pointer>());
auto* ptr_ty = type.ast->As<ast::Pointer>(); auto* ptr_ty = type->As<ast::Pointer>();
EXPECT_NE(ptr_ty, nullptr); EXPECT_NE(ptr_ty, nullptr);
EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kFunction); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kFunction);
@ -790,11 +790,11 @@ TEST_F(SpvParserTest, ConvertType_PointerToPointer) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(3); auto* type = p->ConvertType(3);
EXPECT_NE(type, nullptr); EXPECT_NE(type, nullptr);
EXPECT_TRUE(type.ast->Is<ast::Pointer>()); EXPECT_TRUE(type->Is<ast::Pointer>());
auto* ptr_ty = type.ast->As<ast::Pointer>(); auto* ptr_ty = type->As<ast::Pointer>();
EXPECT_NE(ptr_ty, nullptr); EXPECT_NE(ptr_ty, nullptr);
EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kInput); EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kInput);
EXPECT_TRUE(ptr_ty->type()->Is<ast::Pointer>()); EXPECT_TRUE(ptr_ty->type()->Is<ast::Pointer>());
@ -814,8 +814,8 @@ TEST_F(SpvParserTest, ConvertType_Sampler_PretendVoid) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(1); auto* type = p->ConvertType(1);
EXPECT_TRUE(type.ast->Is<ast::Void>()); EXPECT_TRUE(type->Is<ast::Void>());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }
@ -827,8 +827,8 @@ TEST_F(SpvParserTest, ConvertType_Image_PretendVoid) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(1); auto* type = p->ConvertType(1);
EXPECT_TRUE(type.ast->Is<ast::Void>()); EXPECT_TRUE(type->Is<ast::Void>());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }
@ -840,8 +840,8 @@ TEST_F(SpvParserTest, ConvertType_SampledImage_PretendVoid) {
)")); )"));
EXPECT_TRUE(p->BuildInternalModule()); EXPECT_TRUE(p->BuildInternalModule());
auto type = p->ConvertType(1); auto* type = p->ConvertType(1);
EXPECT_TRUE(type.ast->Is<ast::Void>()); EXPECT_TRUE(type->Is<ast::Void>());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
} }

View File

@ -158,7 +158,7 @@ class ParserImplWrapperForTest {
/// after the internal representation of the module has been built. /// after the internal representation of the module has been built.
/// @param id the SPIR-V ID of a type. /// @param id the SPIR-V ID of a type.
/// @returns a Tint type, or nullptr /// @returns a Tint type, or nullptr
typ::Type ConvertType(uint32_t id) { return impl_.ConvertType(id); } ast::Type* ConvertType(uint32_t id) { return impl_.ConvertType(id); }
/// Gets the list of decorations for a SPIR-V result ID. Returns an empty /// Gets the list of decorations for a SPIR-V result ID. Returns an empty
/// vector if the ID is not a result ID, or if no decorations target that ID. /// vector if the ID is not a result ID, or if no decorations target that ID.

View File

@ -271,27 +271,6 @@ inline auto MakeTypePair(AST* ast, SEM* sem) {
return TypePair<AST, SEM>{ast, sem}; return TypePair<AST, SEM>{ast, sem};
} }
/// Performs an As operation on the `ast` and `sem` members of the input type
/// pair, deducing the mapped type from typ::* to ast::* and sem::*
/// respectively.
/// @param tp the type pair to call As on
/// @returns a new type pair after As has been called on each of `sem` and `ast`
template <typename TargetTYP, typename AST, typename SEM>
auto As(TypePair<AST, SEM> tp)
-> TypePair<typename TargetTYP::AST_TYPE, typename TargetTYP::SEM_TYPE> {
return MakeTypePair(
tp.ast ? tp.ast->template As<typename TargetTYP::AST_TYPE>() : nullptr,
tp.sem ? tp.sem->template As<typename TargetTYP::SEM_TYPE>() : nullptr);
}
/// Invokes the `type()` member function on each of `ast` and `sem` of the input
/// type pair
/// @param tp the type pair
/// @returns a type pair with the result of calling `type()` on `ast` and `sem`
template <typename AST, typename SEM>
TypePair<AST, SEM> Call_type(TypePair<AST, SEM> tp) {
return MakeTypePair(tp.ast->type(), tp.sem->type());
}
} // namespace typ } // namespace typ