Clean up the ScopeStack interface

There's no need for the ScopeStack to include 'global' information. This
is easily obtainable from the element type.
Replace the get-by-reference, with a simpler return value.

Change-Id: Ic6f4c0f656a2019417d68ffb3fe85ba8343ad15e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68403
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2021-11-04 22:29:22 +00:00
parent 307eff0da4
commit 89d8b2b7a5
4 changed files with 85 additions and 167 deletions

View File

@ -656,7 +656,7 @@ bool Resolver::GlobalVariable(const ast::Variable* var) {
if (!info) {
return false;
}
variable_stack_.set_global(var->symbol, info);
variable_stack_.Set(var->symbol, info);
if (!var->is_const && info->storage_class == ast::StorageClass::kNone) {
AddError("global variables must have a storage class", var->source);
@ -1769,7 +1769,7 @@ bool Resolver::Function(const ast::Function* func) {
TINT_SCOPED_ASSIGNMENT(current_function_, info);
variable_stack_.push_scope();
variable_stack_.Push();
uint32_t parameter_index = 0;
std::unordered_map<Symbol, Source> parameter_names;
for (auto* param : func->params) {
@ -1798,7 +1798,7 @@ bool Resolver::Function(const ast::Function* func) {
return false;
}
variable_stack_.set(param->symbol, param_info);
variable_stack_.Set(param->symbol, param_info);
info->parameters.emplace_back(param_info);
if (!ApplyStorageClassUsageToType(param->declared_storage_class,
@ -1887,7 +1887,7 @@ bool Resolver::Function(const ast::Function* func) {
return false;
}
}
variable_stack_.pop_scope();
variable_stack_.Pop();
for (auto* deco : func->decorations) {
Mark(deco);
@ -1952,9 +1952,8 @@ bool Resolver::Function(const ast::Function* func) {
if (auto* ident = expr->As<ast::IdentifierExpression>()) {
// We have an identifier of a module-scope constant.
VariableInfo* var = nullptr;
if (!variable_stack_.get(ident->symbol, &var) ||
!(var->declaration->is_const)) {
VariableInfo* var = variable_stack_.Get(ident->symbol);
if (!var || !(var->declaration->is_const)) {
AddError(kErrBadType, expr->source);
return false;
}
@ -2635,8 +2634,8 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
if (param->declaration->type->Is<ast::Pointer>()) {
auto is_valid = false;
if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
VariableInfo* var;
if (!variable_stack_.get(ident_expr->symbol, &var)) {
VariableInfo* var = variable_stack_.Get(ident_expr->symbol);
if (!var) {
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
return false;
}
@ -2647,8 +2646,8 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
if (unary->op == ast::UnaryOp::kAddressOf) {
if (auto* ident_unary =
unary->expr->As<ast::IdentifierExpression>()) {
VariableInfo* var;
if (!variable_stack_.get(ident_unary->symbol, &var)) {
VariableInfo* var = variable_stack_.Get(ident_unary->symbol);
if (!var) {
TINT_ICE(Resolver, diagnostics_)
<< "failed to resolve identifier";
return false;
@ -2987,8 +2986,7 @@ bool Resolver::ValidateScalarConstructor(
bool Resolver::Identifier(const ast::IdentifierExpression* expr) {
auto symbol = expr->symbol;
VariableInfo* var;
if (variable_stack_.get(symbol, &var)) {
if (VariableInfo* var = variable_stack_.Get(symbol)) {
SetExprInfo(expr, var->type, var->type_name);
var->users.push_back(expr);
@ -3485,7 +3483,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
}
}
variable_stack_.set(var->symbol, info);
variable_stack_.Set(var->symbol, info);
if (current_block_) { // Not all statements are inside a block
current_block_->AddDecl(var);
}
@ -3909,9 +3907,8 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
// Make sure the identifier is a non-overridable module-scope constant.
VariableInfo* var = nullptr;
bool is_global = false;
if (!variable_stack_.get(ident->symbol, &var, &is_global) || !is_global ||
VariableInfo* var = variable_stack_.Get(ident->symbol);
if (!var || var->kind != VariableKind::kGlobal ||
!var->declaration->is_const) {
AddError("array size identifier must be a module-scope constant",
size_source);
@ -4489,8 +4486,7 @@ bool Resolver::ValidateAssignment(const ast::AssignmentStatement* a) {
auto const* lhs_type = TypeOf(a->lhs);
if (auto* ident = a->lhs->As<ast::IdentifierExpression>()) {
VariableInfo* var;
if (variable_stack_.get(ident->symbol, &var)) {
if (VariableInfo* var = variable_stack_.Get(ident->symbol)) {
if (var->kind == VariableKind::kParameter) {
AddError("cannot assign to function parameter", a->lhs->source);
AddNote("'" + builder_->Symbols().NameFor(ident->symbol) +
@ -4542,10 +4538,8 @@ bool Resolver::ValidateNoDuplicateDefinition(Symbol sym,
const Source& source,
bool check_global_scope_only) {
if (check_global_scope_only) {
bool is_global = false;
VariableInfo* var;
if (variable_stack_.get(sym, &var, &is_global)) {
if (is_global) {
if (VariableInfo* var = variable_stack_.Get(sym)) {
if (var->kind == VariableKind::kGlobal) {
AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'",
source);
AddNote("previous definition is here", var->declaration->source);
@ -4560,8 +4554,7 @@ bool Resolver::ValidateNoDuplicateDefinition(Symbol sym,
return false;
}
} else {
VariableInfo* var;
if (variable_stack_.get(sym, &var)) {
if (VariableInfo* var = variable_stack_.Get(sym)) {
AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'",
source);
AddNote("previous definition is here", var->declaration->source);
@ -4635,10 +4628,10 @@ bool Resolver::Scope(sem::CompoundStatement* stmt, F&& callback) {
current_statement_ = stmt;
current_compound_statement_ = stmt;
current_block_ = stmt->As<sem::BlockStatement>();
variable_stack_.push_scope();
variable_stack_.Push();
TINT_DEFER({
TINT_DEFER(variable_stack_.pop_scope());
TINT_DEFER(variable_stack_.Pop());
current_block_ = prev_current_block;
current_compound_statement_ = prev_current_compound_statement;
current_statement_ = prev_current_statement;

View File

@ -36,60 +36,33 @@ class ScopeStack {
~ScopeStack() = default;
/// Push a new scope on to the stack
void push_scope() { stack_.push_back({}); }
void Push() { stack_.push_back({}); }
/// Pop the scope off the top of the stack
void pop_scope() {
void Pop() {
if (stack_.size() > 1) {
stack_.pop_back();
}
}
/// Set a global variable in the stack
/// Assigns the value into the top most scope of the stack
/// @param symbol the symbol of the variable
/// @param val the value
void set_global(const Symbol& symbol, T val) { stack_[0][symbol] = val; }
void Set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; }
/// Sets variable into the top most scope of the stack
/// @param symbol the symbol of the variable
/// @param val the value
void set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; }
/// Checks for the given `symbol` in the stack
/// Retrieves a value from the stack
/// @param symbol the symbol to look for
/// @returns true if the stack contains `symbol`
bool has(const Symbol& symbol) const { return get(symbol, nullptr); }
/// Retrieves a given variable from the stack
/// @param symbol the symbol to look for
/// @param ret where to place the value
/// @returns true if the symbol was successfully found, false otherwise
bool get(const Symbol& symbol, T* ret) const {
return get(symbol, ret, nullptr);
}
/// Retrieves a given variable from the stack
/// @param symbol the symbol to look for
/// @param ret where to place the value
/// @param is_global set true if the symbol references a global variable
/// otherwise unchanged
/// @returns true if the symbol was successfully found, false otherwise
bool get(const Symbol& symbol, T* ret, bool* is_global) const {
/// @returns the variable, or the zero initializer if the value was not found
T Get(const Symbol& symbol) const {
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
auto& map = *iter;
auto val = map.find(symbol);
if (val != map.end()) {
if (ret) {
*ret = val->second;
}
if (is_global && iter == stack_.rend() - 1) {
*is_global = true;
}
return true;
return val->second;
}
}
return false;
return T{};
}
private:

View File

@ -21,78 +21,32 @@ namespace {
class ScopeStackTest : public ProgramBuilder, public testing::Test {};
TEST_F(ScopeStackTest, Global) {
TEST_F(ScopeStackTest, Get) {
ScopeStack<uint32_t> s;
Symbol sym(1, ID());
s.set_global(sym, 5);
Symbol a(1, ID());
Symbol b(3, ID());
s.Push();
s.Set(a, 5u);
s.Set(b, 10u);
uint32_t val = 0;
EXPECT_TRUE(s.get(sym, &val));
EXPECT_EQ(val, 5u);
}
EXPECT_EQ(s.Get(a), 5u);
EXPECT_EQ(s.Get(b), 10u);
TEST_F(ScopeStackTest, Global_SetWithPointer) {
auto* v = Var("my_var", ty.f32(), ast::StorageClass::kNone);
ScopeStack<const ast::Variable*> s;
s.set_global(v->symbol, v);
s.Push();
const ast::Variable* v2 = nullptr;
EXPECT_TRUE(s.get(v->symbol, &v2));
EXPECT_EQ(v2->symbol, v->symbol);
}
s.Set(a, 15u);
EXPECT_EQ(s.Get(a), 15u);
EXPECT_EQ(s.Get(b), 10u);
TEST_F(ScopeStackTest, Global_CanNotPop) {
ScopeStack<uint32_t> s;
Symbol sym(1, ID());
s.set_global(sym, 5);
s.pop_scope();
uint32_t val = 0;
EXPECT_TRUE(s.get(sym, &val));
EXPECT_EQ(val, 5u);
}
TEST_F(ScopeStackTest, Scope) {
ScopeStack<uint32_t> s;
Symbol sym(1, ID());
s.push_scope();
s.set(sym, 5);
uint32_t val = 0;
EXPECT_TRUE(s.get(sym, &val));
EXPECT_EQ(val, 5u);
s.Pop();
EXPECT_EQ(s.Get(a), 5u);
EXPECT_EQ(s.Get(b), 10u);
}
TEST_F(ScopeStackTest, Get_MissingSymbol) {
ScopeStack<uint32_t> s;
Symbol sym(1, ID());
uint32_t ret = 0;
EXPECT_FALSE(s.get(sym, &ret));
EXPECT_EQ(ret, 0u);
}
TEST_F(ScopeStackTest, Has) {
ScopeStack<uint32_t> s;
Symbol sym(1, ID());
Symbol sym2(2, ID());
s.set_global(sym2, 3);
s.push_scope();
s.set(sym, 5);
EXPECT_TRUE(s.has(sym));
EXPECT_TRUE(s.has(sym2));
}
TEST_F(ScopeStackTest, ReturnsScopeBeforeGlobalFirst) {
ScopeStack<uint32_t> s;
Symbol sym(1, ID());
s.set_global(sym, 3);
s.push_scope();
s.set(sym, 5);
uint32_t ret;
EXPECT_TRUE(s.get(sym, &ret));
EXPECT_EQ(ret, 5u);
EXPECT_EQ(s.Get(sym), 0u);
}
} // namespace

View File

@ -46,6 +46,7 @@
#include "src/transform/simplify.h"
#include "src/transform/vectorize_scalar_matrix_constructors.h"
#include "src/transform/zero_init_workgroup_memory.h"
#include "src/utils/defer.h"
#include "src/utils/get_or_create.h"
#include "src/writer/append_vector.h"
@ -468,8 +469,8 @@ bool Builder::GenerateEntryPoint(const ast::Function* func, uint32_t id) {
continue;
}
uint32_t var_id;
if (!scope_stack_.get(var->Declaration()->symbol, &var_id)) {
uint32_t var_id = scope_stack_.Get(var->Declaration()->symbol);
if (var_id == 0) {
error_ = "unable to find ID for global variable: " +
builder_.Symbols().NameFor(var->Declaration()->symbol);
return false;
@ -613,7 +614,8 @@ bool Builder::GenerateFunction(const ast::Function* func_ast) {
return false;
}
scope_stack_.push_scope();
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
auto definition_inst = Instruction{
spv::Op::OpFunction,
@ -636,7 +638,7 @@ bool Builder::GenerateFunction(const ast::Function* func_ast) {
params.push_back(Instruction{spv::Op::OpFunctionParameter,
{Operand::Int(param_type_id), param_op}});
scope_stack_.set(param->Declaration()->symbol, param_id);
scope_stack_.Set(param->Declaration()->symbol, param_id);
}
push_function(Function{definition_inst, result_op(), std::move(params)});
@ -656,8 +658,6 @@ bool Builder::GenerateFunction(const ast::Function* func_ast) {
}
}
scope_stack_.pop_scope();
func_symbol_to_id_[func_ast->symbol] = func_id;
return true;
@ -706,7 +706,7 @@ bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
error_ = "missing constructor for constant";
return false;
}
scope_stack_.set(var->symbol, init_id);
scope_stack_.Set(var->symbol, init_id);
spirv_id_to_variable_[init_id] = var;
return true;
}
@ -740,7 +740,7 @@ bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
}
}
scope_stack_.set(var->symbol, var_id);
scope_stack_.Set(var->symbol, var_id);
spirv_id_to_variable_[var_id] = var;
return true;
@ -803,7 +803,7 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
{Operand::Int(init_id),
Operand::String(builder_.Symbols().NameFor(var->symbol))});
scope_stack_.set_global(var->symbol, init_id);
scope_stack_.Set(var->symbol, init_id);
spirv_id_to_variable_[init_id] = var;
return true;
}
@ -898,7 +898,7 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
}
}
scope_stack_.set_global(var->symbol, var_id);
scope_stack_.Set(var->symbol, var_id);
spirv_id_to_variable_[var_id] = var;
return true;
}
@ -1173,14 +1173,12 @@ uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) {
uint32_t Builder::GenerateIdentifierExpression(
const ast::IdentifierExpression* expr) {
uint32_t val = 0;
if (scope_stack_.get(expr->symbol, &val)) {
return val;
uint32_t val = scope_stack_.Get(expr->symbol);
if (val == 0) {
error_ = "unable to find variable with identifier: " +
builder_.Symbols().NameFor(expr->symbol);
}
error_ = "unable to find variable with identifier: " +
builder_.Symbols().NameFor(expr->symbol);
return 0;
return val;
}
uint32_t Builder::GenerateLoadIfNeeded(const sem::Type* type, uint32_t id) {
@ -2231,10 +2229,9 @@ uint32_t Builder::GenerateBinaryExpression(const ast::BinaryExpression* expr) {
}
bool Builder::GenerateBlockStatement(const ast::BlockStatement* stmt) {
scope_stack_.push_scope();
auto result = GenerateBlockStatementWithoutScoping(stmt);
scope_stack_.pop_scope();
return result;
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
return GenerateBlockStatementWithoutScoping(stmt);
}
bool Builder::GenerateBlockStatementWithoutScoping(
@ -3736,34 +3733,35 @@ bool Builder::GenerateLoopStatement(const ast::LoopStatement* stmt) {
// We need variables from the body to be visible in the continuing block, so
// manage scope outside of GenerateBlockStatement.
scope_stack_.push_scope();
{
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
if (!GenerateBlockStatementWithoutScoping(stmt->body)) {
return false;
}
// We only branch if the last element of the body didn't already branch.
if (!LastIsTerminator(stmt->body)) {
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(continue_block_id)})) {
if (!GenerateBlockStatementWithoutScoping(stmt->body)) {
return false;
}
}
if (!GenerateLabel(continue_block_id)) {
return false;
}
if (stmt->continuing && !stmt->continuing->Empty()) {
continuing_stack_.emplace_back(stmt->continuing->Last(), loop_header_id,
merge_block_id);
if (!GenerateBlockStatementWithoutScoping(stmt->continuing)) {
// We only branch if the last element of the body didn't already branch.
if (!LastIsTerminator(stmt->body)) {
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(continue_block_id)})) {
return false;
}
}
if (!GenerateLabel(continue_block_id)) {
return false;
}
continuing_stack_.pop_back();
if (stmt->continuing && !stmt->continuing->Empty()) {
continuing_stack_.emplace_back(stmt->continuing->Last(), loop_header_id,
merge_block_id);
if (!GenerateBlockStatementWithoutScoping(stmt->continuing)) {
return false;
}
continuing_stack_.pop_back();
}
}
scope_stack_.pop_scope();
// Generate the backedge.
TINT_ASSERT(Writer, !backedge_stack_.empty());
const Backedge& backedge = backedge_stack_.back();