diff --git a/src/validator/validator_function_test.cc b/src/validator/validator_function_test.cc index aa857a58a4..13f6e5985a 100644 --- a/src/validator/validator_function_test.cc +++ b/src/validator/validator_function_test.cc @@ -116,9 +116,8 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_Pass) { }); ValidatorImpl& v = Build(); - const Program* program = v.program(); - EXPECT_TRUE(v.ValidateFunctions(program->AST().Functions())) << v.error(); + EXPECT_TRUE(v.Validate()) << v.error(); } TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_fail) { diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc index 019423eb73..f379508d05 100644 --- a/src/validator/validator_impl.cc +++ b/src/validator/validator_impl.cc @@ -286,101 +286,85 @@ void ValidatorImpl::add_error(const Source& src, const std::string& msg) { } bool ValidatorImpl::Validate() { - function_stack_.push_scope(); - if (!ValidateGlobalVariables(program_->AST().GlobalVariables())) { - return false; - } - if (!ValidateConstructedTypes(program_->AST().ConstructedTypes())) { - return false; - } - if (!ValidateFunctions(program_->AST().Functions())) { - return false; + // Validate global declarations in the order they appear in the module. + for (auto* decl : program_->AST().GlobalDeclarations()) { + if (auto* ty = decl->As()) { + if (!ValidateConstructedType(ty)) { + return false; + } + } else if (auto* func = decl->As()) { + current_function_ = func; + if (!ValidateFunction(func)) { + return false; + } + current_function_ = nullptr; + } else if (auto* var = decl->As()) { + if (!ValidateGlobalVariable(var)) { + return false; + } + } else { + assert(false /* unreachable */); + } } if (!ValidateEntryPoint(program_->AST().Functions())) { return false; } - function_stack_.pop_scope(); return true; } -bool ValidatorImpl::ValidateConstructedTypes( - const std::vector& constructed_types) { - for (auto* const ct : constructed_types) { - if (auto* st = ct->As()) { - for (auto* member : st->impl()->members()) { - if (auto* r = member->type()->UnwrapAll()->As()) { - if (r->IsRuntimeArray()) { - if (member != st->impl()->members().back()) { - add_error(member->source(), "v-0015", - "runtime arrays may only appear as the last member of " - "a struct"); - return false; - } - if (!st->IsBlockDecorated()) { - add_error(member->source(), "v-0015", - "a struct containing a runtime-sized array " - "requires the [[block]] attribute: '" + - program_->Symbols().NameFor(st->symbol()) + "'"); - return false; - } +bool ValidatorImpl::ValidateConstructedType(const type::Type* type) { + if (auto* st = type->As()) { + for (auto* member : st->impl()->members()) { + if (auto* r = member->type()->UnwrapAll()->As()) { + if (r->IsRuntimeArray()) { + if (member != st->impl()->members().back()) { + add_error(member->source(), "v-0015", + "runtime arrays may only appear as the last member of " + "a struct"); + return false; + } + if (!st->IsBlockDecorated()) { + add_error(member->source(), "v-0015", + "a struct containing a runtime-sized array " + "requires the [[block]] attribute: '" + + program_->Symbols().NameFor(st->symbol()) + "'"); + return false; } } } } } + return true; } -bool ValidatorImpl::ValidateGlobalVariables( - const ast::VariableList& global_vars) { - for (auto* var : global_vars) { - auto* sem = program_->Sem().Get(var); - if (!sem) { - add_error(var->source(), "no semantic information for variable '" + - program_->Symbols().NameFor(var->symbol()) + - "'"); - return false; - } - - if (variable_stack_.has(var->symbol())) { - add_error(var->source(), "v-0011", - "redeclared global identifier '" + - program_->Symbols().NameFor(var->symbol()) + "'"); - return false; - } - if (!var->is_const() && sem->StorageClass() == ast::StorageClass::kNone) { - add_error(var->source(), "v-0022", - "global variables must have a storage class"); - return false; - } - if (var->is_const() && !(sem->StorageClass() == ast::StorageClass::kNone)) { - add_error(var->source(), "v-global01", - "global constants shouldn't have a storage class"); - return false; - } - variable_stack_.set_global(var->symbol(), var); - } - return true; -} - -bool ValidatorImpl::ValidateFunctions(const ast::FunctionList& funcs) { - for (auto* func : funcs) { - if (function_stack_.has(func->symbol())) { - add_error(func->source(), "v-0016", - "function names must be unique '" + - program_->Symbols().NameFor(func->symbol()) + "'"); - return false; - } - - function_stack_.set(func->symbol(), func); - current_function_ = func; - if (!ValidateFunction(func)) { - return false; - } - current_function_ = nullptr; +bool ValidatorImpl::ValidateGlobalVariable(const ast::Variable* var) { + auto* sem = program_->Sem().Get(var); + if (!sem) { + add_error(var->source(), "no semantic information for variable '" + + program_->Symbols().NameFor(var->symbol()) + + "'"); + return false; } + if (variable_stack_.has(var->symbol())) { + add_error(var->source(), "v-0011", + "redeclared global identifier '" + + program_->Symbols().NameFor(var->symbol()) + "'"); + return false; + } + if (!var->is_const() && sem->StorageClass() == ast::StorageClass::kNone) { + add_error(var->source(), "v-0022", + "global variables must have a storage class"); + return false; + } + if (var->is_const() && !(sem->StorageClass() == ast::StorageClass::kNone)) { + add_error(var->source(), "v-global01", + "global constants shouldn't have a storage class"); + return false; + } + variable_stack_.set_global(var->symbol(), var); return true; } @@ -425,6 +409,15 @@ bool ValidatorImpl::ValidateEntryPoint(const ast::FunctionList& funcs) { } bool ValidatorImpl::ValidateFunction(const ast::Function* func) { + if (function_stack_.has(func->symbol())) { + add_error(func->source(), "v-0016", + "function names must be unique '" + + program_->Symbols().NameFor(func->symbol()) + "'"); + return false; + } + + function_stack_.set(func->symbol(), func); + variable_stack_.push_scope(); for (auto* param : func->params()) { @@ -906,7 +899,7 @@ bool ValidatorImpl::ValidateBadAssignmentToIdentifier( // It wasn't an identifier in the first place. return true; } - ast::Variable* var; + const ast::Variable* var; if (variable_stack_.get(ident->symbol(), &var)) { // Give a nicer message if the LHS of the assignment is a const identifier. // It's likely to be a common programmer error. @@ -989,7 +982,7 @@ bool ValidatorImpl::ValidateExpression(const ast::Expression* expr) { } bool ValidatorImpl::ValidateIdentifier(const ast::IdentifierExpression* ident) { - ast::Variable* var; + const ast::Variable* var; if (!variable_stack_.get(ident->symbol(), &var)) { add_error(ident->source(), "v-0006", "'" + program_->Symbols().NameFor(ident->symbol()) + diff --git a/src/validator/validator_impl.h b/src/validator/validator_impl.h index 0119aa35f7..0ad3330df1 100644 --- a/src/validator/validator_impl.h +++ b/src/validator/validator_impl.h @@ -76,14 +76,10 @@ class ValidatorImpl { /// @param msg the error message void add_error(const Source& src, const std::string& msg); - /// Validate global variables - /// @param global_vars list of global variables to check + /// Validates a global variable + /// @param var the global variable to check /// @returns true if the validation was successful - bool ValidateGlobalVariables(const ast::VariableList& global_vars); - /// Validates Functions - /// @param funcs the functions to check - /// @returns true if the validation was successful - bool ValidateFunctions(const ast::FunctionList& funcs); + bool ValidateGlobalVariable(const ast::Variable* var); /// Validates a function /// @param func the function to check /// @returns true if the validation was successful @@ -144,11 +140,10 @@ class ValidatorImpl { /// @returns true if the valdiation was successful bool ValidateEntryPoint(const ast::FunctionList& funcs); - /// Validates constructed types - /// @param constructed_types the types to check + /// Validates a constructed type + /// @param type the type to check /// @returns true if the valdiation was successful - bool ValidateConstructedTypes( - const std::vector& constructed_types); + bool ValidateConstructedType(const type::Type* type); /// Returns true if the given type is storable. This uses and /// updates `storable_` and `not_storable_`. @@ -165,8 +160,8 @@ class ValidatorImpl { private: const Program* program_; diag::List diags_; - ScopeStack variable_stack_; - ScopeStack function_stack_; + ScopeStack variable_stack_; + ScopeStack function_stack_; ast::Function* current_function_ = nullptr; }; diff --git a/src/validator/validator_test.cc b/src/validator/validator_test.cc index 4b0f379a81..1a88227836 100644 --- a/src/validator/validator_test.cc +++ b/src/validator/validator_test.cc @@ -375,15 +375,13 @@ TEST_F(ValidatorTest, AssignIncompatibleTypesInNestedBlockStatement_Fail) { TEST_F(ValidatorTest, GlobalVariableWithStorageClass_Pass) { // var gloabl_var: f32; - Global(Source{Source::Location{12, 34}}, "global_var", - ast::StorageClass::kInput, ty.f32(), nullptr, - ast::VariableDecorationList{}); + auto* var = Global(Source{Source::Location{12, 34}}, "global_var", + ast::StorageClass::kInput, ty.f32(), nullptr, + ast::VariableDecorationList{}); ValidatorImpl& v = Build(); - const Program* program = v.program(); - EXPECT_TRUE(v.ValidateGlobalVariables(program->AST().GlobalVariables())) - << v.error(); + EXPECT_TRUE(v.ValidateGlobalVariable(var)) << v.error(); } TEST_F(ValidatorTest, GlobalVariableNoStorageClass_Fail) { @@ -449,6 +447,32 @@ TEST_F(ValidatorTest, UsingUndefinedVariableGlobalVariable_Fail) { EXPECT_EQ(v.error(), "12:34 v-0006: 'not_global_var' is not declared"); } +TEST_F(ValidatorTest, UsingUndefinedVariableGlobalVariableAfter_Fail) { + // fn my_func() -> void { + // global_var = 3.14f; + // } + // var global_var: f32 = 2.1; + + SetSource(Source{Source::Location{12, 34}}); + auto* lhs = Expr("global_var"); + auto* rhs = Expr(3.14f); + + Func("my_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + create(lhs, rhs), + }, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kVertex)}); + + Global("global_var", ast::StorageClass::kPrivate, ty.f32(), Expr(2.1f), + ast::VariableDecorationList{}); + + ValidatorImpl& v = Build(); + + EXPECT_FALSE(v.Validate()); + EXPECT_EQ(v.error(), "12:34 v-0006: 'global_var' is not declared"); +} + TEST_F(ValidatorTest, UsingUndefinedVariableGlobalVariable_Pass) { // var global_var: f32 = 2.1; // fn my_func() -> void { @@ -579,18 +603,17 @@ TEST_F(ValidatorTest, UsingUndefinedVariableDifferentScope_Fail) { TEST_F(ValidatorTest, GlobalVariableUnique_Pass) { // var global_var0 : f32 = 0.1; // var global_var1 : i32 = 0; - Global("global_var0", ast::StorageClass::kPrivate, ty.f32(), Expr(0.1f), - ast::VariableDecorationList{}); + auto* var0 = Global("global_var0", ast::StorageClass::kPrivate, ty.f32(), + Expr(0.1f), ast::VariableDecorationList{}); - Global(Source{Source::Location{12, 34}}, "global_var1", - ast::StorageClass::kPrivate, ty.f32(), Expr(0), - ast::VariableDecorationList{}); + auto* var1 = Global(Source{Source::Location{12, 34}}, "global_var1", + ast::StorageClass::kPrivate, ty.f32(), Expr(0), + ast::VariableDecorationList{}); ValidatorImpl& v = Build(); - const Program* program = v.program(); - EXPECT_TRUE(v.ValidateGlobalVariables(program->AST().GlobalVariables())) - << v.error(); + EXPECT_TRUE(v.ValidateGlobalVariable(var0)) << v.error(); + EXPECT_TRUE(v.ValidateGlobalVariable(var1)) << v.error(); } TEST_F(ValidatorTest, GlobalVariableNotUnique_Fail) { @@ -604,9 +627,8 @@ TEST_F(ValidatorTest, GlobalVariableNotUnique_Fail) { ast::VariableDecorationList{}); ValidatorImpl& v = Build(); - const Program* program = v.program(); - EXPECT_FALSE(v.ValidateGlobalVariables(program->AST().GlobalVariables())); + EXPECT_FALSE(v.Validate()); EXPECT_EQ(v.error(), "12:34 v-0011: redeclared global identifier 'global_var'"); } @@ -639,6 +661,30 @@ TEST_F(ValidatorTest, AssignToConstant_Fail) { EXPECT_EQ(v.error(), "12:34 v-0021: cannot re-assign a constant: 'a'"); } +TEST_F(ValidatorTest, GlobalVariableFunctionVariableNotUnique_Pass) { + // fn my_func -> void { + // var a: f32 = 2.0; + // } + // var a: f32 = 2.1; + + auto* var = Var("a", ast::StorageClass::kNone, ty.f32(), Expr(2.0f), + ast::VariableDecorationList{}); + + Func("my_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + create(var), + }, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kVertex)}); + + Global("a", ast::StorageClass::kPrivate, ty.f32(), Expr(2.1f), + ast::VariableDecorationList{}); + + ValidatorImpl& v = Build(); + + EXPECT_TRUE(v.Validate()) << v.error(); +} + TEST_F(ValidatorTest, GlobalVariableFunctionVariableNotUnique_Fail) { // var a: f32 = 2.1; // fn my_func -> void { @@ -665,7 +711,7 @@ TEST_F(ValidatorTest, GlobalVariableFunctionVariableNotUnique_Fail) { EXPECT_EQ(v.error(), "12:34 v-0013: redeclared identifier 'a'"); } -TEST_F(ValidatorTest, RedeclaredIndentifier_Fail) { +TEST_F(ValidatorTest, RedeclaredIdentifier_Fail) { // fn my_func() -> void { // var a :i32 = 2; // var a :f21 = 2.0; diff --git a/src/validator/validator_type_test.cc b/src/validator/validator_type_test.cc index 5f136eaadf..1df5059d8d 100644 --- a/src/validator/validator_type_test.cc +++ b/src/validator/validator_type_test.cc @@ -54,9 +54,8 @@ TEST_F(ValidatorTypeTest, RuntimeArrayIsLast_Pass) { AST().AddConstructedType(struct_type); ValidatorImpl& v = Build(); - const Program* program = v.program(); - EXPECT_TRUE(v.ValidateConstructedTypes(program->AST().ConstructedTypes())); + EXPECT_TRUE(v.ValidateConstructedType(struct_type)); } TEST_F(ValidatorTypeTest, RuntimeArrayIsLastNoBlock_Fail) { @@ -75,9 +74,8 @@ TEST_F(ValidatorTypeTest, RuntimeArrayIsLastNoBlock_Fail) { AST().AddConstructedType(struct_type); ValidatorImpl& v = Build(); - const Program* program = v.program(); - EXPECT_FALSE(v.ValidateConstructedTypes(program->AST().ConstructedTypes())); + EXPECT_FALSE(v.ValidateConstructedType(struct_type)); EXPECT_EQ(v.error(), "v-0015: a struct containing a runtime-sized array requires the " "[[block]] attribute: 'Foo'"); @@ -104,9 +102,8 @@ TEST_F(ValidatorTypeTest, RuntimeArrayIsNotLast_Fail) { AST().AddConstructedType(struct_type); ValidatorImpl& v = Build(); - const Program* program = v.program(); - EXPECT_FALSE(v.ValidateConstructedTypes(program->AST().ConstructedTypes())); + EXPECT_FALSE(v.ValidateConstructedType(struct_type)); EXPECT_EQ(v.error(), "12:34 v-0015: runtime arrays may only appear as the last member " "of a struct"); @@ -131,9 +128,8 @@ TEST_F(ValidatorTypeTest, AliasRuntimeArrayIsNotLast_Fail) { AST().AddConstructedType(struct_type); ValidatorImpl& v = Build(); - const Program* program = v.program(); - EXPECT_FALSE(v.ValidateConstructedTypes(program->AST().ConstructedTypes())); + EXPECT_FALSE(v.ValidateConstructedType(struct_type)); EXPECT_EQ(v.error(), "v-0015: runtime arrays may only appear as the last member " "of a struct"); @@ -158,9 +154,8 @@ TEST_F(ValidatorTypeTest, AliasRuntimeArrayIsLast_Pass) { AST().AddConstructedType(struct_type); ValidatorImpl& v = Build(); - const Program* program = v.program(); - EXPECT_TRUE(v.ValidateConstructedTypes(program->AST().ConstructedTypes())); + EXPECT_TRUE(v.ValidateConstructedType(struct_type)); } TEST_F(ValidatorTypeTest, RuntimeArrayInFunction_Fail) {