Rework Resolver so that we construct semantic types in a single pass.

The semantic nodes cannot be fully immutable, as they contain cyclic
references. Remove Resolver::CreateSemanticNodes(), and instead
construct and mutate the semantic nodes in the single traversal pass.

Give up on trying to maintain the 'authored' type names (aliased names).
These are a nightmare to maintain, and provided limited use.

Significantly simplfies the Resolver, and allows us to generate more
semantic to semantic references, reducing sem -> ast -> sem hops.

Note: This change introduces constant value propagation across constant
variables. This is unlocked by the earlier construction of the
sem::Variable.

Change-Id: I592092fdc47fe24d30e512952511c9ab7c16d7a1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68406
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2021-11-05 16:51:38 +00:00 committed by Tint LUCI CQ
parent 2423df3e04
commit a9156ff091
43 changed files with 1448 additions and 1550 deletions

View File

@ -22,6 +22,7 @@
#include "src/castable.h"
#include "src/program.h"
#include "src/sem/block_statement.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/variable.h"
@ -74,16 +75,15 @@ std::vector<const sem::Variable*> GetAllVarsInScope(
}
// Process function parameters.
for (const auto* param : curr_stmt->Function()->params) {
const auto* sem_var = program.Sem().Get(param);
if (pred(sem_var)) {
result.push_back(sem_var);
for (const auto* param : curr_stmt->Function()->Parameters()) {
if (pred(param)) {
result.push_back(param);
}
}
// Global variables do not belong to any ast::BlockStatement.
for (const auto* global_decl : program.AST().GlobalDeclarations()) {
if (global_decl == curr_stmt->Function()) {
if (global_decl == curr_stmt->Function()->Declaration()) {
// The same situation as in the previous loop. The current function has
// been reached. If there are any variables declared below, they won't be
// visible in this function. Thus, exit the loop.

View File

@ -827,11 +827,11 @@ void Inspector::GenerateSamplerTargets() {
}
auto* call_func = call->Stmt()->Function();
std::vector<Symbol> entry_points;
if (call_func->IsEntryPoint()) {
entry_points = {call_func->symbol};
std::vector<const sem::Function*> entry_points;
if (call_func->Declaration()->IsEntryPoint()) {
entry_points = {call_func};
} else {
entry_points = sem.Get(call_func)->AncestorEntryPoints();
entry_points = call_func->AncestorEntryPoints();
}
if (entry_points.empty()) {
@ -854,8 +854,9 @@ void Inspector::GenerateSamplerTargets() {
sampler->Declaration()->BindingPoint().group->value,
sampler->Declaration()->BindingPoint().binding->value};
for (auto entry_point : entry_points) {
const auto& ep_name = program_->Symbols().NameFor(entry_point);
for (auto* entry_point : entry_points) {
const auto& ep_name =
program_->Symbols().NameFor(entry_point->Declaration()->symbol);
(*sampler_targets_)[ep_name].add(
{sampler_binding_point, texture_binding_point});
}
@ -911,8 +912,8 @@ void Inspector::GetOriginatingResources(
// is not called. Ignore.
return;
}
for (auto* call_expr : func->CallSites()) {
callsites.add(call_expr);
for (auto* call : func->CallSites()) {
callsites.add(call->Declaration());
}
// Need to evaluate each function call with the group of
// expressions, so move on to the next expression.

View File

@ -1070,8 +1070,8 @@ const sem::Intrinsic* Impl::Match(sem::IntrinsicType intrinsic_type,
ast::StorageClass::kNone, ast::Access::kUndefined, p.usage));
}
return builder.create<sem::Intrinsic>(
intrinsic.type, intrinsic.return_type,
std::move(params), intrinsic.supported_stages, intrinsic.is_deprecated);
intrinsic.type, intrinsic.return_type, std::move(params),
intrinsic.supported_stages, intrinsic.is_deprecated);
});
}

View File

@ -289,8 +289,9 @@ TEST_F(ResolverArrayAccessorTest, EXpr_Deref_FuncBadParent) {
Func("func", {p}, ty.f32(), {Decl(idx), Decl(x), Return(x)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: cannot index type 'ptr<function, vec4<f32>>'");
EXPECT_EQ(
r()->error(),
"12:34 error: cannot index type 'ptr<function, vec4<f32>, read_write>'");
}
TEST_F(ResolverArrayAccessorTest, Exr_Deref_BadParent) {

View File

@ -102,7 +102,7 @@ TEST_F(ResolverAssignmentValidationTest,
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: cannot assign 'array<f32, len>' to 'array<f32, 4>'");
"12:34 error: cannot assign 'array<f32, 5>' to 'array<f32, 4>'");
}
TEST_F(ResolverAssignmentValidationTest,
@ -332,7 +332,7 @@ TEST_F(ResolverAssignmentValidationTest, AssignToPhony_DynamicArray_Fail) {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: cannot assign 'ref<storage, array<i32>, read>' to '_'. "
"12:34 error: cannot assign 'array<i32>' to '_'. "
"'_' can only be assigned a constructible, pointer, texture or sampler "
"type");
}

View File

@ -43,7 +43,7 @@ TEST_F(ResolverCompoundStatementTest, FunctionBlock) {
ASSERT_TRUE(s->Block()->Is<sem::FunctionBlockStatement>());
EXPECT_EQ(s->Block(), s->FindFirstParent<sem::BlockStatement>());
EXPECT_EQ(s->Block(), s->FindFirstParent<sem::FunctionBlockStatement>());
EXPECT_EQ(s->Block()->As<sem::FunctionBlockStatement>()->Function(), f);
EXPECT_EQ(s->Function()->Declaration(), f);
EXPECT_EQ(s->Block()->Parent(), nullptr);
}
@ -74,8 +74,7 @@ TEST_F(ResolverCompoundStatementTest, Block) {
EXPECT_EQ(s->Block()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>());
ASSERT_TRUE(s->Block()->Parent()->Is<sem::FunctionBlockStatement>());
EXPECT_EQ(
s->Block()->Parent()->As<sem::FunctionBlockStatement>()->Function(), f);
EXPECT_EQ(s->Function()->Declaration(), f);
EXPECT_EQ(s->Block()->Parent()->Parent(), nullptr);
}
}
@ -118,7 +117,7 @@ TEST_F(ResolverCompoundStatementTest, Loop) {
EXPECT_TRUE(
Is<sem::FunctionBlockStatement>(s->Parent()->Parent()->Parent()));
EXPECT_EQ(s->FindFirstParent<sem::FunctionBlockStatement>()->Function(), f);
EXPECT_EQ(s->Function()->Declaration(), f);
EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(), nullptr);
}
@ -144,7 +143,7 @@ TEST_F(ResolverCompoundStatementTest, Loop) {
s->FindFirstParent<sem::FunctionBlockStatement>());
EXPECT_TRUE(Is<sem::FunctionBlockStatement>(
s->Parent()->Parent()->Parent()->Parent()));
EXPECT_EQ(s->FindFirstParent<sem::FunctionBlockStatement>()->Function(), f);
EXPECT_EQ(s->Function()->Declaration(), f);
EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent()->Parent(), nullptr);
}
@ -213,12 +212,7 @@ TEST_F(ResolverCompoundStatementTest, ForLoop) {
Is<sem::FunctionBlockStatement>(s->Block()->Parent()->Parent()));
EXPECT_EQ(s->Block()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>());
EXPECT_EQ(s->Block()
->Parent()
->Parent()
->As<sem::FunctionBlockStatement>()
->Function(),
f);
EXPECT_EQ(s->Function()->Declaration(), f);
EXPECT_EQ(s->Block()->Parent()->Parent()->Parent(), nullptr);
}
}

