From 9ef17472e8719f15d35d1df842c1f074aaf8af5f Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Fri, 26 Mar 2021 12:47:58 +0000 Subject: [PATCH] Add semantic::Variable::Type() and use it instead of ast::Variable::type() In anticipation of adding support for type inference, no longer use ast::Variable::type() everywhere, as it will eventually return nullptr for type-inferred variables. Instead, the Resolver now stores the final resolved type into the semantic::Variable, and nearly all code now makes use of that. ast::Variable::type() has been renamed to ast::Variable::declared_type() to help make its usage clear, and to distinguish it from semantic::Variable::Type(). Fixed tests that failed after this change because variables were missing VariableDeclStatements, so there was no path to the variables during resolving, and thus no semantic info generated for them. Bug: tint:672 Change-Id: I0125e2f555839a4892248dc6739a72e9c7f51b1e Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46100 Reviewed-by: Ben Clayton Kokoro: Kokoro Commit-Queue: Antonio Maiorano --- src/ast/function.cc | 4 +- src/ast/variable.cc | 14 +++--- src/ast/variable.h | 14 +++--- src/ast/variable_test.cc | 6 +-- src/inspector/inspector.cc | 21 +++----- .../parser_impl_global_constant_decl_test.cc | 8 +-- .../parser_impl_global_variable_decl_test.cc | 16 +++--- .../wgsl/parser_impl_param_list_test.cc | 12 ++--- src/resolver/resolver.cc | 49 ++++++++++--------- src/resolver/resolver.h | 3 +- src/semantic/sem_function.cc | 15 ++++-- src/semantic/sem_variable.cc | 4 +- src/semantic/variable.h | 12 +++-- src/transform/first_index_offset.cc | 2 +- src/transform/hlsl.cc | 18 +++---- src/transform/msl.cc | 14 +++--- src/transform/spirv.cc | 4 +- src/transform/vertex_pulling.cc | 2 +- src/validator/validator_impl.cc | 3 +- src/writer/hlsl/generator_impl.cc | 48 +++++++++--------- src/writer/hlsl/generator_impl_binary_test.cc | 9 +++- .../generator_impl_module_constant_test.cc | 3 ++ ...rator_impl_variable_decl_statement_test.cc | 1 + src/writer/msl/generator_impl.cc | 47 ++++++++++-------- .../generator_impl_module_constant_test.cc | 2 + src/writer/spirv/builder.cc | 39 ++++++++------- src/writer/wgsl/generator_impl.cc | 6 +-- .../wgsl/generator_impl_variable_test.cc | 11 +++-- 28 files changed, 216 insertions(+), 171 deletions(-) diff --git a/src/ast/function.cc b/src/ast/function.cc index b2d31748d8..6e438d814a 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -125,7 +125,9 @@ std::string Function::type_name() const { out << "__func" + return_type_->type_name(); for (auto* param : params_) { - out << param->type()->type_name(); + // No need for the semantic::Variable here, functions params must have a + // type + out << param->declared_type()->type_name(); } return out.str(); diff --git a/src/ast/variable.cc b/src/ast/variable.cc index a7a5f1bc3c..cab6350fc8 100644 --- a/src/ast/variable.cc +++ b/src/ast/variable.cc @@ -25,20 +25,20 @@ namespace ast { Variable::Variable(const Source& source, const Symbol& sym, - StorageClass sc, - type::Type* type, + StorageClass declared_storage_class, + type::Type* declared_type, bool is_const, Expression* constructor, DecorationList decorations) : Base(source), symbol_(sym), - type_(type), + declared_type_(declared_type), is_const_(is_const), constructor_(constructor), decorations_(std::move(decorations)), - declared_storage_class_(sc) { + declared_storage_class_(declared_storage_class) { TINT_ASSERT(symbol_.IsValid()); - TINT_ASSERT(type_); + TINT_ASSERT(declared_type_); } Variable::Variable(Variable&&) = default; @@ -94,7 +94,7 @@ uint32_t Variable::constant_id() const { Variable* Variable::Clone(CloneContext* ctx) const { auto src = ctx->Clone(source()); auto sym = ctx->Clone(symbol()); - auto* ty = ctx->Clone(type()); + auto* ty = ctx->Clone(declared_type()); auto* ctor = ctx->Clone(constructor()); auto decos = ctx->Clone(decorations()); return ctx->dst->create(src, sym, declared_storage_class(), ty, @@ -111,7 +111,7 @@ void Variable::info_to_str(const semantic::Info& sem, out << (var_sem ? var_sem->StorageClass() : declared_storage_class()) << std::endl; make_indent(out, indent); - out << type_->type_name() << std::endl; + out << declared_type_->type_name() << std::endl; } void Variable::constructor_to_str(const semantic::Info& sem, diff --git a/src/ast/variable.h b/src/ast/variable.h index a7c7b652b6..f5f37cb58f 100644 --- a/src/ast/variable.h +++ b/src/ast/variable.h @@ -79,15 +79,15 @@ class Variable : public Castable { /// Create a variable /// @param source the variable source /// @param sym the variable symbol - /// @param sc the declared storage class - /// @param type the value type + /// @param declared_storage_class the declared storage class + /// @param declared_type the declared variable type /// @param is_const true if the variable is const /// @param constructor the constructor expression /// @param decorations the variable decorations Variable(const Source& source, const Symbol& sym, - StorageClass sc, - type::Type* type, + StorageClass declared_storage_class, + type::Type* declared_type, bool is_const, Expression* constructor, DecorationList decorations); @@ -99,8 +99,8 @@ class Variable : public Castable { /// @returns the variable symbol const Symbol& symbol() const { return symbol_; } - /// @returns the variable's type. - type::Type* type() const { return type_; } + /// @returns the declared type + type::Type* declared_type() const { return declared_type_; } /// @returns the declared storage class StorageClass declared_storage_class() const { @@ -166,7 +166,7 @@ class Variable : public Castable { Symbol const symbol_; // The value type if a const or formal paramter, and the store type if a var - type::Type* const type_; + type::Type* const declared_type_; bool const is_const_; Expression* const constructor_; DecorationList const decorations_; diff --git a/src/ast/variable_test.cc b/src/ast/variable_test.cc index 4d5b967146..2651f1a1e3 100644 --- a/src/ast/variable_test.cc +++ b/src/ast/variable_test.cc @@ -27,7 +27,7 @@ TEST_F(VariableTest, Creation) { EXPECT_EQ(v->symbol(), Symbol(1)); EXPECT_EQ(v->declared_storage_class(), StorageClass::kFunction); - EXPECT_EQ(v->type(), ty.i32()); + EXPECT_EQ(v->declared_type(), ty.i32()); EXPECT_EQ(v->source().range.begin.line, 0u); EXPECT_EQ(v->source().range.begin.column, 0u); EXPECT_EQ(v->source().range.end.line, 0u); @@ -41,7 +41,7 @@ TEST_F(VariableTest, CreationWithSource) { EXPECT_EQ(v->symbol(), Symbol(1)); EXPECT_EQ(v->declared_storage_class(), StorageClass::kPrivate); - EXPECT_EQ(v->type(), ty.f32()); + EXPECT_EQ(v->declared_type(), ty.f32()); EXPECT_EQ(v->source().range.begin.line, 27u); EXPECT_EQ(v->source().range.begin.column, 4u); EXPECT_EQ(v->source().range.end.line, 27u); @@ -55,7 +55,7 @@ TEST_F(VariableTest, CreationEmpty) { EXPECT_EQ(v->symbol(), Symbol(1)); EXPECT_EQ(v->declared_storage_class(), StorageClass::kWorkgroup); - EXPECT_EQ(v->type(), ty.i32()); + EXPECT_EQ(v->declared_type(), ty.i32()); EXPECT_EQ(v->source().range.begin.line, 27u); EXPECT_EQ(v->source().range.begin.column, 4u); EXPECT_EQ(v->source().range.end.line, 27u); diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc index a45c811a1b..5df2847c75 100644 --- a/src/inspector/inspector.cc +++ b/src/inspector/inspector.cc @@ -211,7 +211,7 @@ std::vector Inspector::GetEntryPoints() { stage_variable.name = name; stage_variable.component_type = ComponentType::kUnknown; - auto* type = var->Declaration()->type()->UnwrapAll(); + auto* type = var->Type()->UnwrapAll(); if (type->is_float_scalar_or_vector() || type->is_float_matrix()) { stage_variable.component_type = ComponentType::kFloat; } else if (type->is_unsigned_scalar_or_vector()) { @@ -367,10 +367,9 @@ std::vector Inspector::GetUniformBufferResourceBindings( auto* func_sem = program_->Sem().Get(func); for (auto& ruv : func_sem->ReferencedUniformVariables()) { auto* var = ruv.first; - auto* decl = var->Declaration(); auto binding_info = ruv.second; - auto* unwrapped_type = decl->type()->UnwrapIfNeeded(); + auto* unwrapped_type = var->Type()->UnwrapIfNeeded(); auto* str = unwrapped_type->As(); if (str == nullptr) { continue; @@ -492,7 +491,6 @@ std::vector Inspector::GetDepthTextureResourceBindings( auto* func_sem = program_->Sem().Get(func); for (auto& ref : func_sem->ReferencedDepthTextureVariables()) { auto* var = ref.first; - auto* decl = var->Declaration(); auto binding_info = ref.second; ResourceBinding entry; @@ -500,7 +498,7 @@ std::vector Inspector::GetDepthTextureResourceBindings( entry.bind_group = binding_info.group->value(); entry.binding = binding_info.binding->value(); - auto* texture_type = decl->type()->UnwrapIfNeeded()->As(); + auto* texture_type = var->Type()->UnwrapIfNeeded()->As(); entry.dim = TypeTextureDimensionToResourceBindingTextureDimension( texture_type->dim()); @@ -537,10 +535,9 @@ std::vector Inspector::GetStorageBufferResourceBindingsImpl( std::vector result; for (auto& rsv : func_sem->ReferencedStorageBufferVariables()) { auto* var = rsv.first; - auto* decl = var->Declaration(); auto binding_info = rsv.second; - auto* ac_type = decl->type()->As(); + auto* ac_type = var->Type()->As(); if (ac_type == nullptr) { continue; } @@ -549,7 +546,7 @@ std::vector Inspector::GetStorageBufferResourceBindingsImpl( continue; } - auto* str = decl->type()->UnwrapIfNeeded()->As(); + auto* str = var->Type()->UnwrapIfNeeded()->As(); if (!str) { continue; } @@ -591,7 +588,6 @@ std::vector Inspector::GetSampledTextureResourceBindingsImpl( : func_sem->ReferencedSampledTextureVariables(); for (auto& ref : referenced_variables) { auto* var = ref.first; - auto* decl = var->Declaration(); auto binding_info = ref.second; ResourceBinding entry; @@ -601,7 +597,7 @@ std::vector Inspector::GetSampledTextureResourceBindingsImpl( entry.bind_group = binding_info.group->value(); entry.binding = binding_info.binding->value(); - auto* texture_type = decl->type()->UnwrapIfNeeded()->As(); + auto* texture_type = var->Type()->UnwrapIfNeeded()->As(); entry.dim = TypeTextureDimensionToResourceBindingTextureDimension( texture_type->dim()); @@ -634,10 +630,9 @@ std::vector Inspector::GetStorageTextureResourceBindingsImpl( std::vector result; for (auto& ref : func_sem->ReferencedStorageTextureVariables()) { auto* var = ref.first; - auto* decl = var->Declaration(); auto binding_info = ref.second; - auto* ac_type = decl->type()->As(); + auto* ac_type = var->Type()->As(); if (ac_type == nullptr) { continue; } @@ -654,7 +649,7 @@ std::vector Inspector::GetStorageTextureResourceBindingsImpl( entry.binding = binding_info.binding->value(); auto* texture_type = - decl->type()->UnwrapIfNeeded()->As(); + var->Type()->UnwrapIfNeeded()->As(); entry.dim = TypeTextureDimensionToResourceBindingTextureDimension( texture_type->dim()); diff --git a/src/reader/wgsl/parser_impl_global_constant_decl_test.cc b/src/reader/wgsl/parser_impl_global_constant_decl_test.cc index c6f93d5699..11d5c9331a 100644 --- a/src/reader/wgsl/parser_impl_global_constant_decl_test.cc +++ b/src/reader/wgsl/parser_impl_global_constant_decl_test.cc @@ -32,8 +32,8 @@ TEST_F(ParserImplTest, GlobalConstantDecl) { EXPECT_TRUE(e->is_const()); EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a")); - ASSERT_NE(e->type(), nullptr); - EXPECT_TRUE(e->type()->Is()); + ASSERT_NE(e->declared_type(), nullptr); + EXPECT_TRUE(e->declared_type()->Is()); EXPECT_EQ(e->source().range.begin.line, 1u); EXPECT_EQ(e->source().range.begin.column, 7u); @@ -112,8 +112,8 @@ TEST_F(ParserImplTest, GlobalConstantDec_ConstantId) { EXPECT_TRUE(e->is_const()); EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a")); - ASSERT_NE(e->type(), nullptr); - EXPECT_TRUE(e->type()->Is()); + ASSERT_NE(e->declared_type(), nullptr); + EXPECT_TRUE(e->declared_type()->Is()); EXPECT_EQ(e->source().range.begin.line, 1u); EXPECT_EQ(e->source().range.begin.column, 26u); diff --git a/src/reader/wgsl/parser_impl_global_variable_decl_test.cc b/src/reader/wgsl/parser_impl_global_variable_decl_test.cc index 89d6909687..831d115239 100644 --- a/src/reader/wgsl/parser_impl_global_variable_decl_test.cc +++ b/src/reader/wgsl/parser_impl_global_variable_decl_test.cc @@ -31,7 +31,7 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithoutConstructor) { ASSERT_NE(e.value, nullptr); EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a")); - EXPECT_TRUE(e->type()->Is()); + EXPECT_TRUE(e->declared_type()->Is()); EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kOutput); EXPECT_EQ(e->source().range.begin.line, 1u); @@ -54,7 +54,7 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithConstructor) { ASSERT_NE(e.value, nullptr); EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a")); - EXPECT_TRUE(e->type()->Is()); + EXPECT_TRUE(e->declared_type()->Is()); EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kOutput); EXPECT_EQ(e->source().range.begin.line, 1u); @@ -79,8 +79,8 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithDecoration) { ASSERT_NE(e.value, nullptr); EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a")); - ASSERT_NE(e->type(), nullptr); - EXPECT_TRUE(e->type()->Is()); + ASSERT_NE(e->declared_type(), nullptr); + EXPECT_TRUE(e->declared_type()->Is()); EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kOutput); EXPECT_EQ(e->source().range.begin.line, 1u); @@ -109,8 +109,8 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithDecoration_MulitpleGroups) { ASSERT_NE(e.value, nullptr); EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a")); - ASSERT_NE(e->type(), nullptr); - EXPECT_TRUE(e->type()->Is()); + ASSERT_NE(e->declared_type(), nullptr); + EXPECT_TRUE(e->declared_type()->Is()); EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kOutput); EXPECT_EQ(e->source().range.begin.line, 1u); @@ -180,7 +180,7 @@ TEST_F(ParserImplTest, GlobalVariableDecl_SamplerImplicitStorageClass) { ASSERT_NE(e.value, nullptr); EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("s")); - EXPECT_TRUE(e->type()->Is()); + EXPECT_TRUE(e->declared_type()->Is()); EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kUniformConstant); } @@ -196,7 +196,7 @@ TEST_F(ParserImplTest, GlobalVariableDecl_TextureImplicitStorageClass) { ASSERT_NE(e.value, nullptr); EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("s")); - EXPECT_TRUE(e->type()->UnwrapAll()->Is()); + EXPECT_TRUE(e->declared_type()->UnwrapAll()->Is()); EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kUniformConstant); } diff --git a/src/reader/wgsl/parser_impl_param_list_test.cc b/src/reader/wgsl/parser_impl_param_list_test.cc index 0bf7686420..c581145e5c 100644 --- a/src/reader/wgsl/parser_impl_param_list_test.cc +++ b/src/reader/wgsl/parser_impl_param_list_test.cc @@ -30,7 +30,7 @@ TEST_F(ParserImplTest, ParamList_Single) { EXPECT_EQ(e.value.size(), 1u); EXPECT_EQ(e.value[0]->symbol(), p->builder().Symbols().Get("a")); - EXPECT_EQ(e.value[0]->type(), i32); + EXPECT_EQ(e.value[0]->declared_type(), i32); EXPECT_TRUE(e.value[0]->is_const()); ASSERT_EQ(e.value[0]->source().range.begin.line, 1u); @@ -52,7 +52,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) { EXPECT_EQ(e.value.size(), 3u); EXPECT_EQ(e.value[0]->symbol(), p->builder().Symbols().Get("a")); - EXPECT_EQ(e.value[0]->type(), i32); + EXPECT_EQ(e.value[0]->declared_type(), i32); EXPECT_TRUE(e.value[0]->is_const()); ASSERT_EQ(e.value[0]->source().range.begin.line, 1u); @@ -61,7 +61,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) { ASSERT_EQ(e.value[0]->source().range.end.column, 2u); EXPECT_EQ(e.value[1]->symbol(), p->builder().Symbols().Get("b")); - EXPECT_EQ(e.value[1]->type(), f32); + EXPECT_EQ(e.value[1]->declared_type(), f32); EXPECT_TRUE(e.value[1]->is_const()); ASSERT_EQ(e.value[1]->source().range.begin.line, 1u); @@ -70,7 +70,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) { ASSERT_EQ(e.value[1]->source().range.end.column, 11u); EXPECT_EQ(e.value[2]->symbol(), p->builder().Symbols().Get("c")); - EXPECT_EQ(e.value[2]->type(), vec2); + EXPECT_EQ(e.value[2]->declared_type(), vec2); EXPECT_TRUE(e.value[2]->is_const()); ASSERT_EQ(e.value[2]->source().range.begin.line, 1u); @@ -109,7 +109,7 @@ TEST_F(ParserImplTest, ParamList_Decorations) { ASSERT_EQ(e.value.size(), 2u); EXPECT_EQ(e.value[0]->symbol(), p->builder().Symbols().Get("coord")); - EXPECT_EQ(e.value[0]->type(), vec4); + EXPECT_EQ(e.value[0]->declared_type(), vec4); EXPECT_TRUE(e.value[0]->is_const()); auto decos0 = e.value[0]->decorations(); ASSERT_EQ(decos0.size(), 1u); @@ -123,7 +123,7 @@ TEST_F(ParserImplTest, ParamList_Decorations) { ASSERT_EQ(e.value[0]->source().range.end.column, 30u); EXPECT_EQ(e.value[1]->symbol(), p->builder().Symbols().Get("loc1")); - EXPECT_EQ(e.value[1]->type(), f32); + EXPECT_EQ(e.value[1]->declared_type(), f32); EXPECT_TRUE(e.value[1]->is_const()); auto decos1 = e.value[1]->decorations(); ASSERT_EQ(decos1.size(), 1u); diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 189202ae28..bdf95be95c 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -201,7 +201,8 @@ bool Resolver::ResolveInternal() { return false; } } else if (auto* var = decl->As()) { - variable_stack_.set_global(var->symbol(), CreateVariableInfo(var)); + auto* info = CreateVariableInfo(var); + variable_stack_.set_global(var->symbol(), info); if (var->has_constructor()) { if (!Expression(var->constructor())) { @@ -210,7 +211,7 @@ bool Resolver::ResolveInternal() { } if (!ApplyStorageClassUsageToType(var->declared_storage_class(), - var->type(), var->source())) { + info->type, var->source())) { diagnostics_.add_note("while instantiating variable " + builder_->Symbols().NameFor(var->symbol()), var->source()); @@ -223,7 +224,8 @@ bool Resolver::ResolveInternal() { } bool Resolver::ValidateParameter(const ast::Variable* param) { - if (auto* r = param->type()->UnwrapAll()->As()) { + auto* type = variable_to_info_[param]->type; + if (auto* r = type->UnwrapAll()->As()) { if (r->IsRuntimeArray()) { diagnostics_.add_error( "v-0015", @@ -277,10 +279,6 @@ bool Resolver::ValidateFunction(const ast::Function* func) { bool Resolver::Function(ast::Function* func) { auto* func_info = function_infos_.Create(func); - if (!ValidateFunction(func)) { - return false; - } - ScopedAssignment sa(current_function_, func_info); variable_stack_.push_scope(); @@ -293,6 +291,10 @@ bool Resolver::Function(ast::Function* func) { } variable_stack_.pop_scope(); + if (!ValidateFunction(func)) { + return false; + } + // Register the function information _after_ processing the statements. This // allows us to catch a function calling itself when determining the call // information as this function doesn't exist until it's finished. @@ -780,12 +782,12 @@ bool Resolver::Identifier(ast::IdentifierExpression* expr) { // A constant is the type, but a variable is always a pointer so synthesize // the pointer around the variable type. if (var->declaration->is_const()) { - SetType(expr, var->declaration->type()); - } else if (var->declaration->type()->Is()) { - SetType(expr, var->declaration->type()); + SetType(expr, var->type); + } else if (var->type->Is()) { + SetType(expr, var->type); } else { - SetType(expr, builder_->create(var->declaration->type(), - var->storage_class)); + SetType(expr, + builder_->create(var->type, var->storage_class)); } var->users.push_back(expr); @@ -1200,15 +1202,18 @@ bool Resolver::UnaryOp(ast::UnaryOpExpression* expr) { } bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { + ast::Variable* var = stmt->variable(); + type::Type* type = var->declared_type(); + if (auto* ctor = stmt->variable()->constructor()) { if (!Expression(ctor)) { return false; } - auto* lhs_type = stmt->variable()->type(); auto* rhs_type = TypeOf(ctor); - if (!IsValidAssignment(lhs_type, rhs_type)) { + + if (!IsValidAssignment(type, rhs_type)) { diagnostics_.add_error( - "variable of type '" + lhs_type->FriendlyName(builder_->Symbols()) + + "variable of type '" + type->FriendlyName(builder_->Symbols()) + "' cannot be initialized with a value of type '" + rhs_type->FriendlyName(builder_->Symbols()) + "'", stmt->source()); @@ -1216,10 +1221,8 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { } } - auto* var = stmt->variable(); - auto* info = CreateVariableInfo(var); - variable_to_info_.emplace(var, info); + info->type = type; variable_stack_.set(var->symbol(), info); current_block_->decls.push_back(var); @@ -1235,7 +1238,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { } } - if (!ApplyStorageClassUsageToType(info->storage_class, var->type(), + if (!ApplyStorageClassUsageToType(info->storage_class, info->type, var->source())) { diagnostics_.add_note("while instantiating variable " + builder_->Symbols().NameFor(var->symbol()), @@ -1303,8 +1306,8 @@ void Resolver::CreateSemanticNodes() const { } users.push_back(sem_expr); } - sem.Add(var, builder_->create(var, info->storage_class, - std::move(users))); + sem.Add(var, builder_->create( + var, info->type, info->storage_class, std::move(users))); } auto remap_vars = [&sem](const std::vector& in) { @@ -1812,7 +1815,9 @@ std::string Resolver::VectorPretty(uint32_t size, type::Type* element_type) { } Resolver::VariableInfo::VariableInfo(ast::Variable* decl) - : declaration(decl), storage_class(decl->declared_storage_class()) {} + : declaration(decl), + type(decl->declared_type()), + storage_class(decl->declared_storage_class()) {} Resolver::VariableInfo::~VariableInfo() = default; diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 3109dc6f2f..3f39d7387e 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -93,6 +93,7 @@ class Resolver { ~VariableInfo(); ast::Variable* const declaration; + type::Type* type; ast::StorageClass storage_class; std::vector users; }; @@ -290,7 +291,7 @@ class Resolver { ScopeStack variable_stack_; std::unordered_map symbol_to_function_; std::unordered_map function_to_info_; - std::unordered_map variable_to_info_; + std::unordered_map variable_to_info_; std::unordered_map function_calls_; std::unordered_map expr_info_; std::unordered_map struct_info_; diff --git a/src/semantic/sem_function.cc b/src/semantic/sem_function.cc index 15e0b6597f..8e4dcbd07a 100644 --- a/src/semantic/sem_function.cc +++ b/src/semantic/sem_function.cc @@ -32,7 +32,8 @@ ParameterList GetParameters(ast::Function* ast) { ParameterList parameters; parameters.reserve(ast->params().size()); for (auto* param : ast->params()) { - parameters.emplace_back(Parameter{param->type(), Parameter::Usage::kNone}); + parameters.emplace_back( + Parameter{param->declared_type(), Parameter::Usage::kNone}); } return parameters; } @@ -160,7 +161,8 @@ Function::VariableBindings Function::ReferencedStorageTextureVariables() const { VariableBindings ret; for (auto* var : ReferencedModuleVariables()) { - auto* unwrapped_type = var->Declaration()->type()->UnwrapIfNeeded(); + auto* unwrapped_type = + var->Declaration()->declared_type()->UnwrapIfNeeded(); auto* storage_texture = unwrapped_type->As(); if (storage_texture == nullptr) { continue; @@ -182,7 +184,8 @@ Function::VariableBindings Function::ReferencedDepthTextureVariables() const { VariableBindings ret; for (auto* var : ReferencedModuleVariables()) { - auto* unwrapped_type = var->Declaration()->type()->UnwrapIfNeeded(); + auto* unwrapped_type = + var->Declaration()->declared_type()->UnwrapIfNeeded(); auto* storage_texture = unwrapped_type->As(); if (storage_texture == nullptr) { continue; @@ -229,7 +232,8 @@ Function::VariableBindings Function::ReferencedSamplerVariablesImpl( VariableBindings ret; for (auto* var : ReferencedModuleVariables()) { - auto* unwrapped_type = var->Declaration()->type()->UnwrapIfNeeded(); + auto* unwrapped_type = + var->Declaration()->declared_type()->UnwrapIfNeeded(); auto* sampler = unwrapped_type->As(); if (sampler == nullptr || sampler->kind() != kind) { continue; @@ -252,7 +256,8 @@ Function::VariableBindings Function::ReferencedSampledTextureVariablesImpl( VariableBindings ret; for (auto* var : ReferencedModuleVariables()) { - auto* unwrapped_type = var->Declaration()->type()->UnwrapIfNeeded(); + auto* unwrapped_type = + var->Declaration()->declared_type()->UnwrapIfNeeded(); auto* texture = unwrapped_type->As(); if (texture == nullptr) { continue; diff --git a/src/semantic/sem_variable.cc b/src/semantic/sem_variable.cc index 1051a6e6ab..03dc0344bf 100644 --- a/src/semantic/sem_variable.cc +++ b/src/semantic/sem_variable.cc @@ -19,10 +19,12 @@ TINT_INSTANTIATE_TYPEINFO(tint::semantic::Variable); namespace tint { namespace semantic { -Variable::Variable(ast::Variable* declaration, +Variable::Variable(const ast::Variable* declaration, + type::Type* type, ast::StorageClass storage_class, std::vector users) : declaration_(declaration), + type_(type), storage_class_(storage_class), users_(std::move(users)) {} diff --git a/src/semantic/variable.h b/src/semantic/variable.h index 96f023a42d..0e606a27c7 100644 --- a/src/semantic/variable.h +++ b/src/semantic/variable.h @@ -37,9 +37,11 @@ class Variable : public Castable { public: /// Constructor /// @param declaration the AST declaration node + /// @param type the variable type /// @param storage_class the variable storage class /// @param users the expressions that use the variable - explicit Variable(ast::Variable* declaration, + explicit Variable(const ast::Variable* declaration, + type::Type* type, ast::StorageClass storage_class, std::vector users); @@ -47,7 +49,10 @@ class Variable : public Castable { ~Variable() override; /// @returns the AST declaration node - ast::Variable* Declaration() const { return declaration_; } + const ast::Variable* Declaration() const { return declaration_; } + + /// @returns the type for the variable + type::Type* Type() const { return type_; } /// @returns the storage class for the variable ast::StorageClass StorageClass() const { return storage_class_; } @@ -56,7 +61,8 @@ class Variable : public Castable { const std::vector& Users() const { return users_; } private: - ast::Variable* const declaration_; + const ast::Variable* const declaration_; + type::Type* const type_; ast::StorageClass const storage_class_; std::vector const users_; }; diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index eccd8130b1..cdc0288adb 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -39,7 +39,7 @@ ast::Variable* clone_variable_with_new_name(CloneContext* ctx, // Clone arguments outside of create() call to have deterministic ordering auto source = ctx->Clone(in->source()); auto symbol = ctx->dst->Symbols().Register(new_name); - auto* type = ctx->Clone(in->type()); + auto* type = ctx->Clone(in->declared_type()); auto* constructor = ctx->Clone(in->constructor()); auto decorations = ctx->Clone(in->decorations()); return ctx->dst->create( diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc index 494ca38b4d..8ed06ab907 100644 --- a/src/transform/hlsl.cc +++ b/src/transform/hlsl.cc @@ -145,7 +145,8 @@ void Hlsl::HandleEntryPointIOTypes(CloneContext& ctx) const { // Build a new structure to hold the non-struct input parameters. ast::StructMemberList struct_members; for (auto* param : func->params()) { - if (param->type()->Is()) { + auto* type = ctx.src->Sem().Get(param)->Type(); + if (type->Is()) { // Already a struct, nothing to do. continue; } @@ -159,14 +160,12 @@ void Hlsl::HandleEntryPointIOTypes(CloneContext& ctx) const { auto* deco = param->decorations()[0]; if (auto* builtin = deco->As()) { // Create a struct member with the builtin decoration. - struct_members.push_back( - ctx.dst->Member(name, ctx.Clone(param->type()), - ast::DecorationList{ctx.Clone(builtin)})); + struct_members.push_back(ctx.dst->Member( + name, ctx.Clone(type), ast::DecorationList{ctx.Clone(builtin)})); } else if (auto* loc = deco->As()) { // Create a struct member with the location decoration. - struct_members.push_back( - ctx.dst->Member(name, ctx.Clone(param->type()), - ast::DecorationList{ctx.Clone(loc)})); + struct_members.push_back(ctx.dst->Member( + name, ctx.Clone(type), ast::DecorationList{ctx.Clone(loc)})); } else { TINT_ICE(ctx.dst->Diagnostics()) << "Unsupported entry point parameter decoration"; @@ -195,7 +194,8 @@ void Hlsl::HandleEntryPointIOTypes(CloneContext& ctx) const { // Replace the original parameters with function-scope constants. for (auto* param : func->params()) { - if (param->type()->Is()) { + auto* type = ctx.src->Sem().Get(param)->Type(); + if (type->Is()) { // Keep struct parameters unchanged. new_parameters.push_back(ctx.Clone(param)); continue; @@ -207,7 +207,7 @@ void Hlsl::HandleEntryPointIOTypes(CloneContext& ctx) const { // Initialize it with the value extracted from the struct parameter. auto func_const_symbol = ctx.dst->Symbols().Register(name); auto* func_const = - ctx.dst->Const(func_const_symbol, ctx.Clone(param->type()), + ctx.dst->Const(func_const_symbol, ctx.Clone(type), ctx.dst->MemberAccessor(struct_param_symbol, name)); new_body.push_back(ctx.dst->WrapInStatement(func_const)); diff --git a/src/transform/msl.cc b/src/transform/msl.cc index 3110418310..6245fc7f88 100644 --- a/src/transform/msl.cc +++ b/src/transform/msl.cc @@ -14,6 +14,7 @@ #include "src/transform/msl.h" +#include #include #include #include @@ -326,9 +327,10 @@ void Msl::HandleEntryPointIOTypes(CloneContext& ctx) const { continue; } else if (auto* loc = deco->As()) { // Create a struct member with the location decoration. - struct_members.push_back(ctx.dst->Member( - ctx.src->Symbols().NameFor(param->symbol()), - ctx.Clone(param->type()), ast::DecorationList{ctx.Clone(loc)})); + std::string name = ctx.src->Symbols().NameFor(param->symbol()); + auto* type = ctx.Clone(ctx.src->Sem().Get(param)->Type()); + struct_members.push_back( + ctx.dst->Member(name, type, ast::DecorationList{ctx.Clone(loc)})); } else { TINT_ICE(ctx.dst->Diagnostics()) << "Unsupported entry point parameter decoration"; @@ -368,9 +370,9 @@ void Msl::HandleEntryPointIOTypes(CloneContext& ctx) const { // Create a function-scope const to replace the parameter. // Initialize it with the value extracted from the struct parameter. auto func_const_symbol = ctx.dst->Symbols().Register(name); - auto* func_const = - ctx.dst->Const(func_const_symbol, ctx.Clone(param->type()), - ctx.dst->MemberAccessor(struct_param_symbol, name)); + auto* type = ctx.Clone(ctx.src->Sem().Get(param)->Type()); + auto* constructor = ctx.dst->MemberAccessor(struct_param_symbol, name); + auto* func_const = ctx.dst->Const(func_const_symbol, type, constructor); new_body.push_back(ctx.dst->WrapInStatement(func_const)); diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc index 0f872d3f1b..06bd5f63fc 100644 --- a/src/transform/spirv.cc +++ b/src/transform/spirv.cc @@ -137,8 +137,8 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const { } for (auto* param : func->params()) { - Symbol new_var = - HoistToInputVariables(ctx, func, param->type(), param->decorations()); + Symbol new_var = HoistToInputVariables( + ctx, func, ctx.src->Sem().Get(param)->Type(), param->decorations()); // Replace all uses of the function parameter with the new variable. for (auto* user : ctx.src->Sem().Get(param)->Users()) { diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc index 78e96c7482..638eb53818 100644 --- a/src/transform/vertex_pulling.cc +++ b/src/transform/vertex_pulling.cc @@ -207,7 +207,7 @@ void VertexPulling::State::ConvertVertexInputVariablesToPrivate() { Source{}, // source ctx.dst->Symbols().Register(name), // symbol ast::StorageClass::kPrivate, // storage_class - ctx.Clone(v->type()), // type + ctx.Clone(v->declared_type()), // type false, // is_const nullptr, // constructor ast::DecorationList{}); // decorations diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc index f9aea11507..a00c5fa899 100644 --- a/src/validator/validator_impl.cc +++ b/src/validator/validator_impl.cc @@ -219,7 +219,8 @@ bool ValidatorImpl::ValidateDeclStatement( // storable. // - types match or the RHS can be dereferenced to equal the LHS type. variable_stack_.set(symbol, decl->variable()); - if (auto* arr = decl->variable()->type()->UnwrapAll()->As()) { + if (auto* arr = + decl->variable()->declared_type()->UnwrapAll()->As()) { if (arr->IsRuntimeArray()) { add_error( decl->source(), "v-0015", diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 1b940e8580..c5e4a4cbd8 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -1515,11 +1515,13 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, } first = false; - if (!EmitType(out, v->type(), builder_.Symbols().NameFor(v->symbol()))) { + auto* type = builder_.Sem().Get(v)->Type(); + + if (!EmitType(out, type, builder_.Symbols().NameFor(v->symbol()))) { return false; } // Array name is output as part of the type - if (!v->type()->Is()) { + if (!type->Is()) { out << " " << builder_.Symbols().NameFor(v->symbol()); } } @@ -1541,8 +1543,8 @@ bool GeneratorImpl::EmitEntryPointData( std::ostream& out, ast::Function* func, std::unordered_set& emitted_globals) { - std::vector> in_variables; - std::vector> outvariables; + std::vector> in_variables; + std::vector> outvariables; auto* func_sem = builder_.Sem().Get(func); auto func_sym = func->symbol(); @@ -1595,7 +1597,7 @@ bool GeneratorImpl::EmitEntryPointData( } // auto* set = data.second.set; - auto* type = decl->type()->UnwrapIfNeeded(); + auto* type = var->Type()->UnwrapIfNeeded(); if (auto* strct = type->As()) { out << "ConstantBuffer<" << builder_.Symbols().NameFor(strct->symbol()) << "> " << builder_.Symbols().NameFor(decl->symbol()) @@ -1638,7 +1640,7 @@ bool GeneratorImpl::EmitEntryPointData( } auto* binding = data.second.binding; - auto* ac = decl->type()->As(); + auto* ac = var->Type()->As(); if (ac == nullptr) { diagnostics_.add_error("access control type required for storage buffer"); return false; @@ -1672,10 +1674,10 @@ bool GeneratorImpl::EmitEntryPointData( for (auto& data : in_variables) { auto* var = data.first; auto* deco = data.second; + auto* type = builder_.Sem().Get(var)->Type(); make_indent(out); - if (!EmitType(out, var->type(), - builder_.Symbols().NameFor(var->symbol()))) { + if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) { return false; } @@ -1722,10 +1724,10 @@ bool GeneratorImpl::EmitEntryPointData( for (auto& data : outvariables) { auto* var = data.first; auto* deco = data.second; + auto* type = builder_.Sem().Get(var)->Type(); make_indent(out); - if (!EmitType(out, var->type(), - builder_.Symbols().NameFor(var->symbol()))) { + if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) { return false; } @@ -1766,7 +1768,7 @@ bool GeneratorImpl::EmitEntryPointData( for (auto* var : func_sem->ReferencedModuleVariables()) { auto* decl = var->Declaration(); - auto* unwrapped_type = decl->type()->UnwrapAll(); + auto* unwrapped_type = var->Type()->UnwrapAll(); if (!unwrapped_type->Is() && !unwrapped_type->Is()) { continue; // Not interested in this type @@ -1776,7 +1778,7 @@ bool GeneratorImpl::EmitEntryPointData( continue; // Global already emitted } - if (!EmitType(out, decl->type(), "")) { + if (!EmitType(out, var->Type(), "")) { return false; } out << " " << namer_.NameFor(builder_.Symbols().NameFor(decl->symbol())) @@ -1861,7 +1863,8 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, // Emit entry point parameters. for (auto* var : func->params()) { - if (!var->type()->Is()) { + auto* type = builder_.Sem().Get(var)->Type(); + if (!type->Is()) { TINT_ICE(diagnostics_) << "Unsupported non-struct entry point parameter"; } @@ -1870,7 +1873,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, } first = false; - if (!EmitType(out, var->type(), "")) { + if (!EmitType(out, type, "")) { return false; } @@ -2024,7 +2027,7 @@ bool GeneratorImpl::EmitLoop(std::ostream& out, ast::LoopStatement* stmt) { if (var->constructor() != nullptr) { out << constructor_out.str(); } else { - if (!EmitZeroValue(out, var->type())) { + if (!EmitZeroValue(out, builder_.Sem().Get(var)->Type())) { return false; } } @@ -2678,10 +2681,11 @@ bool GeneratorImpl::EmitVariable(std::ostream& out, if (var->is_const()) { out << "const "; } - if (!EmitType(out, var->type(), builder_.Symbols().NameFor(var->symbol()))) { + auto* type = builder_.Sem().Get(var)->Type(); + if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) { return false; } - if (!var->type()->Is()) { + if (!type->Is()) { out << " " << builder_.Symbols().NameFor(var->symbol()); } out << constructor_out.str() << ";" << std::endl; @@ -2713,6 +2717,8 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out, out << pre.str(); } + auto* type = builder_.Sem().Get(var)->Type(); + if (var->HasConstantIdDecoration()) { auto const_id = var->constant_id(); @@ -2727,8 +2733,7 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out, } out << "#endif" << std::endl; out << "static const "; - if (!EmitType(out, var->type(), - builder_.Symbols().NameFor(var->symbol()))) { + if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) { return false; } out << " " << builder_.Symbols().NameFor(var->symbol()) @@ -2736,11 +2741,10 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out, out << "#undef WGSL_SPEC_CONSTANT_" << const_id << std::endl; } else { out << "static const "; - if (!EmitType(out, var->type(), - builder_.Symbols().NameFor(var->symbol()))) { + if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) { return false; } - if (!var->type()->Is()) { + if (!type->Is()) { out << " " << builder_.Symbols().NameFor(var->symbol()); } diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc index 0f65366cc9..f5941f0e0b 100644 --- a/src/writer/hlsl/generator_impl_binary_test.cc +++ b/src/writer/hlsl/generator_impl_binary_test.cc @@ -412,6 +412,10 @@ a = (_tint_tmp_0); TEST_F(HlslGeneratorImplTest_Binary, Decl_WithLogical) { // var a : bool = (b && c) || d; + auto* b_decl = Decl(Var("b", ty.bool_(), ast::StorageClass::kFunction)); + auto* c_decl = Decl(Var("c", ty.bool_(), ast::StorageClass::kFunction)); + auto* d_decl = Decl(Var("d", ty.bool_(), ast::StorageClass::kFunction)); + auto* b = Expr("b"); auto* c = Expr("c"); auto* d = Expr("d"); @@ -422,11 +426,12 @@ TEST_F(HlslGeneratorImplTest_Binary, Decl_WithLogical) { ast::BinaryOp::kLogicalOr, create(ast::BinaryOp::kLogicalAnd, b, c), d)); - auto* expr = create(var); + auto* decl = Decl(var); + WrapInFunction(b_decl, c_decl, d_decl, Decl(var)); GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitStatement(out, expr)) << gen.error(); + ASSERT_TRUE(gen.EmitStatement(out, decl)) << gen.error(); EXPECT_EQ(result(), R"(bool _tint_tmp = b; if (_tint_tmp) { _tint_tmp = c; diff --git a/src/writer/hlsl/generator_impl_module_constant_test.cc b/src/writer/hlsl/generator_impl_module_constant_test.cc index b3a197f5d7..405b916bc2 100644 --- a/src/writer/hlsl/generator_impl_module_constant_test.cc +++ b/src/writer/hlsl/generator_impl_module_constant_test.cc @@ -24,6 +24,7 @@ using HlslGeneratorImplTest_ModuleConstant = TestHelper; TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_ModuleConstant) { auto* var = Const("pos", ty.array(), array(1.f, 2.f, 3.f)); + WrapInFunction(Decl(var)); GeneratorImpl& gen = Build(); @@ -36,6 +37,7 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant) { ast::DecorationList{ create(23), }); + WrapInFunction(Decl(var)); GeneratorImpl& gen = Build(); @@ -53,6 +55,7 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoConstructor) { ast::DecorationList{ create(23), }); + WrapInFunction(Decl(var)); GeneratorImpl& gen = Build(); diff --git a/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc b/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc index 8ffce85ce1..495763d096 100644 --- a/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc +++ b/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc @@ -39,6 +39,7 @@ TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const) { auto* var = Const("a", ty.f32()); auto* stmt = create(var); + WrapInFunction(stmt); GeneratorImpl& gen = Build(); diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index b3142a2e8d..cf7e396b2e 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -964,8 +964,8 @@ bool GeneratorImpl::EmitLiteral(ast::Literal* lit) { bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { auto* func_sem = program_->Sem().Get(func); - std::vector> in_locations; - std::vector> out_variables; + std::vector> in_locations; + std::vector> out_variables; for (auto data : func_sem->ReferencedLocationVariables()) { auto* var = data.first; @@ -1003,7 +1003,8 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { uint32_t loc = data.second; make_indent(); - if (!EmitType(var->type(), program_->Symbols().NameFor(var->symbol()))) { + if (!EmitType(program_->Sem().Get(var)->Type(), + program_->Symbols().NameFor(var->symbol()))) { return false; } @@ -1039,7 +1040,8 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { auto* deco = data.second; make_indent(); - if (!EmitType(var->type(), program_->Symbols().NameFor(var->symbol()))) { + if (!EmitType(program_->Sem().Get(var)->Type(), + program_->Symbols().NameFor(var->symbol()))) { return false; } @@ -1252,7 +1254,7 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, first = false; out_ << "thread "; - if (!EmitType(var->Declaration()->type(), "")) { + if (!EmitType(var->Type(), "")) { return false; } out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol()); @@ -1267,7 +1269,7 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, out_ << "constant "; // TODO(dsinclair): Can arrays be uniform? If so, fix this ... - if (!EmitType(var->Declaration()->type(), "")) { + if (!EmitType(var->Type(), "")) { return false; } out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol()); @@ -1280,7 +1282,7 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, } first = false; - auto* ac = var->Declaration()->type()->As(); + auto* ac = var->Type()->As(); if (ac == nullptr) { diagnostics_.add_error( "invalid type for storage buffer, expected access control"); @@ -1303,11 +1305,13 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, } first = false; - if (!EmitType(v->type(), program_->Symbols().NameFor(v->symbol()))) { + auto* type = program_->Sem().Get(v)->Type(); + + if (!EmitType(type, program_->Symbols().NameFor(v->symbol()))) { return false; } // Array name is output as part of the type - if (!v->type()->Is()) { + if (!type->Is()) { out_ << " " << program_->Symbols().NameFor(v->symbol()); } } @@ -1394,13 +1398,15 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { } first = false; - if (!EmitType(var->type(), "")) { + auto* type = program_->Sem().Get(var)->Type(); + + if (!EmitType(type, "")) { return false; } out_ << " " << program_->Symbols().NameFor(var->symbol()); - if (var->type()->Is()) { + if (type->Is()) { out_ << " [[stage_in]]"; } else { auto& decos = var->decorations(); @@ -1440,7 +1446,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { auto* builtin = data.second; - if (!EmitType(var->Declaration()->type(), "")) { + if (!EmitType(var->Type(), "")) { return false; } @@ -1475,7 +1481,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { out_ << "constant "; // TODO(dsinclair): Can you have a uniform array? If so, this needs to be // updated to handle arrays property. - if (!EmitType(var->Declaration()->type(), "")) { + if (!EmitType(var->Type(), "")) { return false; } out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol()) @@ -1495,7 +1501,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { auto* binding = data.second.binding; // auto* set = data.second.set; - auto* ac = var->Declaration()->type()->As(); + auto* ac = var->Type()->As(); if (ac == nullptr) { diagnostics_.add_error( "invalid type for storage buffer, expected access control"); @@ -1640,7 +1646,7 @@ bool GeneratorImpl::EmitLoop(ast::LoopStatement* stmt) { return false; } } else { - if (!EmitZeroValue(var->type())) { + if (!EmitZeroValue(program_->Sem().Get(var)->Type())) { return false; } } @@ -2156,10 +2162,10 @@ bool GeneratorImpl::EmitVariable(const semantic::Variable* var, if (decl->is_const()) { out_ << "const "; } - if (!EmitType(decl->type(), program_->Symbols().NameFor(decl->symbol()))) { + if (!EmitType(var->Type(), program_->Symbols().NameFor(decl->symbol()))) { return false; } - if (!decl->type()->Is()) { + if (!var->Type()->Is()) { out_ << " " << program_->Symbols().NameFor(decl->symbol()); } @@ -2173,7 +2179,7 @@ bool GeneratorImpl::EmitVariable(const semantic::Variable* var, var->StorageClass() == ast::StorageClass::kFunction || var->StorageClass() == ast::StorageClass::kNone || var->StorageClass() == ast::StorageClass::kOutput) { - if (!EmitZeroValue(decl->type())) { + if (!EmitZeroValue(var->Type())) { return false; } } @@ -2198,10 +2204,11 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { } out_ << "constant "; - if (!EmitType(var->type(), program_->Symbols().NameFor(var->symbol()))) { + auto* type = program_->Sem().Get(var)->Type(); + if (!EmitType(type, program_->Symbols().NameFor(var->symbol()))) { return false; } - if (!var->type()->Is()) { + if (!type->Is()) { out_ << " " << program_->Symbols().NameFor(var->symbol()); } diff --git a/src/writer/msl/generator_impl_module_constant_test.cc b/src/writer/msl/generator_impl_module_constant_test.cc index b17ce9d303..c7932146fc 100644 --- a/src/writer/msl/generator_impl_module_constant_test.cc +++ b/src/writer/msl/generator_impl_module_constant_test.cc @@ -24,6 +24,7 @@ using MslGeneratorImplTest = TestHelper; TEST_F(MslGeneratorImplTest, Emit_ModuleConstant) { auto* var = Const("pos", ty.array(), array(1.f, 2.f, 3.f)); + WrapInFunction(Decl(var)); GeneratorImpl& gen = Build(); @@ -36,6 +37,7 @@ TEST_F(MslGeneratorImplTest, Emit_SpecConstant) { ast::DecorationList{ create(23), }); + WrapInFunction(Decl(var)); GeneratorImpl& gen = Build(); diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 440699a1cb..96b5f65947 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -538,7 +538,8 @@ bool Builder::GenerateFunction(ast::Function* func) { auto param_op = result_op(); auto param_id = param_op.to_i(); - auto param_type_id = GenerateTypeIfNeeded(param->type()); + auto param_type_id = + GenerateTypeIfNeeded(builder_.Sem().Get(param)->Type()); if (param_type_id == 0) { return false; } @@ -592,7 +593,8 @@ uint32_t Builder::GenerateFunctionTypeIfNeeded(ast::Function* func) { OperandList ops = {func_op, Operand::Int(ret_id)}; for (auto* param : func->params()) { - auto param_type_id = GenerateTypeIfNeeded(param->type()); + auto param_type_id = + GenerateTypeIfNeeded(builder_.Sem().Get(param)->Type()); if (param_type_id == 0) { return 0; } @@ -629,7 +631,8 @@ bool Builder::GenerateFunctionVariable(ast::Variable* var) { auto result = result_op(); auto var_id = result.to_i(); auto sc = ast::StorageClass::kFunction; - type::Pointer pt(var->type(), sc); + auto* type = builder_.Sem().Get(var)->Type(); + type::Pointer pt(type, sc); auto type_id = GenerateTypeIfNeeded(&pt); if (type_id == 0) { return false; @@ -641,7 +644,7 @@ bool Builder::GenerateFunctionVariable(ast::Variable* var) { // TODO(dsinclair) We could detect if the constructor is fully const and emit // an initializer value for the variable instead of doing the OpLoad. - auto null_id = GenerateConstantNullIfNeeded(var->type()->UnwrapPtrIfNeeded()); + auto null_id = GenerateConstantNullIfNeeded(type->UnwrapPtrIfNeeded()); if (null_id == 0) { return 0; } @@ -704,7 +707,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { ? ast::StorageClass::kPrivate : sem->StorageClass(); - type::Pointer pt(var->type(), sc); + type::Pointer pt(sem->Type(), sc); auto type_id = GenerateTypeIfNeeded(&pt); if (type_id == 0) { return false; @@ -719,11 +722,11 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { // Unwrap after emitting the access control as unwrap all removes access // control types. - auto* type = var->type()->UnwrapAll(); + auto* type_no_ac = sem->Type()->UnwrapAll(); if (var->has_constructor()) { ops.push_back(Operand::Int(init_id)); - } else if (type->Is()) { - if (auto* ac = var->type()->As()) { + } else if (type_no_ac->Is()) { + if (auto* ac = sem->Type()->As()) { switch (ac->access_control()) { case ast::AccessControl::kWriteOnly: push_annot( @@ -739,7 +742,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { break; } } - } else if (!type->Is()) { + } else if (!type_no_ac->Is()) { // Certain cases require us to generate a constructor value. // // 1- ConstantId's must be attached to the OpConstant, if we have a @@ -748,17 +751,17 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { // 2- If we don't have a constructor and we're an Output or Private variable // then WGSL requires an initializer. if (var->HasConstantIdDecoration()) { - if (type->Is()) { - ast::FloatLiteral l(Source{}, type, 0.0f); + if (type_no_ac->Is()) { + ast::FloatLiteral l(Source{}, type_no_ac, 0.0f); init_id = GenerateLiteralIfNeeded(var, &l); - } else if (type->Is()) { - ast::UintLiteral l(Source{}, type, 0); + } else if (type_no_ac->Is()) { + ast::UintLiteral l(Source{}, type_no_ac, 0); init_id = GenerateLiteralIfNeeded(var, &l); - } else if (type->Is()) { - ast::SintLiteral l(Source{}, type, 0); + } else if (type_no_ac->Is()) { + ast::SintLiteral l(Source{}, type_no_ac, 0); init_id = GenerateLiteralIfNeeded(var, &l); - } else if (type->Is()) { - ast::BoolLiteral l(Source{}, type, false); + } else if (type_no_ac->Is()) { + ast::BoolLiteral l(Source{}, type_no_ac, false); init_id = GenerateLiteralIfNeeded(var, &l); } else { error_ = "invalid type for constant_id, must be scalar"; @@ -771,7 +774,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { } else if (sem->StorageClass() == ast::StorageClass::kPrivate || sem->StorageClass() == ast::StorageClass::kNone || sem->StorageClass() == ast::StorageClass::kOutput) { - init_id = GenerateConstantNullIfNeeded(type); + init_id = GenerateConstantNullIfNeeded(type_no_ac); if (init_id == 0) { return 0; } diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index 6539d93dde..f894c4e63b 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -321,7 +321,7 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { out_ << program_->Symbols().NameFor(v->symbol()) << " : "; - if (!EmitType(v->type())) { + if (!EmitType(program_->Sem().Get(v)->Type())) { return false; } } @@ -578,13 +578,13 @@ bool GeneratorImpl::EmitVariable(ast::Variable* var) { out_ << "var"; if (sem->StorageClass() != ast::StorageClass::kNone && sem->StorageClass() != ast::StorageClass::kFunction && - !var->type()->UnwrapAll()->is_handle()) { + !sem->Type()->UnwrapAll()->is_handle()) { out_ << "<" << sem->StorageClass() << ">"; } } out_ << " " << program_->Symbols().NameFor(var->symbol()) << " : "; - if (!EmitType(var->type())) { + if (!EmitType(sem->Type())) { return false; } diff --git a/src/writer/wgsl/generator_impl_variable_test.cc b/src/writer/wgsl/generator_impl_variable_test.cc index b80d4bfd39..87660fd727 100644 --- a/src/writer/wgsl/generator_impl_variable_test.cc +++ b/src/writer/wgsl/generator_impl_variable_test.cc @@ -75,23 +75,24 @@ TEST_F(WgslGeneratorImplTest, EmitVariable_Decorated_Multiple) { } TEST_F(WgslGeneratorImplTest, EmitVariable_Constructor) { - auto* v = - Global("a", ty.f32(), ast::StorageClass::kNone, Expr("initializer")); + auto* v = Global("a", ty.f32(), ast::StorageClass::kNone, Expr(1.0f)); + WrapInFunction(Decl(v)); GeneratorImpl& gen = Build(); ASSERT_TRUE(gen.EmitVariable(v)) << gen.error(); - EXPECT_EQ(gen.result(), R"(var a : f32 = initializer; + EXPECT_EQ(gen.result(), R"(var a : f32 = 1.0; )"); } TEST_F(WgslGeneratorImplTest, EmitVariable_Const) { - auto* v = Const("a", ty.f32(), Expr("initializer")); + auto* v = Const("a", ty.f32(), Expr(1.0f)); + WrapInFunction(Decl(v)); GeneratorImpl& gen = Build(); ASSERT_TRUE(gen.EmitVariable(v)) << gen.error(); - EXPECT_EQ(gen.result(), R"(const a : f32 = initializer; + EXPECT_EQ(gen.result(), R"(const a : f32 = 1.0; )"); }