diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 037793a5dc..a51bd22851 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -1126,6 +1126,7 @@ if (tint_build_unittests) { "resolver/resolver_test_helper.cc", "resolver/resolver_test_helper.h", "resolver/side_effects_test.cc", + "resolver/static_assert_test.cc", "resolver/source_variable_test.cc", "resolver/storage_class_layout_validation_test.cc", "resolver/storage_class_validation_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 5f93189f91..9a6a42d616 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -809,6 +809,7 @@ if(TINT_BUILD_TESTS) resolver/resolver_test_helper.h resolver/resolver_test.cc resolver/side_effects_test.cc + resolver/static_assert_test.cc resolver/source_variable_test.cc resolver/storage_class_layout_validation_test.cc resolver/storage_class_validation_test.cc diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc index 36bff3f79b..2ddcaaa0fd 100644 --- a/src/tint/resolver/dependency_graph.cc +++ b/src/tint/resolver/dependency_graph.cc @@ -204,6 +204,7 @@ class DependencyScanner { [&](const ast::Enable*) { // Enable directives do not effect the dependency graph. }, + [&](const ast::StaticAssert* assertion) { TraverseExpression(assertion->condition); }, [&](Default) { UnhandledNode(diagnostics_, global->node); }); } @@ -315,6 +316,7 @@ class DependencyScanner { TraverseExpression(w->condition); TraverseStatement(w->body); }, + [&](const ast::StaticAssert* assertion) { TraverseExpression(assertion->condition); }, [&](Default) { if (!stmt->IsAnyOf()) { @@ -515,6 +517,8 @@ struct DependencyAnalysis { [&](const ast::TypeDecl* td) { return td->name; }, [&](const ast::Function* func) { return func->symbol; }, [&](const ast::Variable* var) { return var->symbol; }, + [&](const ast::Enable*) { return Symbol(); }, + [&](const ast::StaticAssert*) { return Symbol(); }, [&](Default) { UnhandledNode(diagnostics_, node); return Symbol{}; @@ -533,11 +537,12 @@ struct DependencyAnalysis { /// declaration std::string KindOf(const ast::Node* node) { return Switch( - node, // - [&](const ast::Struct*) { return "struct"; }, // - [&](const ast::Alias*) { return "alias"; }, // - [&](const ast::Function*) { return "function"; }, // - [&](const ast::Variable* v) { return v->Kind(); }, // + node, // + [&](const ast::Struct*) { return "struct"; }, // + [&](const ast::Alias*) { return "alias"; }, // + [&](const ast::Function*) { return "function"; }, // + [&](const ast::Variable* v) { return v->Kind(); }, // + [&](const ast::StaticAssert*) { return "static_assert"; }, // [&](Default) { UnhandledNode(diagnostics_, node); return ""; @@ -549,9 +554,8 @@ struct DependencyAnalysis { void GatherGlobals(const ast::Module& module) { for (auto* node : module.GlobalDeclarations()) { auto* global = allocator_.Create(node); - // Enable directives do not form a symbol. Skip them. - if (!node->Is()) { - globals_.emplace(SymbolOf(node), global); + if (auto symbol = SymbolOf(node); symbol.IsValid()) { + globals_.emplace(symbol, global); } declaration_order_.emplace_back(global); } diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 64c3e033d2..4295babbd7 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -140,6 +140,7 @@ bool Resolver::ResolveInternal() { [&](const ast::TypeDecl* td) { return TypeDecl(td); }, [&](const ast::Function* func) { return Function(func); }, [&](const ast::Variable* var) { return GlobalVariable(var); }, + [&](const ast::StaticAssert* sa) { return StaticAssert(sa); }, [&](Default) { TINT_UNREACHABLE(Resolver, diagnostics_) << "unhandled global declaration: " << decl->TypeInfo().name; @@ -737,6 +738,33 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) { return sem; } +sem::Statement* Resolver::StaticAssert(const ast::StaticAssert* assertion) { + auto* expr = Expression(assertion->condition); + if (!expr) { + return nullptr; + } + auto* cond = expr->ConstantValue(); + if (!cond) { + AddError("static assertion condition must be a constant expression", + assertion->condition->source); + return nullptr; + } + if (auto* ty = cond->Type(); !ty->Is()) { + AddError( + "static assertion condition must be a bool, got '" + builder_->FriendlyName(ty) + "'", + assertion->condition->source); + return nullptr; + } + if (!cond->As()) { + AddError("static assertion failed", assertion->source); + return nullptr; + } + auto* sem = + builder_->create(assertion, current_compound_statement_, current_function_); + builder_->Sem().Add(assertion, sem); + return sem; +} + sem::Function* Resolver::Function(const ast::Function* decl) { uint32_t parameter_index = 0; std::unordered_map parameter_names; @@ -1042,6 +1070,7 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) { [&](const ast::IncrementDecrementStatement* i) { return IncrementDecrementStatement(i); }, [&](const ast::ReturnStatement* r) { return ReturnStatement(r); }, [&](const ast::VariableDeclStatement* v) { return VariableDeclStatement(v); }, + [&](const ast::StaticAssert* sa) { return StaticAssert(sa); }, // Error cases [&](const ast::CaseStatement*) { diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 10b5806e5d..2b6a2404e7 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -250,6 +250,7 @@ class Resolver { sem::LoopStatement* LoopStatement(const ast::LoopStatement*); sem::Statement* ReturnStatement(const ast::ReturnStatement*); sem::Statement* Statement(const ast::Statement*); + sem::Statement* StaticAssert(const ast::StaticAssert*); sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s); sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*); bool Statements(utils::VectorRef); diff --git a/src/tint/resolver/static_assert_test.cc b/src/tint/resolver/static_assert_test.cc new file mode 100644 index 0000000000..3cb67c9ebd --- /dev/null +++ b/src/tint/resolver/static_assert_test.cc @@ -0,0 +1,110 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/resolver/resolver.h" + +#include "gmock/gmock.h" +#include "src/tint/resolver/resolver_test_helper.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::resolver { +namespace { + +using ResolverStaticAssertTest = ResolverTest; + +TEST_F(ResolverStaticAssertTest, Global_True_Pass) { + GlobalStaticAssert(true); + ASSERT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverStaticAssertTest, Global_False_Fail) { + GlobalStaticAssert(Source{{12, 34}}, false); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: static assertion failed"); +} + +TEST_F(ResolverStaticAssertTest, Global_Const_Pass) { + GlobalConst("C", ty.bool_(), Expr(true)); + GlobalStaticAssert("C"); + ASSERT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverStaticAssertTest, Global_Const_Fail) { + GlobalConst("C", ty.bool_(), Expr(false)); + GlobalStaticAssert(Source{{12, 34}}, "C"); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: static assertion failed"); +} + +// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation. +TEST_F(ResolverStaticAssertTest, DISABLED_Global_LessThan_Pass) { + GlobalStaticAssert(LessThan(2_i, 3_i)); + ASSERT_TRUE(r()->Resolve()) << r()->error(); +} + +// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation. +TEST_F(ResolverStaticAssertTest, DISABLED_Global_LessThan_Fail) { + GlobalStaticAssert(Source{{12, 34}}, LessThan(4_i, 3_i)); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: static assertion failed"); +} + +TEST_F(ResolverStaticAssertTest, Local_True_Pass) { + WrapInFunction(StaticAssert(true)); + ASSERT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverStaticAssertTest, Local_False_Fail) { + WrapInFunction(StaticAssert(Source{{12, 34}}, false)); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: static assertion failed"); +} + +TEST_F(ResolverStaticAssertTest, Local_Const_Pass) { + GlobalConst("C", ty.bool_(), Expr(true)); + WrapInFunction(StaticAssert("C")); + ASSERT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverStaticAssertTest, Local_Const_Fail) { + GlobalConst("C", ty.bool_(), Expr(false)); + WrapInFunction(StaticAssert(Source{{12, 34}}, "C")); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: static assertion failed"); +} + +TEST_F(ResolverStaticAssertTest, Local_NonConst) { + GlobalVar("V", ty.bool_(), Expr(true), ast::StorageClass::kPrivate); + WrapInFunction(StaticAssert(Expr(Source{{12, 34}}, "V"))); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: static assertion condition must be a constant expression"); +} + +// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation. +TEST_F(ResolverStaticAssertTest, DISABLED_Local_LessThan_Pass) { + WrapInFunction(StaticAssert(LessThan(2_i, 3_i))); + ASSERT_TRUE(r()->Resolve()) << r()->error(); +} + +// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation. +TEST_F(ResolverStaticAssertTest, DISABLED_Local_LessThan_Fail) { + WrapInFunction(StaticAssert(Source{{12, 34}}, LessThan(4_i, 3_i))); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: static assertion failed"); +} + +} // namespace +} // namespace tint::resolver diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc index 555ea1aabb..9ccb7ef856 100644 --- a/src/tint/resolver/uniformity.cc +++ b/src/tint/resolver/uniformity.cc @@ -847,6 +847,7 @@ class UniformityGraph { return cfx; } }, + [&](const ast::ReturnStatement* r) { Node* cf_ret; if (r->value) { @@ -870,6 +871,7 @@ class UniformityGraph { return cf_ret; }, + [&](const ast::SwitchStatement* s) { auto* sem_switch = sem_.Get(s); auto [cfx, v_cond] = ProcessExpression(cf, s->condition); @@ -938,6 +940,7 @@ class UniformityGraph { return cf_end ? cf_end : cf; }, + [&](const ast::VariableDeclStatement* decl) { Node* node; if (decl->variable->constructor) { @@ -956,6 +959,11 @@ class UniformityGraph { return cf; }, + + [&](const ast::StaticAssert*) { + return cf; // No impact on uniformity + }, + [&](Default) { TINT_ICE(Resolver, diagnostics_) << "unknown statement type: " << std::string(stmt->TypeInfo().name);