diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 16cbeafc0e..58aaf57a7d 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -1802,33 +1802,51 @@ bool Resolver::Function(ast::Function* func) { if (auto* workgroup = ast::GetDecoration(func->decorations())) { auto values = workgroup->values(); - auto is_i32 = false; - auto is_less_than_one = true; + auto any_i32 = false; + auto any_u32 = false; for (int i = 0; i < 3; i++) { // Each argument to this decoration can either be a literal, an // identifier for a module-scope constants, or nullptr if not specified. - if (!values[i]) { + auto* expr = values[i]; + if (!expr) { // Not specified, just use the default. continue; } - Mark(values[i]); + Mark(expr); + if (!Expression(expr)) { + return false; + } - uint32_t value = 0; - if (auto* ident = values[i]->As()) { + constexpr const char* kErrBadType = + "workgroup_size parameter must be either literal or module-scope " + "constant of type i32 or u32"; + constexpr const char* kErrInconsistentType = + "workgroup_size parameters must be of the same type, either i32 " + "or u32"; + + auto* ty = TypeOf(expr); + bool is_i32 = ty->UnwrapRef()->Is(); + bool is_u32 = ty->UnwrapRef()->Is(); + if (!is_i32 && !is_u32) { + AddError(kErrBadType, expr->source()); + return false; + } + + any_i32 = any_i32 || is_i32; + any_u32 = any_u32 || is_u32; + if (any_i32 && any_u32) { + AddError(kErrInconsistentType, expr->source()); + return false; + } + + if (auto* ident = expr->As()) { // We have an identifier of a module-scope constant. - if (!Identifier(ident)) { - return false; - } - - VariableInfo* var; + VariableInfo* var = nullptr; if (!variable_stack_.get(ident->symbol(), &var) || - !(var->declaration->is_const() && var->type->is_integer_scalar())) { - AddError( - "workgroup_size parameter must be either literal or module-scope " - "constant of type i32 or u32", - values[i]->source()); + !(var->declaration->is_const())) { + AddError(kErrBadType, expr->source()); return false; } @@ -1838,75 +1856,30 @@ bool Resolver::Function(ast::Function* func) { info->workgroup_size[i].overridable_const = var->declaration; } - auto* constructor = var->declaration->constructor(); - if (constructor) { - // Resolve the constructor expression to use as the default value. - auto val = ConstantValueOf(constructor); - if (!val.IsValid() || !val.Type()->is_integer_scalar()) { - TINT_ICE(Resolver, diagnostics_) - << "failed to resolve workgroup_size constant value"; - return false; - } - - if (i == 0) { - is_i32 = val.Type()->Is(); - } else { - if (is_i32 != val.Type()->Is()) { - AddError( - "workgroup_size parameters must be of the same type, " - "either i32 or u32", - values[i]->source()); - return false; - } - } - is_less_than_one = - is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1; - - value = is_i32 ? static_cast(val.Elements()[0].i32) - : val.Elements()[0].u32; - } else { + expr = var->declaration->constructor(); + if (!expr) { // No constructor means this value must be overriden by the user. info->workgroup_size[i].value = 0; continue; } - } else if (auto* scalar = - values[i]->As()) { - // We have a literal. - Mark(scalar->literal()); - auto* literal = scalar->literal()->As(); - if (!literal) { - AddError( - "workgroup_size parameter must be either literal or module-scope " - "constant of type i32 or u32", - values[i]->source()); - return false; - } - - if (i == 0) { - is_i32 = literal->Is(); - } else { - if (literal->Is() != is_i32) { - AddError( - "workgroup_size parameters must be of the same type, " - "either i32 or u32", - values[i]->source()); - return false; - } - } - - is_less_than_one = - is_i32 ? literal->value_as_i32() < 1 : literal->value_as_u32() < 1; - value = is_i32 ? static_cast(literal->value_as_i32()) - : literal->value_as_u32(); } + auto val = ConstantValueOf(expr); + if (!val) { + TINT_ICE(Resolver, diagnostics_) + << "could not resolve constant workgroup_size constant value"; + continue; + } // Validate and set the default value for this dimension. - if (is_less_than_one) { + if (is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1) { AddError("workgroup_size parameter must be at least 1", values[i]->source()); return false; } - info->workgroup_size[i].value = value; + + info->workgroup_size[i].value = + is_i32 ? static_cast(val.Elements()[0].i32) + : val.Elements()[0].u32; } }