dawn-cmake/src/validator/validator_impl.cc
Ben Clayton 5b36d2c612 Remove all unnecessary includes
All includes from .cc to .h are preserved, even when transitively included.

It's clear that there are far too many includes in header files, and we should be more aggressive with forward declarations. tint:532 will continue to track this work.

There are, however, plenty of includes that have accumulated over time which are no longer required directly or transitively, so this change starts with a clean slate of *required* includes.

Bug: tint:532
Change-Id: Ie1718dad565f8309fa180ef91bcf3920e76dba18
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44042
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
2021-03-09 11:11:17 +00:00

571 lines
17 KiB
C++

// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/validator/validator_impl.h"
#include <utility>
#include "src/ast/call_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/sint_literal.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/uint_literal.h"
#include "src/semantic/call.h"
#include "src/semantic/function.h"
#include "src/semantic/variable.h"
#include "src/type/alias_type.h"
#include "src/type/array_type.h"
#include "src/type/matrix_type.h"
#include "src/type/pointer_type.h"
#include "src/type/struct_type.h"
#include "src/type/u32_type.h"
#include "src/type/vector_type.h"
#include "src/type/void_type.h"
namespace tint {
ValidatorImpl::ValidatorImpl(const Program* program) : program_(program) {}
ValidatorImpl::~ValidatorImpl() = default;
void ValidatorImpl::add_error(const Source& src,
const char* code,
const std::string& msg) {
diag::Diagnostic diag;
diag.severity = diag::Severity::Error;
diag.source = src;
diag.message = msg;
diag.code = code;
diags_.add(std::move(diag));
}
void ValidatorImpl::add_error(const Source& src, const std::string& msg) {
diags_.add_error(msg, src);
}
bool ValidatorImpl::Validate() {
if (!program_->IsValid()) {
// If we're attempting to validate an invalid program, fail with the
// program's diagnostics.
diags_.add(program_->Diagnostics());
return false;
}
// Validate global declarations in the order they appear in the module.
for (auto* decl : program_->AST().GlobalDeclarations()) {
if (auto* ty = decl->As<type::Type>()) {
if (!ValidateConstructedType(ty)) {
return false;
}
} else if (auto* func = decl->As<ast::Function>()) {
current_function_ = func;
if (!ValidateFunction(func)) {
return false;
}
current_function_ = nullptr;
} else if (auto* var = decl->As<ast::Variable>()) {
if (!ValidateGlobalVariable(var)) {
return false;
}
} else {
TINT_UNREACHABLE(diags_);
return false;
}
}
if (!ValidateEntryPoint(program_->AST().Functions())) {
return false;
}
return true;
}
bool ValidatorImpl::ValidateConstructedType(const type::Type* type) {
if (auto* st = type->As<type::Struct>()) {
for (auto* member : st->impl()->members()) {
if (auto* r = member->type()->UnwrapAll()->As<type::Array>()) {
if (r->IsRuntimeArray()) {
if (member != st->impl()->members().back()) {
add_error(member->source(), "v-0015",
"runtime arrays may only appear as the last member of "
"a struct");
return false;
}
if (!st->IsBlockDecorated()) {
add_error(member->source(), "v-0015",
"a struct containing a runtime-sized array "
"requires the [[block]] attribute: '" +
program_->Symbols().NameFor(st->symbol()) + "'");
return false;
}
}
}
}
}
return true;
}
bool ValidatorImpl::ValidateGlobalVariable(const ast::Variable* var) {
auto* sem = program_->Sem().Get(var);
if (!sem) {
add_error(var->source(), "no semantic information for variable '" +
program_->Symbols().NameFor(var->symbol()) +
"'");
return false;
}
if (variable_stack_.has(var->symbol())) {
add_error(var->source(), "v-0011",
"redeclared global identifier '" +
program_->Symbols().NameFor(var->symbol()) + "'");
return false;
}
if (!var->is_const() && sem->StorageClass() == ast::StorageClass::kNone) {
add_error(var->source(), "v-0022",
"global variables must have a storage class");
return false;
}
if (var->is_const() && !(sem->StorageClass() == ast::StorageClass::kNone)) {
add_error(var->source(), "v-global01",
"global constants shouldn't have a storage class");
return false;
}
variable_stack_.set_global(var->symbol(), var);
return true;
}
bool ValidatorImpl::ValidateEntryPoint(const ast::FunctionList& funcs) {
auto shader_is_present = false;
for (auto* func : funcs) {
if (func->IsEntryPoint()) {
shader_is_present = true;
if (!func->params().empty()) {
add_error(func->source(), "v-0023",
"Entry point function must accept no parameters: '" +
program_->Symbols().NameFor(func->symbol()) + "'");
return false;
}
if (!func->return_type()->Is<type::Void>()) {
add_error(func->source(), "v-0024",
"Entry point function must return void: '" +
program_->Symbols().NameFor(func->symbol()) + "'");
return false;
}
auto stage_deco_count = 0;
for (auto* deco : func->decorations()) {
if (deco->Is<ast::StageDecoration>()) {
stage_deco_count++;
}
}
if (stage_deco_count > 1) {
add_error(func->source(), "v-0020",
"only one stage decoration permitted per entry point");
return false;
}
}
}
if (!shader_is_present) {
add_error(Source{}, "v-0003",
"At least one of vertex, fragment or compute shader must "
"be present");
return false;
}
return true;
}
bool ValidatorImpl::ValidateFunction(const ast::Function* func) {
if (function_stack_.has(func->symbol())) {
add_error(func->source(), "v-0016",
"function names must be unique '" +
program_->Symbols().NameFor(func->symbol()) + "'");
return false;
}
function_stack_.set(func->symbol(), func);
variable_stack_.push_scope();
for (auto* param : func->params()) {
variable_stack_.set(param->symbol(), param);
if (!ValidateParameter(param)) {
return false;
}
}
if (!ValidateStatements(func->body())) {
return false;
}
variable_stack_.pop_scope();
if (!current_function_->return_type()->Is<type::Void>()) {
if (!func->get_last_statement() ||
!func->get_last_statement()->Is<ast::ReturnStatement>()) {
add_error(func->source(), "v-0002",
"non-void function must end with a return statement");
return false;
}
}
return true;
}
bool ValidatorImpl::ValidateParameter(const ast::Variable* param) {
if (auto* r = param->type()->UnwrapAll()->As<type::Array>()) {
if (r->IsRuntimeArray()) {
add_error(
param->source(), "v-0015",
"runtime arrays may only appear as the last member of a struct");
return false;
}
}
return true;
}
bool ValidatorImpl::ValidateReturnStatement(const ast::ReturnStatement* ret) {
// TODO(sarahM0): update this when this issue resolves:
// https://github.com/gpuweb/gpuweb/issues/996
type::Type* func_type = current_function_->return_type();
type::Void void_type;
auto* ret_type = ret->has_value()
? program_->Sem().Get(ret->value())->Type()->UnwrapAll()
: &void_type;
if (func_type->type_name() != ret_type->type_name()) {
add_error(ret->source(), "v-000y",
"return statement type must match its function return "
"type, returned '" +
ret_type->type_name() + "', expected '" +
func_type->type_name() + "'");
return false;
}
return true;
}
bool ValidatorImpl::ValidateStatements(const ast::BlockStatement* block) {
if (!block) {
return false;
}
bool is_valid = true;
variable_stack_.push_scope();
for (auto* stmt : *block) {
if (!ValidateStatement(stmt)) {
is_valid = false;
break;
}
}
variable_stack_.pop_scope();
return is_valid;
}
bool ValidatorImpl::ValidateDeclStatement(
const ast::VariableDeclStatement* decl) {
auto symbol = decl->variable()->symbol();
bool is_global = false;
if (variable_stack_.get(symbol, nullptr, &is_global)) {
const char* error_code = "v-0014";
if (is_global) {
error_code = "v-0013";
}
add_error(
decl->source(), error_code,
"redeclared identifier '" + program_->Symbols().NameFor(symbol) + "'");
return false;
}
// TODO(dneto): Check type compatibility of the initializer.
// - if it's non-constant, then is storable or can be dereferenced to be
// storable.
// - types match or the RHS can be dereferenced to equal the LHS type.
variable_stack_.set(symbol, decl->variable());
if (auto* arr = decl->variable()->type()->UnwrapAll()->As<type::Array>()) {
if (arr->IsRuntimeArray()) {
add_error(
decl->source(), "v-0015",
"runtime arrays may only appear as the last member of a struct");
return false;
}
}
return true;
}
bool ValidatorImpl::ValidateStatement(const ast::Statement* stmt) {
if (!stmt) {
return false;
}
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
bool constructor_valid =
v->variable()->has_constructor()
? ValidateExpression(v->variable()->constructor())
: true;
return constructor_valid && ValidateDeclStatement(v);
}
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return ValidateAssign(a);
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return ValidateReturnStatement(r);
}
if (auto* c = stmt->As<ast::CallStatement>()) {
return ValidateCallExpr(c->expr());
}
if (auto* s = stmt->As<ast::SwitchStatement>()) {
return ValidateSwitch(s);
}
if (auto* c = stmt->As<ast::CaseStatement>()) {
return ValidateCase(c);
}
if (auto* b = stmt->As<ast::BlockStatement>()) {
return ValidateStatements(b);
}
return true;
}
bool ValidatorImpl::ValidateSwitch(const ast::SwitchStatement* s) {
if (!ValidateExpression(s->condition())) {
return false;
}
auto* cond_type = program_->Sem().Get(s->condition())->Type()->UnwrapAll();
if (!cond_type->is_integer_scalar()) {
add_error(s->condition()->source(), "v-0025",
"switch statement selector expression must be of a "
"scalar integer type");
return false;
}
int default_counter = 0;
std::unordered_set<int32_t> selector_set;
for (auto* case_stmt : s->body()) {
if (!ValidateStatement(case_stmt)) {
return false;
}
if (case_stmt->IsDefault()) {
default_counter++;
}
for (auto* selector : case_stmt->selectors()) {
if (cond_type != selector->type()) {
add_error(case_stmt->source(), "v-0026",
"the case selector values must have the same "
"type as the selector expression.");
return false;
}
auto v =
static_cast<int32_t>(selector->type()->Is<type::U32>()
? selector->As<ast::UintLiteral>()->value()
: selector->As<ast::SintLiteral>()->value());
if (selector_set.count(v)) {
add_error(case_stmt->source(), "v-0027",
"a literal value must not appear more than once in "
"the case selectors for a switch statement: '" +
program_->str(selector) + "'");
return false;
}
selector_set.emplace(v);
}
}
if (default_counter != 1) {
add_error(s->source(), "v-0008",
"switch statement must have exactly one default clause");
return false;
}
auto* last_clause = s->body().back();
auto* last_stmt_of_last_clause =
last_clause->As<ast::CaseStatement>()->body()->last();
if (last_stmt_of_last_clause &&
last_stmt_of_last_clause->Is<ast::FallthroughStatement>()) {
add_error(last_stmt_of_last_clause->source(), "v-0028",
"a fallthrough statement must not appear as "
"the last statement in last clause of a switch");
return false;
}
return true;
}
bool ValidatorImpl::ValidateCase(const ast::CaseStatement* c) {
if (!ValidateStatement(c->body())) {
return false;
}
return true;
}
bool ValidatorImpl::ValidateCallExpr(const ast::CallExpression* expr) {
if (!expr) {
// TODO(sarahM0): Here and other Validate.*: figure out whether return
// false or true
return false;
}
auto* call = program_->Sem().Get(expr);
if (call == nullptr) {
add_error(expr->source(), "CallExpression is missing semantic information");
return false;
}
auto* target = call->Target();
if (target->Is<semantic::Intrinsic>()) {
// TODO(bclayton): Add intrinsic validation checks here.
return true;
}
if (auto* func = target->As<semantic::Function>()) {
if (current_function_ == func->Declaration()) {
add_error(expr->source(), "v-0004",
"recursion is not allowed: '" +
program_->Symbols().NameFor(current_function_->symbol()) +
"'");
return false;
}
return true;
}
add_error(expr->source(), "Invalid function call expression");
return false;
}
bool ValidatorImpl::ValidateBadAssignmentToIdentifier(
const ast::AssignmentStatement* assign) {
auto* ident = assign->lhs()->As<ast::IdentifierExpression>();
if (!ident) {
// It wasn't an identifier in the first place.
return true;
}
const ast::Variable* var;
if (variable_stack_.get(ident->symbol(), &var)) {
// Give a nicer message if the LHS of the assignment is a const identifier.
// It's likely to be a common programmer error.
if (var->is_const()) {
add_error(assign->source(), "v-0021",
"cannot re-assign a constant: '" +
program_->Symbols().NameFor(ident->symbol()) + "'");
return false;
}
} else {
// The identifier is not defined. This should already have been caught
// when validating the subexpression.
add_error(ident->source(), "v-0006",
"'" + program_->Symbols().NameFor(ident->symbol()) +
"' is not declared");
return false;
}
return true;
}
bool ValidatorImpl::ValidateAssign(const ast::AssignmentStatement* assign) {
if (!assign) {
return false;
}
auto* lhs = assign->lhs();
auto* rhs = assign->rhs();
if (!ValidateExpression(lhs)) {
return false;
}
if (!ValidateExpression(rhs)) {
return false;
}
// Pointers are not storable in WGSL, but the right-hand side must be
// storable. The raw right-hand side might be a pointer value which must be
// loaded (dereferenced) to provide the value to be stored.
auto* rhs_result_type = program_->Sem().Get(rhs)->Type()->UnwrapAll();
if (!IsStorable(rhs_result_type)) {
add_error(assign->source(), "v-000x",
"invalid assignment: right-hand-side is not storable: " +
program_->Sem().Get(rhs)->Type()->type_name());
return false;
}
auto* lhs_result_type = program_->Sem().Get(lhs)->Type()->UnwrapIfNeeded();
if (auto* lhs_reference_type = As<type::Pointer>(lhs_result_type)) {
auto* lhs_store_type = lhs_reference_type->type()->UnwrapIfNeeded();
if (lhs_store_type != rhs_result_type) {
add_error(assign->source(), "v-000x",
"invalid assignment: can't assign value of type '" +
rhs_result_type->type_name() + "' to '" +
lhs_store_type->type_name() + "'");
return false;
}
} else {
if (!ValidateBadAssignmentToIdentifier(assign)) {
return false;
}
// Issue a generic error.
add_error(
assign->source(), "v-000x",
"invalid assignment: left-hand-side does not reference storage: " +
program_->Sem().Get(lhs)->Type()->type_name());
return false;
}
return true;
}
bool ValidatorImpl::ValidateExpression(const ast::Expression* expr) {
if (!expr) {
return false;
}
if (auto* i = expr->As<ast::IdentifierExpression>()) {
return ValidateIdentifier(i);
}
if (auto* c = expr->As<ast::CallExpression>()) {
return ValidateCallExpr(c);
}
return true;
}
bool ValidatorImpl::ValidateIdentifier(const ast::IdentifierExpression* ident) {
const ast::Variable* var;
if (!variable_stack_.get(ident->symbol(), &var)) {
add_error(ident->source(), "v-0006",
"'" + program_->Symbols().NameFor(ident->symbol()) +
"' is not declared");
return false;
}
return true;
}
bool ValidatorImpl::IsStorable(type::Type* type) {
if (type == nullptr) {
return false;
}
if (type->is_scalar() || type->Is<type::Vector>() ||
type->Is<type::Matrix>()) {
return true;
}
if (type::Array* array_type = type->As<type::Array>()) {
return IsStorable(array_type->type());
}
if (type::Struct* struct_type = type->As<type::Struct>()) {
for (const auto* member : struct_type->impl()->members()) {
if (!IsStorable(member->type())) {
return false;
}
}
return true;
}
if (type::Alias* alias_type = type->As<type::Alias>()) {
return IsStorable(alias_type->type());
}
return false;
}
} // namespace tint