From 89d8b2b7a5053a01c7030538e75aec9fa48265d4 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Thu, 4 Nov 2021 22:29:22 +0000 Subject: [PATCH] Clean up the ScopeStack interface There's no need for the ScopeStack to include 'global' information. This is easily obtainable from the element type. Replace the get-by-reference, with a simpler return value. Change-Id: Ic6f4c0f656a2019417d68ffb3fe85ba8343ad15e Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68403 Kokoro: Kokoro Reviewed-by: James Price --- src/resolver/resolver.cc | 47 ++++++++++------------ src/scope_stack.h | 47 +++++----------------- src/scope_stack_test.cc | 78 ++++++++---------------------------- src/writer/spirv/builder.cc | 80 ++++++++++++++++++------------------- 4 files changed, 85 insertions(+), 167 deletions(-) diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index fb1b25dd10..6836f12ff5 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -656,7 +656,7 @@ bool Resolver::GlobalVariable(const ast::Variable* var) { if (!info) { return false; } - variable_stack_.set_global(var->symbol, info); + variable_stack_.Set(var->symbol, info); if (!var->is_const && info->storage_class == ast::StorageClass::kNone) { AddError("global variables must have a storage class", var->source); @@ -1769,7 +1769,7 @@ bool Resolver::Function(const ast::Function* func) { TINT_SCOPED_ASSIGNMENT(current_function_, info); - variable_stack_.push_scope(); + variable_stack_.Push(); uint32_t parameter_index = 0; std::unordered_map parameter_names; for (auto* param : func->params) { @@ -1798,7 +1798,7 @@ bool Resolver::Function(const ast::Function* func) { return false; } - variable_stack_.set(param->symbol, param_info); + variable_stack_.Set(param->symbol, param_info); info->parameters.emplace_back(param_info); if (!ApplyStorageClassUsageToType(param->declared_storage_class, @@ -1887,7 +1887,7 @@ bool Resolver::Function(const ast::Function* func) { return false; } } - variable_stack_.pop_scope(); + variable_stack_.Pop(); for (auto* deco : func->decorations) { Mark(deco); @@ -1952,9 +1952,8 @@ bool Resolver::Function(const ast::Function* func) { if (auto* ident = expr->As()) { // We have an identifier of a module-scope constant. - VariableInfo* var = nullptr; - if (!variable_stack_.get(ident->symbol, &var) || - !(var->declaration->is_const)) { + VariableInfo* var = variable_stack_.Get(ident->symbol); + if (!var || !(var->declaration->is_const)) { AddError(kErrBadType, expr->source); return false; } @@ -2635,8 +2634,8 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call, if (param->declaration->type->Is()) { auto is_valid = false; if (auto* ident_expr = arg_expr->As()) { - VariableInfo* var; - if (!variable_stack_.get(ident_expr->symbol, &var)) { + VariableInfo* var = variable_stack_.Get(ident_expr->symbol); + if (!var) { TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; return false; } @@ -2647,8 +2646,8 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call, if (unary->op == ast::UnaryOp::kAddressOf) { if (auto* ident_unary = unary->expr->As()) { - VariableInfo* var; - if (!variable_stack_.get(ident_unary->symbol, &var)) { + VariableInfo* var = variable_stack_.Get(ident_unary->symbol); + if (!var) { TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; return false; @@ -2987,8 +2986,7 @@ bool Resolver::ValidateScalarConstructor( bool Resolver::Identifier(const ast::IdentifierExpression* expr) { auto symbol = expr->symbol; - VariableInfo* var; - if (variable_stack_.get(symbol, &var)) { + if (VariableInfo* var = variable_stack_.Get(symbol)) { SetExprInfo(expr, var->type, var->type_name); var->users.push_back(expr); @@ -3485,7 +3483,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { } } - variable_stack_.set(var->symbol, info); + variable_stack_.Set(var->symbol, info); if (current_block_) { // Not all statements are inside a block current_block_->AddDecl(var); } @@ -3909,9 +3907,8 @@ sem::Array* Resolver::Array(const ast::Array* arr) { if (auto* ident = count_expr->As()) { // Make sure the identifier is a non-overridable module-scope constant. - VariableInfo* var = nullptr; - bool is_global = false; - if (!variable_stack_.get(ident->symbol, &var, &is_global) || !is_global || + VariableInfo* var = variable_stack_.Get(ident->symbol); + if (!var || var->kind != VariableKind::kGlobal || !var->declaration->is_const) { AddError("array size identifier must be a module-scope constant", size_source); @@ -4489,8 +4486,7 @@ bool Resolver::ValidateAssignment(const ast::AssignmentStatement* a) { auto const* lhs_type = TypeOf(a->lhs); if (auto* ident = a->lhs->As()) { - VariableInfo* var; - if (variable_stack_.get(ident->symbol, &var)) { + if (VariableInfo* var = variable_stack_.Get(ident->symbol)) { if (var->kind == VariableKind::kParameter) { AddError("cannot assign to function parameter", a->lhs->source); AddNote("'" + builder_->Symbols().NameFor(ident->symbol) + @@ -4542,10 +4538,8 @@ bool Resolver::ValidateNoDuplicateDefinition(Symbol sym, const Source& source, bool check_global_scope_only) { if (check_global_scope_only) { - bool is_global = false; - VariableInfo* var; - if (variable_stack_.get(sym, &var, &is_global)) { - if (is_global) { + if (VariableInfo* var = variable_stack_.Get(sym)) { + if (var->kind == VariableKind::kGlobal) { AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'", source); AddNote("previous definition is here", var->declaration->source); @@ -4560,8 +4554,7 @@ bool Resolver::ValidateNoDuplicateDefinition(Symbol sym, return false; } } else { - VariableInfo* var; - if (variable_stack_.get(sym, &var)) { + if (VariableInfo* var = variable_stack_.Get(sym)) { AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'", source); AddNote("previous definition is here", var->declaration->source); @@ -4635,10 +4628,10 @@ bool Resolver::Scope(sem::CompoundStatement* stmt, F&& callback) { current_statement_ = stmt; current_compound_statement_ = stmt; current_block_ = stmt->As(); - variable_stack_.push_scope(); + variable_stack_.Push(); TINT_DEFER({ - TINT_DEFER(variable_stack_.pop_scope()); + TINT_DEFER(variable_stack_.Pop()); current_block_ = prev_current_block; current_compound_statement_ = prev_current_compound_statement; current_statement_ = prev_current_statement; diff --git a/src/scope_stack.h b/src/scope_stack.h index a1619c4de8..f8ddadc151 100644 --- a/src/scope_stack.h +++ b/src/scope_stack.h @@ -36,60 +36,33 @@ class ScopeStack { ~ScopeStack() = default; /// Push a new scope on to the stack - void push_scope() { stack_.push_back({}); } + void Push() { stack_.push_back({}); } /// Pop the scope off the top of the stack - void pop_scope() { + void Pop() { if (stack_.size() > 1) { stack_.pop_back(); } } - /// Set a global variable in the stack + /// Assigns the value into the top most scope of the stack /// @param symbol the symbol of the variable /// @param val the value - void set_global(const Symbol& symbol, T val) { stack_[0][symbol] = val; } + void Set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; } - /// Sets variable into the top most scope of the stack - /// @param symbol the symbol of the variable - /// @param val the value - void set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; } - - /// Checks for the given `symbol` in the stack + /// Retrieves a value from the stack /// @param symbol the symbol to look for - /// @returns true if the stack contains `symbol` - bool has(const Symbol& symbol) const { return get(symbol, nullptr); } - - /// Retrieves a given variable from the stack - /// @param symbol the symbol to look for - /// @param ret where to place the value - /// @returns true if the symbol was successfully found, false otherwise - bool get(const Symbol& symbol, T* ret) const { - return get(symbol, ret, nullptr); - } - - /// Retrieves a given variable from the stack - /// @param symbol the symbol to look for - /// @param ret where to place the value - /// @param is_global set true if the symbol references a global variable - /// otherwise unchanged - /// @returns true if the symbol was successfully found, false otherwise - bool get(const Symbol& symbol, T* ret, bool* is_global) const { + /// @returns the variable, or the zero initializer if the value was not found + T Get(const Symbol& symbol) const { for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) { auto& map = *iter; auto val = map.find(symbol); - if (val != map.end()) { - if (ret) { - *ret = val->second; - } - if (is_global && iter == stack_.rend() - 1) { - *is_global = true; - } - return true; + return val->second; } } - return false; + + return T{}; } private: diff --git a/src/scope_stack_test.cc b/src/scope_stack_test.cc index 0006347933..f96f0d165f 100644 --- a/src/scope_stack_test.cc +++ b/src/scope_stack_test.cc @@ -21,78 +21,32 @@ namespace { class ScopeStackTest : public ProgramBuilder, public testing::Test {}; -TEST_F(ScopeStackTest, Global) { +TEST_F(ScopeStackTest, Get) { ScopeStack s; - Symbol sym(1, ID()); - s.set_global(sym, 5); + Symbol a(1, ID()); + Symbol b(3, ID()); + s.Push(); + s.Set(a, 5u); + s.Set(b, 10u); - uint32_t val = 0; - EXPECT_TRUE(s.get(sym, &val)); - EXPECT_EQ(val, 5u); -} + EXPECT_EQ(s.Get(a), 5u); + EXPECT_EQ(s.Get(b), 10u); -TEST_F(ScopeStackTest, Global_SetWithPointer) { - auto* v = Var("my_var", ty.f32(), ast::StorageClass::kNone); - ScopeStack s; - s.set_global(v->symbol, v); + s.Push(); - const ast::Variable* v2 = nullptr; - EXPECT_TRUE(s.get(v->symbol, &v2)); - EXPECT_EQ(v2->symbol, v->symbol); -} + s.Set(a, 15u); + EXPECT_EQ(s.Get(a), 15u); + EXPECT_EQ(s.Get(b), 10u); -TEST_F(ScopeStackTest, Global_CanNotPop) { - ScopeStack s; - Symbol sym(1, ID()); - s.set_global(sym, 5); - s.pop_scope(); - - uint32_t val = 0; - EXPECT_TRUE(s.get(sym, &val)); - EXPECT_EQ(val, 5u); -} - -TEST_F(ScopeStackTest, Scope) { - ScopeStack s; - Symbol sym(1, ID()); - s.push_scope(); - s.set(sym, 5); - - uint32_t val = 0; - EXPECT_TRUE(s.get(sym, &val)); - EXPECT_EQ(val, 5u); + s.Pop(); + EXPECT_EQ(s.Get(a), 5u); + EXPECT_EQ(s.Get(b), 10u); } TEST_F(ScopeStackTest, Get_MissingSymbol) { ScopeStack s; Symbol sym(1, ID()); - uint32_t ret = 0; - EXPECT_FALSE(s.get(sym, &ret)); - EXPECT_EQ(ret, 0u); -} - -TEST_F(ScopeStackTest, Has) { - ScopeStack s; - Symbol sym(1, ID()); - Symbol sym2(2, ID()); - s.set_global(sym2, 3); - s.push_scope(); - s.set(sym, 5); - - EXPECT_TRUE(s.has(sym)); - EXPECT_TRUE(s.has(sym2)); -} - -TEST_F(ScopeStackTest, ReturnsScopeBeforeGlobalFirst) { - ScopeStack s; - Symbol sym(1, ID()); - s.set_global(sym, 3); - s.push_scope(); - s.set(sym, 5); - - uint32_t ret; - EXPECT_TRUE(s.get(sym, &ret)); - EXPECT_EQ(ret, 5u); + EXPECT_EQ(s.Get(sym), 0u); } } // namespace diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 1ca479f9ba..71dc5417a3 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -46,6 +46,7 @@ #include "src/transform/simplify.h" #include "src/transform/vectorize_scalar_matrix_constructors.h" #include "src/transform/zero_init_workgroup_memory.h" +#include "src/utils/defer.h" #include "src/utils/get_or_create.h" #include "src/writer/append_vector.h" @@ -468,8 +469,8 @@ bool Builder::GenerateEntryPoint(const ast::Function* func, uint32_t id) { continue; } - uint32_t var_id; - if (!scope_stack_.get(var->Declaration()->symbol, &var_id)) { + uint32_t var_id = scope_stack_.Get(var->Declaration()->symbol); + if (var_id == 0) { error_ = "unable to find ID for global variable: " + builder_.Symbols().NameFor(var->Declaration()->symbol); return false; @@ -613,7 +614,8 @@ bool Builder::GenerateFunction(const ast::Function* func_ast) { return false; } - scope_stack_.push_scope(); + scope_stack_.Push(); + TINT_DEFER(scope_stack_.Pop()); auto definition_inst = Instruction{ spv::Op::OpFunction, @@ -636,7 +638,7 @@ bool Builder::GenerateFunction(const ast::Function* func_ast) { params.push_back(Instruction{spv::Op::OpFunctionParameter, {Operand::Int(param_type_id), param_op}}); - scope_stack_.set(param->Declaration()->symbol, param_id); + scope_stack_.Set(param->Declaration()->symbol, param_id); } push_function(Function{definition_inst, result_op(), std::move(params)}); @@ -656,8 +658,6 @@ bool Builder::GenerateFunction(const ast::Function* func_ast) { } } - scope_stack_.pop_scope(); - func_symbol_to_id_[func_ast->symbol] = func_id; return true; @@ -706,7 +706,7 @@ bool Builder::GenerateFunctionVariable(const ast::Variable* var) { error_ = "missing constructor for constant"; return false; } - scope_stack_.set(var->symbol, init_id); + scope_stack_.Set(var->symbol, init_id); spirv_id_to_variable_[init_id] = var; return true; } @@ -740,7 +740,7 @@ bool Builder::GenerateFunctionVariable(const ast::Variable* var) { } } - scope_stack_.set(var->symbol, var_id); + scope_stack_.Set(var->symbol, var_id); spirv_id_to_variable_[var_id] = var; return true; @@ -803,7 +803,7 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) { {Operand::Int(init_id), Operand::String(builder_.Symbols().NameFor(var->symbol))}); - scope_stack_.set_global(var->symbol, init_id); + scope_stack_.Set(var->symbol, init_id); spirv_id_to_variable_[init_id] = var; return true; } @@ -898,7 +898,7 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) { } } - scope_stack_.set_global(var->symbol, var_id); + scope_stack_.Set(var->symbol, var_id); spirv_id_to_variable_[var_id] = var; return true; } @@ -1173,14 +1173,12 @@ uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) { uint32_t Builder::GenerateIdentifierExpression( const ast::IdentifierExpression* expr) { - uint32_t val = 0; - if (scope_stack_.get(expr->symbol, &val)) { - return val; + uint32_t val = scope_stack_.Get(expr->symbol); + if (val == 0) { + error_ = "unable to find variable with identifier: " + + builder_.Symbols().NameFor(expr->symbol); } - - error_ = "unable to find variable with identifier: " + - builder_.Symbols().NameFor(expr->symbol); - return 0; + return val; } uint32_t Builder::GenerateLoadIfNeeded(const sem::Type* type, uint32_t id) { @@ -2231,10 +2229,9 @@ uint32_t Builder::GenerateBinaryExpression(const ast::BinaryExpression* expr) { } bool Builder::GenerateBlockStatement(const ast::BlockStatement* stmt) { - scope_stack_.push_scope(); - auto result = GenerateBlockStatementWithoutScoping(stmt); - scope_stack_.pop_scope(); - return result; + scope_stack_.Push(); + TINT_DEFER(scope_stack_.Pop()); + return GenerateBlockStatementWithoutScoping(stmt); } bool Builder::GenerateBlockStatementWithoutScoping( @@ -3736,34 +3733,35 @@ bool Builder::GenerateLoopStatement(const ast::LoopStatement* stmt) { // We need variables from the body to be visible in the continuing block, so // manage scope outside of GenerateBlockStatement. - scope_stack_.push_scope(); + { + scope_stack_.Push(); + TINT_DEFER(scope_stack_.Pop()); - if (!GenerateBlockStatementWithoutScoping(stmt->body)) { - return false; - } - - // We only branch if the last element of the body didn't already branch. - if (!LastIsTerminator(stmt->body)) { - if (!push_function_inst(spv::Op::OpBranch, - {Operand::Int(continue_block_id)})) { + if (!GenerateBlockStatementWithoutScoping(stmt->body)) { return false; } - } - if (!GenerateLabel(continue_block_id)) { - return false; - } - if (stmt->continuing && !stmt->continuing->Empty()) { - continuing_stack_.emplace_back(stmt->continuing->Last(), loop_header_id, - merge_block_id); - if (!GenerateBlockStatementWithoutScoping(stmt->continuing)) { + // We only branch if the last element of the body didn't already branch. + if (!LastIsTerminator(stmt->body)) { + if (!push_function_inst(spv::Op::OpBranch, + {Operand::Int(continue_block_id)})) { + return false; + } + } + + if (!GenerateLabel(continue_block_id)) { return false; } - continuing_stack_.pop_back(); + if (stmt->continuing && !stmt->continuing->Empty()) { + continuing_stack_.emplace_back(stmt->continuing->Last(), loop_header_id, + merge_block_id); + if (!GenerateBlockStatementWithoutScoping(stmt->continuing)) { + return false; + } + continuing_stack_.pop_back(); + } } - scope_stack_.pop_scope(); - // Generate the backedge. TINT_ASSERT(Writer, !backedge_stack_.empty()); const Backedge& backedge = backedge_stack_.back();