Resolver: Track storage class usages of structures

This will be used to validate layout rules, as well as preventing
illegal types from being used in a uniform / storage buffer.

Also: Cleanup logic around VariableDeclStatement
This was spread across 3 places, entirely unnecessarily.

Bug: tint:643
Change-Id: I9d309c3a5dfb5676984f49ce51763a97bcac93bb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45125
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Ben Clayton 2021-03-17 22:47:33 +00:00 committed by Commit Bot service account
parent 893afdfd2c
commit a88090b04d
8 changed files with 282 additions and 64 deletions

View File

@ -473,6 +473,7 @@ if(${TINT_BUILD_TESTS})
resolver/resolver_test_helper.h resolver/resolver_test_helper.h
resolver/resolver_test.cc resolver/resolver_test.cc
resolver/struct_layout_test.cc resolver/struct_layout_test.cc
resolver/struct_storage_class_use_test.cc
resolver/validation_test.cc resolver/validation_test.cc
scope_stack_test.cc scope_stack_test.cc
semantic/sem_intrinsic_test.cc semantic/sem_intrinsic_test.cc

View File

@ -34,6 +34,12 @@ enum class StorageClass {
kFunction kFunction
}; };
/// @returns true if the StorageClass is host-sharable
/// @see https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable
inline bool IsHostSharable(StorageClass sc) {
return sc == ast::StorageClass::kUniform || sc == ast::StorageClass::kStorage;
}
std::ostream& operator<<(std::ostream& out, StorageClass sc); std::ostream& operator<<(std::ostream& out, StorageClass sc);
} // namespace ast } // namespace ast

View File

