tint/resolver: Resolve static_assert

No readers produce this, yet.

Bug: tint:1625
Change-Id: I94ce3e5afd7bd81b0a5059451136aa0eed7e9283
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97961
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-08-02 23:28:28 +00:00 committed by Dawn LUCI CQ
parent bfd1a81364
commit 02791e95f3
7 changed files with 162 additions and 8 deletions

View File

@ -1126,6 +1126,7 @@ if (tint_build_unittests) {
"resolver/resolver_test_helper.cc",
"resolver/resolver_test_helper.h",
"resolver/side_effects_test.cc",
"resolver/static_assert_test.cc",
"resolver/source_variable_test.cc",
"resolver/storage_class_layout_validation_test.cc",
"resolver/storage_class_validation_test.cc",

View File

@ -809,6 +809,7 @@ if(TINT_BUILD_TESTS)
resolver/resolver_test_helper.h
resolver/resolver_test.cc
resolver/side_effects_test.cc
resolver/static_assert_test.cc
resolver/source_variable_test.cc
resolver/storage_class_layout_validation_test.cc
resolver/storage_class_validation_test.cc

View File

@ -204,6 +204,7 @@ class DependencyScanner {
[&](const ast::Enable*) {
// Enable directives do not effect the dependency graph.
},
[&](const ast::StaticAssert* assertion) { TraverseExpression(assertion->condition); },
[&](Default) { UnhandledNode(diagnostics_, global->node); });
}
@ -315,6 +316,7 @@ class DependencyScanner {
TraverseExpression(w->condition);
TraverseStatement(w->body);
},
[&](const ast::StaticAssert* assertion) { TraverseExpression(assertion->condition); },
[&](Default) {
if (!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
ast::DiscardStatement, ast::FallthroughStatement>()) {
@ -515,6 +517,8 @@ struct DependencyAnalysis {
[&](const ast::TypeDecl* td) { return td->name; },
[&](const ast::Function* func) { return func->symbol; },
[&](const ast::Variable* var) { return var->symbol; },
[&](const ast::Enable*) { return Symbol(); },
[&](const ast::StaticAssert*) { return Symbol(); },
[&](Default) {
UnhandledNode(diagnostics_, node);
return Symbol{};
@ -533,11 +537,12 @@ struct DependencyAnalysis {
/// declaration
std::string KindOf(const ast::Node* node) {
return Switch(
node, //
[&](const ast::Struct*) { return "struct"; }, //
[&](const ast::Alias*) { return "alias"; }, //
[&](const ast::Function*) { return "function"; }, //
[&](const ast::Variable* v) { return v->Kind(); }, //
node, //
[&](const ast::Struct*) { return "struct"; }, //
[&](const ast::Alias*) { return "alias"; }, //
[&](const ast::Function*) { return "function"; }, //
[&](const ast::Variable* v) { return v->Kind(); }, //
[&](const ast::StaticAssert*) { return "static_assert"; }, //
[&](Default) {
UnhandledNode(diagnostics_, node);
return "<error>";
@ -549,9 +554,8 @@ struct DependencyAnalysis {
void GatherGlobals(const ast::Module& module) {
for (auto* node : module.GlobalDeclarations()) {
auto* global = allocator_.Create(node);
// Enable directives do not form a symbol. Skip them.
if (!node->Is<ast::Enable>()) {
globals_.emplace(SymbolOf(node), global);
if (auto symbol = SymbolOf(node); symbol.IsValid()) {
globals_.emplace(symbol, global);
}
declaration_order_.emplace_back(global);
}

View File

@ -140,6 +140,7 @@ bool Resolver::ResolveInternal() {
[&](const ast::TypeDecl* td) { return TypeDecl(td); },
[&](const ast::Function* func) { return Function(func); },
[&](const ast::Variable* var) { return GlobalVariable(var); },
[&](const ast::StaticAssert* sa) { return StaticAssert(sa); },
[&](Default) {
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "unhandled global declaration: " << decl->TypeInfo().name;
@ -737,6 +738,33 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) {
return sem;
}
sem::Statement* Resolver::StaticAssert(const ast::StaticAssert* assertion) {
auto* expr = Expression(assertion->condition);
if (!expr) {
return nullptr;
}
auto* cond = expr->ConstantValue();
if (!cond) {
AddError("static assertion condition must be a constant expression",
assertion->condition->source);
return nullptr;
}
if (auto* ty = cond->Type(); !ty->Is<sem::Bool>()) {
AddError(
"static assertion condition must be a bool, got '" + builder_->FriendlyName(ty) + "'",
assertion->condition->source);
return nullptr;
}
if (!cond->As<bool>()) {
AddError("static assertion failed", assertion->source);
return nullptr;
}
auto* sem =
builder_->create<sem::Statement>(assertion, current_compound_statement_, current_function_);
builder_->Sem().Add(assertion, sem);
return sem;
}
sem::Function* Resolver::Function(const ast::Function* decl) {
uint32_t parameter_index = 0;
std::unordered_map<Symbol, Source> parameter_names;
@ -1042,6 +1070,7 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
[&](const ast::IncrementDecrementStatement* i) { return IncrementDecrementStatement(i); },
[&](const ast::ReturnStatement* r) { return ReturnStatement(r); },
[&](const ast::VariableDeclStatement* v) { return VariableDeclStatement(v); },
[&](const ast::StaticAssert* sa) { return StaticAssert(sa); },
// Error cases
[&](const ast::CaseStatement*) {

View File

@ -250,6 +250,7 @@ class Resolver {
sem::LoopStatement* LoopStatement(const ast::LoopStatement*);
sem::Statement* ReturnStatement(const ast::ReturnStatement*);
sem::Statement* Statement(const ast::Statement*);
sem::Statement* StaticAssert(const ast::StaticAssert*);
sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s);
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
bool Statements(utils::VectorRef<const ast::Statement*>);

View File

@ -0,0 +1,110 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/resolver/resolver.h"
#include "gmock/gmock.h"
#include "src/tint/resolver/resolver_test_helper.h"
using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver {
namespace {
using ResolverStaticAssertTest = ResolverTest;
TEST_F(ResolverStaticAssertTest, Global_True_Pass) {
GlobalStaticAssert(true);
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverStaticAssertTest, Global_False_Fail) {
GlobalStaticAssert(Source{{12, 34}}, false);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
}
TEST_F(ResolverStaticAssertTest, Global_Const_Pass) {
GlobalConst("C", ty.bool_(), Expr(true));
GlobalStaticAssert("C");
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverStaticAssertTest, Global_Const_Fail) {
GlobalConst("C", ty.bool_(), Expr(false));
GlobalStaticAssert(Source{{12, 34}}, "C");
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
}
// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation.
TEST_F(ResolverStaticAssertTest, DISABLED_Global_LessThan_Pass) {
GlobalStaticAssert(LessThan(2_i, 3_i));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation.
TEST_F(ResolverStaticAssertTest, DISABLED_Global_LessThan_Fail) {
GlobalStaticAssert(Source{{12, 34}}, LessThan(4_i, 3_i));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
}
TEST_F(ResolverStaticAssertTest, Local_True_Pass) {
WrapInFunction(StaticAssert(true));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverStaticAssertTest, Local_False_Fail) {
WrapInFunction(StaticAssert(Source{{12, 34}}, false));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
}
TEST_F(ResolverStaticAssertTest, Local_Const_Pass) {
GlobalConst("C", ty.bool_(), Expr(true));
WrapInFunction(StaticAssert("C"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverStaticAssertTest, Local_Const_Fail) {
GlobalConst("C", ty.bool_(), Expr(false));
WrapInFunction(StaticAssert(Source{{12, 34}}, "C"));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
}
TEST_F(ResolverStaticAssertTest, Local_NonConst) {
GlobalVar("V", ty.bool_(), Expr(true), ast::StorageClass::kPrivate);
WrapInFunction(StaticAssert(Expr(Source{{12, 34}}, "V")));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: static assertion condition must be a constant expression");
}
// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation.
TEST_F(ResolverStaticAssertTest, DISABLED_Local_LessThan_Pass) {
WrapInFunction(StaticAssert(LessThan(2_i, 3_i)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation.
TEST_F(ResolverStaticAssertTest, DISABLED_Local_LessThan_Fail) {
WrapInFunction(StaticAssert(Source{{12, 34}}, LessThan(4_i, 3_i)));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
}
} // namespace
} // namespace tint::resolver

View File

@ -847,6 +847,7 @@ class UniformityGraph {
return cfx;
}
},
[&](const ast::ReturnStatement* r) {
Node* cf_ret;
if (r->value) {
@ -870,6 +871,7 @@ class UniformityGraph {
return cf_ret;
},
[&](const ast::SwitchStatement* s) {
auto* sem_switch = sem_.Get(s);
auto [cfx, v_cond] = ProcessExpression(cf, s->condition);
@ -938,6 +940,7 @@ class UniformityGraph {
return cf_end ? cf_end : cf;
},
[&](const ast::VariableDeclStatement* decl) {
Node* node;
if (decl->variable->constructor) {
@ -956,6 +959,11 @@ class UniformityGraph {
return cf;
},
[&](const ast::StaticAssert*) {
return cf; // No impact on uniformity
},
[&](Default) {
TINT_ICE(Resolver, diagnostics_)
<< "unknown statement type: " << std::string(stmt->TypeInfo().name);