View File

@ -388,7 +388,7 @@ TEST_F(ResolverFunctionValidationTest,
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: return statement type must match its function return "
"type, returned 'u32', expected 'myf32'");
"type, returned 'u32', expected 'f32'");
}
TEST_F(ResolverFunctionValidationTest, CannotCallEntryPoint) {

View File

@ -98,11 +98,16 @@ TEST_F(ResolverPtrRefTest, DefaultPtrStorageClass) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_TRUE(TypeOf(function_ptr)->Is<sem::Pointer>());
ASSERT_TRUE(TypeOf(private_ptr)->Is<sem::Pointer>());
ASSERT_TRUE(TypeOf(workgroup_ptr)->Is<sem::Pointer>());
ASSERT_TRUE(TypeOf(uniform_ptr)->Is<sem::Pointer>());
ASSERT_TRUE(TypeOf(storage_ptr)->Is<sem::Pointer>());
ASSERT_TRUE(TypeOf(function_ptr)->Is<sem::Pointer>())
<< "function_ptr is " << TypeOf(function_ptr)->TypeInfo().name;
ASSERT_TRUE(TypeOf(private_ptr)->Is<sem::Pointer>())
<< "private_ptr is " << TypeOf(private_ptr)->TypeInfo().name;
ASSERT_TRUE(TypeOf(workgroup_ptr)->Is<sem::Pointer>())
<< "workgroup_ptr is " << TypeOf(workgroup_ptr)->TypeInfo().name;
ASSERT_TRUE(TypeOf(uniform_ptr)->Is<sem::Pointer>())
<< "uniform_ptr is " << TypeOf(uniform_ptr)->TypeInfo().name;
ASSERT_TRUE(TypeOf(storage_ptr)->Is<sem::Pointer>())
<< "storage_ptr is " << TypeOf(storage_ptr)->TypeInfo().name;
EXPECT_EQ(TypeOf(function_ptr)->As<sem::Pointer>()->Access(),
ast::Access::kReadWrite);

View File

@ -167,7 +167,7 @@ TEST_F(ResolverPtrRefValidationTest, InferredPtrAccessMismatch) {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: cannot initialize let of type "
"'ptr<storage, i32>' with value of type "
"'ptr<storage, i32, read>' with value of type "
"'ptr<storage, i32, read_write>'");
}

File diff suppressed because it is too large Load Diff

View File