@ -151,6 +151,11 @@ bool Resolver::ResolveInternal() {
return false; return false;
} }
} }
if (!ApplyStorageClassUsageToType(var->declared_storage_class(),
var->type())) {
return false;
}
} }
if (!Functions(builder_->AST().Functions())) { if (!Functions(builder_->AST().Functions())) {
@ -200,16 +205,6 @@ bool Resolver::BlockStatement(const ast::BlockStatement* stmt) {
bool Resolver::Statements(const ast::StatementList& stmts) { bool Resolver::Statements(const ast::StatementList& stmts) {
for (auto* stmt : stmts) { for (auto* stmt : stmts) {
if (auto* decl = stmt->As<ast::VariableDeclStatement>()) {
if (!VariableDeclStatement(decl)) {
return false;
}
}
if (!VariableStorageClass(stmt)) {
return false;
}
if (!Statement(stmt)) { if (!Statement(stmt)) {
return false; return false;
} }
@ -217,36 +212,6 @@ bool Resolver::Statements(const ast::StatementList& stmts) {
return true; return true;
} }
bool Resolver::VariableStorageClass(ast::Statement* stmt) {
auto* var_decl = stmt->As<ast::VariableDeclStatement>();
if (var_decl == nullptr) {
return true;
}
auto* var = var_decl->variable();
auto* info = CreateVariableInfo(var);
variable_to_info_.emplace(var, info);
// Nothing to do for const
if (var->is_const()) {
return true;
}
if (info->storage_class == ast::StorageClass::kFunction) {
return true;
}
if (info->storage_class != ast::StorageClass::kNone) {
diagnostics_.add_error("function variable has a non-function storage class",
stmt->source());
return false;
}
info->storage_class = ast::StorageClass::kFunction;
return true;
}
bool Resolver::Statement(ast::Statement* stmt) { bool Resolver::Statement(ast::Statement* stmt) {
auto* sem_statement = builder_->create<semantic::Statement>(stmt); auto* sem_statement = builder_->create<semantic::Statement>(stmt);
@ -336,10 +301,7 @@ bool Resolver::Statement(ast::Statement* stmt) {
return true; return true;
} }
if (auto* v = stmt->As<ast::VariableDeclStatement>()) { if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
variable_stack_.set(v->variable()->symbol(), return VariableDeclStatement(v);
variable_to_info_.at(v->variable()));
current_block_->decls.push_back(v->variable());
return Expression(v->variable()->constructor());
} }
diagnostics_.add_error( diagnostics_.add_error(
@ -1118,21 +1080,44 @@ bool Resolver::UnaryOp(ast::UnaryOpExpression* expr) {
} }
bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
auto* ctor = stmt->variable()->constructor(); if (auto* ctor = stmt->variable()->constructor()) {
if (!ctor) { if (!Expression(ctor)) {
return true;
}
if (auto* sce = ctor->As<ast::ScalarConstructorExpression>()) {
auto* lhs_type = stmt->variable()->type()->UnwrapAliasIfNeeded();
auto* rhs_type = sce->literal()->type()->UnwrapAliasIfNeeded();
if (lhs_type != rhs_type) {
diagnostics_.add_error(
"constructor expression type does not match variable type",
stmt->source());
return false; return false;
} }
if (auto* sce = ctor->As<ast::ScalarConstructorExpression>()) {
auto* lhs_type = stmt->variable()->type()->UnwrapAliasIfNeeded();
auto* rhs_type = sce->literal()->type()->UnwrapAliasIfNeeded();
if (lhs_type != rhs_type) {
diagnostics_.add_error(
"constructor expression type does not match variable type",
stmt->source());
return false;
}
}
}
auto* var = stmt->variable();
auto* info = CreateVariableInfo(var);
variable_to_info_.emplace(var, info);
variable_stack_.set(var->symbol(), info);
current_block_->decls.push_back(var);
if (!var->is_const()) {
if (info->storage_class != ast::StorageClass::kFunction) {
if (info->storage_class != ast::StorageClass::kNone) {
diagnostics_.add_error(
"function variable has a non-function storage class",
stmt->source());
return false;
}
info->storage_class = ast::StorageClass::kFunction;
}
}
if (!ApplyStorageClassUsageToType(info->storage_class, var->type())) {
return false;
} }
return true; return true;
@ -1247,9 +1232,10 @@ void Resolver::CreateSemanticNodes() const {
for (auto it : struct_info_) { for (auto it : struct_info_) {
auto* str = it.first; auto* str = it.first;
auto* info = it.second; auto* info = it.second;
builder_->Sem().Add(str, builder_->create<semantic::Struct>( builder_->Sem().Add(
str, std::move(info->members), info->align, str, builder_->create<semantic::Struct>(
info->size, info->size_no_padding)); str, std::move(info->members), info->align, info->size,
info->size_no_padding, info->storage_class_usage));
} }
} }
@ -1470,6 +1456,44 @@ Resolver::StructInfo* Resolver::Structure(type::Struct* str) {
return info; return info;
} }
bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
type::Type* ty) {
ty = ty->UnwrapAliasIfNeeded();
if (auto* str = ty->As<type::Struct>()) {
auto* info = Structure(str);
if (!info) {
return false;
}
if (info->storage_class_usage.count(sc)) {
return true; // Already applied
}
info->storage_class_usage.emplace(sc);
for (auto* member : str->impl()->members()) {
// TODO(amaiorano): Determine the host-sharable types
bool can_be_host_sharable = true;
if (ast::IsHostSharable(sc) && !can_be_host_sharable) {
std::stringstream err;
err << "Structure '" << str->FriendlyName(builder_->Symbols())
<< "' is used by storage class " << sc
<< " which contains a member of non-host-sharable type "
<< member->type()->FriendlyName(builder_->Symbols());
diagnostics_.add_error(err.str(), member->source());
return false;
}
if (!ApplyStorageClassUsageToType(sc, member->type())) {
return false;
}
}
}
if (auto* arr = ty->As<type::Array>()) {
return ApplyStorageClassUsageToType(sc, arr->type());
}
return true;
}
template <typename F> template <typename F>
bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) { bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) {
BlockInfo block_info(type, current_block_); BlockInfo block_info(type, current_block_);

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "src/intrinsic_table.h" #include "src/intrinsic_table.h"
@ -124,6 +125,7 @@ class Resolver {
uint32_t align = 0; uint32_t align = 0;
uint32_t size = 0; uint32_t size = 0;
uint32_t size_no_padding = 0; uint32_t size_no_padding = 0;
std::unordered_set<ast::StorageClass> storage_class_usage;
}; };
/// Structure holding semantic information about a block (i.e. scope), such as /// Structure holding semantic information about a block (i.e. scope), such as
@ -206,7 +208,6 @@ class Resolver {
bool Statements(const ast::StatementList&); bool Statements(const ast::StatementList&);
bool UnaryOp(ast::UnaryOpExpression*); bool UnaryOp(ast::UnaryOpExpression*);
bool VariableDeclStatement(const ast::VariableDeclStatement*); bool VariableDeclStatement(const ast::VariableDeclStatement*);
bool VariableStorageClass(ast::Statement*);
/// @returns the semantic information for the array `arr`, building it if it /// @returns the semantic information for the array `arr`, building it if it
/// hasn't been constructed already. If an error is raised, nullptr is /// hasn't been constructed already. If an error is raised, nullptr is
@ -217,6 +218,12 @@ class Resolver {
/// been constructed already. If an error is raised, nullptr is returned. /// been constructed already. If an error is raised, nullptr is returned.
StructInfo* Structure(type::Struct* str); StructInfo* Structure(type::Struct* str);
/// Records the storage class usage for the given type, and any transient
/// dependencies of the type. Validates that the type can be used for the
/// given storage class, erroring if it cannot.
/// @returns true on success, false on error
bool ApplyStorageClassUsageToType(ast::StorageClass, type::Type*);
/// @param align the output default alignment in bytes for the type `ty` /// @param align the output default alignment in bytes for the type `ty`
/// @param size the output default size in bytes for the type `ty` /// @param size the output default size in bytes for the type `ty`
/// @returns true on success, false on error /// @returns true on success, false on error

View File

@ -0,0 +1,161 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/resolver/resolver.h"
#include "gmock/gmock.h"
#include "src/resolver/resolver_test_helper.h"
#include "src/semantic/struct.h"
using ::testing::UnorderedElementsAre;
namespace tint {
namespace resolver {
namespace {
using ResolverStorageClassUseTest = ResolverTest;
TEST_F(ResolverStorageClassUseTest, UnreachableStruct) {
auto* s = Structure("S", {Member("a", ty.f32())});
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_TRUE(sem->StorageClassUsage().empty());
}
TEST_F(ResolverStorageClassUseTest, StructReachableFromGlobal) {
auto* s = Structure("S", {Member("a", ty.f32())});
Global("g", s, ast::StorageClass::kStorage);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
UnorderedElementsAre(ast::StorageClass::kStorage));
}
TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalAlias) {
auto* s = Structure("S", {Member("a", ty.f32())});
auto* a = ty.alias("A", s);
Global("g", a, ast::StorageClass::kStorage);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
UnorderedElementsAre(ast::StorageClass::kStorage));
}
TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalStruct) {
auto* s = Structure("S", {Member("a", ty.f32())});
auto* o = Structure("O", {Member("a", s)});
Global("g", o, ast::StorageClass::kStorage);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
UnorderedElementsAre(ast::StorageClass::kStorage));
}
TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalArray) {
auto* s = Structure("S", {Member("a", ty.f32())});
auto* a = ty.array(s, 3);
Global("g", a, ast::StorageClass::kStorage);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
UnorderedElementsAre(ast::StorageClass::kStorage));
}
TEST_F(ResolverStorageClassUseTest, StructReachableFromLocal) {
auto* s = Structure("S", {Member("a", ty.f32())});
WrapInFunction(Var("g", s, ast::StorageClass::kFunction));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
UnorderedElementsAre(ast::StorageClass::kFunction));
}
TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalAlias) {
auto* s = Structure("S", {Member("a", ty.f32())});
auto* a = ty.alias("A", s);
WrapInFunction(Var("g", a, ast::StorageClass::kFunction));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
UnorderedElementsAre(ast::StorageClass::kFunction));
}
TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalStruct) {
auto* s = Structure("S", {Member("a", ty.f32())});
auto* o = Structure("O", {Member("a", s)});
WrapInFunction(Var("g", o, ast::StorageClass::kFunction));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
UnorderedElementsAre(ast::StorageClass::kFunction));
}
TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalArray) {
auto* s = Structure("S", {Member("a", ty.f32())});
auto* a = ty.array(s, 3);
WrapInFunction(Var("g", a, ast::StorageClass::kFunction));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
UnorderedElementsAre(ast::StorageClass::kFunction));
}
TEST_F(ResolverStorageClassUseTest, StructMultipleStorageClassUses) {
auto* s = Structure("S", {Member("a", ty.f32())});
Global("x", s, ast::StorageClass::kStorage);
Global("y", s, ast::StorageClass::kUniform);
WrapInFunction(Var("g", s, ast::StorageClass::kFunction));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
UnorderedElementsAre(ast::StorageClass::kStorage,
ast::StorageClass::kUniform,
ast::StorageClass::kFunction));
}
} // namespace
} // namespace resolver
} // namespace tint

