diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index fbf443c303..057433ae58 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -410,6 +410,7 @@ libtint_source_set("libtint_core_all_src") { "sem/if_statement.h", "sem/info.h", "sem/loop_statement.h", + "sem/materialize.h", "sem/matrix.h", "sem/module.h", "sem/multisampled_texture.h", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 85ec3dbab7..c194f845fe 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -774,6 +774,7 @@ if(TINT_BUILD_TESTS) resolver/intrinsic_table_test.cc resolver/is_host_shareable_test.cc resolver/is_storeable_test.cc + resolver/materialize_test.cc resolver/pipeline_overridable_constant_test.cc resolver/ptr_ref_test.cc resolver/ptr_ref_validation_test.cc diff --git a/src/tint/resolver/intrinsic_table.h b/src/tint/resolver/intrinsic_table.h index 5a8985fea0..312f952205 100644 --- a/src/tint/resolver/intrinsic_table.h +++ b/src/tint/resolver/intrinsic_table.h @@ -43,17 +43,17 @@ class IntrinsicTable { struct UnaryOperator { /// The result type of the unary operator const sem::Type* result; - /// The type of the arg of the unary operator - const sem::Type* arg; + /// The type of the parameter of the unary operator + const sem::Type* parameter; }; /// BinaryOperator describes a resolved binary operator struct BinaryOperator { /// The result type of the binary operator const sem::Type* result; - /// The type of LHS of the binary operator + /// The type of LHS parameter of the binary operator const sem::Type* lhs; - /// The type of RHS of the binary operator + /// The type of RHS parameter of the binary operator const sem::Type* rhs; }; diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc new file mode 100644 index 0000000000..e2e1a02ff1 --- /dev/null +++ b/src/tint/resolver/materialize_test.cc @@ -0,0 +1,391 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/sem/materialize.h" + +#include "src/tint/resolver/resolver.h" +#include "src/tint/resolver/resolver_test_helper.h" +#include "src/tint/sem/test_helper.h" + +#include "gmock/gmock.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::resolver { +namespace { + +using AFloatV = builder::vec<3, AFloat>; +using AFloatM = builder::mat<3, 2, AFloat>; +using AIntV = builder::vec<3, AInt>; +using f32V = builder::vec<3, f32>; +using f16V = builder::vec<3, f16>; +using i32V = builder::vec<3, i32>; +using u32V = builder::vec<3, u32>; +using f32M = builder::mat<3, 2, f32>; + +//////////////////////////////////////////////////////////////////////////////// +// MaterializeTests +//////////////////////////////////////////////////////////////////////////////// +namespace MaterializeTests { + +// How should the materialization occur? +enum class Method { + // var a : T = literal; + kVar, + + // let a : T = literal; + kLet, + + // fn F(v : T) {} + // fn x() { + // F(literal); + // } + kFnArg, + + // min(target_expr, literal); + kBuiltinArg, + + // fn F() : T { + // return literal; + // } + kReturn, + + // array(literal); + kArray, + + // struct S { + // v : T + // }; + // fn x() { + // _ = S(literal) + // } + kStruct, + + // target_expr + literal + kBinaryOp, + + // switch (literal) { + // case target_expr: {} + // default: {} + // } + kSwitchCond, + + // switch (target_expr) { + // case literal: {} + // default: {} + // } + kSwitchCase, + + // switch (literal) { + // case 123: {} + // case target_expr: {} + // default: {} + // } + kSwitchCondWithAbstractCase, + + // switch (target_expr) { + // case 123: {} + // case literal: {} + // default: {} + // } + kSwitchCaseWithAbstractCase, +}; + +static std::ostream& operator<<(std::ostream& o, Method m) { + switch (m) { + case Method::kVar: + return o << "var"; + case Method::kLet: + return o << "let"; + case Method::kFnArg: + return o << "fn-arg"; + case Method::kBuiltinArg: + return o << "builtin-arg"; + case Method::kReturn: + return o << "return"; + case Method::kArray: + return o << "array"; + case Method::kStruct: + return o << "struct"; + case Method::kBinaryOp: + return o << "binary-op"; + case Method::kSwitchCond: + return o << "switch-cond"; + case Method::kSwitchCase: + return o << "switch-case"; + case Method::kSwitchCondWithAbstractCase: + return o << "switch-cond-with-abstract"; + case Method::kSwitchCaseWithAbstractCase: + return o << "switch-case-with-abstract"; + } + return o << ""; +} + +struct Data { + std::string target_type_name; + builder::ast_type_func_ptr target_ast_ty; + builder::sem_type_func_ptr target_sem_ty; + builder::ast_expr_func_ptr target_expr; + std::string literal_type_name; + builder::ast_expr_func_ptr literal_value; + std::variant materialized_value; +}; + +template +Data Types(MATERIALIZED_TYPE materialized_value = 0_a) { + return { + builder::DataType::Name(), // + builder::DataType::AST, // + builder::DataType::Sem, // + builder::DataType::Expr, // + builder::DataType::Name(), // + builder::DataType::Expr, // + materialized_value, + }; +} + +static std::ostream& operator<<(std::ostream& o, const Data& c) { + return o << "[" << c.target_type_name << " <- " << c.literal_type_name << "]"; +} + +enum class Expectation { + kMaterialize, + kNoMaterialize, + kInvalidCast, +}; + +static std::ostream& operator<<(std::ostream& o, Expectation m) { + switch (m) { + case Expectation::kMaterialize: + return o << "pass"; + case Expectation::kNoMaterialize: + return o << "no-materialize"; + case Expectation::kInvalidCast: + return o << "invalid-cast"; + } + return o << ""; +} + +using MaterializeAbstractNumeric = + resolver::ResolverTestWithParam>; + +TEST_P(MaterializeAbstractNumeric, Test) { + // Once F16 is properly supported, we'll need to enable this: + // Enable(ast::Extension::kF16); + + const auto& param = GetParam(); + const auto& expectation = std::get<0>(param); + const auto& method = std::get<1>(param); + const auto& data = std::get<2>(param); + + auto target_ty = [&] { return data.target_ast_ty(*this); }; + auto target_expr = [&] { return data.target_expr(*this, 42); }; + auto* literal = data.literal_value(*this, 1); + switch (method) { + case Method::kVar: + WrapInFunction(Decl(Var("a", target_ty(), literal))); + break; + case Method::kLet: + WrapInFunction(Decl(Let("a", target_ty(), literal))); + break; + case Method::kFnArg: + Func("F", {Param("P", target_ty())}, ty.void_(), {}); + WrapInFunction(CallStmt(Call("F", literal))); + break; + case Method::kBuiltinArg: + WrapInFunction(CallStmt(Call("min", target_expr(), literal))); + break; + case Method::kReturn: + Func("F", {}, target_ty(), {Return(literal)}); + break; + case Method::kArray: + WrapInFunction(Construct(ty.array(target_ty(), 1_i), literal)); + break; + case Method::kStruct: + Structure("S", {Member("v", target_ty())}); + WrapInFunction(Construct(ty.type_name("S"), literal)); + break; + case Method::kBinaryOp: + WrapInFunction(Add(target_expr(), literal)); + break; + case Method::kSwitchCond: + WrapInFunction(Switch(literal, // + Case(target_expr()->As()), // + DefaultCase())); + break; + case Method::kSwitchCase: + WrapInFunction(Switch(target_expr(), // + Case(literal->As()), // + DefaultCase())); + break; + case Method::kSwitchCondWithAbstractCase: + WrapInFunction(Switch(literal, // + Case(Expr(123_a)), // + Case(target_expr()->As()), // + DefaultCase())); + break; + case Method::kSwitchCaseWithAbstractCase: + WrapInFunction(Switch(target_expr(), // + Case(Expr(123_a)), // + Case(literal->As()), // + DefaultCase())); + break; + } + + auto check_types_and_values = [&](const sem::Expression* expr) { + auto* target_sem_ty = data.target_sem_ty(*this); + + EXPECT_TYPE(expr->Type(), target_sem_ty); + EXPECT_TYPE(expr->ConstantValue().Type(), target_sem_ty); + + uint32_t num_elems = 0; + const sem::Type* target_sem_el_ty = sem::Type::ElementOf(target_sem_ty, &num_elems); + EXPECT_TYPE(expr->ConstantValue().ElementType(), target_sem_el_ty); + std::visit( + [&](auto&& v) { + EXPECT_EQ(expr->ConstantValue().Elements(), sem::Constant::Scalars(num_elems, {v})); + }, + data.materialized_value); + }; + + switch (expectation) { + case Expectation::kMaterialize: { + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* materialize = Sem().Get(literal); + ASSERT_NE(materialize, nullptr); + check_types_and_values(materialize); + break; + } + case Expectation::kNoMaterialize: { + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* sem = Sem().Get(literal); + ASSERT_NE(sem, nullptr); + EXPECT_FALSE(sem->Is()); + check_types_and_values(sem); + break; + } + case Expectation::kInvalidCast: { + ASSERT_FALSE(r()->Resolve()); + std::string expect; + switch (method) { + case Method::kBuiltinArg: + expect = "error: no matching call to min(" + data.target_type_name + ", " + + data.literal_type_name + ")"; + break; + case Method::kBinaryOp: + expect = "error: no matching overload for operator + (" + + data.target_type_name + ", " + data.literal_type_name + ")"; + break; + default: + expect = "error: cannot convert value of type '" + data.literal_type_name + + "' to type '" + data.target_type_name + "'"; + break; + } + EXPECT_THAT(r()->error(), testing::StartsWith(expect)); + break; + } + } +} + +// TODO(crbug.com/tint/1504): Test for abstract-numeric values not fitting in materialized types. + +INSTANTIATE_TEST_SUITE_P(MaterializeScalar, + MaterializeAbstractNumeric, // + testing::Combine(testing::Values(Expectation::kMaterialize), // + testing::Values(Method::kLet, // + Method::kVar, // + Method::kFnArg, // + Method::kBuiltinArg, // + Method::kReturn, // + Method::kArray, // + Method::kStruct, // + Method::kBinaryOp), // + testing::Values(Types(1_a), // + Types(1_a), // + Types(1.0_a) // + /* Types(1.0_a), */ // + /* Types(1.0_a), */))); + +INSTANTIATE_TEST_SUITE_P(MaterializeVector, + MaterializeAbstractNumeric, // + testing::Combine(testing::Values(Expectation::kMaterialize), // + testing::Values(Method::kLet, // + Method::kVar, // + Method::kFnArg, // + Method::kBuiltinArg, // + Method::kReturn, // + Method::kArray, // + Method::kStruct, // + Method::kBinaryOp), // + testing::Values(Types(1_a), // + Types(1_a), // + Types(1.0_a) // + /* Types(1.0_a), */ // + /* Types(1.0_a), */))); + +INSTANTIATE_TEST_SUITE_P(MaterializeMatrix, + MaterializeAbstractNumeric, // + testing::Combine(testing::Values(Expectation::kMaterialize), // + testing::Values(Method::kLet, // + Method::kVar, // + Method::kFnArg, // + Method::kReturn, // + Method::kArray, // + Method::kStruct, // + Method::kBinaryOp), // + testing::Values(Types(1.0_a) // + /* Types(1.0_a), */ // + ))); + +INSTANTIATE_TEST_SUITE_P(MaterializeSwitch, + MaterializeAbstractNumeric, // + testing::Combine(testing::Values(Expectation::kMaterialize), // + testing::Values(Method::kSwitchCond, // + Method::kSwitchCase, // + Method::kSwitchCondWithAbstractCase, // + Method::kSwitchCaseWithAbstractCase), // + testing::Values(Types(1_a), // + Types(1_a)))); + +// TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops. +INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize, + MaterializeAbstractNumeric, // + testing::Combine(testing::Values(Expectation::kNoMaterialize), // + testing::Values(Method::kBuiltinArg, // + Method::kBinaryOp), // + testing::Values(Types(1_a), // + Types(1.0_a), // + Types(1_a), // + Types(1.0_a), // + Types(1.0_a)))); +INSTANTIATE_TEST_SUITE_P(InvalidCast, + MaterializeAbstractNumeric, // + testing::Combine(testing::Values(Expectation::kInvalidCast), // + testing::Values(Method::kLet, // + Method::kVar, // + Method::kFnArg, // + Method::kBuiltinArg, // + Method::kReturn, // + Method::kArray, // + Method::kStruct, // + Method::kBinaryOp), // + testing::Values(Types(), // + Types(), // + Types(), // + Types()))); + +} // namespace MaterializeTests + +} // namespace +} // namespace tint::resolver diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 270acf1777..9d3d7a4360 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -62,6 +62,7 @@ #include "src/tint/sem/function.h" #include "src/tint/sem/if_statement.h" #include "src/tint/sem/loop_statement.h" +#include "src/tint/sem/materialize.h" #include "src/tint/sem/member_accessor_expression.h" #include "src/tint/sem/module.h" #include "src/tint/sem/multisampled_texture.h" @@ -318,7 +319,11 @@ sem::Variable* Resolver::Variable(const ast::Variable* var, // Does the variable have a constructor? if (var->constructor) { - rhs = Expression(var->constructor); + auto* ctor = Expression(var->constructor); + if (!ctor) { + return nullptr; + } + rhs = Materialize(ctor, storage_ty); if (!rhs) { return nullptr; } @@ -1100,6 +1105,83 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) { return nullptr; } +const sem::Expression* Resolver::Materialize(const sem::Expression* expr, + const sem::Type* target_type /* = nullptr */) { + // Helper for actually creating the the materialize node, performing the constant cast, updating + // the ast -> sem binding, and performing validation. + auto materialize = [&](const sem::Type* target_ty) -> sem::Materialize* { + auto expr_val = EvaluateConstantValue(expr->Declaration(), expr->Type()); + if (!expr_val.IsValid()) { + TINT_ICE(Resolver, builder_->Diagnostics()) + << expr->Declaration()->source + << " EvaluateConstantValue() returned invalid value for materialized " + "value of type: " + << (expr->Type() ? expr->Type()->FriendlyName(builder_->Symbols()) : ""); + return nullptr; + } + auto materialized_val = ConstantCast(expr_val, target_ty); + auto* m = builder_->create(expr, current_statement_, materialized_val); + m->Behaviors() = expr->Behaviors(); + builder_->Sem().Replace(expr->Declaration(), m); + return validator_.Materialize(m) ? m : nullptr; + }; + + // Helpers for constructing semantic types + auto i32 = [&] { return builder_->create(); }; + auto f32 = [&] { return builder_->create(); }; + auto i32v = [&](uint32_t width) { return builder_->create(i32(), width); }; + auto f32v = [&](uint32_t width) { return builder_->create(f32(), width); }; + auto f32m = [&](uint32_t columns, uint32_t rows) { + return builder_->create(f32v(columns), rows); + }; + + // Type dispatch based on the expression type + return Switch( + expr->Type(), // + [&](const sem::AbstractInt*) { return materialize(target_type ? target_type : i32()); }, + [&](const sem::AbstractFloat*) { return materialize(target_type ? target_type : f32()); }, + [&](const sem::Vector* v) { + return Switch( + v->type(), // + [&](const sem::AbstractInt*) { + return materialize(target_type ? target_type : i32v(v->Width())); + }, + [&](const sem::AbstractFloat*) { + return materialize(target_type ? target_type : f32v(v->Width())); + }, + [&](Default) { return expr; }); + }, + [&](const sem::Matrix* m) { + return Switch( + m->type(), // + [&](const sem::AbstractFloat*) { + return materialize(target_type ? target_type : f32m(m->columns(), m->rows())); + }, + [&](Default) { return expr; }); + }, + [&](Default) { return expr; }); +} + +bool Resolver::MaterializeArguments(std::vector& args, + const sem::CallTarget* target) { + for (size_t i = 0, n = std::min(args.size(), target->Parameters().size()); i < n; i++) { + const auto* param_ty = target->Parameters()[i]->Type(); + if (ShouldMaterializeArgument(param_ty)) { + auto* materialized = Materialize(args[i], param_ty); + if (!materialized) { + return false; + } + args[i] = materialized; + } + } + return true; +} + +bool Resolver::ShouldMaterializeArgument(const sem::Type* parameter_ty) const { + const auto* param_el_ty = sem::Type::ElementOf(parameter_ty); + return param_el_ty && !param_el_ty->Is(); +} + sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* expr) { auto* idx = sem_.Get(expr->index); auto* obj = sem_.Get(expr->object); @@ -1192,6 +1274,9 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { if (!call_target) { return nullptr; } + if (!MaterializeArguments(args, call_target)) { + return nullptr; + } auto value = EvaluateConstantValue(expr, call_target->ReturnType()); return builder_->create(expr, call_target, std::move(args), current_statement_, value, has_side_effects); @@ -1227,6 +1312,9 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { } return builder_->create(arr, std::move(params)); }); + if (!MaterializeArguments(args, call_target)) { + return nullptr; + } auto value = EvaluateConstantValue(expr, call_target->ReturnType()); return builder_->create(expr, call_target, std::move(args), current_statement_, value, has_side_effects); @@ -1246,6 +1334,9 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { } return builder_->create(str, std::move(params)); }); + if (!MaterializeArguments(args, call_target)) { + return nullptr; + } auto value = EvaluateConstantValue(expr, call_target->ReturnType()); return builder_->create(expr, call_target, std::move(args), current_statement_, value, has_side_effects); @@ -1368,6 +1459,10 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, } } + if (!MaterializeArguments(args, builtin)) { + return nullptr; + } + if (builtin->IsDeprecated()) { AddWarning("use of deprecated builtin", expr->source); } @@ -1425,6 +1520,10 @@ sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr, auto sym = expr->target.name->symbol; auto name = builder_->Symbols().NameFor(sym); + if (!MaterializeArguments(args, target)) { + return nullptr; + } + // TODO(crbug.com/tint/1420): For now, assume all function calls have side // effects. bool has_side_effects = true; @@ -1715,6 +1814,18 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { if (!op.result) { return nullptr; } + if (ShouldMaterializeArgument(op.lhs)) { + lhs = Materialize(lhs, op.lhs); + if (!lhs) { + return nullptr; + } + } + if (ShouldMaterializeArgument(op.rhs)) { + rhs = Materialize(rhs, op.rhs); + if (!rhs) { + return nullptr; + } + } auto val = EvaluateConstantValue(expr, op.result); bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects(); @@ -1775,10 +1886,17 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { break; default: { - ty = intrinsic_table_->Lookup(unary->op, expr_ty, unary->source).result; - if (!ty) { + auto op = intrinsic_table_->Lookup(unary->op, expr_ty, unary->source); + if (!op.result) { return nullptr; } + if (ShouldMaterializeArgument(op.parameter)) { + expr = Materialize(expr, op.parameter); + if (!expr) { + return nullptr; + } + } + ty = op.result; break; } } @@ -2118,7 +2236,11 @@ sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) { const sem::Type* value_ty = nullptr; if (auto* value = stmt->value) { - auto* expr = Expression(value); + const auto* expr = Expression(value); + if (!expr) { + return false; + } + expr = Materialize(expr, current_function_->ReturnType()); if (!expr) { return false; } @@ -2141,22 +2263,54 @@ sem::SwitchStatement* Resolver::SwitchStatement(const ast::SwitchStatement* stmt return StatementScope(stmt, sem, [&] { auto& behaviors = sem->Behaviors(); - auto* cond = Expression(stmt->condition); + const auto* cond = Expression(stmt->condition); if (!cond) { return false; } behaviors = cond->Behaviors() - sem::Behavior::kNext; + auto* cond_ty = cond->Type()->UnwrapRef(); + + utils::UniqueVector types; + types.add(cond_ty); + + std::vector cases; + cases.reserve(stmt->body.size()); for (auto* case_stmt : stmt->body) { Mark(case_stmt); auto* c = CaseStatement(case_stmt); if (!c) { return false; } + for (auto* expr : c->Selectors()) { + types.add(expr->Type()->UnwrapRef()); + } + cases.emplace_back(c); behaviors.Add(c->Behaviors()); sem->Cases().emplace_back(c); } + // Determine the common type across all selectors and the switch expression + // This must materialize to an integer scalar (non-abstract). + auto* common_ty = sem::Type::Common(types.data(), types.size()); + if (!common_ty || !common_ty->is_integer_scalar()) { + // No common type found or the common type was abstract. + // Pick i32 and let validation deal with any mismatches. + common_ty = builder_->create(); + } + cond = Materialize(cond, common_ty); + if (!cond) { + return false; + } + for (auto* c : cases) { + for (auto*& sel : c->Selectors()) { // Note: pointer reference + sel = Materialize(sel, common_ty); + if (!sel) { + return false; + } + } + } + if (behaviors.Contains(sem::Behavior::kBreak)) { behaviors.Add(sem::Behavior::kNext); } diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 1725072a76..b03bb32c52 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -198,6 +198,30 @@ class Resolver { sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*); sem::Expression* UnaryOp(const ast::UnaryOpExpression*); + /// If `expr` is not of an abstract-numeric type, then Materialize() will just return `expr`. + /// If `expr` is of an abstract-numeric type: + /// * Materialize will create and return a sem::Materialize node wrapping `expr`. + /// * The AST -> Sem binding will be updated to point to the new sem::Materialize node. + /// * The sem::Materialize node will have a new concrete type, which will be `target_type` if + /// not nullptr, otherwise: + /// * a type with the element type of `i32` (e.g. `i32`, `vec2`) if `expr` has a + /// element type of abstract-integer... + /// * ... or a type with the element type of `f32` (e.g. `f32`, vec3`, `mat2x3`) + /// if `expr` has a element type of abstract-float. + /// * The sem::Materialize constant value will be the value of `expr` value-converted to the + /// materialized type. + const sem::Expression* Materialize(const sem::Expression* expr, + const sem::Type* target_type = nullptr); + + /// Materializes all the arguments in `args` to the parameter types of `target`. + /// @returns true on success, false on failure. + bool MaterializeArguments(std::vector& args, + const sem::CallTarget* target); + + /// @returns true if an argument of an abstract numeric type, passed to a parameter of type + /// `parameter_ty` should be materialized. + bool ShouldMaterializeArgument(const sem::Type* parameter_ty) const; + // Statement resolving methods // Each return true on success, false on failure. sem::Statement* AssignmentStatement(const ast::AssignmentStatement*); diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index 9f698e6a86..fdd74e9576 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -57,6 +57,7 @@ #include "src/tint/sem/function.h" #include "src/tint/sem/if_statement.h" #include "src/tint/sem/loop_statement.h" +#include "src/tint/sem/materialize.h" #include "src/tint/sem/member_accessor_expression.h" #include "src/tint/sem/multisampled_texture.h" #include "src/tint/sem/pointer.h" @@ -276,6 +277,19 @@ bool Validator::StorageTexture(const ast::StorageTexture* t) const { return true; } +bool Validator::Materialize(const sem::Materialize* m) const { + auto* from = m->Expr()->Type(); + auto* to = m->Type(); + + if (sem::Type::ConversionRank(from, to) == sem::Type::kNoConversion) { + AddError("cannot convert value of type '" + sem_.TypeNameOf(from) + "' to type '" + + sem_.TypeNameOf(to) + "'", + m->Expr()->Declaration()->source); + return false; + } + return true; +} + bool Validator::VariableConstructorOrCast(const ast::Variable* var, ast::StorageClass storage_class, const sem::Type* storage_ty, diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h index a8c18d5894..b30fdc72a1 100644 --- a/src/tint/resolver/validator.h +++ b/src/tint/resolver/validator.h @@ -54,6 +54,7 @@ class CaseStatement; class ForLoopStatement; class IfStatement; class LoopStatement; +class Materialize; class Statement; class SwitchStatement; class TypeConstructor; @@ -275,6 +276,11 @@ class Validator { /// @returns true on success, false otherwise. bool LoopStatement(const sem::LoopStatement* stmt) const; + /// Validates a materialize of an abstract numeric value + /// @param m the materialize to validate + /// @returns true on success, false otherwise + bool Materialize(const sem::Materialize* m) const; + /// Validates a matrix /// @param ty the matrix to validate /// @param source the source of the matrix diff --git a/src/tint/sem/info.h b/src/tint/sem/info.h index 66b2cd531f..41321cff57 100644 --- a/src/tint/sem/info.h +++ b/src/tint/sem/info.h @@ -70,8 +70,7 @@ class Info { return As(it->second); } - /// Add registers the semantic node `sem_node` for the AST or type node - /// `node`. + /// Add registers the semantic node `sem_node` for the AST or type node `node`. /// @param node the AST or type node /// @param sem_node the semantic node template @@ -81,6 +80,14 @@ class Info { map_.emplace(node, sem_node); } + /// Replace replaces any existing semantic node `sem_node` for the AST or type node `node`. + /// @param node the AST or type node + /// @param sem_node the new semantic node + template + void Replace(const AST_OR_TYPE* node, const SemanticNodeTypeFor* sem_node) { + map_[node] = sem_node; + } + /// Wrap returns a new Info created with the contents of `inner`. /// The Info returned by Wrap is intended to temporarily extend the contents /// of an existing immutable Info. diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn index 2ef120246c..4f1619f1ee 100644 --- a/test/tint/BUILD.gn +++ b/test/tint/BUILD.gn @@ -257,6 +257,7 @@ tint_unittests_source_set("tint_unittests_resolver_src") { "../../src/tint/resolver/intrinsic_table_test.cc", "../../src/tint/resolver/is_host_shareable_test.cc", "../../src/tint/resolver/is_storeable_test.cc", + "../../src/tint/resolver/materialize_test.cc", "../../src/tint/resolver/pipeline_overridable_constant_test.cc", "../../src/tint/resolver/ptr_ref_test.cc", "../../src/tint/resolver/ptr_ref_validation_test.cc",