@ -95,79 +95,9 @@ class Resolver {
/// Describes the context in which a variable is declared
enum class VariableKind { kParameter, kLocal, kGlobal };
/// Structure holding semantic information about a variable.
/// Used to build the sem::Variable nodes at the end of resolving.
struct VariableInfo {
VariableInfo(const ast::Variable* decl,
sem::Type* type,
const std::string& type_name,
ast::StorageClass storage_class,
ast::Access ac,
VariableKind k,
uint32_t idx);
~VariableInfo();
ast::Variable const* const declaration;
sem::Type* type;
std::string const type_name;
ast::StorageClass storage_class;
ast::Access const access;
std::vector<const ast::IdentifierExpression*> users;
sem::BindingPoint binding_point;
VariableKind kind;
uint32_t index = 0; // Parameter index, if kind == kParameter
uint16_t constant_id = 0;
};
struct IntrinsicCallInfo {
const ast::CallExpression* call;
const sem::Intrinsic* intrinsic;
};
std::set<std::pair<const sem::Struct*, ast::StorageClass>>
valid_struct_storage_layouts_;
/// Structure holding semantic information about a function.
/// Used to build the sem::Function nodes at the end of resolving.
struct FunctionInfo {
explicit FunctionInfo(const ast::Function* decl);
~FunctionInfo();
const ast::Function* const declaration;
std::vector<VariableInfo*> parameters;
utils::UniqueVector<VariableInfo*> referenced_module_vars;
utils::UniqueVector<VariableInfo*> local_referenced_module_vars;
std::vector<const ast::ReturnStatement*> return_statements;
std::vector<const ast::CallExpression*> callsites;
sem::Type* return_type = nullptr;
std::string return_type_name;
std::array<sem::WorkgroupDimension, 3> workgroup_size;
std::vector<IntrinsicCallInfo> intrinsic_calls;
// List of transitive calls this function makes
utils::UniqueVector<FunctionInfo*> transitive_calls;
// List of entry point functions that transitively call this function
utils::UniqueVector<FunctionInfo*> ancestor_entry_points;
};
/// Structure holding semantic information about an expression.
/// Used to build the sem::Expression nodes at the end of resolving.
struct ExpressionInfo {
sem::Type const* type;
std::string const type_name; // Declared type name
sem::Statement* statement;
sem::Constant constant_value;
};
/// Structure holding semantic information about a call expression to an
/// ast::Function.
/// Used to build the sem::Call nodes at the end of resolving.
struct FunctionCallInfo {
FunctionInfo* function;
sem::Statement* statement;
};
/// Structure holding semantic information about a block (i.e. scope), such as
/// parent block and variables declared in the block.
/// Used to validate variable scoping rules.
@ -231,35 +161,40 @@ class Resolver {
const ast::ExpressionList& params,
uint32_t* id);
void set_referenced_from_function_if_needed(VariableInfo* var, bool local);
//////////////////////////////////////////////////////////////////////////////
// AST and Type traversal methods
//////////////////////////////////////////////////////////////////////////////
// Expression resolving methods
// Returns the semantic node pointer on success, nullptr on failure.
sem::Expression* ArrayAccessor(const ast::ArrayAccessorExpression*);
sem::Expression* Binary(const ast::BinaryExpression*);
sem::Expression* Bitcast(const ast::BitcastExpression*);
sem::Expression* Call(const ast::CallExpression*);
sem::Expression* Constructor(const ast::ConstructorExpression*);
sem::Expression* Expression(const ast::Expression*);
sem::Function* Function(const ast::Function*);
sem::Call* FunctionCall(const ast::CallExpression*);
sem::Expression* Identifier(const ast::IdentifierExpression*);
sem::Call* IntrinsicCall(const ast::CallExpression*, sem::IntrinsicType);
sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*);
sem::Expression* UnaryOp(const ast::UnaryOpExpression*);
// Statement resolving methods
// Each return true on success, false on failure.
bool ArrayAccessor(const ast::ArrayAccessorExpression*);
bool Assignment(const ast::AssignmentStatement* a);
bool Binary(const ast::BinaryExpression*);
bool Bitcast(const ast::BitcastExpression*);
bool BlockStatement(const ast::BlockStatement*);
bool Call(const ast::CallExpression*);
bool CaseStatement(const ast::CaseStatement*);
bool Constructor(const ast::ConstructorExpression*);
bool ElseStatement(const ast::ElseStatement*);
bool Expression(const ast::Expression*);
bool ForLoopStatement(const ast::ForLoopStatement*);
bool Function(const ast::Function*);
bool FunctionCall(const ast::CallExpression* call);
bool GlobalVariable(const ast::Variable* var);
bool Identifier(const ast::IdentifierExpression*);
bool IfStatement(const ast::IfStatement*);
bool IntrinsicCall(const ast::CallExpression*, sem::IntrinsicType);
bool LoopStatement(const ast::LoopStatement*);
bool MemberAccessor(const ast::MemberAccessorExpression*);
bool Parameter(const ast::Variable* param);
bool GlobalVariable(const ast::Variable* var);
bool IfStatement(const ast::IfStatement*);
bool LoopStatement(const ast::LoopStatement*);
bool Return(const ast::ReturnStatement* ret);
bool Statement(const ast::Statement*);
bool Statements(const ast::StatementList&);
bool SwitchStatement(const ast::SwitchStatement* s);
bool UnaryOp(const ast::UnaryOpExpression*);
bool VariableDeclStatement(const ast::VariableDeclStatement*);
// AST and Type validation methods
@ -270,18 +205,16 @@ class Resolver {
uint32_t el_align,
const Source& source);
bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
bool ValidateAtomicVariable(const VariableInfo* info);
bool ValidateAtomicVariable(const sem::Variable* var);
bool ValidateAssignment(const ast::AssignmentStatement* a);
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type,
const bool is_input);
bool ValidateCall(const ast::CallExpression* call);
bool ValidateCallStatement(const ast::CallStatement* stmt);
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunctionCall(const ast::CallExpression* call,
const FunctionInfo* target);
bool ValidateGlobalVariable(const VariableInfo* var);
bool ValidateCall(const sem::Call* call);
bool ValidateEntryPoint(const sem::Function* func);
bool ValidateFunction(const sem::Function* func);
bool ValidateFunctionCall(const sem::Call* call);
bool ValidateGlobalVariable(const sem::Variable* var);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type);
bool ValidateLocationDecoration(const ast::LocationDecoration* location,
@ -291,11 +224,11 @@ class Resolver {
const bool is_input = false);
bool ValidateMatrix(const sem::Matrix* ty, const Source& source);
bool ValidateFunctionParameter(const ast::Function* func,
const VariableInfo* info);
const sem::Variable* var);
bool ValidateNoDuplicateDefinition(Symbol sym,
const Source& source,
bool check_global_scope_only = false);
bool ValidateParameter(const ast::Function* func, const VariableInfo* info);
bool ValidateParameter(const ast::Function* func, const sem::Variable* var);
bool ValidateReturn(const ast::ReturnStatement* ret);
bool ValidateStatements(const ast::StatementList& stmts);
bool ValidateStorageTexture(const ast::StorageTexture* t);
@ -303,33 +236,30 @@ class Resolver {
bool ValidateStructureConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Struct* struct_type);
bool ValidateSwitch(const ast::SwitchStatement* s);
bool ValidateVariable(const VariableInfo* info);
bool ValidateVariable(const sem::Variable* var);
bool ValidateVariableConstructor(const ast::Variable* var,
ast::StorageClass storage_class,
const sem::Type* storage_type,
const std::string& type_name,
const sem::Type* rhs_type,
const std::string& rhs_type_name);
const sem::Type* rhs_type);
bool ValidateVector(const sem::Vector* ty, const Source& source);
bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Vector* vec_type,
const std::string& type_name);
const sem::Vector* vec_type);
bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Matrix* matrix_type,
const std::string& type_name);
const sem::Matrix* matrix_type);
bool ValidateScalarConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Type* type,
const std::string& type_name);
const sem::Type* type);
bool ValidateArrayConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Array* arr_type);
bool ValidateTypeDecl(const ast::TypeDecl* named_type) const;
bool ValidateTextureIntrinsicFunction(const ast::CallExpression* ast_call,
const sem::Call* sem_call);
bool ValidateTextureIntrinsicFunction(const sem::Call* call);
bool ValidateNoDuplicateDecorations(const ast::DecorationList& decorations);
// sem::Struct is assumed to have at least one member
bool ValidateStorageClassLayout(const sem::Struct* type,
ast::StorageClass sc);
bool ValidateStorageClassLayout(const VariableInfo* info);
bool ValidateStorageClassLayout(const sem::Variable* var);
/// Resolves the WorkgroupSize for the given function
bool WorkgroupSizeFor(const ast::Function*, sem::WorkgroupSize& ws);
/// @returns the sem::Type for the ast::Type `ty`, building it if it
/// hasn't been constructed already. If an error is raised, nullptr is
@ -355,16 +285,16 @@ class Resolver {
/// raised. raised, nullptr is returned.
sem::Struct* Structure(const ast::Struct* str);
/// @returns the VariableInfo for the variable `var`, building it if it hasn't
/// been constructed already. If an error is raised, nullptr is returned.
/// @returns the semantic info for the variable `var`. If an error is raised,
/// nullptr is returned.
/// @note this method does not resolve the decorations as these are
/// context-dependent (global, local, parameter)
/// @param var the variable to create or return the `VariableInfo` for
/// @param kind what kind of variable we are declaring
/// @param index the index of the parameter, if this variable is a parameter
VariableInfo* Variable(const ast::Variable* var,
VariableKind kind,
uint32_t index = 0);
sem::Variable* Variable(const ast::Variable* var,
VariableKind kind,
uint32_t index = 0);
/// Records the storage class usage for the given type, and any transient
/// dependencies of the type. Validates that the type can be used for the
@ -389,23 +319,17 @@ class Resolver {
/// @param expr the expression
sem::Type* TypeOf(const ast::Expression* expr);
/// @returns the declared type name of the ast::Expression `expr`
/// @param expr the type name
std::string TypeNameOf(const ast::Expression* expr);
/// @returns the type name of the given semantic type, unwrapping references.
std::string TypeNameOf(const sem::Type* ty);
/// @returns the type name of the given semantic type, without unwrapping
/// references.
std::string RawTypeNameOf(const sem::Type* ty);
/// @returns the semantic type of the AST literal `lit`
/// @param lit the literal
sem::Type* TypeOf(const ast::Literal* lit);
/// Records the semantic information for the expression node with the resolved
/// type `type` and optional declared type name `type_name`.
/// @param expr the expression
/// @param type the resolved type
/// @param type_name the declared type name
void SetExprInfo(const ast::Expression* expr,
const sem::Type* type,
std::string type_name = "");
/// Assigns `stmt` to #current_statement_, #current_compound_statement_, and
/// possibly #current_block_, pushes the variable scope, then calls
/// `callback`. Before returning #current_statement_,
@ -437,16 +361,13 @@ class Resolver {
void AddNote(const std::string& msg, const Source& source) const;
template <typename CALLBACK>
void TraverseCallChain(FunctionInfo* from,
FunctionInfo* to,
void TraverseCallChain(const sem::Function* from,
const sem::Function* to,
CALLBACK&& callback) const;
//////////////////////////////////////////////////////////////////////////////
/// Constant value evaluation methods
//////////////////////////////////////////////////////////////////////////////
/// @return the Constant value of the given Expression
sem::Constant ConstantValueOf(const ast::Expression* expr);
/// Cast `Value` to `target_type`
/// @return the casted value
sem::Constant ConstantCast(const sem::Constant& value,
@ -461,29 +382,27 @@ class Resolver {
const ast::TypeConstructorExpression* type_ctor,
const sem::Type* type);
/// Sem is a helper for obtaining the semantic node for the given AST node.
template <typename SEM = sem::Info::InferFromAST,
typename AST_OR_TYPE = CastableBase>
const sem::Info::GetResultType<SEM, AST_OR_TYPE>* Sem(const AST_OR_TYPE* ast);
ProgramBuilder* const builder_;
diag::List& diagnostics_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_;
ScopeStack<VariableInfo*> variable_stack_;
std::unordered_map<Symbol, FunctionInfo*> symbol_to_function_;
std::vector<FunctionInfo*> entry_points_;
ScopeStack<sem::Variable*> variable_stack_;
std::unordered_map<Symbol, sem::Function*> symbol_to_function_;
std::vector<sem::Function*> entry_points_;
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_;
std::unordered_map<const ast::Function*, FunctionInfo*> function_to_info_;
std::unordered_map<const ast::Variable*, VariableInfo*> variable_to_info_;
std::unordered_map<const ast::CallExpression*, FunctionCallInfo>
function_calls_;
std::unordered_map<const ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<Symbol, TypeDeclInfo> named_type_info_;
std::unordered_set<const ast::Node*> marked_;
std::unordered_map<uint32_t, const VariableInfo*> constant_ids_;
std::unordered_map<uint32_t, const sem::Variable*> constant_ids_;
FunctionInfo* current_function_ = nullptr;
sem::Function* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr;
sem::CompoundStatement* current_compound_statement_ = nullptr;
sem::BlockStatement* current_block_ = nullptr;
BlockAllocator<VariableInfo> variable_infos_;
BlockAllocator<FunctionInfo> function_infos_;
};
} // namespace resolver

View File

@ -15,6 +15,7 @@
#include "src/resolver/resolver.h"
#include "src/sem/constant.h"
#include "src/utils/get_or_create.h"
namespace tint {
namespace resolver {
@ -26,46 +27,6 @@ using f32 = ProgramBuilder::f32;
} // namespace
sem::Constant Resolver::ConstantCast(const sem::Constant& value,
const sem::Type* target_elem_type) {
if (value.ElementType() == target_elem_type) {
return value;
}
sem::Constant::Scalars elems;
for (size_t i = 0; i < value.Elements().size(); ++i) {
if (target_elem_type->Is<sem::I32>()) {
elems.emplace_back(
value.WithScalarAt(i, [](auto&& s) { return static_cast<i32>(s); }));
} else if (target_elem_type->Is<sem::U32>()) {
elems.emplace_back(
value.WithScalarAt(i, [](auto&& s) { return static_cast<u32>(s); }));
} else if (target_elem_type->Is<sem::F32>()) {
elems.emplace_back(
value.WithScalarAt(i, [](auto&& s) { return static_cast<f32>(s); }));
} else if (target_elem_type->Is<sem::Bool>()) {
elems.emplace_back(
value.WithScalarAt(i, [](auto&& s) { return static_cast<bool>(s); }));
}
}
auto* target_type =
value.Type()->Is<sem::Vector>()
? builder_->create<sem::Vector>(target_elem_type,
static_cast<uint32_t>(elems.size()))
: target_elem_type;
return sem::Constant(target_type, elems);
}
sem::Constant Resolver::ConstantValueOf(const ast::Expression* expr) {
auto it = expr_info_.find(expr);
if (it != expr_info_.end()) {
return it->second.constant_value;
}
return {};
}
sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr,
const sem::Type* type) {
if (auto* e = expr->As<ast::ScalarConstructorExpression>()) {
@ -131,11 +92,11 @@ sem::Constant Resolver::EvaluateConstantValue(
// type_ctor's type.
sem::Constant::Scalars elems;
for (auto* cv : ctor_values) {
auto value = ConstantValueOf(cv);
if (!value.IsValid()) {
auto* expr = builder_->Sem().Get(cv);
if (!expr || !expr->ConstantValue()) {
return {};
}
auto cast = ConstantCast(value, elem_type);
auto cast = ConstantCast(expr->ConstantValue(), elem_type);
elems.insert(elems.end(), cast.Elements().begin(), cast.Elements().end());
}
@ -149,5 +110,37 @@ sem::Constant Resolver::EvaluateConstantValue(
return sem::Constant(type, std::move(elems));
}
sem::Constant Resolver::ConstantCast(const sem::Constant& value,
const sem::Type* target_elem_type) {
if (value.ElementType() == target_elem_type) {
return value;
}
sem::Constant::Scalars elems;
for (size_t i = 0; i < value.Elements().size(); ++i) {
if (target_elem_type->Is<sem::I32>()) {
elems.emplace_back(
value.WithScalarAt(i, [](auto&& s) { return static_cast<i32>(s); }));
} else if (target_elem_type->Is<sem::U32>()) {
elems.emplace_back(
value.WithScalarAt(i, [](auto&& s) { return static_cast<u32>(s); }));
} else if (target_elem_type->Is<sem::F32>()) {
elems.emplace_back(
value.WithScalarAt(i, [](auto&& s) { return static_cast<f32>(s); }));
} else if (target_elem_type->Is<sem::Bool>()) {
elems.emplace_back(
value.WithScalarAt(i, [](auto&& s) { return static_cast<bool>(s); }));
}
}
auto* target_type =
value.Type()->Is<sem::Vector>()
? builder_->create<sem::Vector>(target_elem_type,
static_cast<uint32_t>(elems.size()))
: target_elem_type;
return sem::Constant(target_type, elems);
}
} // namespace resolver
} // namespace tint

View File

@ -922,8 +922,8 @@ TEST_F(ResolverTest, Function_CallSites) {
auto* foo_sem = Sem().Get(foo);
ASSERT_NE(foo_sem, nullptr);
ASSERT_EQ(foo_sem->CallSites().size(), 2u);
EXPECT_EQ(foo_sem->CallSites()[0], call_1);
EXPECT_EQ(foo_sem->CallSites()[1], call_2);
EXPECT_EQ(foo_sem->CallSites()[0]->Declaration(), call_1);
EXPECT_EQ(foo_sem->CallSites()[1]->Declaration(), call_2);
auto* bar_sem = Sem().Get(bar);
ASSERT_NE(bar_sem, nullptr);
@ -1908,17 +1908,17 @@ TEST_F(ResolverTest, Function_EntryPoints_StageDecoration) {
const auto& b_eps = func_b_sem->AncestorEntryPoints();
ASSERT_EQ(2u, b_eps.size());
EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]);
EXPECT_EQ(Symbols().Register("ep_2"), b_eps[1]);
EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]->Declaration()->symbol);
EXPECT_EQ(Symbols().Register("ep_2"), b_eps[1]->Declaration()->symbol);
const auto& a_eps = func_a_sem->AncestorEntryPoints();
ASSERT_EQ(1u, a_eps.size());
EXPECT_EQ(Symbols().Register("ep_1"), a_eps[0]);
EXPECT_EQ(Symbols().Register("ep_1"), a_eps[0]->Declaration()->symbol);
const auto& c_eps = func_c_sem->AncestorEntryPoints();
ASSERT_EQ(2u, c_eps.size());
EXPECT_EQ(Symbols().Register("ep_1"), c_eps[0]);
EXPECT_EQ(Symbols().Register("ep_2"), c_eps[1]);
EXPECT_EQ(Symbols().Register("ep_1"), c_eps[0]->Declaration()->symbol);
EXPECT_EQ(Symbols().Register("ep_2"), c_eps[1]->Declaration()->symbol);
EXPECT_TRUE(ep_1_sem->AncestorEntryPoints().empty());
EXPECT_TRUE(ep_2_sem->AncestorEntryPoints().empty());

View File

@ -179,11 +179,11 @@ TEST_F(ResolverStorageClassLayoutValidationTest,
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(56:78 error: the offset of a struct member of type 'Inner' in storage class 'uniform' must be a multiple of 16 bytes, but 'inner' is currently at offset 4. Consider setting [[align(16)]] on this member
R"(56:78 error: the offset of a struct member of type '[[stride(16)]] array<f32, 10>' in storage class 'uniform' must be a multiple of 16 bytes, but 'inner' is currently at offset 4. Consider setting [[align(16)]] on this member
12:34 note: see layout of struct:
/* align(4) size(164) */ struct Outer {
/* offset( 0) align(4) size( 4) */ scalar : f32;
/* offset( 4) align(4) size(160) */ inner : Inner;
/* offset( 4) align(4) size(160) */ inner : [[stride(16)]] array<f32, 10>;
/* */ };
78:90 note: see declaration of variable)");
}
@ -351,7 +351,7 @@ TEST_F(ResolverStorageClassLayoutValidationTest,
R"(34:56 error: uniform storage requires that array elements be aligned to 16 bytes, but array stride of 'inner' is currently 8. Consider setting [[stride(16)]] on the array type
12:34 note: see layout of struct:
/* align(4) size(84) */ struct Outer {
/* offset( 0) align(4) size(80) */ inner : Inner;
/* offset( 0) align(4) size(80) */ inner : [[stride(8)]] array<f32, 10>;
/* offset(80) align(4) size( 4) */ scalar : i32;
/* */ };
78:90 note: see declaration of variable)");

View File

@ -96,7 +96,8 @@ TEST_F(ResolverStorageClassValidationTest, StorageBufferBoolAlias) {
EXPECT_EQ(
r()->error(),
R"(56:78 error: variables declared in the <storage> storage class must be of a structure type)");
R"(56:78 error: Type 'bool' cannot be used in storage class 'storage' as it is non-host-shareable
56:78 note: while instantiating variable g)");
}
TEST_F(ResolverStorageClassValidationTest, NotStorage_AccessMode) {
@ -194,7 +195,8 @@ TEST_F(ResolverStorageClassValidationTest, UniformBufferBool) {
EXPECT_EQ(
r()->error(),
R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
R"(56:78 error: Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable
56:78 note: while instantiating variable g)");
}
TEST_F(ResolverStorageClassValidationTest, UniformBufferPointer) {
@ -243,7 +245,8 @@ TEST_F(ResolverStorageClassValidationTest, UniformBufferBoolAlias) {
EXPECT_EQ(
r()->error(),
R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
R"(56:78 error: Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable
56:78 note: while instantiating variable g)");
}
TEST_F(ResolverStorageClassValidationTest, UniformBufferNoBlockDecoration) {

View File

@ -1553,7 +1553,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: type in vector constructor does not match vector "
"type: expected 'f32', found 'UnsignedInt'");
"type: expected 'f32', found 'u32'");
}
TEST_F(ResolverTypeConstructorValidationTest,
@ -1638,10 +1638,9 @@ struct MatrixDimensions {
uint32_t columns;
};
static std::string MatrixStr(const MatrixDimensions& dimensions,
std::string subtype = "f32") {
static std::string MatrixStr(const MatrixDimensions& dimensions) {
return "mat" + std::to_string(dimensions.columns) + "x" +
std::to_string(dimensions.rows) + "<" + subtype + ">";
std::to_string(dimensions.rows) + "<f32>";
}
using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
@ -1919,9 +1918,9 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param, "Float32") +
"\n\n3 candidates available:"));
EXPECT_THAT(r()->error(),
HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param) + "\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {

View File

@ -499,9 +499,10 @@ TEST_F(ResolverValidationTest, EXpr_MemberAccessor_FuncBadParent) {
Func("func", {p}, ty.f32(), {Decl(x), Return(x)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"error: invalid member accessor expression. Expected vector or "
"struct, got 'ptr<function, vec4<f32>>'");
EXPECT_EQ(
r()->error(),
"error: invalid member accessor expression. "
"Expected vector or struct, got 'ptr<function, vec4<f32>, read_write>'");
}
TEST_F(ResolverValidationTest,

View File

@ -120,7 +120,7 @@ TEST_F(ResolverVarLetValidationTest, LetConstructorWrongTypeViaAlias) {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(3:3 error: cannot initialize let of type 'I32' with value of type 'u32')");
R"(3:3 error: cannot initialize let of type 'i32' with value of type 'u32')");
}
TEST_F(ResolverVarLetValidationTest, VarConstructorWrongTypeViaAlias) {
@ -131,7 +131,7 @@ TEST_F(ResolverVarLetValidationTest, VarConstructorWrongTypeViaAlias) {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(3:3 error: cannot initialize var of type 'I32' with value of type 'u32')");
R"(3:3 error: cannot initialize var of type 'i32' with value of type 'u32')");
}
TEST_F(ResolverVarLetValidationTest, LetOfPtrConstructedWithRef) {
@ -147,7 +147,7 @@ TEST_F(ResolverVarLetValidationTest, LetOfPtrConstructedWithRef) {
EXPECT_EQ(
r()->error(),
R"(12:34 error: cannot initialize let of type 'ptr<function, f32>' with value of type 'f32')");
R"(12:34 error: cannot initialize let of type 'ptr<function, f32, read_write>' with value of type 'f32')");
}
TEST_F(ResolverVarLetValidationTest, LocalVarRedeclared) {

View File

@ -26,8 +26,9 @@ namespace tint {
namespace sem {
BlockStatement::BlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {}
BlockStatement::~BlockStatement() = default;
@ -39,15 +40,19 @@ void BlockStatement::AddDecl(const ast::Variable* var) {
decls_.push_back(var);
}
FunctionBlockStatement::FunctionBlockStatement(const ast::Function* function)
: Base(function->body, nullptr), function_(function) {}
FunctionBlockStatement::FunctionBlockStatement(const sem::Function* function)
: Base(function->Declaration()->body, nullptr, function) {
TINT_ASSERT(Semantic, function);
}
FunctionBlockStatement::~FunctionBlockStatement() = default;
LoopBlockStatement::LoopBlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent)
: Base(declaration, parent) {
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
LoopBlockStatement::~LoopBlockStatement() = default;

View File

@ -40,8 +40,10 @@ class BlockStatement : public Castable<BlockStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this block statement
/// @param parent the owning statement
/// @param function the owning function
BlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~BlockStatement() override;
@ -67,16 +69,10 @@ class FunctionBlockStatement
public:
/// Constructor
/// @param function the owning function
explicit FunctionBlockStatement(const ast::Function* function);
explicit FunctionBlockStatement(const sem::Function* function);
/// Destructor
~FunctionBlockStatement() override;
/// @returns the function owning this block
const ast::Function* Function() const { return function_; }
private:
ast::Function const* const function_;
};
/// Holds semantic information about a loop body block or for-loop body block
@ -85,8 +81,10 @@ class LoopBlockStatement : public Castable<LoopBlockStatement, BlockStatement> {
/// Constructor
/// @param declaration the AST node for this block statement
/// @param parent the owning statement
/// @param function the owning function
LoopBlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~LoopBlockStatement() override;

View File

@ -14,16 +14,21 @@
#include "src/sem/call.h"
#include <utility>
#include <vector>
TINT_INSTANTIATE_TYPEINFO(tint::sem::Call);
namespace tint {
namespace sem {
Call::Call(const ast::Expression* declaration,
Call::Call(const ast::CallExpression* declaration,
const CallTarget* target,
std::vector<const sem::Expression*> arguments,
Statement* statement)
: Base(declaration, target->ReturnType(), statement, Constant{}),
target_(target) {}
target_(target),
arguments_(std::move(arguments)) {}
Call::~Call() = default;

View File

@ -15,6 +15,8 @@
#ifndef SRC_SEM_CALL_H_
#define SRC_SEM_CALL_H_
#include <vector>
#include "src/sem/expression.h"
#include "src/sem/intrinsic.h"
@ -28,9 +30,11 @@ class Call : public Castable<Call, Expression> {
/// Constructor
/// @param declaration the AST node
/// @param target the call target
/// @param arguments the call arguments
/// @param statement the statement that owns this expression
Call(const ast::Expression* declaration,
Call(const ast::CallExpression* declaration,
const CallTarget* target,
std::vector<const sem::Expression*> arguments,
Statement* statement);
/// Destructor
@ -39,8 +43,19 @@ class Call : public Castable<Call, Expression> {
/// @return the target of the call
const CallTarget* Target() const { return target_; }
/// @return the call arguments
const std::vector<const sem::Expression*>& Arguments() const {
return arguments_;
}
/// @returns the AST node
const ast::CallExpression* Declaration() const {
return static_cast<const ast::CallExpression*>(declaration_);
}
private:
CallTarget const* const target_;
std::vector<const sem::Expression*> arguments_;
};
} // namespace sem

View File

@ -59,5 +59,7 @@ Constant::Constant(const Constant&) = default;
Constant::~Constant() = default;
Constant& Constant::operator=(const Constant& rhs) = default;
} // namespace sem
} // namespace tint

View File

@ -77,6 +77,11 @@ class Constant {
/// Destructor
~Constant();
/// Copy assignment
/// @param other the Constant to copy
/// @returns this Constant
Constant& operator=(const Constant& other);
/// @returns true if the Constant has been initialized
bool IsValid() const { return type_ != nullptr; }

View File

@ -53,8 +53,11 @@ class Expression : public Castable<Expression, Node> {
/// @returns the AST node
const ast::Expression* Declaration() const { return declaration_; }
private:
protected:
/// The AST expression node for this semantic expression
const ast::Expression* const declaration_;
private:
const sem::Type* const type_;
const Statement* const statement_;
const Constant constant_;

View File

@ -22,8 +22,9 @@ namespace tint {
namespace sem {
ForLoopStatement::ForLoopStatement(const ast::ForLoopStatement* declaration,
CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {}
ForLoopStatement::~ForLoopStatement() = default;

View File

@ -32,8 +32,10 @@ class ForLoopStatement : public Castable<ForLoopStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this for-loop statement
/// @param parent the owning statement
/// @param function the owning function
ForLoopStatement(const ast::ForLoopStatement* declaration,
CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~ForLoopStatement() override;

View File

@ -28,23 +28,13 @@ TINT_INSTANTIATE_TYPEINFO(tint::sem::Function);
namespace tint {
namespace sem {
Function::Function(
const ast::Function* declaration,
Type* return_type,
std::vector<Parameter*> parameters,
std::vector<const GlobalVariable*> transitively_referenced_globals,
std::vector<const GlobalVariable*> directly_referenced_globals,
std::vector<const ast::CallExpression*> callsites,
std::vector<Symbol> ancestor_entry_points,
sem::WorkgroupSize workgroup_size)
Function::Function(const ast::Function* declaration,
Type* return_type,
std::vector<Parameter*> parameters,
sem::WorkgroupSize workgroup_size)
: Base(return_type, utils::ToConstPtrVec(parameters)),
declaration_(declaration),
workgroup_size_(std::move(workgroup_size)),
directly_referenced_globals_(std::move(directly_referenced_globals)),
transitively_referenced_globals_(
std::move(transitively_referenced_globals)),
callsites_(callsites),
ancestor_entry_points_(std::move(ancestor_entry_points)) {
workgroup_size_(std::move(workgroup_size)) {
for (auto* parameter : parameters) {
parameter->SetOwner(this);
}
@ -150,8 +140,8 @@ Function::VariableBindings Function::TransitivelyReferencedVariablesOfType(
}
bool Function::HasAncestorEntryPoint(Symbol symbol) const {
for (const auto& point : ancestor_entry_points_) {
if (point == symbol) {
for (const auto* point : ancestor_entry_points_) {
if (point->Declaration()->symbol == symbol) {
return true;
}
}

View File

@ -62,18 +62,10 @@ class Function : public Castable<Function, CallTarget> {
/// @param declaration the ast::Function
/// @param return_type the return type of the function
/// @param parameters the parameters to the function
/// @param transitively_referenced_globals the referenced module variables
/// @param directly_referenced_globals the locally referenced module
/// @param callsites the callsites of the function
/// @param ancestor_entry_points the ancestor entry points
/// @param workgroup_size the workgroup size
Function(const ast::Function* declaration,
Type* return_type,
std::vector<Parameter*> parameters,
std::vector<const GlobalVariable*> transitively_referenced_globals,
std::vector<const GlobalVariable*> directly_referenced_globals,
std::vector<const ast::CallExpression*> callsites,
std::vector<Symbol> ancestor_entry_points,
sem::WorkgroupSize workgroup_size);
/// Destructor
@ -85,22 +77,98 @@ class Function : public Castable<Function, CallTarget> {
/// @returns the workgroup size {x, y, z} for the function.
const sem::WorkgroupSize& WorkgroupSize() const { return workgroup_size_; }
/// @returns all directly referenced global variables
const utils::UniqueVector<const GlobalVariable*>& DirectlyReferencedGlobals()
const {
return directly_referenced_globals_;
}
/// Records that this function directly references the given global variable.
/// Note: Implicitly adds this global to the transtively-called globals.
/// @param global the module-scope variable
void AddDirectlyReferencedGlobal(const sem::GlobalVariable* global) {
directly_referenced_globals_.add(global);
transitively_referenced_globals_.add(global);
}
/// @returns all transitively referenced global variables
const utils::UniqueVector<const GlobalVariable*>&
TransitivelyReferencedGlobals() const {
return transitively_referenced_globals_;
}
/// @returns the list of callsites of this function
std::vector<const ast::CallExpression*> CallSites() const {
return callsites_;
/// Records that this function transitively references the given global
/// variable.
/// @param global the module-scoped variable
void AddTransitivelyReferencedGlobal(const sem::GlobalVariable* global) {
transitively_referenced_globals_.add(global);
}
/// @returns the names of the ancestor entry points
const std::vector<Symbol>& AncestorEntryPoints() const {
/// @returns the list of functions that this function transitively calls.
const utils::UniqueVector<const Function*>& TransitivelyCalledFunctions()
const {
return transitively_called_functions_;
}
/// Records that this function transitively calls `function`.
/// @param function the function this function transitively calls
void AddTransitivelyCalledFunction(const Function* function) {
transitively_called_functions_.add(function);
}
/// @returns the list of intrinsics that this function directly calls.
const utils::UniqueVector<const Intrinsic*>& DirectlyCalledIntrinsics()
const {
return directly_called_intrinsics_;
}
/// Records that this function transitively calls `intrinsic`.
/// @param intrinsic the intrinsic this function directly calls
void AddDirectlyCalledIntrinsic(const Intrinsic* intrinsic) {
directly_called_intrinsics_.add(intrinsic);
}
/// @returns the list of direct calls to functions / intrinsics made by this
/// function
std::vector<const Call*> DirectCallStatements() const {
return direct_calls_;
}
/// Adds a record of the direct function / intrinsic calls made by this
/// function
/// @param call the call
void AddDirectCall(const Call* call) { direct_calls_.emplace_back(call); }
/// @param target the target of a call
/// @returns the Call to the given CallTarget, or nullptr the target was not
/// called by this function.
const Call* FindDirectCallTo(const CallTarget* target) const {
for (auto* call : direct_calls_) {
if (call->Target() == target) {
return call;
}
}
return nullptr;
}
/// @returns the list of callsites of this function
std::vector<const Call*> CallSites() const { return callsites_; }
/// Adds a record of a callsite to this function
/// @param call the callsite
void AddCallSite(const Call* call) { callsites_.emplace_back(call); }
/// @returns the ancestor entry points
const std::vector<const Function*>& AncestorEntryPoints() const {
return ancestor_entry_points_;
}
/// Adds a record that the given entry point transitively calls this function
/// @param entry_point the entry point that transtively calls this function
void AddAncestorEntryPoint(const sem::Function* entry_point) {
ancestor_entry_points_.emplace_back(entry_point);
}
/// Retrieves any referenced location variables
/// @returns the <variable, decoration> pair.
std::vector<std::pair<const Variable*, const ast::LocationDecoration*>>
@ -174,8 +242,9 @@ class Function : public Castable<Function, CallTarget> {
utils::UniqueVector<const GlobalVariable*> transitively_referenced_globals_;
utils::UniqueVector<const Function*> transitively_called_functions_;
utils::UniqueVector<const Intrinsic*> directly_called_intrinsics_;
std::vector<const ast::CallExpression*> callsites_;
std::vector<Symbol> ancestor_entry_points_;
std::vector<const Call*> direct_calls_;
std::vector<const Call*> callsites_;
std::vector<const Function*> ancestor_entry_points_;
};
} // namespace sem

View File

@ -23,14 +23,16 @@ namespace tint {
namespace sem {
IfStatement::IfStatement(const ast::IfStatement* declaration,
CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {}
IfStatement::~IfStatement() = default;
ElseStatement::ElseStatement(const ast::ElseStatement* declaration,
CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {}
ElseStatement::~ElseStatement() = default;

View File

@ -34,7 +34,10 @@ class IfStatement : public Castable<IfStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this if statement
/// @param parent the owning statement
IfStatement(const ast::IfStatement* declaration, CompoundStatement* parent);
/// @param function the owning function
IfStatement(const ast::IfStatement* declaration,
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~IfStatement() override;
@ -46,8 +49,10 @@ class ElseStatement : public Castable<ElseStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this else statement
/// @param parent the owning statement
/// @param function the owning function
ElseStatement(const ast::ElseStatement* declaration,
CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~ElseStatement() override;

View File

@ -27,10 +27,18 @@ namespace sem {
/// Info holds all the resolved semantic information for a Program.
class Info {
public:
/// Placeholder type used by Get() to provide a default value for EXPLICIT_SEM
using InferFromAST = std::nullptr_t;
public:
/// Resolves to the return type of the Get() method given the desired sementic
/// type and AST type.
template <typename SEM, typename AST_OR_TYPE>
using GetResultType =
std::conditional_t<std::is_same<SEM, InferFromAST>::value,
SemanticNodeTypeFor<AST_OR_TYPE>,
SEM>;
/// Constructor
Info();
@ -50,10 +58,7 @@ class Info {
/// @returns a pointer to the semantic node if found, otherwise nullptr
template <typename SEM = InferFromAST,
typename AST_OR_TYPE = CastableBase,
typename RESULT =
std::conditional_t<std::is_same<SEM, InferFromAST>::value,
SemanticNodeTypeFor<AST_OR_TYPE>,
SEM>>
typename RESULT = GetResultType<SEM, AST_OR_TYPE>>
const RESULT* Get(const AST_OR_TYPE* node) const {
auto it = map.find(node);
if (it == map.end()) {

View File

@ -23,16 +23,22 @@ namespace tint {
namespace sem {
LoopStatement::LoopStatement(const ast::LoopStatement* declaration,
CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
LoopStatement::~LoopStatement() = default;
LoopContinuingBlockStatement::LoopContinuingBlockStatement(
const ast::BlockStatement* declaration,
const CompoundStatement* parent)
: Base(declaration, parent) {
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
LoopContinuingBlockStatement::~LoopContinuingBlockStatement() = default;

View File

@ -33,8 +33,10 @@ class LoopStatement : public Castable<LoopStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this loop statement
/// @param parent the owning statement
/// @param function the owning function
LoopStatement(const ast::LoopStatement* declaration,
CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~LoopStatement() override;
@ -47,8 +49,10 @@ class LoopContinuingBlockStatement
/// Constructor
/// @param declaration the AST node for this block statement
/// @param parent the owning statement
/// @param function the owning function
LoopContinuingBlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~LoopContinuingBlockStatement() override;

View File

@ -27,23 +27,18 @@ namespace tint {
namespace sem {
Statement::Statement(const ast::Statement* declaration,
const CompoundStatement* parent)
: declaration_(declaration), parent_(parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: declaration_(declaration), parent_(parent), function_(function) {}
const BlockStatement* Statement::Block() const {
return FindFirstParent<BlockStatement>();
}
const ast::Function* Statement::Function() const {
if (auto* fbs = FindFirstParent<FunctionBlockStatement>()) {
return fbs->Function();
}
return nullptr;
}
CompoundStatement::CompoundStatement(const ast::Statement* declaration,
const CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {}
CompoundStatement::~CompoundStatement() = default;

View File

@ -33,6 +33,7 @@ namespace sem {
/// Forward declaration
class CompoundStatement;
class Function;
namespace detail {
/// FindFirstParentReturn is a traits helper for determining the return type for
@ -64,7 +65,10 @@ class Statement : public Castable<Statement, Node> {
/// Constructor
/// @param declaration the AST node for this statement
/// @param parent the owning statement
Statement(const ast::Statement* declaration, const CompoundStatement* parent);
/// @param function the owning function
Statement(const ast::Statement* declaration,
const CompoundStatement* parent,
const sem::Function* function);
/// @return the AST node for this statement
const ast::Statement* Declaration() const { return declaration_; }
@ -90,11 +94,12 @@ class Statement : public Castable<Statement, Node> {
const BlockStatement* Block() const;
/// @returns the function that owns this statement
const ast::Function* Function() const;
const sem::Function* Function() const { return function_; }
private:
ast::Statement const* const declaration_;
CompoundStatement const* const parent_;
const ast::Statement* const declaration_;
const CompoundStatement* const parent_;
const sem::Function* const function_;
};
/// CompoundStatement is the base class of statements that can hold other
@ -103,9 +108,11 @@ class CompoundStatement : public Castable<Statement, Statement> {
public:
/// Constructor
/// @param declaration the AST node for this statement
/// @param parent the owning statement
/// @param statement the owning statement
/// @param function the owning function
CompoundStatement(const ast::Statement* declaration,
const CompoundStatement* parent);
const CompoundStatement* statement,
const sem::Function* function);
/// Destructor
~CompoundStatement() override;

View File

@ -23,16 +23,22 @@ namespace tint {
namespace sem {
SwitchStatement::SwitchStatement(const ast::SwitchStatement* declaration,
CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
SwitchStatement::~SwitchStatement() = default;
SwitchCaseBlockStatement::SwitchCaseBlockStatement(
const ast::BlockStatement* declaration,
const CompoundStatement* parent)
: Base(declaration, parent) {
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
SwitchCaseBlockStatement::~SwitchCaseBlockStatement() = default;

View File

@ -33,8 +33,10 @@ class SwitchStatement : public Castable<SwitchStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this switch statement
/// @param parent the owning statement
/// @param function the owning function
SwitchStatement(const ast::SwitchStatement* declaration,
CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~SwitchStatement() override;
@ -47,8 +49,10 @@ class SwitchCaseBlockStatement
/// Constructor
/// @param declaration the AST node for this block statement
/// @param parent the owning statement
/// @param function the owning function
SwitchCaseBlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~SwitchCaseBlockStatement() override;

View File

@ -31,19 +31,26 @@ namespace sem {
Variable::Variable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access)
ast::Access access,
Constant constant_value)
: declaration_(declaration),
type_(type),
storage_class_(storage_class),
access_(access) {}
access_(access),
constant_value_(constant_value) {}
Variable::~Variable() = default;
LocalVariable::LocalVariable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access)
: Base(declaration, type, storage_class, access) {}
ast::Access access,
Constant constant_value)
: Base(declaration,
type,
storage_class,
access,
std::move(constant_value)) {}
LocalVariable::~LocalVariable() = default;
@ -51,20 +58,12 @@ GlobalVariable::GlobalVariable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access,
Constant constant_value,
sem::BindingPoint binding_point)
: Base(declaration, type, storage_class, access),
: Base(declaration, type, storage_class, access, std::move(constant_value)),
binding_point_(binding_point),
is_pipeline_constant_(false) {}
GlobalVariable::GlobalVariable(const ast::Variable* declaration,
const sem::Type* type,
uint16_t constant_id)
: Base(declaration,
type,
ast::StorageClass::kNone,
ast::Access::kReadWrite),
is_pipeline_constant_(true),
constant_id_(constant_id) {}
GlobalVariable::~GlobalVariable() = default;
@ -74,18 +73,16 @@ Parameter::Parameter(const ast::Variable* declaration,
ast::StorageClass storage_class,
ast::Access access,
const ParameterUsage usage /* = ParameterUsage::kNone */)
: Base(declaration, type, storage_class, access),
: Base(declaration, type, storage_class, access, Constant{}),
index_(index),
usage_(usage) {}
Parameter::~Parameter() = default;
VariableUser::VariableUser(const ast::IdentifierExpression* declaration,
const sem::Type* type,
Statement* statement,
sem::Variable* variable,
Constant constant_value)
: Base(declaration, type, statement, std::move(constant_value)),
sem::Variable* variable)
: Base(declaration, variable->Type(), statement, variable->ConstantValue()),
variable_(variable) {}
} // namespace sem

View File

@ -47,10 +47,12 @@ class Variable : public Castable<Variable, Node> {
/// @param type the variable type
/// @param storage_class the variable storage class
/// @param access the variable access control type
/// @param constant_value the constant value for the variable. May be invalid
Variable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access);
ast::Access access,
Constant constant_value);
/// Destructor
~Variable() override;
@ -67,6 +69,9 @@ class Variable : public Castable<Variable, Node> {
/// @returns the access control for the variable
ast::Access Access() const { return access_; }
/// @return the constant value of this expression
const Constant& ConstantValue() const { return constant_value_; }
/// @returns the expressions that use the variable
const std::vector<const VariableUser*>& Users() const { return users_; }
@ -78,6 +83,7 @@ class Variable : public Castable<Variable, Node> {
const sem::Type* const type_;
ast::StorageClass const storage_class_;
ast::Access const access_;
const Constant constant_value_;
std::vector<const VariableUser*> users_;
};
@ -89,10 +95,12 @@ class LocalVariable : public Castable<LocalVariable, Variable> {
/// @param type the variable type
/// @param storage_class the variable storage class
/// @param access the variable access control type
/// @param constant_value the constant value for the variable. May be invalid
LocalVariable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access);
ast::Access access,
Constant constant_value);
/// Destructor
~LocalVariable() override;
@ -101,26 +109,20 @@ class LocalVariable : public Castable<LocalVariable, Variable> {
/// GlobalVariable is a module-scope variable
class GlobalVariable : public Castable<GlobalVariable, Variable> {
public:
/// Constructor for non-overridable constants
/// Constructor
/// @param declaration the AST declaration node
/// @param type the variable type
/// @param storage_class the variable storage class
/// @param access the variable access control type
/// @param constant_value the constant value for the variable. May be invalid
/// @param binding_point the optional resource binding point of the variable
GlobalVariable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access,
Constant constant_value,
sem::BindingPoint binding_point = {});
/// Constructor for overridable pipeline constants
/// @param declaration the AST declaration node
/// @param type the variable type
/// @param constant_id the pipeline constant ID
GlobalVariable(const ast::Variable* declaration,
const sem::Type* type,
uint16_t constant_id);
/// Destructor
~GlobalVariable() override;
@ -130,13 +132,20 @@ class GlobalVariable : public Castable<GlobalVariable, Variable> {
/// @returns the pipeline constant ID associated with the variable
uint16_t ConstantId() const { return constant_id_; }
/// @param id the constant identifier to assign to this variable
void SetConstantId(uint16_t id) {
constant_id_ = id;
is_pipeline_constant_ = true;
}
/// @returns true if this variable is an overridable pipeline constant
bool IsPipelineConstant() const { return is_pipeline_constant_; }
private:
sem::BindingPoint binding_point_;
const bool is_pipeline_constant_;
uint16_t const constant_id_ = 0;
const sem::BindingPoint binding_point_;
bool is_pipeline_constant_ = false;
uint16_t constant_id_ = 0;
};
/// Parameter is a function parameter
@ -186,15 +195,11 @@ class VariableUser : public Castable<VariableUser, Expression> {
public:
/// Constructor
/// @param declaration the AST identifier node
/// @param type the resolved type of the expression
/// @param statement the statement that owns this expression
/// @param variable the semantic variable
/// @param constant_value the constant value for the variable. May be invalid
VariableUser(const ast::IdentifierExpression* declaration,
const sem::Type* type,
Statement* statement,
sem::Variable* variable,
Constant constant_value);
sem::Variable* variable);
/// @returns the variable that this expression refers to
const sem::Variable* Variable() const { return variable_; }

View File

@ -109,8 +109,8 @@ struct ModuleScopeVarToEntryPointParam::State {
// Find all of the calls to this function that will need to be replaced.
for (auto* call : func_sem->CallSites()) {
auto* call_sem = ctx.src->Sem().Get(call);
calls_to_replace[call_sem->Stmt()->Function()].push_back(call);
calls_to_replace[call->Stmt()->Function()->Declaration()].push_back(
call->Declaration());
}
}
}
@ -268,7 +268,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// Replace all uses of the module-scope variable.
// For non-entry points, dereference non-handle pointer parameters.
for (auto* user : var->Users()) {
if (user->Stmt()->Function() == func_ast) {
if (user->Stmt()->Function()->Declaration() == func_ast) {
const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (is_pointer) {
// If this identifier is used by an address-of operator, just

View File

@ -39,7 +39,7 @@ var<private> a : array<f32, 3>;
let c : u32 = 1u;
fn f() {
let b : f32 = a[min(c, 2u)];
let b : f32 = a[1u];
}
)";
@ -807,7 +807,7 @@ struct S {
let c : u32 = 1u;
fn f() {
let b : f32 = s.b[min(c, (arrayLength(&(s.b)) - 1u))];
let b : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
let x : i32 = min(1, 2);
let y : u32 = arrayLength(&(s.b));
}

View File

@ -177,6 +177,7 @@
%209 = OpConstantComposite %v2float %float_1 %float_1
%_ptr_Function_v2float = OpTypePointer Function %v2float
%213 = OpConstantNull %v2float
%216 = OpConstantComposite %v2int %int_16 %int_16
%int_1 = OpConstant %int 1
%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
%_ptr_StorageBuffer_uint_0 = OpTypePointer StorageBuffer %uint
@ -359,7 +360,6 @@
%210 = OpFSub %v2float %208 %209
OpStore %floorCoord %210
%215 = OpLoad %v2int %tilePixel0Idx
%216 = OpCompositeConstruct %v2int %int_16 %int_16
%217 = OpIAdd %v2int %215 %216
%214 = OpConvertSToF %v2float %217
%218 = OpVectorTimesScalar %v2float %214 %float_2