View File

@ -24,12 +24,14 @@ Struct::Struct(type::Struct* type,
StructMemberList members, StructMemberList members,
uint32_t align, uint32_t align,
uint32_t size, uint32_t size,
uint32_t size_no_padding) uint32_t size_no_padding,
std::unordered_set<ast::StorageClass> storage_class_usage)
: type_(type), : type_(type),
members_(std::move(members)), members_(std::move(members)),
align_(align), align_(align),
size_(size), size_(size),
size_no_padding_(size_no_padding) {} size_no_padding_(size_no_padding),
storage_class_usage_(std::move(storage_class_usage)) {}
Struct::~Struct() = default; Struct::~Struct() = default;

View File

@ -17,8 +17,10 @@
#include <stdint.h> #include <stdint.h>
#include <unordered_set>
#include <vector> #include <vector>
#include "src/ast/storage_class.h"
#include "src/semantic/node.h" #include "src/semantic/node.h"
namespace tint { namespace tint {
@ -48,11 +50,13 @@ class Struct : public Castable<Struct, Node> {
/// @param size the byte size of the structure /// @param size the byte size of the structure
/// @param size_no_padding size of the members without the end of structure /// @param size_no_padding size of the members without the end of structure
/// alignment padding /// alignment padding
/// @param storage_class_usage a set of all the storage class usages
Struct(type::Struct* type, Struct(type::Struct* type,
StructMemberList members, StructMemberList members,
uint32_t align, uint32_t align,
uint32_t size, uint32_t size,
uint32_t size_no_padding); uint32_t size_no_padding,
std::unordered_set<ast::StorageClass> storage_class_usage);
/// Destructor /// Destructor
~Struct() override; ~Struct() override;
@ -79,12 +83,24 @@ class Struct : public Castable<Struct, Node> {
/// alignment padding /// alignment padding
uint32_t SizeNoPadding() const { return size_no_padding_; } uint32_t SizeNoPadding() const { return size_no_padding_; }
/// @returns the set of storage class uses of this structure
const std::unordered_set<ast::StorageClass>& StorageClassUsage() const {
return storage_class_usage_;
}
/// @param usage the ast::StorageClass usage type to query
/// @returns true iff this structure has been used as the given storage class
bool UsedAs(ast::StorageClass usage) const {
return storage_class_usage_.count(usage) > 0;
}
private: private:
type::Struct* const type_; type::Struct* const type_;
StructMemberList const members_; StructMemberList const members_;
uint32_t const align_; uint32_t const align_;
uint32_t const size_; uint32_t const size_;
uint32_t const size_no_padding_; uint32_t const size_no_padding_;
std::unordered_set<ast::StorageClass> const storage_class_usage_;
}; };
/// StructMember holds the semantic information for structure members. /// StructMember holds the semantic information for structure members.

View File

@ -175,6 +175,7 @@ source_set("tint_unittests_core_src") {
"../src/resolver/resolver_test_helper.h", "../src/resolver/resolver_test_helper.h",
"../src/resolver/resolver_test.cc", "../src/resolver/resolver_test.cc",
"../src/resolver/struct_layout_test.cc", "../src/resolver/struct_layout_test.cc",
"../src/resolver/struct_storage_class_use_test.cc",
"../src/resolver/validation_test.cc", "../src/resolver/validation_test.cc",
"../src/scope_stack_test.cc", "../src/scope_stack_test.cc",
"../src/semantic/sem_intrinsic_test.cc", "../src/semantic/sem_intrinsic_test.cc",