diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc index 2cb8905ec7..36221e204d 100644 --- a/src/tint/resolver/attribute_validation_test.cc +++ b/src/tint/resolver/attribute_validation_test.cc @@ -895,7 +895,8 @@ TEST_P(ArrayStrideTest, All) { << ", should_pass: " << params.should_pass; SCOPED_TRACE(ss.str()); - auto* arr = ty.array(Source{{12, 34}}, el_ty, 4_u, params.stride); + auto* arr = + ty.array(el_ty, 4_u, {create(Source{{12, 34}}, params.stride)}); GlobalVar("myarray", arr, ast::StorageClass::kPrivate); @@ -906,7 +907,7 @@ TEST_P(ArrayStrideTest, All) { EXPECT_EQ(r()->error(), "12:34 error: arrays decorated with the stride attribute must " "have a stride that is at least the size of the element type, " - "and be a multiple of the element type's alignment value."); + "and be a multiple of the element type's alignment value"); } } diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index b1961c3129..7eb4d28fad 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -2259,102 +2259,128 @@ sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) { } sem::Array* Resolver::Array(const ast::Array* arr) { - auto source = arr->source; - - auto* elem_type = Type(arr->type); - if (!elem_type) { + if (!arr->type) { + AddError("missing array element type", arr->source.End()); return nullptr; } - if (!validator_.IsPlain(elem_type)) { // Check must come before GetDefaultAlignAndSize() - AddError(sem_.TypeNameOf(elem_type) + " cannot be used as an element type of an array", - source); - return nullptr; - } - - uint32_t el_align = elem_type->Align(); - uint32_t el_size = elem_type->Size(); - - if (!validator_.NoDuplicateAttributes(arr->attributes)) { + auto* el_ty = Type(arr->type); + if (!el_ty) { return nullptr; } // Look for explicit stride via @stride(n) attribute uint32_t explicit_stride = 0; - for (auto* attr : arr->attributes) { + if (!ArrayAttributes(arr->attributes, el_ty, explicit_stride)) { + return nullptr; + } + + uint32_t el_count = 0; // sem::Array uses a size of 0 for a runtime-sized array. + + // Evaluate the constant array size expression. + if (auto* count_expr = arr->count) { + if (auto count = ArrayCount(count_expr)) { + el_count = count.Get(); + } else { + return nullptr; + } + } + + auto* out = Array(arr->source, el_ty, el_count, explicit_stride); + if (out == nullptr) { + return nullptr; + } + + if (el_ty->Is()) { + atomic_composite_info_.emplace(out, arr->type->source); + } else { + auto found = atomic_composite_info_.find(el_ty); + if (found != atomic_composite_info_.end()) { + atomic_composite_info_.emplace(out, found->second); + } + } + + return out; +} + +utils::Result Resolver::ArrayCount(const ast::Expression* count_expr) { + // Evaluate the constant array size expression. + const auto* count_sem = Materialize(Expression(count_expr)); + if (!count_sem) { + return utils::Failure; + } + + auto* count_val = count_sem->ConstantValue(); + if (!count_val) { + AddError("array size must evaluate to a constant integer expression", count_expr->source); + return utils::Failure; + } + + if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) { + AddError("array size must evaluate to a constant integer expression, but is type '" + + builder_->FriendlyName(ty) + "'", + count_expr->source); + return utils::Failure; + } + + int64_t count = count_val->As(); + if (count < 1) { + AddError("array size (" + std::to_string(count) + ") must be greater than 0", + count_expr->source); + return utils::Failure; + } + + return static_cast(count); +} + +bool Resolver::ArrayAttributes(const ast::AttributeList& attributes, + const sem::Type* el_ty, + uint32_t& explicit_stride) { + if (!validator_.NoDuplicateAttributes(attributes)) { + return false; + } + + for (auto* attr : attributes) { Mark(attr); if (auto* sd = attr->As()) { explicit_stride = sd->stride; - if (!validator_.ArrayStrideAttribute(sd, el_size, el_align, source)) { - return nullptr; + if (!validator_.ArrayStrideAttribute(sd, el_ty->Size(), el_ty->Align())) { + return false; } continue; } AddError("attribute is not valid for array types", attr->source); - return nullptr; + return false; } - // Calculate implicit stride - uint64_t implicit_stride = utils::RoundUp(el_align, el_size); + return true; +} +sem::Array* Resolver::Array(const Source& source, + const sem::Type* el_ty, + uint32_t el_count, + uint32_t explicit_stride) { + uint32_t el_align = el_ty->Align(); + uint32_t el_size = el_ty->Size(); + uint64_t implicit_stride = el_size ? utils::RoundUp(el_align, el_size) : 0; uint64_t stride = explicit_stride ? explicit_stride : implicit_stride; - int64_t count = 0; // sem::Array uses a size of 0 for a runtime-sized array. - - // Evaluate the constant array size expression. - if (auto* count_expr = arr->count) { - const auto* count_sem = Materialize(Expression(count_expr)); - if (!count_sem) { - return nullptr; - } - - auto* count_val = count_sem->ConstantValue(); - if (!count_val) { - AddError("array size must evaluate to a constant integer expression", - count_expr->source); - return nullptr; - } - - if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) { - AddError("array size must evaluate to a constant integer expression, but is type '" + - builder_->FriendlyName(ty) + "'", - count_expr->source); - return nullptr; - } - - count = count_val->As(); - if (count < 1) { - AddError("array size (" + std::to_string(count) + ") must be greater than 0", - count_expr->source); - return nullptr; - } - } - - auto size = std::max(static_cast(count), 1u) * stride; + auto size = std::max(el_count, 1u) * stride; if (size > std::numeric_limits::max()) { std::stringstream msg; msg << "array size (0x" << std::hex << size << ") must not exceed 0xffffffff bytes"; - AddError(msg.str(), arr->source); + AddError(msg.str(), source); return nullptr; } - auto* out = builder_->create( - elem_type, static_cast(count), el_align, static_cast(size), - static_cast(stride), static_cast(implicit_stride)); + auto* out = builder_->create(el_ty, el_count, el_align, static_cast(size), + static_cast(stride), + static_cast(implicit_stride)); if (!validator_.Array(out, source)) { return nullptr; } - if (elem_type->Is()) { - atomic_composite_info_.emplace(out, arr->type->source); - } else { - auto found = atomic_composite_info_.find(elem_type); - if (found != atomic_composite_info_.end()) { - atomic_composite_info_.emplace(out, found->second); - } - } - return out; } diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index ccd3895eaf..5d38fef74b 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -286,14 +286,38 @@ class Resolver { /// @returns the resolved semantic type sem::Type* TypeDecl(const ast::TypeDecl* named_type); - /// Builds and returns the semantic information for the array `arr`. - /// This method does not mark the ast::Array node, nor attach the generated - /// semantic information to the AST node. - /// @returns the semantic Array information, or nullptr if an error is - /// raised. + /// Builds and returns the semantic information for the AST array `arr`. + /// This method does not mark the ast::Array node, nor attach the generated semantic information + /// to the AST node. + /// @returns the semantic Array information, or nullptr if an error is raised. /// @param arr the Array to get semantic information for sem::Array* Array(const ast::Array* arr); + /// Resolves and validates the expression used as the count parameter of an array. + /// @param count_expr the expression used as the second template parameter to an array<>. + /// @returns the number of elements in the array. + utils::Result ArrayCount(const ast::Expression* count_expr); + + /// Resolves and validates the attributes on an array. + /// @param attributes the attributes on the array type. + /// @param el_ty the element type of the array. + /// @param explicit_stride assigned the specified stride of the array in bytes. + /// @returns true on success, false on failure + bool ArrayAttributes(const ast::AttributeList& attributes, + const sem::Type* el_ty, + uint32_t& explicit_stride); + + /// Builds and returns the semantic information for an array. + /// @returns the semantic Array information, or nullptr if an error is raised. + /// @param source the source of the array declaration + /// @param el_ty the Array element type + /// @param el_count the number of elements in the array. Zero means runtime-sized. + /// @param explicit_stride the explicit byte stride of the array. Zero means implicit stride. + sem::Array* Array(const Source& source, + const sem::Type* el_ty, + uint32_t el_count, + uint32_t explicit_stride); + /// Builds and returns the semantic information for the alias `alias`. /// This method does not mark the ast::Alias node, nor attach the generated /// semantic information to the AST node. diff --git a/src/tint/resolver/type_constructor_validation_test.cc b/src/tint/resolver/type_constructor_validation_test.cc index 0fb01bc90d..aa0d65a331 100644 --- a/src/tint/resolver/type_constructor_validation_test.cc +++ b/src/tint/resolver/type_constructor_validation_test.cc @@ -527,9 +527,7 @@ TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Array_type_Mismat WrapInFunction(tc); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: type in array constructor does not match array type: " - "expected 'u32', found 'f32'"); + EXPECT_EQ(r()->error(), R"(12:34 error: 'f32' cannot be used to construct an array of 'u32')"); } TEST_F(ResolverTypeConstructorValidationTest, @@ -539,9 +537,7 @@ TEST_F(ResolverTypeConstructorValidationTest, WrapInFunction(tc); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: type in array constructor does not match array type: " - "expected 'f32', found 'i32'"); + EXPECT_EQ(r()->error(), R"(12:34 error: 'i32' cannot be used to construct an array of 'f32')"); } TEST_F(ResolverTypeConstructorValidationTest, @@ -552,9 +548,7 @@ TEST_F(ResolverTypeConstructorValidationTest, WrapInFunction(tc); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: type in array constructor does not match array type: " - "expected 'u32', found 'i32'"); + EXPECT_EQ(r()->error(), R"(12:34 error: 'i32' cannot be used to construct an array of 'u32')"); } TEST_F(ResolverTypeConstructorValidationTest, @@ -564,8 +558,7 @@ TEST_F(ResolverTypeConstructorValidationTest, WrapInFunction(tc); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: type in array constructor does not match array type: " - "expected 'i32', found 'vec2'"); + R"(12:34 error: 'vec2' cannot be used to construct an array of 'i32')"); } TEST_F(ResolverTypeConstructorValidationTest, @@ -579,8 +572,7 @@ TEST_F(ResolverTypeConstructorValidationTest, EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: type in array constructor does not match array type: " - "expected 'vec3', found 'vec3'"); + R"(12:34 error: 'vec3' cannot be used to construct an array of 'vec3')"); } TEST_F(ResolverTypeConstructorValidationTest, @@ -594,8 +586,7 @@ TEST_F(ResolverTypeConstructorValidationTest, EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: type in array constructor does not match array type: " - "expected 'vec3', found 'vec3'"); + R"(12:34 error: 'vec3' cannot be used to construct an array of 'vec3')"); } TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_ArrayOfArray_SubElemSizeMismatch) { @@ -607,9 +598,9 @@ TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_ArrayOfArray_SubE WrapInFunction(t); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: type in array constructor does not match array type: " - "expected 'array', found 'array'"); + EXPECT_EQ( + r()->error(), + R"(12:34 error: 'array' cannot be used to construct an array of 'array')"); } TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_ArrayOfArray_SubElemTypeMismatch) { @@ -621,9 +612,9 @@ TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_ArrayOfArray_SubE WrapInFunction(t); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: type in array constructor does not match array type: " - "expected 'array', found 'array'"); + EXPECT_EQ( + r()->error(), + R"(12:34 error: 'array' cannot be used to construct an array of 'array')"); } TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Array_TooFewElements) { @@ -656,7 +647,7 @@ TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Array_Runtime) { WrapInFunction(tc); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "error: cannot init a runtime-sized array"); + EXPECT_EQ(r()->error(), "error: cannot construct a runtime-sized array"); } TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Array_RuntimeZeroValue) { @@ -665,7 +656,7 @@ TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Array_RuntimeZero WrapInFunction(tc); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "error: cannot init a runtime-sized array"); + EXPECT_EQ(r()->error(), "error: cannot construct a runtime-sized array"); } } // namespace ArrayConstructor diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index 940e154a3d..d854447f0f 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -1875,17 +1875,16 @@ bool Validator::ArrayConstructor(const ast::CallExpression* ctor, for (auto* value : values) { auto* value_ty = sem_.TypeOf(value)->UnwrapRef(); if (value_ty != elem_ty) { - AddError( - "type in array constructor does not match array type: " - "expected '" + - sem_.TypeNameOf(elem_ty) + "', found '" + sem_.TypeNameOf(value_ty) + "'", - value->source); + AddError("'" + sem_.TypeNameOf(value_ty) + + "' cannot be used to construct an array of '" + sem_.TypeNameOf(elem_ty) + + "'", + value->source); return false; } } if (array_type->IsRuntimeSized()) { - AddError("cannot init a runtime-sized array", ctor->source); + AddError("cannot construct a runtime-sized array", ctor->source); return false; } else if (!elem_ty->IsConstructible()) { AddError("array constructor has non-constructible element type", ctor->source); @@ -2011,6 +2010,11 @@ bool Validator::PipelineStages(const std::vector& entry_points) bool Validator::Array(const sem::Array* arr, const Source& source) const { auto* el_ty = arr->ElemType(); + if (!IsPlain(el_ty)) { + AddError(sem_.TypeNameOf(el_ty) + " cannot be used as an element type of an array", source); + return false; + } + if (!IsFixedFootprint(el_ty)) { AddError("an array element type cannot contain a runtime-sized array", source); return false; @@ -2020,8 +2024,7 @@ bool Validator::Array(const sem::Array* arr, const Source& source) const { bool Validator::ArrayStrideAttribute(const ast::StrideAttribute* attr, uint32_t el_size, - uint32_t el_align, - const Source& source) const { + uint32_t el_align) const { auto stride = attr->stride; bool is_valid_stride = (stride >= el_size) && (stride >= el_align) && (stride % el_align == 0); if (!is_valid_stride) { @@ -2032,8 +2035,8 @@ bool Validator::ArrayStrideAttribute(const ast::StrideAttribute* attr, AddError( "arrays decorated with the stride attribute must have a stride " "that is at least the size of the element type, and be a multiple " - "of the element type's alignment value.", - source); + "of the element type's alignment value", + attr->source); return false; } return true; diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h index e9f9ceac4c..385e020f83 100644 --- a/src/tint/resolver/validator.h +++ b/src/tint/resolver/validator.h @@ -131,12 +131,10 @@ class Validator { /// @param attr the stride attribute to validate /// @param el_size the element size /// @param el_align the element alignment - /// @param source the source of the attribute /// @returns true on success, false otherwise bool ArrayStrideAttribute(const ast::StrideAttribute* attr, uint32_t el_size, - uint32_t el_align, - const Source& source) const; + uint32_t el_align) const; /// Validates an atomic /// @param a the atomic ast node to validate