diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 4c6eb94e8a..8095beaa8b 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -413,20 +413,57 @@ sem::Type* Resolver::Type(const ast::Type* ty) { return s; } -Resolver::VariableInfo* Resolver::Variable( - ast::Variable* var, - const sem::Type* type, /* = nullptr */ - std::string type_name /* = "" */) { - auto it = variable_to_info_.find(var); - if (it != variable_to_info_.end()) { - return it->second; +Resolver::VariableInfo* Resolver::Variable(ast::Variable* var, + bool is_parameter) { + if (variable_to_info_.count(var)) { + TINT_ICE(diagnostics_) << "Variable " + << builder_->Symbols().NameFor(var->symbol()) + << " already resolved"; + return nullptr; } - if (type == nullptr && var->type()) { - type = Type(var->type()); - type_name = var->type()->FriendlyName(builder_->Symbols()); + // If the variable has a declared type, resolve it. + std::string type_name; + const sem::Type* type = nullptr; + if (auto* ty = var->type()) { + type_name = ty->FriendlyName(builder_->Symbols()); + type = Type(ty); + if (!type) { + return nullptr; + } } - if (type == nullptr) { + + // Does the variable have a constructor? + if (auto* ctor = var->constructor()) { + Mark(var->constructor()); + if (!Expression(var->constructor())) { + return nullptr; + } + + // Fetch the constructor's type + auto* rhs_type = TypeOf(ctor); + if (!rhs_type) { + return nullptr; + } + + // If the variable has no declared type, infer it from the RHS + if (type == nullptr) { + type_name = TypeNameOf(ctor); + type = rhs_type->UnwrapPtr(); + } + + if (!IsValidAssignment(type, rhs_type)) { + diagnostics_.add_error( + "variable of type '" + type_name + + "' cannot be initialized with a value of type '" + + TypeNameOf(ctor) + "'", + var->source()); + return nullptr; + } + } else if (var->is_const() && !is_parameter && + !ast::HasDecoration(var->decorations())) { + diagnostics_.add_error("let declarations must have initializers", + var->source()); return nullptr; } @@ -446,7 +483,7 @@ bool Resolver::GlobalVariable(ast::Variable* var) { return false; } - auto* info = Variable(var); + auto* info = Variable(var, /* is_parameter */ false); if (!info) { return false; } @@ -472,20 +509,6 @@ bool Resolver::GlobalVariable(ast::Variable* var) { info->binding_point = {bp.group->value(), bp.binding->value()}; } - if (var->has_constructor()) { - Mark(var->constructor()); - if (!Expression(var->constructor())) { - return false; - } - } else { - if (var->is_const() && - !ast::HasDecoration(var->decorations())) { - diagnostics_.add_error("let declarations must have initializers", - var->source()); - return false; - } - } - if (!ValidateGlobalVariable(info)) { return false; } @@ -1020,7 +1043,7 @@ bool Resolver::Function(ast::Function* func) { variable_stack_.push_scope(); for (auto* param : func->params()) { Mark(param); - auto* param_info = Variable(param); + auto* param_info = Variable(param, /* is_parameter */ true); if (!param_info) { return false; } @@ -2072,17 +2095,6 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { ast::Variable* var = stmt->variable(); Mark(var); - // If the variable has a declared type, resolve it. - std::string type_name; - const sem::Type* type = nullptr; - if (auto* ast_ty = var->type()) { - type_name = ast_ty->FriendlyName(builder_->Symbols()); - type = Type(ast_ty); - if (!type) { - return false; - } - } - bool is_global = false; if (variable_stack_.get(var->symbol(), nullptr, &is_global)) { const char* error_code = is_global ? "v-0013" : "v-0014"; @@ -2093,33 +2105,9 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { return false; } - if (auto* ctor = stmt->variable()->constructor()) { - Mark(ctor); - if (!Expression(ctor)) { - return false; - } - auto* rhs_type = TypeOf(ctor); - - // If the variable has no type, infer it from the rhs - if (type == nullptr) { - type_name = TypeNameOf(ctor); - type = rhs_type->UnwrapPtr(); - } - - if (!IsValidAssignment(type, rhs_type)) { - diagnostics_.add_error( - "variable of type '" + type_name + - "' cannot be initialized with a value of type '" + - TypeNameOf(ctor) + "'", - stmt->source()); - return false; - } - } else { - if (stmt->variable()->is_const()) { - diagnostics_.add_error("let declarations must have initializers", - var->source()); - return false; - } + auto* info = Variable(var, /* is_parameter */ false); + if (!info) { + return false; } for (auto* deco : var->decorations()) { @@ -2127,13 +2115,6 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { Mark(deco); } - auto* info = Variable(var, type, type_name); - if (!info) { - return false; - } - // TODO(bclayton): Remove this and fix tests. We're overriding the semantic - // type stored in info->type here with a possibly non-canonicalized type. - info->type = const_cast(type); variable_stack_.set(var->symbol(), info); current_block_->decls.push_back(var); @@ -2212,8 +2193,7 @@ void Resolver::SetType(ast::Expression* expr, TINT_ASSERT(type.sem); } if (expr_info_.count(expr)) { - TINT_ICE(builder_->Diagnostics()) - << "SetType() called twice for the same expression"; + TINT_ICE(diagnostics_) << "SetType() called twice for the same expression"; } expr_info_.emplace(expr, ExpressionInfo{type, type_name, current_statement_}); } @@ -2260,9 +2240,8 @@ void Resolver::CreateSemanticNodes() const { } else { auto* sem_user = sem_expr->As(); if (!sem_user) { - TINT_ICE(builder_->Diagnostics()) - << "expected sem::VariableUser, got " - << sem_expr->TypeInfo().name; + TINT_ICE(diagnostics_) << "expected sem::VariableUser, got " + << sem_expr->TypeInfo().name; } sem_var->AddUser(sem_user); } diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 02dce4f1cc..95a73652cd 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -262,13 +262,11 @@ class Resolver { /// @returns the VariableInfo for the variable `var`, building it if it hasn't /// been constructed already. If an error is raised, nullptr is returned. + /// @note this method does not resolve the decorations as these are + /// context-dependent (global, local, parameter) /// @param var the variable to create or return the `VariableInfo` for - /// @param type optional type of `var` to use instead of `var->type()`. - /// @param type_name optional type name of `var` to use instead of - /// `var->type()->FriendlyName()`. - VariableInfo* Variable(ast::Variable* var, - const sem::Type* type = nullptr, - std::string type_name = ""); + /// @param is_parameter true if the variable represents a parameter + VariableInfo* Variable(ast::Variable* var, bool is_parameter); /// Records the storage class usage for the given type, and any transient /// dependencies of the type. Validates that the type can be used for the diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc index 82a231ea72..a97d6779a5 100644 --- a/src/resolver/type_validation_test.cc +++ b/src/resolver/type_validation_test.cc @@ -90,8 +90,8 @@ TEST_F(ResolverTypeValidationTest, GlobalConstantWithStorageClass_Fail) { // const global_var: f32; AST().AddGlobalVariable( create(Source{{12, 34}}, Symbols().Register("global_var"), - ast::StorageClass::kInput, ty.f32(), true, nullptr, - ast::DecorationList{})); + ast::StorageClass::kInput, ty.f32(), true, + Expr(1.23f), ast::DecorationList{})); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), @@ -113,7 +113,7 @@ TEST_F(ResolverTypeValidationTest, GlobalVariableUnique_Pass) { Global("global_var0", ty.f32(), ast::StorageClass::kPrivate, Expr(0.1f)); Global(Source{{12, 34}}, "global_var1", ty.f32(), ast::StorageClass::kPrivate, - Expr(0)); + Expr(1.0f)); EXPECT_TRUE(r()->Resolve()) << r()->error(); } diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index 53c087c2ce..87a9608080 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -164,10 +164,8 @@ TEST_F(ResolverValidationTest, Stmt_Else_NonBool) { TEST_F(ResolverValidationTest, Stmt_VariableDecl_MismatchedTypeScalarConstructor) { u32 unsigned_value = 2u; // Type does not match variable type - auto* var = - Var("my_var", ty.i32(), ast::StorageClass::kNone, Expr(unsigned_value)); - - auto* decl = Decl(Source{{{3, 3}, {3, 22}}}, var); + auto* decl = Decl(Var(Source{{3, 3}}, "my_var", ty.i32(), + ast::StorageClass::kNone, Expr(unsigned_value))); WrapInFunction(decl); EXPECT_FALSE(r()->Resolve()); @@ -181,10 +179,8 @@ TEST_F(ResolverValidationTest, auto* my_int = ty.alias("MyInt", ty.i32()); AST().AddConstructedType(my_int); u32 unsigned_value = 2u; // Type does not match variable type - auto* var = - Var("my_var", my_int, ast::StorageClass::kNone, Expr(unsigned_value)); - - auto* decl = Decl(Source{{{3, 3}, {3, 22}}}, var); + auto* decl = Decl(Var(Source{{3, 3}}, "my_var", my_int, + ast::StorageClass::kNone, Expr(unsigned_value))); WrapInFunction(decl); EXPECT_FALSE(r()->Resolve()); diff --git a/src/writer/spirv/builder_function_variable_test.cc b/src/writer/spirv/builder_function_variable_test.cc index 1f4217a2a9..7fae8b3bc9 100644 --- a/src/writer/spirv/builder_function_variable_test.cc +++ b/src/writer/spirv/builder_function_variable_test.cc @@ -45,7 +45,7 @@ TEST_F(BuilderTest, FunctionVar_NoStorageClass) { TEST_F(BuilderTest, FunctionVar_WithConstantConstructor) { auto* init = vec3(1.f, 1.f, 3.f); - auto* v = Global("var", ty.f32(), ast::StorageClass::kOutput, init); + auto* v = Global("var", ty.vec3(), ast::StorageClass::kOutput, init); spirv::Builder& b = Build(); @@ -60,8 +60,8 @@ TEST_F(BuilderTest, FunctionVar_WithConstantConstructor) { %3 = OpConstant %2 1 %4 = OpConstant %2 3 %5 = OpConstantComposite %1 %3 %3 %4 -%7 = OpTypePointer Function %2 -%8 = OpConstantNull %2 +%7 = OpTypePointer Function %1 +%8 = OpConstantNull %1 )"); EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), R"(%6 = OpVariable %7 Function %8 diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc index 23e8b785e5..79b8d5cfce 100644 --- a/src/writer/spirv/builder_global_variable_test.cc +++ b/src/writer/spirv/builder_global_variable_test.cc @@ -57,7 +57,7 @@ TEST_F(BuilderTest, GlobalVar_WithStorageClass_Input) { TEST_F(BuilderTest, GlobalVar_WithConstructor) { auto* init = vec3(1.f, 1.f, 3.f); - auto* v = Global("var", ty.f32(), ast::StorageClass::kOutput, init); + auto* v = Global("var", ty.vec3(), ast::StorageClass::kOutput, init); spirv::Builder& b = Build(); @@ -71,7 +71,7 @@ TEST_F(BuilderTest, GlobalVar_WithConstructor) { %3 = OpConstant %2 1 %4 = OpConstant %2 3 %5 = OpConstantComposite %1 %3 %3 %4 -%7 = OpTypePointer Output %2 +%7 = OpTypePointer Output %1 %6 = OpVariable %7 Output %5 )"); } @@ -79,7 +79,7 @@ TEST_F(BuilderTest, GlobalVar_WithConstructor) { TEST_F(BuilderTest, GlobalVar_Const) { auto* init = vec3(1.f, 1.f, 3.f); - auto* v = GlobalConst("var", ty.f32(), init); + auto* v = GlobalConst("var", ty.vec3(), init); spirv::Builder& b = Build(); @@ -99,7 +99,7 @@ TEST_F(BuilderTest, GlobalVar_Const) { TEST_F(BuilderTest, GlobalVar_Complex_Constructor) { auto* init = vec3(ast::ExpressionList{Expr(1.f), Expr(2.f), Expr(3.f)}); - auto* v = GlobalConst("var", ty.f32(), init); + auto* v = GlobalConst("var", ty.vec3(), init); spirv::Builder& b = Build(); @@ -118,7 +118,7 @@ TEST_F(BuilderTest, GlobalVar_Complex_Constructor) { TEST_F(BuilderTest, GlobalVar_Complex_ConstructorWithExtract) { auto* init = vec3(vec2(1.f, 2.f), 3.f); - auto* v = GlobalConst("var", ty.f32(), init); + auto* v = GlobalConst("var", ty.vec3(), init); spirv::Builder& b = Build(); diff --git a/src/writer/spirv/builder_ident_expression_test.cc b/src/writer/spirv/builder_ident_expression_test.cc index 55f2a128da..821f9e5824 100644 --- a/src/writer/spirv/builder_ident_expression_test.cc +++ b/src/writer/spirv/builder_ident_expression_test.cc @@ -25,7 +25,7 @@ using BuilderTest = TestHelper; TEST_F(BuilderTest, IdentifierExpression_GlobalConst) { auto* init = vec3(1.f, 1.f, 3.f); - auto* v = GlobalConst("var", ty.f32(), init); + auto* v = GlobalConst("var", ty.vec3(), init); auto* expr = Expr("var"); WrapInFunction(expr);