tint: Refactor sem::Constant to be less memory-hungry

Change sem::Constant to be an interface to the constant data. Implement
this so that zero-initialized data doesn't need to allocate the full
size of the type.

This also makes usage a lot cleaner (no more flattened-list of
elements!), and gives us a clear path for supporting constant
structures if/when we want to support them.

Bug: chromium:1339558
Bug: chromium:1339561
Bug: chromium:1339580
Bug: chromium:1339597
Change-Id: Ifcd456f69aee18d5b84befa896d7b0189d68c2dd
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94942
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
Ben Clayton 2022-06-29 19:07:30 +00:00 committed by Dawn LUCI CQ
parent 7a64127a41
commit aa037ac489
41 changed files with 2255 additions and 1786 deletions

View File

@ -1131,7 +1131,6 @@ if (tint_build_unittests) {
"sem/atomic_test.cc", "sem/atomic_test.cc",
"sem/bool_test.cc", "sem/bool_test.cc",
"sem/builtin_test.cc", "sem/builtin_test.cc",
"sem/constant_test.cc",
"sem/depth_multisampled_texture_test.cc", "sem/depth_multisampled_texture_test.cc",
"sem/depth_texture_test.cc", "sem/depth_texture_test.cc",
"sem/expression_test.cc", "sem/expression_test.cc",

View File

@ -810,7 +810,6 @@ if(TINT_BUILD_TESTS)
sem/atomic.cc sem/atomic.cc
sem/bool_test.cc sem/bool_test.cc
sem/builtin_test.cc sem/builtin_test.cc
sem/constant_test.cc
sem/depth_multisampled_texture_test.cc sem/depth_multisampled_texture_test.cc
sem/depth_texture_test.cc sem/depth_texture_test.cc
sem/expression_test.cc sem/expression_test.cc

View File

@ -38,6 +38,7 @@ Program::Program(Program&& program)
types_(std::move(program.types_)), types_(std::move(program.types_)),
ast_nodes_(std::move(program.ast_nodes_)), ast_nodes_(std::move(program.ast_nodes_)),
sem_nodes_(std::move(program.sem_nodes_)), sem_nodes_(std::move(program.sem_nodes_)),
constant_nodes_(std::move(program.constant_nodes_)),
ast_(std::move(program.ast_)), ast_(std::move(program.ast_)),
sem_(std::move(program.sem_)), sem_(std::move(program.sem_)),
symbols_(std::move(program.symbols_)), symbols_(std::move(program.symbols_)),
@ -62,6 +63,7 @@ Program::Program(ProgramBuilder&& builder) {
types_ = std::move(builder.Types()); types_ = std::move(builder.Types());
ast_nodes_ = std::move(builder.ASTNodes()); ast_nodes_ = std::move(builder.ASTNodes());
sem_nodes_ = std::move(builder.SemNodes()); sem_nodes_ = std::move(builder.SemNodes());
constant_nodes_ = std::move(builder.ConstantNodes());
ast_ = &builder.AST(); // ast::Module is actually a heap allocation. ast_ = &builder.AST(); // ast::Module is actually a heap allocation.
sem_ = std::move(builder.Sem()); sem_ = std::move(builder.Sem());
symbols_ = std::move(builder.Symbols()); symbols_ = std::move(builder.Symbols());
@ -86,6 +88,7 @@ Program& Program::operator=(Program&& program) {
types_ = std::move(program.types_); types_ = std::move(program.types_);
ast_nodes_ = std::move(program.ast_nodes_); ast_nodes_ = std::move(program.ast_nodes_);
sem_nodes_ = std::move(program.sem_nodes_); sem_nodes_ = std::move(program.sem_nodes_);
constant_nodes_ = std::move(program.constant_nodes_);
ast_ = std::move(program.ast_); ast_ = std::move(program.ast_);
sem_ = std::move(program.sem_); sem_ = std::move(program.sem_);
symbols_ = std::move(program.symbols_); symbols_ = std::move(program.symbols_);

View File

@ -20,6 +20,7 @@
#include "src/tint/ast/function.h" #include "src/tint/ast/function.h"
#include "src/tint/program_id.h" #include "src/tint/program_id.h"
#include "src/tint/sem/constant.h"
#include "src/tint/sem/info.h" #include "src/tint/sem/info.h"
#include "src/tint/sem/type_manager.h" #include "src/tint/sem/type_manager.h"
#include "src/tint/symbol_table.h" #include "src/tint/symbol_table.h"
@ -43,6 +44,9 @@ class Program {
/// SemNodeAllocator is an alias to BlockAllocator<sem::Node> /// SemNodeAllocator is an alias to BlockAllocator<sem::Node>
using SemNodeAllocator = utils::BlockAllocator<sem::Node>; using SemNodeAllocator = utils::BlockAllocator<sem::Node>;
/// ConstantAllocator is an alias to BlockAllocator<sem::Constant>
using ConstantAllocator = utils::BlockAllocator<sem::Constant>;
/// Constructor /// Constructor
Program(); Program();
@ -160,6 +164,7 @@ class Program {
sem::Manager types_; sem::Manager types_;
ASTNodeAllocator ast_nodes_; ASTNodeAllocator ast_nodes_;
SemNodeAllocator sem_nodes_; SemNodeAllocator sem_nodes_;
ConstantAllocator constant_nodes_;
ast::Module* ast_ = nullptr; ast::Module* ast_ = nullptr;
sem::Info sem_; sem::Info sem_;
SymbolTable symbols_{id_}; SymbolTable symbols_{id_};

View File

@ -88,6 +88,7 @@
#include "src/tint/program_id.h" #include "src/tint/program_id.h"
#include "src/tint/sem/array.h" #include "src/tint/sem/array.h"
#include "src/tint/sem/bool.h" #include "src/tint/sem/bool.h"
#include "src/tint/sem/constant.h"
#include "src/tint/sem/depth_texture.h" #include "src/tint/sem/depth_texture.h"
#include "src/tint/sem/external_texture.h" #include "src/tint/sem/external_texture.h"
#include "src/tint/sem/f16.h" #include "src/tint/sem/f16.h"
@ -163,6 +164,9 @@ class ProgramBuilder {
/// SemNodeAllocator is an alias to BlockAllocator<sem::Node> /// SemNodeAllocator is an alias to BlockAllocator<sem::Node>
using SemNodeAllocator = utils::BlockAllocator<sem::Node>; using SemNodeAllocator = utils::BlockAllocator<sem::Node>;
/// ConstantAllocator is an alias to BlockAllocator<sem::Constant>
using ConstantAllocator = utils::BlockAllocator<sem::Constant>;
/// Constructor /// Constructor
ProgramBuilder(); ProgramBuilder();
@ -229,6 +233,12 @@ class ProgramBuilder {
return sem_nodes_; return sem_nodes_;
} }
/// @returns a reference to the program's semantic constant storage
ConstantAllocator& ConstantNodes() {
AssertNotMoved();
return constant_nodes_;
}
/// @returns a reference to the program's AST root Module /// @returns a reference to the program's AST root Module
ast::Module& AST() { ast::Module& AST() {
AssertNotMoved(); AssertNotMoved();
@ -332,9 +342,8 @@ class ProgramBuilder {
} }
/// Creates a new sem::Node owned by the ProgramBuilder. /// Creates a new sem::Node owned by the ProgramBuilder.
/// When the ProgramBuilder is destructed, the sem::Node will also be /// When the ProgramBuilder is destructed, the sem::Node will also be destructed.
/// destructed. /// @param args the arguments to pass to the constructor
/// @param args the arguments to pass to the type constructor
/// @returns the node pointer /// @returns the node pointer
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
traits::EnableIf<traits::IsTypeOrDerived<T, sem::Node> && traits::EnableIf<traits::IsTypeOrDerived<T, sem::Node> &&
@ -345,6 +354,16 @@ class ProgramBuilder {
return sem_nodes_.Create<T>(std::forward<ARGS>(args)...); return sem_nodes_.Create<T>(std::forward<ARGS>(args)...);
} }
/// Creates a new sem::Constant owned by the ProgramBuilder.
/// When the ProgramBuilder is destructed, the sem::Node will also be destructed.
/// @param args the arguments to pass to the constructor
/// @returns the node pointer
template <typename T, typename... ARGS>
traits::EnableIf<traits::IsTypeOrDerived<T, sem::Constant>, T>* create(ARGS&&... args) {
AssertNotMoved();
return constant_nodes_.Create<T>(std::forward<ARGS>(args)...);
}
/// Creates a new sem::Type owned by the ProgramBuilder. /// Creates a new sem::Type owned by the ProgramBuilder.
/// When the ProgramBuilder is destructed, owned ProgramBuilder and the /// When the ProgramBuilder is destructed, owned ProgramBuilder and the
/// returned`Type` will also be destructed. /// returned`Type` will also be destructed.
@ -2747,6 +2766,7 @@ class ProgramBuilder {
sem::Manager types_; sem::Manager types_;
ASTNodeAllocator ast_nodes_; ASTNodeAllocator ast_nodes_;
SemNodeAllocator sem_nodes_; SemNodeAllocator sem_nodes_;
ConstantAllocator constant_nodes_;
ast::Module* ast_; ast::Module* ast_;
sem::Info sem_; sem::Info sem_;
SymbolTable symbols_{id_}; SymbolTable symbols_{id_};

View File

@ -30,7 +30,9 @@ class Constant;
namespace tint::resolver::const_eval { namespace tint::resolver::const_eval {
/// Typedef for a constant evaluation function /// Typedef for a constant evaluation function
using Function = sem::Constant(ProgramBuilder& builder, const sem::Constant* args, size_t num_args); using Function = const sem::Constant*(ProgramBuilder& builder,
sem::Constant const* const* args,
size_t num_args);
} // namespace tint::resolver::const_eval } // namespace tint::resolver::const_eval

View File

@ -73,12 +73,61 @@ static std::ostream& operator<<(std::ostream& o, Expectation m) {
return o << "<unknown>"; return o << "<unknown>";
} }
template <typename CASE>
class MaterializeTest : public resolver::ResolverTestWithParam<CASE> {
protected:
using ProgramBuilder::FriendlyName;
void CheckTypesAndValues(const sem::Expression* expr,
const tint::sem::Type* expected_sem_ty,
const std::variant<AInt, AFloat>& expected_value) {
std::visit([&](auto v) { CheckTypesAndValuesImpl(expr, expected_sem_ty, v); },
expected_value);
}
private:
template <typename T>
void CheckTypesAndValuesImpl(const sem::Expression* expr,
const tint::sem::Type* expected_sem_ty,
T expected_value) {
EXPECT_TYPE(expr->Type(), expected_sem_ty);
auto* value = expr->ConstantValue();
ASSERT_NE(value, nullptr);
EXPECT_TYPE(expr->Type(), value->Type());
tint::Switch(
expected_sem_ty, //
[&](const sem::Vector* v) {
for (uint32_t i = 0; i < v->Width(); i++) {
auto* el = value->Index(i);
ASSERT_NE(el, nullptr);
EXPECT_TYPE(el->Type(), v->type());
EXPECT_EQ(std::get<T>(el->Value()), expected_value);
}
},
[&](const sem::Matrix* m) {
for (uint32_t c = 0; c < m->columns(); c++) {
auto* column = value->Index(c);
ASSERT_NE(column, nullptr);
EXPECT_TYPE(column->Type(), m->ColumnType());
for (uint32_t r = 0; r < m->rows(); r++) {
auto* el = column->Index(r);
ASSERT_NE(el, nullptr);
EXPECT_TYPE(el->Type(), m->type());
EXPECT_EQ(std::get<T>(el->Value()), expected_value);
}
}
},
[&](Default) { EXPECT_EQ(std::get<T>(value->Value()), expected_value); });
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// MaterializeAbstractNumericToConcreteType // MaterializeAbstractNumericToConcreteType
// Tests that an abstract-numeric will materialize to the expected concrete type // Tests that an abstract-numeric will materialize to the expected concrete type
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
namespace materialize_abstract_numeric_to_concrete_type { namespace materialize_abstract_numeric_to_concrete_type {
// How should the materialization occur? // How should the materialization occur?
enum class Method { enum class Method {
// var a : target_type = abstract_expr; // var a : target_type = abstract_expr;
@ -247,7 +296,7 @@ static std::ostream& operator<<(std::ostream& o, const Data& c) {
} }
using MaterializeAbstractNumericToConcreteType = using MaterializeAbstractNumericToConcreteType =
resolver::ResolverTestWithParam<std::tuple<Expectation, Method, Data>>; MaterializeTest<std::tuple<Expectation, Method, Data>>;
TEST_P(MaterializeAbstractNumericToConcreteType, Test) { TEST_P(MaterializeAbstractNumericToConcreteType, Test) {
// Once built-in and ops using f16 is properly supported, we'll need to enable this: // Once built-in and ops using f16 is properly supported, we'll need to enable this:
@ -323,30 +372,12 @@ TEST_P(MaterializeAbstractNumericToConcreteType, Test) {
break; break;
} }
auto check_types_and_values = [&](const sem::Expression* expr) {
auto* target_sem_ty = data.target_sem_ty(*this);
EXPECT_TYPE(expr->Type(), target_sem_ty);
EXPECT_TYPE(expr->ConstantValue().Type(), target_sem_ty);
uint32_t num_elems = 0;
const sem::Type* target_sem_el_ty = sem::Type::DeepestElementOf(target_sem_ty, &num_elems);
EXPECT_TYPE(expr->ConstantValue().ElementType(), target_sem_el_ty);
expr->ConstantValue().WithElements([&](auto&& vec) {
using VEC_TY = std::decay_t<decltype(vec)>;
using EL_TY = typename VEC_TY::value_type;
ASSERT_TRUE(std::holds_alternative<EL_TY>(data.materialized_value));
VEC_TY expected(num_elems, std::get<EL_TY>(data.materialized_value));
EXPECT_EQ(vec, expected);
});
};
switch (expectation) { switch (expectation) {
case Expectation::kMaterialize: { case Expectation::kMaterialize: {
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* materialize = Sem().Get<sem::Materialize>(abstract_expr); auto* materialize = Sem().Get<sem::Materialize>(abstract_expr);
ASSERT_NE(materialize, nullptr); ASSERT_NE(materialize, nullptr);
check_types_and_values(materialize); CheckTypesAndValues(materialize, data.target_sem_ty(*this), data.materialized_value);
break; break;
} }
case Expectation::kNoMaterialize: { case Expectation::kNoMaterialize: {
@ -354,7 +385,7 @@ TEST_P(MaterializeAbstractNumericToConcreteType, Test) {
auto* sem = Sem().Get(abstract_expr); auto* sem = Sem().Get(abstract_expr);
ASSERT_NE(sem, nullptr); ASSERT_NE(sem, nullptr);
EXPECT_FALSE(sem->Is<sem::Materialize>()); EXPECT_FALSE(sem->Is<sem::Materialize>());
check_types_and_values(sem); CheckTypesAndValues(sem, data.target_sem_ty(*this), data.materialized_value);
break; break;
} }
case Expectation::kInvalidConversion: { case Expectation::kInvalidConversion: {
@ -414,8 +445,8 @@ constexpr Method kSwitchMethods[] = {
/// Methods that do not materialize /// Methods that do not materialize
constexpr Method kNoMaterializeMethods[] = { constexpr Method kNoMaterializeMethods[] = {
Method::kPhonyAssign, Method::kPhonyAssign,
// TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops: // TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary
// Method::kBuiltinArg, Method::kBinaryOp, // ops: Method::kBuiltinArg, Method::kBinaryOp,
}; };
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
MaterializeScalar, MaterializeScalar,
@ -703,7 +734,7 @@ static std::ostream& operator<<(std::ostream& o, const Data& c) {
} }
using MaterializeAbstractNumericToDefaultType = using MaterializeAbstractNumericToDefaultType =
resolver::ResolverTestWithParam<std::tuple<Expectation, Method, Data>>; MaterializeTest<std::tuple<Expectation, Method, Data>>;
TEST_P(MaterializeAbstractNumericToDefaultType, Test) { TEST_P(MaterializeAbstractNumericToDefaultType, Test) {
const auto& param = GetParam(); const auto& param = GetParam();
@ -751,32 +782,14 @@ TEST_P(MaterializeAbstractNumericToDefaultType, Test) {
break; break;
} }
auto check_types_and_values = [&](const sem::Expression* expr) {
auto* expected_sem_ty = data.expected_sem_ty(*this);
EXPECT_TYPE(expr->Type(), expected_sem_ty);
EXPECT_TYPE(expr->ConstantValue().Type(), expected_sem_ty);
uint32_t num_elems = 0;
const sem::Type* expected_sem_el_ty =
sem::Type::DeepestElementOf(expected_sem_ty, &num_elems);
EXPECT_TYPE(expr->ConstantValue().ElementType(), expected_sem_el_ty);
expr->ConstantValue().WithElements([&](auto&& vec) {
using VEC_TY = std::decay_t<decltype(vec)>;
using EL_TY = typename VEC_TY::value_type;
ASSERT_TRUE(std::holds_alternative<EL_TY>(data.materialized_value));
VEC_TY expected(num_elems, std::get<EL_TY>(data.materialized_value));
EXPECT_EQ(vec, expected);
});
};
switch (expectation) { switch (expectation) {
case Expectation::kMaterialize: { case Expectation::kMaterialize: {
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
for (auto* expr : abstract_exprs) { for (auto* expr : abstract_exprs) {
auto* materialize = Sem().Get<sem::Materialize>(expr); auto* materialize = Sem().Get<sem::Materialize>(expr);
ASSERT_NE(materialize, nullptr); ASSERT_NE(materialize, nullptr);
check_types_and_values(materialize); CheckTypesAndValues(materialize, data.expected_sem_ty(*this),
data.materialized_value);
} }
break; break;
} }

View File

@ -365,13 +365,13 @@ sem::Variable* Resolver::Let(const ast::Let* v, bool is_global) {
sem::Variable* sem = nullptr; sem::Variable* sem = nullptr;
if (is_global) { if (is_global) {
sem = builder_->create<sem::GlobalVariable>(v, ty, ast::StorageClass::kNone, sem = builder_->create<sem::GlobalVariable>(
ast::Access::kUndefined, sem::Constant{}, v, ty, ast::StorageClass::kNone, ast::Access::kUndefined, /* constant_value */ nullptr,
sem::BindingPoint{}); sem::BindingPoint{});
} else { } else {
sem = builder_->create<sem::LocalVariable>(v, ty, ast::StorageClass::kNone, sem = builder_->create<sem::LocalVariable>(v, ty, ast::StorageClass::kNone,
ast::Access::kUndefined, current_statement_, ast::Access::kUndefined, current_statement_,
sem::Constant{}); /* constant_value */ nullptr);
} }
sem->SetConstructor(rhs); sem->SetConstructor(rhs);
@ -419,8 +419,8 @@ sem::Variable* Resolver::Override(const ast::Override* v) {
return nullptr; return nullptr;
} }
auto* sem = builder_->create<sem::GlobalVariable>(v, ty, ast::StorageClass::kNone, auto* sem = builder_->create<sem::GlobalVariable>(
ast::Access::kUndefined, sem::Constant{}, v, ty, ast::StorageClass::kNone, ast::Access::kUndefined, /* constant_value */ nullptr,
sem::BindingPoint{}); sem::BindingPoint{});
if (auto* id = ast::GetAttribute<ast::IdAttribute>(v->attributes)) { if (auto* id = ast::GetAttribute<ast::IdAttribute>(v->attributes)) {
@ -564,11 +564,11 @@ sem::Variable* Resolver::Var(const ast::Var* var, bool is_global) {
binding_point = {bp.group->value, bp.binding->value}; binding_point = {bp.group->value, bp.binding->value};
} }
sem = builder_->create<sem::GlobalVariable>(var, var_ty, storage_class, access, sem = builder_->create<sem::GlobalVariable>(var, var_ty, storage_class, access,
sem::Constant{}, binding_point); /* constant_value */ nullptr, binding_point);
} else { } else {
sem = builder_->create<sem::LocalVariable>(var, var_ty, storage_class, access, sem = builder_->create<sem::LocalVariable>(
current_statement_, sem::Constant{}); var, var_ty, storage_class, access, current_statement_, /* constant_value */ nullptr);
} }
sem->SetConstructor(rhs); sem->SetConstructor(rhs);
@ -916,7 +916,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
return false; return false;
} }
sem::Constant value; const sem::Constant* value = nullptr;
if (auto* user = args[i]->As<sem::VariableUser>()) { if (auto* user = args[i]->As<sem::VariableUser>()) {
// We have an variable of a module-scope constant. // We have an variable of a module-scope constant.
@ -950,12 +950,12 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
continue; continue;
} }
// validator_.Validate and set the default value for this dimension. // validator_.Validate and set the default value for this dimension.
if (value.Element<AInt>(0).value < 1) { if (value->As<AInt>() < 1) {
AddError("workgroup_size argument must be at least 1", values[i]->source); AddError("workgroup_size argument must be at least 1", values[i]->source);
return false; return false;
} }
ws[i].value = value.Element<uint32_t>(0); ws[i].value = value->As<uint32_t>();
} }
current_function_->SetWorkgroupSize(std::move(ws)); current_function_->SetWorkgroupSize(std::move(ws));
@ -1266,7 +1266,8 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) {
[&](const ast::UnaryOpExpression* unary) -> sem::Expression* { return UnaryOp(unary); }, [&](const ast::UnaryOpExpression* unary) -> sem::Expression* { return UnaryOp(unary); },
[&](const ast::PhonyExpression*) -> sem::Expression* { [&](const ast::PhonyExpression*) -> sem::Expression* {
return builder_->create<sem::Expression>(expr, builder_->create<sem::Void>(), return builder_->create<sem::Expression>(expr, builder_->create<sem::Void>(),
current_statement_, sem::Constant{}, current_statement_,
/* constant_value */ nullptr,
/* has_side_effects */ false); /* has_side_effects */ false);
}, },
[&](Default) { [&](Default) {
@ -1309,13 +1310,14 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr,
<< ") returned invalid value"; << ") returned invalid value";
return nullptr; return nullptr;
} }
auto materialized_val = ConvertValue(std::move(expr_val), target_ty, decl->source); auto materialized_val = ConvertValue(expr_val, target_ty, decl->source);
if (!materialized_val) { if (!materialized_val) {
// ConvertValue() has already failed and raised an diagnostic error.
return nullptr; return nullptr;
} }
if (!materialized_val->IsValid()) { if (!materialized_val.Get()) {
TINT_ICE(Resolver, builder_->Diagnostics()) TINT_ICE(Resolver, builder_->Diagnostics())
<< decl->source << "ConvertValue(" << builder_->FriendlyName(expr_val.Type()) << decl->source << "ConvertValue(" << builder_->FriendlyName(expr_val->Type())
<< " -> " << builder_->FriendlyName(target_ty) << ") returned invalid value"; << " -> " << builder_->FriendlyName(target_ty) << ") returned invalid value";
return nullptr; return nullptr;
} }
@ -1678,9 +1680,9 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
} }
// If the builtin is @const, and all arguments have constant values, evaluate the builtin now. // If the builtin is @const, and all arguments have constant values, evaluate the builtin now.
sem::Constant constant; const sem::Constant* constant = nullptr;
if (builtin.const_eval_fn) { if (builtin.const_eval_fn) {
std::vector<sem::Constant> values(args.size()); std::vector<const sem::Constant*> values(args.size());
bool is_const = true; // all arguments have constant values bool is_const = true; // all arguments have constant values
for (size_t i = 0; i < values.size(); i++) { for (size_t i = 0; i < values.size(); i++) {
if (auto v = args[i]->ConstantValue()) { if (auto v = args[i]->ConstantValue()) {
@ -1757,7 +1759,7 @@ sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr,
// effects. // effects.
bool has_side_effects = true; bool has_side_effects = true;
auto* call = builder_->create<sem::Call>(expr, target, std::move(args), current_statement_, auto* call = builder_->create<sem::Call>(expr, target, std::move(args), current_statement_,
sem::Constant{}, has_side_effects); /* constant_value */ nullptr, has_side_effects);
target->AddCallSite(call); target->AddCallSite(call);
@ -2226,21 +2228,21 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return nullptr; return nullptr;
} }
auto count_val = count_sem->ConstantValue(); auto* count_val = count_sem->ConstantValue();
if (!count_val) { if (!count_val) {
AddError("array size must evaluate to a constant integer expression", AddError("array size must evaluate to a constant integer expression",
count_expr->source); count_expr->source);
return nullptr; return nullptr;
} }
if (auto* ty = count_val.Type(); !ty->is_integer_scalar()) { if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) {
AddError("array size must evaluate to a constant integer expression, but is type '" + AddError("array size must evaluate to a constant integer expression, but is type '" +
builder_->FriendlyName(ty) + "'", builder_->FriendlyName(ty) + "'",
count_expr->source); count_expr->source);
return nullptr; return nullptr;
} }
count = count_val.Element<AInt>(0).value; count = count_val->As<AInt>();
if (count < 1) { if (count < 1) {
AddError("array size (" + std::to_string(count) + ") must be greater than 0", AddError("array size (" + std::to_string(count) + ") must be greater than 0",
count_expr->source); count_expr->source);

View File

@ -209,22 +209,30 @@ class Resolver {
/// These methods are called from the expression resolving methods, and so child-expression /// These methods are called from the expression resolving methods, and so child-expression
/// nodes are guaranteed to have been already resolved and any constant values calculated. /// nodes are guaranteed to have been already resolved and any constant values calculated.
//////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////
sem::Constant EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type); const sem::Constant* EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
sem::Constant EvaluateConstantValue(const ast::IdentifierExpression* ident, const sem::Constant* EvaluateConstantValue(const ast::IdentifierExpression* ident,
const sem::Type* type); const sem::Type* type);
sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal, const sem::Constant* EvaluateConstantValue(const ast::LiteralExpression* literal,
const sem::Type* type); const sem::Type* type);
sem::Constant EvaluateConstantValue(const ast::CallExpression* call, const sem::Type* type); const sem::Constant* EvaluateConstantValue(const ast::CallExpression* call,
sem::Constant EvaluateConstantValue(const ast::IndexAccessorExpression* call, const sem::Type* type);
const sem::Constant* EvaluateConstantValue(const ast::IndexAccessorExpression* call,
const sem::Type* type); const sem::Type* type);
/// The result type of a ConstantEvaluation method. Holds the constant value and a boolean, /// The result type of a ConstantEvaluation method.
/// which is true on success, false on an error. /// Can be one of three distinct values:
using ConstantResult = utils::Result<sem::Constant>; /// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time
/// value.
/// * A null sem::Constant pointer. Returned when a expression cannot resolve to a creation time
/// value, but is otherwise legal.
/// * `utils::Failure`. Returned when there was a resolver error. In this situation the method
/// will have already reported a diagnostic error message, and the caller should abort
/// resolving.
using ConstantResult = utils::Result<const sem::Constant*>;
/// Convert the `value` to `target_type` /// Convert the `value` to `target_type`
/// @return the converted value /// @return the converted value
ConstantResult ConvertValue(const sem::Constant& value, ConstantResult ConvertValue(const sem::Constant* value,
const sem::Type* target_type, const sem::Type* target_type,
const Source& source); const Source& source);

View File

@ -14,7 +14,6 @@
#include "src/tint/resolver/resolver.h" #include "src/tint/resolver/resolver.h"
#include <cmath>
#include <optional> #include <optional>
#include "src/tint/sem/abstract_float.h" #include "src/tint/sem/abstract_float.h"
@ -22,8 +21,6 @@
#include "src/tint/sem/constant.h" #include "src/tint/sem/constant.h"
#include "src/tint/sem/type_constructor.h" #include "src/tint/sem/type_constructor.h"
#include "src/tint/utils/compiler_macros.h" #include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/transform.h"
using namespace tint::number_suffixes; // NOLINT using namespace tint::number_suffixes; // NOLINT
@ -31,127 +28,334 @@ namespace tint::resolver {
namespace { namespace {
/// Converts and returns all the element values of `in` to the type `T`, using the converter /// TypeDispatch is a helper for calling the function `f`, passing a single zero-value argument of
/// function `CONVERTER`. /// the C++ type that corresponds to the sem::Type `type`. For example, calling `TypeDispatch()`
/// @param elements_in the vector of elements to be converted /// with a type of `sem::I32*` will call the function f with a single argument of `i32(0)`.
/// @param converter a function-like with the signature `void(TO&, FROM)` /// @returns the value returned by calling `f`.
/// @returns the elements converted to type T. /// @note `type` must be a scalar or abstract numeric type. Other types will not call `f`, and will
template <typename T, typename ELEMENTS_IN, typename CONVERTER> /// return the zero-initialized value of the return type for `f`.
sem::Constant::Elements Transform(const ELEMENTS_IN& elements_in, CONVERTER&& converter) { template <typename F>
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE); auto TypeDispatch(const sem::Type* type, F&& f) {
return utils::Transform(elements_in, [&](auto value_in) {
if constexpr (std::is_same_v<UnwrapNumber<T>, bool>) {
return AInt(value_in != 0);
} else {
T converted{};
converter(converted, value_in);
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
return AFloat(converted);
} else {
return AInt(converted);
}
}
});
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
/// Converts and returns all the element values of `in` to the semantic type `el_ty`, using the
/// converter function `CONVERTER`.
/// @param in the constant to convert
/// @param el_ty the target element type
/// @param converter a function-like with the signature `void(TO&, FROM)`
/// @returns the elements converted to `el_ty`
template <typename CONVERTER>
sem::Constant::Elements Transform(const sem::Constant::Elements& in,
const sem::Type* el_ty,
CONVERTER&& converter) {
return std::visit(
[&](auto&& v) {
return Switch( return Switch(
el_ty, // type, //
[&](const sem::AbstractInt*) { return Transform<AInt>(v, converter); }, [&](const sem::AbstractInt*) { return f(AInt(0)); }, //
[&](const sem::AbstractFloat*) { return Transform<AFloat>(v, converter); }, [&](const sem::AbstractFloat*) { return f(AFloat(0)); }, //
[&](const sem::I32*) { return Transform<i32>(v, converter); }, [&](const sem::I32*) { return f(i32(0)); }, //
[&](const sem::U32*) { return Transform<u32>(v, converter); }, [&](const sem::U32*) { return f(u32(0)); }, //
[&](const sem::F32*) { return Transform<f32>(v, converter); }, [&](const sem::F32*) { return f(f32(0)); }, //
[&](const sem::F16*) { return Transform<f16>(v, converter); }, [&](const sem::F16*) { return f(f16(0)); }, //
[&](const sem::Bool*) { return Transform<bool>(v, converter); }, [&](const sem::Bool*) { return f(static_cast<bool>(0)); });
[&](Default) -> sem::Constant::Elements {
diag::List diags;
TINT_UNREACHABLE(Semantic, diags)
<< "invalid element type " << el_ty->TypeInfo().name;
return {};
});
},
in);
} }
/// Converts and returns all the elements in `in` to the type `el_ty`. /// @returns `value` if `T` is not a Number, otherwise ValueOf returns the inner value of the
/// If the value does not fit in the target type, and: /// Number.
/// * the target type is an integer type, then the resulting value will be clamped to the integer's template <typename T>
/// highest or lowest value. inline auto ValueOf(T value) {
/// * the target type is an float type, then the resulting value will be either positive or if constexpr (std::is_same_v<UnwrapNumber<T>, T>) {
/// negative infinity, based on the sign of the input value. return value;
/// @param in the input elements
/// @param el_ty the target element type
/// @returns the elements converted to `el_ty`
sem::Constant::Elements ConvertElements(const sem::Constant::Elements& in, const sem::Type* el_ty) {
return Transform(in, el_ty, [](auto& el_out, auto el_in) {
using OUT = std::decay_t<decltype(el_out)>;
if (auto conv = CheckedConvert<OUT>(el_in)) {
el_out = conv.Get();
} else { } else {
return value.value;
}
}
/// @returns true if `value` is a positive zero.
template <typename T>
inline bool IsPositiveZero(T value) {
using N = UnwrapNumber<T>;
return Number<N>(value) == Number<N>(0); // Considers sign bit
}
/// Constant inherits from sem::Constant to add an private implementation method for conversion.
struct Constant : public sem::Constant {
/// Convert attempts to convert the constant value to the given type. On error, Convert()
/// creates a new diagnostic message and returns a Failure.
virtual utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const = 0;
};
// Forward declaration
const Constant* CreateComposite(ProgramBuilder& builder,
const sem::Type* type,
std::vector<const Constant*> elements);
/// Element holds a single scalar or abstract-numeric value.
/// Element implements the Constant interface.
template <typename T>
struct Element : Constant {
Element(const sem::Type* t, T v) : type(t), value(v) {}
~Element() override = default;
const sem::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override {
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
return static_cast<AFloat>(value);
} else {
return static_cast<AInt>(value);
}
}
const Constant* Index(size_t) const override { return nullptr; }
bool AllZero() const override { return IsPositiveZero(value); }
bool AnyZero() const override { return IsPositiveZero(value); }
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, ValueOf(value)); }
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const override {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
if (target_ty == type) {
// If the types are identical, then no conversion is needed.
return this;
}
bool failed = false;
auto* res = TypeDispatch(target_ty, [&](auto zero_to) -> const Constant* {
// `T` is the source type, `value` is the source value.
// `TO` is the target type.
using TO = std::decay_t<decltype(zero_to)>;
if constexpr (std::is_same_v<TO, bool>) {
// [x -> bool]
return builder.create<Element<TO>>(target_ty, !IsPositiveZero(value));
} else if constexpr (std::is_same_v<T, bool>) {
// [bool -> x]
return builder.create<Element<TO>>(target_ty, TO(value ? 1 : 0));
} else if (auto conv = CheckedConvert<TO>(value)) {
// Conversion success
return builder.create<Element<TO>>(target_ty, conv.Get());
// --- Below this point are the failure cases ---
} else if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) {
// [abstract-numeric -> x] - materialization failure
std::stringstream ss;
ss << "value " << value << " cannot be represented as ";
ss << "'" << builder.FriendlyName(target_ty) << "'";
builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source);
failed = true;
} else if constexpr (IsFloatingPoint<UnwrapNumber<TO>>) {
// [x -> floating-point] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
constexpr auto kInf = std::numeric_limits<double>::infinity(); constexpr auto kInf = std::numeric_limits<double>::infinity();
switch (conv.Failure()) { switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit: case ConversionFailure::kExceedsNegativeLimit:
el_out = IsFloatingPoint<UnwrapNumber<OUT>> ? OUT(-kInf) : OUT::kLowest; return builder.create<Element<TO>>(target_ty, TO(-kInf));
break;
case ConversionFailure::kExceedsPositiveLimit: case ConversionFailure::kExceedsPositiveLimit:
el_out = IsFloatingPoint<UnwrapNumber<OUT>> ? OUT(kInf) : OUT::kHighest; return builder.create<Element<TO>>(target_ty, TO(kInf));
break; }
} else {
// [x -> integer] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
return builder.create<Element<TO>>(target_ty, TO(TO::kLowest));
case ConversionFailure::kExceedsPositiveLimit:
return builder.create<Element<TO>>(target_ty, TO(TO::kHighest));
} }
} }
return nullptr; // Expression is not constant.
});
if (failed) {
// A diagnostic error has been raised, and resolving should abort.
return utils::Failure;
}
return res;
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
sem::Type const* const type;
const T value;
};
/// Splat holds a single Constant value, duplicated as all children.
/// Splat is used for zero-initializers, 'splat' constructors, or constructors where each element is
/// identical. Splat may be of a vector, matrix or array type.
/// Splat implements the Constant interface.
struct Splat : Constant {
Splat(const sem::Type* t, const Constant* e, size_t n) : type(t), el(e), count(n) {}
~Splat() override = default;
const sem::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
const Constant* Index(size_t i) const override { return i < count ? el : nullptr; }
bool AllZero() const override { return el->AllZero(); }
bool AnyZero() const override { return el->AnyZero(); }
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, el->Hash(), count); }
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const override {
// Convert the single splatted element type.
auto conv_el = el->Convert(builder, sem::Type::ElementOf(target_ty), source);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
return builder.create<Splat>(target_ty, conv_el.Get(), count);
}
sem::Type const* const type;
const Constant* el;
const size_t count;
};
/// Composite holds a number of mixed child Constant values.
/// Composite may be of a vector, matrix or array type.
/// If each element is the same type and value, then a Splat would be a more efficient constant
/// implementation. Use CreateComposite() to create the appropriate Constant type.
/// Composite implements the Constant interface.
struct Composite : Constant {
Composite(const sem::Type* t, std::vector<const Constant*> els, bool all_0, bool any_0)
: type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {}
~Composite() override = default;
const sem::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
const Constant* Index(size_t i) const override {
return i < elements.size() ? elements[i] : nullptr;
}
bool AllZero() const override { return all_zero; }
bool AnyZero() const override { return any_zero; }
bool AllEqual() const override { return false; /* otherwise this should be a Splat */ }
size_t Hash() const override { return hash; }
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const override {
// Convert each of the composite element types.
auto* el_ty = sem::Type::ElementOf(target_ty);
std::vector<const Constant*> conv_els;
conv_els.reserve(elements.size());
for (auto* el : elements) {
auto conv_el = el->Convert(builder, el_ty, source);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
conv_els.emplace_back(conv_el.Get());
}
return CreateComposite(builder, target_ty, std::move(conv_els));
}
size_t CalcHash() {
auto h = utils::Hash(type, all_zero, any_zero);
for (auto* el : elements) {
utils::HashCombine(&h, el->Hash());
}
return h;
}
sem::Type const* const type;
const std::vector<const Constant*> elements;
const bool all_zero;
const bool any_zero;
const size_t hash;
};
/// CreateElement constructs and returns an Element<T>.
template <typename T>
const Constant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) {
return builder.create<Element<T>>(t, v);
}
/// ZeroValue returns a Constant for the zero-value of the type `type`.
const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
return Switch(
type, //
[&](const sem::Vector* v) -> const Constant* {
auto* zero_el = ZeroValue(builder, v->type());
return builder.create<Splat>(type, zero_el, v->Width());
},
[&](const sem::Matrix* m) -> const Constant* {
auto* zero_el = ZeroValue(builder, m->ColumnType());
return builder.create<Splat>(type, zero_el, m->columns());
},
[&](const sem::Array* a) -> const Constant* {
if (auto* zero_el = ZeroValue(builder, a->ElemType())) {
return builder.create<Splat>(type, zero_el, a->Count());
}
return nullptr;
},
[&](Default) -> const Constant* {
return TypeDispatch(type, [&](auto zero) -> const Constant* {
return CreateElement(builder, type, zero);
});
}); });
} }
/// Converts and returns all the elements in `in` to the type `el_ty`, by performing a /// Equal returns true if the constants `a` and `b` are of the same type and value.
/// `CheckedConvert` on each element value. A single error diagnostic will be raised if an element bool Equal(const sem::Constant* a, const sem::Constant* b) {
/// value cannot be represented by the target type. if (a->Hash() != b->Hash()) {
/// @param in the input elements return false;
/// @param el_ty the target element type
/// @returns the elements converted to `el_ty`, or a Failure if some elements could not be
/// represented by the target type.
utils::Result<sem::Constant::Elements> MaterializeElements(const sem::Constant::Elements& in,
const sem::Type* el_ty,
ProgramBuilder& builder,
Source source) {
std::optional<std::string> failure;
auto out = Transform(in, el_ty, [&](auto& el_out, auto el_in) {
using OUT = std::decay_t<decltype(el_out)>;
if (auto conv = CheckedConvert<OUT>(el_in)) {
el_out = conv.Get();
} else if (!failure.has_value()) {
std::stringstream ss;
ss << "value " << el_in << " cannot be represented as ";
ss << "'" << builder.FriendlyName(el_ty) << "'";
failure = ss.str();
} }
}); if (a->Type() != b->Type()) {
return false;
if (failure.has_value()) {
builder.Diagnostics().add_error(diag::System::Resolver, std::move(failure.value()), source);
return utils::Failure;
} }
return Switch(
a->Type(), //
[&](const sem::Vector* vec) {
for (size_t i = 0; i < vec->Width(); i++) {
if (!Equal(a->Index(i), b->Index(i))) {
return false;
}
}
return true;
},
[&](const sem::Matrix* mat) {
for (size_t i = 0; i < mat->columns(); i++) {
if (!Equal(a->Index(i), b->Index(i))) {
return false;
}
}
return true;
},
[&](const sem::Array* arr) {
for (size_t i = 0; i < arr->Count(); i++) {
if (!Equal(a->Index(i), b->Index(i))) {
return false;
}
}
return true;
},
[&](Default) { return a->Value() == b->Value(); });
}
return out; /// CreateComposite is used to construct a constant of a vector, matrix or array type.
/// CreateComposite examines the element values and will return either a Composite or a Splat,
/// depending on the element types and values.
const Constant* CreateComposite(ProgramBuilder& builder,
const sem::Type* type,
std::vector<const Constant*> elements) {
if (elements.size() == 0) {
return nullptr;
}
bool any_zero = false;
bool all_zero = true;
bool all_equal = true;
auto* first = elements.front();
for (auto* el : elements) {
if (!any_zero && el->AnyZero()) {
any_zero = true;
}
if (all_zero && !el->AllZero()) {
all_zero = false;
}
if (all_equal && el != first) {
if (!Equal(el, first)) {
all_equal = false;
}
}
}
if (all_equal) {
return builder.create<Splat>(type, elements[0], elements.size());
} else {
return builder.create<Composite>(type, std::move(elements), all_zero, any_zero);
}
} }
} // namespace } // namespace
sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) { const sem::Constant* Resolver::EvaluateConstantValue(const ast::Expression* expr,
const sem::Type* type) {
return Switch( return Switch(
expr, // expr, //
[&](const ast::IdentifierExpression* e) { return EvaluateConstantValue(e, type); }, [&](const ast::IdentifierExpression* e) { return EvaluateConstantValue(e, type); },
@ -160,7 +364,7 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const
[&](const ast::IndexAccessorExpression* e) { return EvaluateConstantValue(e, type); }); [&](const ast::IndexAccessorExpression* e) { return EvaluateConstantValue(e, type); });
} }
sem::Constant Resolver::EvaluateConstantValue(const ast::IdentifierExpression* ident, const sem::Constant* Resolver::EvaluateConstantValue(const ast::IdentifierExpression* ident,
const sem::Type*) { const sem::Type*) {
if (auto* sem = builder_->Sem().Get(ident)) { if (auto* sem = builder_->Sem().Get(ident)) {
return sem->ConstantValue(); return sem->ConstantValue();
@ -168,104 +372,168 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::IdentifierExpression* i
return {}; return {};
} }
sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal, const sem::Constant* Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal,
const sem::Type* type) { const sem::Type* type) {
return Switch( return Switch(
literal, literal,
[&](const ast::BoolLiteralExpression* lit) { [&](const ast::BoolLiteralExpression* lit) {
return sem::Constant{type, {AInt(lit->value ? 1 : 0)}}; return CreateElement(*builder_, type, lit->value);
}, },
[&](const ast::IntLiteralExpression* lit) { [&](const ast::IntLiteralExpression* lit) -> const Constant* {
return sem::Constant{type, {AInt(lit->value)}}; switch (lit->suffix) {
case ast::IntLiteralExpression::Suffix::kNone:
return CreateElement(*builder_, type, AInt(lit->value));
case ast::IntLiteralExpression::Suffix::kI:
return CreateElement(*builder_, type, i32(lit->value));
case ast::IntLiteralExpression::Suffix::kU:
return CreateElement(*builder_, type, u32(lit->value));
}
return nullptr;
}, },
[&](const ast::FloatLiteralExpression* lit) { [&](const ast::FloatLiteralExpression* lit) -> const Constant* {
return sem::Constant{type, {AFloat(lit->value)}}; switch (lit->suffix) {
case ast::FloatLiteralExpression::Suffix::kNone:
return CreateElement(*builder_, type, AFloat(lit->value));
case ast::FloatLiteralExpression::Suffix::kF:
return CreateElement(*builder_, type, f32(lit->value));
case ast::FloatLiteralExpression::Suffix::kH:
return CreateElement(*builder_, type, f16(lit->value));
}
return nullptr;
}); });
} }
sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression* call,
const sem::Type* ty) { const sem::Type* ty) {
uint32_t num_elems = 0;
auto* el_ty = sem::Type::DeepestElementOf(ty, &num_elems);
if (!el_ty || num_elems == 0) {
return {};
}
// Note: we are building constant values for array types. The working group as verbally agreed // Note: we are building constant values for array types. The working group as verbally agreed
// to support constant expression arrays, but this is not (yet) part of the spec. // to support constant expression arrays, but this is not (yet) part of the spec.
// See: https://github.com/gpuweb/gpuweb/issues/3056 // See: https://github.com/gpuweb/gpuweb/issues/3056
// For zero value init, return 0s // For zero value init, return 0s
if (call->args.empty()) { if (call->args.empty()) {
return Switch( return ZeroValue(*builder_, ty);
el_ty,
[&](const sem::AbstractInt*) {
return sem::Constant(ty, std::vector(num_elems, AInt(0)));
},
[&](const sem::AbstractFloat*) {
return sem::Constant(ty, std::vector(num_elems, AFloat(0)));
},
[&](const sem::I32*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); },
[&](const sem::U32*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); },
[&](const sem::F32*) { return sem::Constant(ty, std::vector(num_elems, AFloat(0))); },
[&](const sem::F16*) { return sem::Constant(ty, std::vector(num_elems, AFloat(0))); },
[&](const sem::Bool*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); });
} }
// Build value for type_ctor from each child value by converting to type_ctor's type. uint32_t el_count = 0;
std::optional<sem::Constant::Elements> elements; auto* el_ty = sem::Type::ElementOf(ty, &el_count);
for (auto* expr : call->args) { if (!el_ty) {
auto* arg = builder_->Sem().Get(expr); return nullptr; // Target type does not support constant values
}
// value_of returns a `const Constant*` for the expression `expr`, or nullptr if the expression
// does not have a constant value.
auto value_of = [&](const ast::Expression* expr) {
return static_cast<const Constant*>(builder_->Sem().Get(expr)->ConstantValue());
};
if (call->args.size() == 1) {
// Type constructor or conversion that takes a single argument.
auto& src = call->args[0]->source;
auto* arg = value_of(call->args[0]);
if (!arg) { if (!arg) {
return {}; return nullptr; // Single argument is not constant.
}
auto value = arg->ConstantValue();
if (!value) {
return {};
} }
// Convert the elements to the desired type. if (ty->is_scalar()) { // Scalar type conversion: i32(x), u32(x), bool(x), etc
auto converted = ConvertElements(value.GetElements(), el_ty); return ConvertValue(arg, el_ty, src).Get();
if (elements.has_value()) {
// Append the converted vector to elements
std::visit(
[&](auto&& dst) {
using VEC_TY = std::decay_t<decltype(dst)>;
const auto& src = std::get<VEC_TY>(converted);
dst.insert(dst.end(), src.begin(), src.end());
},
elements.value());
} else {
elements = std::move(converted);
}
} }
if (!elements) { if (arg->Type() == el_ty) {
return {}; // Argument type matches function type. This is a splat.
auto splat = [&](size_t n) { return builder_->create<Splat>(ty, arg, n); };
return Switch(
ty, //
[&](const sem::Vector* v) { return splat(v->Width()); },
[&](const sem::Matrix* m) { return splat(m->columns()); },
[&](const sem::Array* a) { return splat(a->Count()); });
} }
return std::visit( // Argument type and function type mismatch. This is a type conversion.
[&](auto&& v) { if (auto conv = ConvertValue(arg, ty, src)) {
if (num_elems != v.size()) { return conv.Get();
if (v.size() == 1) { }
// Splat single-value initializers
for (uint32_t i = 0; i < num_elems - 1; ++i) { return nullptr;
v.emplace_back(v[0]); }
// Multiple arguments. Must be a type constructor.
std::vector<const Constant*> els; // The constant elements for the composite constant.
els.reserve(el_count);
// Helper for pushing all the argument constants to `els`.
auto push_all_args = [&] {
for (auto* expr : call->args) {
auto* arg = value_of(expr);
if (!arg) {
return;
}
els.emplace_back(arg);
}
};
Switch(
ty, // What's the target type being constructed?
[&](const sem::Vector*) {
// Vector can be constructed with a mix of scalars / abstract numerics and smaller
// vectors.
for (auto* expr : call->args) {
auto* arg = value_of(expr);
if (!arg) {
return;
}
auto* arg_ty = arg->Type();
if (auto* arg_vec = arg_ty->As<sem::Vector>()) {
// Extract out vector elements.
for (uint32_t i = 0; i < arg_vec->Width(); i++) {
auto* el = static_cast<const Constant*>(arg->Index(i));
if (!el) {
return;
}
els.emplace_back(el);
} }
} else { } else {
// Provided number of arguments does not match the required number of elements. els.emplace_back(arg);
// Validation should error here.
return sem::Constant{};
} }
} }
return sem::Constant(ty, std::move(elements.value()));
}, },
elements.value()); [&](const sem::Matrix* m) {
// Matrix can be constructed with a set of scalars / abstract numerics, or column
// vectors.
if (call->args.size() == m->columns() * m->rows()) {
// Matrix built from scalars / abstract numerics
for (uint32_t c = 0; c < m->columns(); c++) {
std::vector<const Constant*> column;
column.reserve(m->rows());
for (uint32_t r = 0; r < m->rows(); r++) {
auto* arg = value_of(call->args[r + c * m->rows()]);
if (!arg) {
return;
}
column.emplace_back(arg);
}
els.push_back(CreateComposite(*builder_, m->ColumnType(), std::move(column)));
}
} else if (call->args.size() == m->columns()) {
// Matrix built from column vectors
push_all_args();
}
},
[&](const sem::Array*) {
// Arrays must be constructed using a list of elements
push_all_args();
});
if (els.size() != el_count) {
// If the number of constant elements doesn't match the type, then something went wrong.
return nullptr;
}
// Construct and return either a Composite or Splat.
return CreateComposite(*builder_, ty, std::move(els));
} }
sem::Constant Resolver::EvaluateConstantValue(const ast::IndexAccessorExpression* accessor, const sem::Constant* Resolver::EvaluateConstantValue(const ast::IndexAccessorExpression* accessor,
const sem::Type* el_ty) { const sem::Type*) {
auto* obj_sem = builder_->Sem().Get(accessor->object); auto* obj_sem = builder_->Sem().Get(accessor->object);
if (!obj_sem) { if (!obj_sem) {
return {}; return {};
@ -282,20 +550,14 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::IndexAccessorExpression
} }
auto idx_val = idx_sem->ConstantValue(); auto idx_val = idx_sem->ConstantValue();
if (!idx_val || idx_val.ElementCount() != 1) { if (!idx_val) {
return {}; return {};
} }
AInt idx = idx_val.Element<AInt>(0);
// The immediate child element count.
uint32_t el_count = 0; uint32_t el_count = 0;
sem::Type::ElementOf(obj_val.Type(), &el_count); sem::Type::ElementOf(obj_val->Type(), &el_count);
// The total number of most-nested elements per child element type.
uint32_t step = 0;
sem::Type::DeepestElementOf(el_ty, &step);
AInt idx = idx_val->As<AInt>();
if (idx < 0 || idx >= el_count) { if (idx < 0 || idx >= el_count) {
auto clamped = std::min<AInt::type>(std::max<AInt::type>(idx, 0), el_count - 1); auto clamped = std::min<AInt::type>(std::max<AInt::type>(idx, 0), el_count - 1);
AddWarning("index " + std::to_string(idx) + " out of bounds [0.." + AddWarning("index " + std::to_string(idx) + " out of bounds [0.." +
@ -305,32 +567,20 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::IndexAccessorExpression
idx = clamped; idx = clamped;
} }
return sem::Constant{el_ty, obj_val.WithElements([&](auto&& v) { return obj_val->Index(static_cast<size_t>(idx));
using VEC = std::decay_t<decltype(v)>;
return sem::Constant::Elements(
VEC(v.begin() + (idx * step), v.begin() + (idx + 1) * step));
})};
} }
utils::Result<sem::Constant> Resolver::ConvertValue(const sem::Constant& value, utils::Result<const sem::Constant*> Resolver::ConvertValue(const sem::Constant* value,
const sem::Type* ty, const sem::Type* target_ty,
const Source& source) { const Source& source) {
if (value.Type() == ty) { if (value->Type() == target_ty) {
return value; return value;
} }
auto conv = static_cast<const Constant*>(value)->Convert(*builder_, target_ty, source);
auto* el_ty = sem::Type::DeepestElementOf(ty); if (!conv) {
if (el_ty == nullptr) {
return sem::Constant{};
}
if (value.ElementType() == el_ty) {
return sem::Constant(ty, value.GetElements());
}
if (auto res = MaterializeElements(value.GetElements(), el_ty, *builder_, source)) {
return sem::Constant(ty, std::move(res.Get()));
}
return utils::Failure; return utils::Failure;
}
return conv.Get();
} }
} // namespace tint::resolver } // namespace tint::resolver

File diff suppressed because it is too large Load Diff

View File

@ -1604,16 +1604,16 @@ bool Validator::TextureBuiltinFunction(const sem::Call* call) const {
auto& signature = builtin->Signature(); auto& signature = builtin->Signature();
auto check_arg_is_constexpr = [&](sem::ParameterUsage usage, int min, int max) { auto check_arg_is_constexpr = [&](sem::ParameterUsage usage, int min, int max) {
auto index = signature.IndexOf(usage); auto signed_index = signature.IndexOf(usage);
if (index < 0) { if (signed_index < 0) {
return true; return true;
} }
auto index = static_cast<size_t>(signed_index);
std::string name = sem::str(usage); std::string name = sem::str(usage);
auto* arg = call->Arguments()[static_cast<size_t>(index)]; auto* arg = call->Arguments()[index];
if (auto values = arg->ConstantValue()) { if (auto values = arg->ConstantValue()) {
// Assert that the constant values are of the expected type. // Assert that the constant values are of the expected type.
if (!values.Type()->IsAnyOf<sem::I32, sem::Vector>() || if (!values->Type()->is_integer_scalar_or_vector()) {
!values.ElementType()->Is<sem::I32>()) {
TINT_ICE(Resolver, diagnostics_) TINT_ICE(Resolver, diagnostics_)
<< "failed to resolve '" + func_name + "' " << name << " parameter type"; << "failed to resolve '" + func_name + "' " << name << " parameter type";
return false; return false;
@ -1631,25 +1631,26 @@ bool Validator::TextureBuiltinFunction(const sem::Call* call) const {
return ast::TraverseAction::Stop; return ast::TraverseAction::Stop;
}); });
if (is_const_expr) { if (is_const_expr) {
auto vector = if (auto* vector = builtin->Parameters()[index]->Type()->As<sem::Vector>()) {
builtin->Parameters()[static_cast<size_t>(index)]->Type()->Is<sem::Vector>(); for (size_t i = 0; i < vector->Width(); i++) {
for (size_t i = 0, n = values.ElementCount(); i < n; i++) { auto value = values->Index(i)->As<AInt>();
auto value = values.Element<AInt>(i).value;
if (value < min || value > max) { if (value < min || value > max) {
if (vector) {
AddError("each component of the " + name + AddError("each component of the " + name +
" argument must be at least " + std::to_string(min) + " argument must be at least " + std::to_string(min) +
" and at most " + std::to_string(max) + ". " + name + " and at most " + std::to_string(max) + ". " + name +
" component " + std::to_string(i) + " is " + " component " + std::to_string(i) + " is " +
std::to_string(value), std::to_string(value),
arg->Declaration()->source); arg->Declaration()->source);
} else { return false;
AddError("the " + name + " argument must be at least " +
std::to_string(min) + " and at most " +
std::to_string(max) + ". " + name + " is " +
std::to_string(value),
arg->Declaration()->source);
} }
}
} else {
auto value = values->As<AInt>();
if (value < min || value > max) {
AddError("the " + name + " argument must be at least " +
std::to_string(min) + " and at most " + std::to_string(max) +
". " + name + " is " + std::to_string(value),
arg->Declaration()->source);
return false; return false;
} }
} }

View File

@ -920,21 +920,13 @@ TEST_F(ResolverVariableTest, LocalConst_ExplicitType_Decls) {
ASSERT_TRUE(TypeOf(c_vf32)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(c_vf32)->Is<sem::Vector>());
ASSERT_TRUE(TypeOf(c_mf32)->Is<sem::Matrix>()); ASSERT_TRUE(TypeOf(c_mf32)->Is<sem::Matrix>());
EXPECT_TRUE(Sem().Get(c_i32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_i32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_u32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_u32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_f32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_f32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vi32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vi32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vu32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vu32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vf32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vf32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_mf32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_mf32)->ConstantValue()->AllZero());
EXPECT_EQ(Sem().Get(c_i32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_u32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_f32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_vi32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vu32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vf32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_mf32)->ConstantValue().ElementCount(), 9u);
} }
TEST_F(ResolverVariableTest, LocalConst_ImplicitType_Decls) { TEST_F(ResolverVariableTest, LocalConst_ImplicitType_Decls) {
@ -949,7 +941,11 @@ TEST_F(ResolverVariableTest, LocalConst_ImplicitType_Decls) {
auto* c_vai = Const("i", nullptr, Construct(ty.vec(nullptr, 3), Expr(0_a))); auto* c_vai = Const("i", nullptr, Construct(ty.vec(nullptr, 3), Expr(0_a)));
auto* c_vaf = Const("j", nullptr, Construct(ty.vec(nullptr, 3), Expr(0._a))); auto* c_vaf = Const("j", nullptr, Construct(ty.vec(nullptr, 3), Expr(0._a)));
auto* c_mf32 = Const("k", nullptr, mat3x3<f32>()); auto* c_mf32 = Const("k", nullptr, mat3x3<f32>());
auto* c_maf32 = Const("l", nullptr, Construct(ty.mat(nullptr, 3, 3), Expr(0._a))); auto* c_maf32 = Const("l", nullptr,
Construct(ty.mat(nullptr, 3, 3), //
Construct(ty.vec(nullptr, 3), Expr(0._a)),
Construct(ty.vec(nullptr, 3), Expr(0._a)),
Construct(ty.vec(nullptr, 3), Expr(0._a))));
WrapInFunction(c_i32, c_u32, c_f32, c_ai, c_af, c_vi32, c_vu32, c_vf32, c_vai, c_vaf, c_mf32, WrapInFunction(c_i32, c_u32, c_f32, c_ai, c_af, c_vi32, c_vu32, c_vf32, c_vai, c_vaf, c_mf32,
c_maf32); c_maf32);
@ -982,31 +978,18 @@ TEST_F(ResolverVariableTest, LocalConst_ImplicitType_Decls) {
ASSERT_TRUE(TypeOf(c_mf32)->Is<sem::Matrix>()); ASSERT_TRUE(TypeOf(c_mf32)->Is<sem::Matrix>());
ASSERT_TRUE(TypeOf(c_maf32)->Is<sem::Matrix>()); ASSERT_TRUE(TypeOf(c_maf32)->Is<sem::Matrix>());
EXPECT_TRUE(Sem().Get(c_i32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_i32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_u32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_u32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_f32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_f32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_ai)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_ai)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_af)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_af)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vi32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vi32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vu32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vu32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vf32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vf32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vai)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vai)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vaf)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vaf)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_mf32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_mf32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_maf32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_maf32)->ConstantValue()->AllZero());
EXPECT_EQ(Sem().Get(c_i32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_u32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_f32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_ai)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_af)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_vi32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vu32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vf32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vai)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vaf)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_mf32)->ConstantValue().ElementCount(), 9u);
EXPECT_EQ(Sem().Get(c_maf32)->ConstantValue().ElementCount(), 9u);
} }
TEST_F(ResolverVariableTest, LocalConst_PropagateConstValue) { TEST_F(ResolverVariableTest, LocalConst_PropagateConstValue) {
@ -1020,8 +1003,7 @@ TEST_F(ResolverVariableTest, LocalConst_PropagateConstValue) {
ASSERT_TRUE(TypeOf(c)->Is<sem::I32>()); ASSERT_TRUE(TypeOf(c)->Is<sem::I32>());
ASSERT_EQ(Sem().Get(c)->ConstantValue().ElementCount(), 1u); EXPECT_EQ(Sem().Get(c)->ConstantValue()->As<i32>(), 42_i);
EXPECT_EQ(Sem().Get(c)->ConstantValue().Element<i32>(0), 42_i);
} }
// Enable when we have @const operators implemented // Enable when we have @const operators implemented
@ -1034,8 +1016,7 @@ TEST_F(ResolverVariableTest, DISABLED_LocalConst_ConstEval) {
ASSERT_TRUE(TypeOf(c)->Is<sem::I32>()); ASSERT_TRUE(TypeOf(c)->Is<sem::I32>());
ASSERT_EQ(Sem().Get(c)->ConstantValue().ElementCount(), 1u); EXPECT_EQ(Sem().Get(c)->ConstantValue()->As<i32>(), 3_i);
EXPECT_EQ(Sem().Get(c)->ConstantValue().Element<i32>(0), 3_i);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -1126,21 +1107,13 @@ TEST_F(ResolverVariableTest, GlobalConst_ExplicitType_Decls) {
ASSERT_TRUE(TypeOf(c_vf32)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(c_vf32)->Is<sem::Vector>());
ASSERT_TRUE(TypeOf(c_mf32)->Is<sem::Matrix>()); ASSERT_TRUE(TypeOf(c_mf32)->Is<sem::Matrix>());
EXPECT_TRUE(Sem().Get(c_i32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_i32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_u32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_u32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_f32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_f32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vi32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vi32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vu32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vu32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vf32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vf32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_mf32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_mf32)->ConstantValue()->AllZero());
EXPECT_EQ(Sem().Get(c_i32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_u32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_f32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_vi32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vu32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vf32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_mf32)->ConstantValue().ElementCount(), 9u);
} }
TEST_F(ResolverVariableTest, GlobalConst_ImplicitType_Decls) { TEST_F(ResolverVariableTest, GlobalConst_ImplicitType_Decls) {
@ -1155,7 +1128,11 @@ TEST_F(ResolverVariableTest, GlobalConst_ImplicitType_Decls) {
auto* c_vai = GlobalConst("i", nullptr, Construct(ty.vec(nullptr, 3), Expr(0_a))); auto* c_vai = GlobalConst("i", nullptr, Construct(ty.vec(nullptr, 3), Expr(0_a)));
auto* c_vaf = GlobalConst("j", nullptr, Construct(ty.vec(nullptr, 3), Expr(0._a))); auto* c_vaf = GlobalConst("j", nullptr, Construct(ty.vec(nullptr, 3), Expr(0._a)));
auto* c_mf32 = GlobalConst("k", nullptr, mat3x3<f32>()); auto* c_mf32 = GlobalConst("k", nullptr, mat3x3<f32>());
auto* c_maf32 = GlobalConst("l", nullptr, Construct(ty.mat(nullptr, 3, 3), Expr(0._a))); auto* c_maf32 = GlobalConst("l", nullptr,
Construct(ty.mat(nullptr, 3, 3), //
Construct(ty.vec(nullptr, 3), Expr(0._a)),
Construct(ty.vec(nullptr, 3), Expr(0._a)),
Construct(ty.vec(nullptr, 3), Expr(0._a))));
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -1185,31 +1162,18 @@ TEST_F(ResolverVariableTest, GlobalConst_ImplicitType_Decls) {
ASSERT_TRUE(TypeOf(c_mf32)->Is<sem::Matrix>()); ASSERT_TRUE(TypeOf(c_mf32)->Is<sem::Matrix>());
ASSERT_TRUE(TypeOf(c_maf32)->Is<sem::Matrix>()); ASSERT_TRUE(TypeOf(c_maf32)->Is<sem::Matrix>());
EXPECT_TRUE(Sem().Get(c_i32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_i32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_u32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_u32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_f32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_f32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_ai)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_ai)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_af)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_af)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vi32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vi32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vu32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vu32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vf32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vf32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vai)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vai)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vaf)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_vaf)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_mf32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_mf32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_maf32)->ConstantValue().AllZero()); EXPECT_TRUE(Sem().Get(c_maf32)->ConstantValue()->AllZero());
EXPECT_EQ(Sem().Get(c_i32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_u32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_f32)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_ai)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_af)->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(Sem().Get(c_vi32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vu32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vf32)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vai)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_vaf)->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(Sem().Get(c_mf32)->ConstantValue().ElementCount(), 9u);
EXPECT_EQ(Sem().Get(c_maf32)->ConstantValue().ElementCount(), 9u);
} }
TEST_F(ResolverVariableTest, GlobalConst_PropagateConstValue) { TEST_F(ResolverVariableTest, GlobalConst_PropagateConstValue) {
@ -1221,8 +1185,7 @@ TEST_F(ResolverVariableTest, GlobalConst_PropagateConstValue) {
ASSERT_TRUE(TypeOf(c)->Is<sem::I32>()); ASSERT_TRUE(TypeOf(c)->Is<sem::I32>());
ASSERT_EQ(Sem().Get(c)->ConstantValue().ElementCount(), 1u); EXPECT_EQ(Sem().Get(c)->ConstantValue()->As<i32>(), 42_i);
EXPECT_EQ(Sem().Get(c)->ConstantValue().Element<i32>(0), 42_i);
} }
// Enable when we have @const operators implemented // Enable when we have @const operators implemented
@ -1233,8 +1196,7 @@ TEST_F(ResolverVariableTest, DISABLED_GlobalConst_ConstEval) {
ASSERT_TRUE(TypeOf(c)->Is<sem::I32>()); ASSERT_TRUE(TypeOf(c)->Is<sem::I32>());
ASSERT_EQ(Sem().Get(c)->ConstantValue().ElementCount(), 1u); EXPECT_EQ(Sem().Get(c)->ConstantValue()->As<i32>(), 3_i);
EXPECT_EQ(Sem().Get(c)->ConstantValue().Element<i32>(0), 3_i);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -25,9 +25,9 @@ Call::Call(const ast::CallExpression* declaration,
const CallTarget* target, const CallTarget* target,
std::vector<const sem::Expression*> arguments, std::vector<const sem::Expression*> arguments,
const Statement* statement, const Statement* statement,
Constant constant, const Constant* constant,
bool has_side_effects) bool has_side_effects)
: Base(declaration, target->ReturnType(), statement, std::move(constant), has_side_effects), : Base(declaration, target->ReturnType(), statement, constant, has_side_effects),
target_(target), target_(target),
arguments_(std::move(arguments)) {} arguments_(std::move(arguments)) {}

View File

@ -38,7 +38,7 @@ class Call final : public Castable<Call, Expression> {
const CallTarget* target, const CallTarget* target,
std::vector<const sem::Expression*> arguments, std::vector<const sem::Expression*> arguments,
const Statement* statement, const Statement* statement,
Constant constant, const Constant* constant,
bool has_side_effects); bool has_side_effects);
/// Destructor /// Destructor

View File

@ -14,102 +14,10 @@
#include "src/tint/sem/constant.h" #include "src/tint/sem/constant.h"
#include <cmath>
#include <utility>
#include "src/tint/debug.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/type.h"
namespace tint::sem { namespace tint::sem {
namespace { Constant::Constant() = default;
size_t CountElements(const Constant::Elements& elements) {
return std::visit([](auto&& vec) { return vec.size(); }, elements);
}
template <typename T>
bool IsNegativeFloat(T value) {
(void)value;
if constexpr (IsFloatingPoint<T>) {
return std::signbit(value);
} else {
return false;
}
}
} // namespace
Constant::Constant() {}
Constant::Constant(const sem::Type* ty, Elements els)
: type_(ty), elem_type_(CheckElemType(ty, CountElements(els))), elems_(std::move(els)) {}
Constant::Constant(const sem::Type* ty, AInts vec) : Constant(ty, Elements{std::move(vec)}) {}
Constant::Constant(const sem::Type* ty, AFloats vec) : Constant(ty, Elements{std::move(vec)}) {}
Constant::Constant(const Constant&) = default;
Constant::~Constant() = default; Constant::~Constant() = default;
Constant& Constant::operator=(const Constant& rhs) = default;
bool Constant::AnyZero() const {
return WithElements([&](auto&& vec) {
using T = typename std::decay_t<decltype(vec)>::value_type;
for (auto el : vec) {
if (el == T(0) && !IsNegativeFloat(el.value)) {
return true;
}
}
return false;
});
}
bool Constant::AllZero(size_t start, size_t end) const {
return WithElements([&](auto&& vec) {
using T = typename std::decay_t<decltype(vec)>::value_type;
for (size_t i = start; i < end; i++) {
auto el = vec[i];
if (el != T(0) || IsNegativeFloat(el.value)) {
return false;
}
}
return true;
});
}
bool Constant::AllEqual(size_t start, size_t end) const {
return WithElements([&](auto&& vec) {
if (!vec.empty()) {
auto value = vec[start];
bool float_sign = IsNegativeFloat(vec[start].value);
for (size_t i = start + 1; i < end; i++) {
if (vec[i] != value || float_sign != IsNegativeFloat(vec[i].value)) {
return false;
}
}
}
return true;
});
}
const Type* Constant::CheckElemType(const sem::Type* ty, size_t num_elements) {
diag::List diag;
uint32_t count = 0;
auto* el_ty = Type::DeepestElementOf(ty, &count);
if (!el_ty) {
TINT_ICE(Semantic, diag) << "Unsupported sem::Constant type: " << ty->TypeInfo().name;
return nullptr;
}
if (num_elements != count) {
TINT_ICE(Semantic, diag) << "sem::Constant() type <-> element mismatch. type: '"
<< ty->TypeInfo().name << "' provided: " << num_elements
<< " require: " << count;
}
TINT_ASSERT(Semantic, el_ty->is_abstract_or_scalar());
return el_ty;
}
} // namespace tint::sem } // namespace tint::sem

View File

@ -15,10 +15,7 @@
#ifndef SRC_TINT_SEM_CONSTANT_H_ #ifndef SRC_TINT_SEM_CONSTANT_H_
#define SRC_TINT_SEM_CONSTANT_H_ #define SRC_TINT_SEM_CONSTANT_H_
#include <ostream>
#include <utility>
#include <variant> #include <variant>
#include <vector>
#include "src/tint/number.h" #include "src/tint/number.h"
@ -29,166 +26,52 @@ class Type;
namespace tint::sem { namespace tint::sem {
/// A Constant holds a compile-time evaluated expression value, expressed as a flattened list of /// Constant is the interface to a compile-time evaluated expression value.
/// element values. The expression type may be of an abstract-numeric, scalar, vector or matrix
/// type. Constant holds the element values in either a vector of abstract-integer (AInt) or
/// abstract-float (AFloat), depending on the element type.
class Constant { class Constant {
public: public:
/// AInts is a vector of AInt, used to hold elements of the WGSL types: /// Constructor
/// * abstract-integer
/// * i32
/// * u32
/// * bool (0 or 1)
using AInts = std::vector<AInt>;
/// AFloats is a vector of AFloat, used to hold elements of the WGSL types:
/// * abstract-float
/// * f32
/// * f16
using AFloats = std::vector<AFloat>;
/// Elements is either a vector of AInts or AFloats
using Elements = std::variant<AInts, AFloats>;
/// Helper that resolves to either AInt or AFloat based on the element type T.
template <typename T>
using ElementFor = std::conditional_t<IsFloatingPoint<UnwrapNumber<T>>, AFloat, AInt>;
/// Helper that resolves to either AInts or AFloats based on the element type T.
template <typename T>
using ElementVectorFor = std::conditional_t<IsFloatingPoint<UnwrapNumber<T>>, AFloats, AInts>;
/// Constructs an invalid Constant
Constant(); Constant();
/// Constructs a Constant of the given type and element values
/// @param ty the Constant type
/// @param els the Constant element values
Constant(const sem::Type* ty, Elements els);
/// Constructs a Constant of the given type and element values
/// @param ty the Constant type
/// @param vec the Constant element values
Constant(const sem::Type* ty, AInts vec);
/// Constructs a Constant of the given type and element values
/// @param ty the Constant type
/// @param vec the Constant element values
Constant(const sem::Type* ty, AFloats vec);
/// Constructs a Constant of the given type and element values
/// @param ty the Constant type
/// @param els the Constant element values
template <typename T>
Constant(const sem::Type* ty, std::initializer_list<T> els);
/// Copy constructor
Constant(const Constant&);
/// Destructor /// Destructor
~Constant(); virtual ~Constant();
/// Copy assignment /// @returns the type of the constant
/// @param other the Constant to copy virtual const sem::Type* Type() const = 0;
/// @returns this Constant
Constant& operator=(const Constant& other);
/// @returns true if the Constant has been initialized /// @returns the value of this Constant, if this constant is of a scalar value or abstract
bool IsValid() const { return type_ != nullptr; } /// numeric, otherwsie std::monostate.
virtual std::variant<std::monostate, AInt, AFloat> Value() const = 0;
/// @return true if the Constant has been initialized /// @returns the child constant element with the given index, or nullptr if the constant has no
operator bool() const { return IsValid(); } /// children, or the index is out of bounds.
virtual const Constant* Index(size_t) const = 0;
/// @returns the type of the Constant /// @returns true if child elements of this constant are positive-zero valued.
const sem::Type* Type() const { return type_; } virtual bool AllZero() const = 0;
/// @returns the number of elements /// @returns true if any child elements of this constant are positive-zero valued.
size_t ElementCount() const { virtual bool AnyZero() const = 0;
return std::visit([](auto&& v) { return v.size(); }, elems_);
}
/// @returns the flattened element type of the Constant /// @returns true if all child elements of this constant have the same value and type.
const sem::Type* ElementType() const { return elem_type_; } virtual bool AllEqual() const = 0;
/// @returns the constant's flattened elements /// @returns a hash of the constant.
const Elements& GetElements() const { return elems_; } virtual size_t Hash() const = 0;
/// WithElements calls the function `f` with the vector of elements as either AFloats or AInts /// @returns the value of the constant as the given scalar or abstract value.
/// @param f a function-like with the signature `R(auto&&)`.
/// @returns the result of calling `f`.
template <typename F>
auto WithElements(F&& f) const {
return std::visit(std::forward<F>(f), elems_);
}
/// WithElements calls the function `f` with the element vector as either AFloats or AInts
/// @param f a function-like with the signature `R(auto&&)`.
/// @returns the result of calling `f`.
template <typename F>
auto WithElements(F&& f) {
return std::visit(std::forward<F>(f), elems_);
}
/// @returns the elements as a vector of AInt
inline const AInts& IElements() const { return std::get<AInts>(elems_); }
/// @returns the elements as a vector of AFloat
inline const AFloats& FElements() const { return std::get<AFloats>(elems_); }
/// @returns true if any element is positive zero
bool AnyZero() const;
/// @returns true if all elements are positive zero
bool AllZero() const { return AllZero(0, ElementCount()); }
/// @returns true if all elements are the same value, with the same sign-bit.
bool AllEqual() const { return AllEqual(0, ElementCount()); }
/// @param start the first element index
/// @param end one past the last element index
/// @returns true if all elements between `[start, end)` are zero
bool AllZero(size_t start, size_t end) const;
/// @param start the first element index
/// @param end one past the last element index
/// @returns true if all elements between `[start, end)` are the same value
bool AllEqual(size_t start, size_t end) const;
/// @param index the index of the element
/// @return the element at `index`, which must be of type `T`.
template <typename T> template <typename T>
T Element(size_t index) const; T As() const {
return std::visit(
private: [](auto v) {
/// Checks that the provided type matches the number of expected elements. if constexpr (std::is_same_v<decltype(v), std::monostate>) {
/// @returns the element type of `ty`. return T(0);
const sem::Type* CheckElemType(const sem::Type* ty, size_t num_elements);
const sem::Type* type_ = nullptr;
const sem::Type* elem_type_ = nullptr;
Elements elems_;
};
template <typename T>
Constant::Constant(const sem::Type* ty, std::initializer_list<T> els)
: type_(ty), elem_type_(CheckElemType(type_, els.size())) {
ElementVectorFor<T> elements;
elements.reserve(els.size());
for (auto el : els) {
elements.emplace_back(ElementFor<T>(el));
}
elems_ = Elements{std::move(elements)};
}
template <typename T>
T Constant::Element(size_t index) const {
if constexpr (std::is_same_v<ElementVectorFor<T>, AFloats>) {
return static_cast<T>(FElements()[index].value);
} else { } else {
return static_cast<T>(IElements()[index].value); return static_cast<T>(v);
} }
} },
Value());
}
};
} // namespace tint::sem } // namespace tint::sem

View File

@ -1,540 +0,0 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/sem/constant.h"
#include <gmock/gmock.h>
#include "src/tint/sem/abstract_float.h"
#include "src/tint/sem/abstract_int.h"
#include "src/tint/sem/test_helper.h"
using namespace tint::number_suffixes; // NOLINT
namespace tint::sem {
namespace {
struct ConstantTest : public TestHelper {
const sem::Array* Array(uint32_t n, const sem::Type* el_ty) {
return create<sem::Array>(el_ty,
/* count */ n,
/* align */ 16u,
/* size */ 4u * n,
/* stride */ 16u * n,
/* implicit_stride */ 16u * n);
}
};
TEST_F(ConstantTest, ConstructorInitializerList) {
{
auto i = AInt(AInt::kHighest);
Constant c(create<AbstractInt>(), {i});
c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(i)); });
}
{
auto i = i32(i32::kHighest);
Constant c(create<I32>(), {i});
c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(i)); });
}
{
auto i = u32(u32::kHighest);
Constant c(create<U32>(), {i});
c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(i)); });
}
{
Constant c(create<Bool>(), {false});
c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(0_a)); });
}
{
Constant c(create<Bool>(), {true});
c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1_a)); });
}
{
auto f = AFloat(AFloat::kHighest);
Constant c(create<AbstractFloat>(), {f});
c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(f)); });
}
{
auto f = f32(f32::kHighest);
Constant c(create<F32>(), {f});
c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(f)); });
}
{
auto f = f16(f16::kHighest);
Constant c(create<F16>(), {f});
c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(f)); });
}
}
TEST_F(ConstantTest, Element_ai) {
auto* ty = create<AbstractInt>();
Constant c(ty, {1_a});
EXPECT_EQ(c.Element<AInt>(0), 1_a);
EXPECT_EQ(c.ElementCount(), 1u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), ty);
}
TEST_F(ConstantTest, Element_i32) {
auto* ty = create<I32>();
Constant c(ty, {1_a});
EXPECT_EQ(c.Element<i32>(0), 1_i);
EXPECT_EQ(c.ElementCount(), 1u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), ty);
}
TEST_F(ConstantTest, Element_u32) {
auto* ty = create<U32>();
Constant c(ty, {1_a});
EXPECT_EQ(c.Element<u32>(0), 1_u);
EXPECT_EQ(c.ElementCount(), 1u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), ty);
}
TEST_F(ConstantTest, Element_bool) {
auto* ty = create<Bool>();
Constant c(ty, {true});
EXPECT_EQ(c.Element<bool>(0), true);
EXPECT_EQ(c.ElementCount(), 1u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), ty);
}
TEST_F(ConstantTest, Element_af) {
auto* ty = create<AbstractFloat>();
Constant c(ty, {1.0_a});
EXPECT_EQ(c.Element<AFloat>(0), 1.0_a);
EXPECT_EQ(c.ElementCount(), 1u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), ty);
}
TEST_F(ConstantTest, Element_f32) {
auto* ty = create<F32>();
Constant c(ty, {1.0_a});
EXPECT_EQ(c.Element<f32>(0), 1.0_f);
EXPECT_EQ(c.ElementCount(), 1u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), ty);
}
TEST_F(ConstantTest, Element_f16) {
auto* ty = create<F16>();
Constant c(ty, {1.0_a});
EXPECT_EQ(c.Element<f16>(0), 1.0_h);
EXPECT_EQ(c.ElementCount(), 1u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), ty);
}
TEST_F(ConstantTest, Element_vec3_ai) {
auto* el_ty = create<AbstractInt>();
auto* ty = create<Vector>(el_ty, 3u);
Constant c(ty, {1_a, 2_a, 3_a});
EXPECT_EQ(c.Element<AInt>(0), 1_a);
EXPECT_EQ(c.Element<AInt>(1), 2_a);
EXPECT_EQ(c.Element<AInt>(2), 3_a);
EXPECT_EQ(c.ElementCount(), 3u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_vec3_i32) {
auto* el_ty = create<I32>();
auto* ty = create<Vector>(el_ty, 3u);
Constant c(ty, {1_a, 2_a, 3_a});
EXPECT_EQ(c.Element<i32>(0), 1_i);
EXPECT_EQ(c.Element<i32>(1), 2_i);
EXPECT_EQ(c.Element<i32>(2), 3_i);
EXPECT_EQ(c.ElementCount(), 3u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_vec3_u32) {
auto* el_ty = create<U32>();
auto* ty = create<Vector>(el_ty, 3u);
Constant c(ty, {1_a, 2_a, 3_a});
EXPECT_EQ(c.Element<u32>(0), 1_u);
EXPECT_EQ(c.Element<u32>(1), 2_u);
EXPECT_EQ(c.Element<u32>(2), 3_u);
EXPECT_EQ(c.ElementCount(), 3u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_vec3_bool) {
auto* el_ty = create<Bool>();
auto* ty = create<Vector>(el_ty, 2u);
Constant c(ty, {true, false});
EXPECT_EQ(c.Element<bool>(0), true);
EXPECT_EQ(c.Element<bool>(1), false);
EXPECT_EQ(c.ElementCount(), 2u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_vec3_af) {
auto* el_ty = create<AbstractFloat>();
auto* ty = create<Vector>(el_ty, 3u);
Constant c(ty, {1.0_a, 2.0_a, 3.0_a});
EXPECT_EQ(c.Element<AFloat>(0), 1.0_a);
EXPECT_EQ(c.Element<AFloat>(1), 2.0_a);
EXPECT_EQ(c.Element<AFloat>(2), 3.0_a);
EXPECT_EQ(c.ElementCount(), 3u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_vec3_f32) {
auto* el_ty = create<F32>();
auto* ty = create<Vector>(el_ty, 3u);
Constant c(ty, {1.0_a, 2.0_a, 3.0_a});
EXPECT_EQ(c.Element<f32>(0), 1.0_f);
EXPECT_EQ(c.Element<f32>(1), 2.0_f);
EXPECT_EQ(c.Element<f32>(2), 3.0_f);
EXPECT_EQ(c.ElementCount(), 3u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_vec3_f16) {
auto* el_ty = create<F16>();
auto* ty = create<Vector>(el_ty, 3u);
Constant c(ty, {1.0_a, 2.0_a, 3.0_a});
EXPECT_EQ(c.Element<f16>(0), 1.0_h);
EXPECT_EQ(c.Element<f16>(1), 2.0_h);
EXPECT_EQ(c.Element<f16>(2), 3.0_h);
EXPECT_EQ(c.ElementCount(), 3u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_mat2x3_af) {
auto* el_ty = create<AbstractFloat>();
auto* ty = create<Matrix>(create<Vector>(el_ty, 3u), 2u);
Constant c(ty, {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a});
EXPECT_EQ(c.Element<AFloat>(0), 1.0_a);
EXPECT_EQ(c.Element<AFloat>(1), 2.0_a);
EXPECT_EQ(c.Element<AFloat>(2), 3.0_a);
EXPECT_EQ(c.Element<AFloat>(3), 4.0_a);
EXPECT_EQ(c.Element<AFloat>(4), 5.0_a);
EXPECT_EQ(c.Element<AFloat>(5), 6.0_a);
EXPECT_EQ(c.ElementCount(), 6u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_mat2x3_f32) {
auto* el_ty = create<F32>();
auto* ty = create<Matrix>(create<Vector>(el_ty, 3u), 2u);
Constant c(ty, {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a});
EXPECT_EQ(c.Element<f32>(0), 1.0_f);
EXPECT_EQ(c.Element<f32>(1), 2.0_f);
EXPECT_EQ(c.Element<f32>(2), 3.0_f);
EXPECT_EQ(c.Element<f32>(3), 4.0_f);
EXPECT_EQ(c.Element<f32>(4), 5.0_f);
EXPECT_EQ(c.Element<f32>(5), 6.0_f);
EXPECT_EQ(c.ElementCount(), 6u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_mat2x3_f16) {
auto* el_ty = create<F16>();
auto* ty = create<Matrix>(create<Vector>(el_ty, 3u), 2u);
Constant c(ty, {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a});
EXPECT_EQ(c.Element<f16>(0), 1.0_h);
EXPECT_EQ(c.Element<f16>(1), 2.0_h);
EXPECT_EQ(c.Element<f16>(2), 3.0_h);
EXPECT_EQ(c.Element<f16>(3), 4.0_h);
EXPECT_EQ(c.Element<f16>(4), 5.0_h);
EXPECT_EQ(c.Element<f16>(5), 6.0_h);
EXPECT_EQ(c.ElementCount(), 6u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_arr_vec3_ai) {
auto* el_ty = create<AbstractInt>();
auto* ty = Array(2, create<Vector>(el_ty, 3u));
Constant c(ty, {1_a, 2_a, 3_a, 4_a, 5_a, 6_a});
EXPECT_EQ(c.Element<AInt>(0), 1_a);
EXPECT_EQ(c.Element<AInt>(1), 2_a);
EXPECT_EQ(c.Element<AInt>(2), 3_a);
EXPECT_EQ(c.Element<AInt>(3), 4_a);
EXPECT_EQ(c.Element<AInt>(4), 5_a);
EXPECT_EQ(c.Element<AInt>(5), 6_a);
EXPECT_EQ(c.ElementCount(), 6u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_arr_vec3_i32) {
auto* el_ty = create<I32>();
auto* ty = Array(2, create<Vector>(el_ty, 3u));
Constant c(ty, {1_a, 2_a, 3_a, 4_a, 5_a, 6_a});
EXPECT_EQ(c.Element<i32>(0), 1_i);
EXPECT_EQ(c.Element<i32>(1), 2_i);
EXPECT_EQ(c.Element<i32>(2), 3_i);
EXPECT_EQ(c.Element<i32>(3), 4_i);
EXPECT_EQ(c.Element<i32>(4), 5_i);
EXPECT_EQ(c.Element<i32>(5), 6_i);
EXPECT_EQ(c.ElementCount(), 6u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_arr_vec3_u32) {
auto* el_ty = create<U32>();
auto* ty = Array(2, create<Vector>(el_ty, 3u));
Constant c(ty, {1_a, 2_a, 3_a, 4_a, 5_a, 6_a});
EXPECT_EQ(c.Element<u32>(0), 1_u);
EXPECT_EQ(c.Element<u32>(1), 2_u);
EXPECT_EQ(c.Element<u32>(2), 3_u);
EXPECT_EQ(c.Element<u32>(3), 4_u);
EXPECT_EQ(c.Element<u32>(4), 5_u);
EXPECT_EQ(c.Element<u32>(5), 6_u);
EXPECT_EQ(c.ElementCount(), 6u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_arr_vec3_bool) {
auto* el_ty = create<Bool>();
auto* ty = Array(2, create<Vector>(el_ty, 2u));
Constant c(ty, {true, false, false, true});
EXPECT_EQ(c.Element<bool>(0), true);
EXPECT_EQ(c.Element<bool>(1), false);
EXPECT_EQ(c.Element<bool>(2), false);
EXPECT_EQ(c.Element<bool>(3), true);
EXPECT_EQ(c.ElementCount(), 4u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_arr_vec3_af) {
auto* el_ty = create<AbstractFloat>();
auto* ty = Array(2, create<Vector>(el_ty, 3u));
Constant c(ty, {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a});
EXPECT_EQ(c.Element<AFloat>(0), 1.0_a);
EXPECT_EQ(c.Element<AFloat>(1), 2.0_a);
EXPECT_EQ(c.Element<AFloat>(2), 3.0_a);
EXPECT_EQ(c.Element<AFloat>(3), 4.0_a);
EXPECT_EQ(c.Element<AFloat>(4), 5.0_a);
EXPECT_EQ(c.Element<AFloat>(5), 6.0_a);
EXPECT_EQ(c.ElementCount(), 6u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_arr_vec3_f32) {
auto* el_ty = create<F32>();
auto* ty = Array(2, create<Vector>(el_ty, 3u));
Constant c(ty, {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a});
EXPECT_EQ(c.Element<f32>(0), 1.0_f);
EXPECT_EQ(c.Element<f32>(1), 2.0_f);
EXPECT_EQ(c.Element<f32>(2), 3.0_f);
EXPECT_EQ(c.Element<f32>(3), 4.0_f);
EXPECT_EQ(c.Element<f32>(4), 5.0_f);
EXPECT_EQ(c.Element<f32>(5), 6.0_f);
EXPECT_EQ(c.ElementCount(), 6u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_arr_vec3_f16) {
auto* el_ty = create<F16>();
auto* ty = Array(2, create<Vector>(el_ty, 3u));
Constant c(ty, {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a});
EXPECT_EQ(c.Element<f16>(0), 1.0_h);
EXPECT_EQ(c.Element<f16>(1), 2.0_h);
EXPECT_EQ(c.Element<f16>(2), 3.0_h);
EXPECT_EQ(c.Element<f16>(3), 4.0_h);
EXPECT_EQ(c.Element<f16>(4), 5.0_h);
EXPECT_EQ(c.Element<f16>(5), 6.0_h);
EXPECT_EQ(c.ElementCount(), 6u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, Element_arr_arr_mat2x3_f32) {
auto* el_ty = create<F32>();
auto* ty = Array(2, Array(2, create<Matrix>(create<Vector>(el_ty, 3u), 2u)));
Constant c(ty, {
1.0_a, 2.0_a, 3.0_a, //
4.0_a, 5.0_a, 6.0_a, //
7.0_a, 8.0_a, 9.0_a, //
10.0_a, 11.0_a, 12.0_a, //
13.0_a, 14.0_a, 15.0_a, //
16.0_a, 17.0_a, 18.0_a, //
19.0_a, 20.0_a, 21.0_a, //
22.0_a, 23.0_a, 24.0_a, //
});
for (size_t i = 0; i < 24; i++) {
EXPECT_EQ(c.Element<f32>(i), f32(i + 1));
}
EXPECT_EQ(c.ElementCount(), 24u);
EXPECT_TYPE(c.Type(), ty);
EXPECT_TYPE(c.ElementType(), el_ty);
}
TEST_F(ConstantTest, AnyZero) {
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AnyZero(), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 3_a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AnyZero(), true);
auto* vec3_af = create<Vector>(create<AbstractFloat>(), 3u);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 3._a}).AnyZero(), false);
EXPECT_EQ(Constant(vec3_af, {0._a, 2._a, 3._a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_af, {1._a, 0._a, 3._a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 0._a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_af, {0._a, 0._a, 0._a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_af, {1._a, -2._a, 3._a}).AnyZero(), false);
EXPECT_EQ(Constant(vec3_af, {0._a, -2._a, 3._a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_af, {1._a, -0._a, 3._a}).AnyZero(), false);
EXPECT_EQ(Constant(vec3_af, {1._a, -2._a, 0._a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_af, {0._a, -0._a, 0._a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_af, {-0._a, -0._a, -0._a}).AnyZero(), false);
}
TEST_F(ConstantTest, AllZero) {
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 3_a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AllZero(), true);
auto* vec3_af = create<Vector>(create<AbstractFloat>(), 3u);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 3._a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_af, {0._a, 2._a, 3._a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 0._a, 3._a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 0._a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_af, {0._a, 0._a, 0._a}).AllZero(), true);
EXPECT_EQ(Constant(vec3_af, {1._a, -2._a, 3._a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_af, {0._a, -2._a, 3._a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_af, {1._a, -0._a, 3._a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_af, {1._a, -2._a, 0._a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_af, {0._a, -0._a, 0._a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_af, {-0._a, -0._a, -0._a}).AllZero(), false);
}
TEST_F(ConstantTest, AllEqual) {
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllEqual(), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 3_a}).AllEqual(), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 3_a, 3_a}).AllEqual(), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 1_a}).AllEqual(), true);
EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 2_a}).AllEqual(), true);
EXPECT_EQ(Constant(vec3_ai, {3_a, 3_a, 3_a}).AllEqual(), true);
EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AllEqual(), true);
auto* vec3_af = create<Vector>(create<AbstractFloat>(), 3u);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 3._a}).AllEqual(), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 1._a, 3._a}).AllEqual(), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 3._a, 3._a}).AllEqual(), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 1._a, 1._a}).AllEqual(), true);
EXPECT_EQ(Constant(vec3_af, {2._a, 2._a, 2._a}).AllEqual(), true);
EXPECT_EQ(Constant(vec3_af, {3._a, 3._a, 3._a}).AllEqual(), true);
EXPECT_EQ(Constant(vec3_af, {0._a, 0._a, 0._a}).AllEqual(), true);
EXPECT_EQ(Constant(vec3_af, {0._a, -0._a, 0._a}).AllEqual(), false);
}
TEST_F(ConstantTest, AllZeroRange) {
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 3_a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 0_a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 0_a}).AllZero(1, 3), true);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 3_a}).AllZero(0, 2), true);
EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 0_a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 0_a}).AllZero(0, 2), false);
auto* vec3_af = create<Vector>(create<AbstractFloat>(), 3u);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 3._a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {0._a, 2._a, 3._a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 3._a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 0._a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {0._a, 0._a, 3._a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {0._a, 2._a, 0._a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 0._a, 0._a}).AllZero(1, 3), true);
EXPECT_EQ(Constant(vec3_af, {1._a, -0._a, 0._a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 0._a, -0._a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {1._a, -0._a, -0._a}).AllZero(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 3._a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_af, {0._a, 2._a, 3._a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 3._a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 0._a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_af, {0._a, 0._a, 3._a}).AllZero(0, 2), true);
EXPECT_EQ(Constant(vec3_af, {-0._a, 0._a, 1._a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_af, {0._a, -0._a, 1._a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_af, {-0._a, -0._a, 1._a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_af, {0._a, 2._a, 0._a}).AllZero(0, 2), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 0._a, 0._a}).AllZero(0, 2), false);
}
TEST_F(ConstantTest, AllEqualRange) {
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 3_a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 3_a, 3_a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 1_a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 2_a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 3_a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 0_a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_ai, {0_a, 1_a, 0_a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 1_a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AllEqual(1, 3), true);
auto* vec3_af = create<Vector>(create<AbstractFloat>(), 3u);
EXPECT_EQ(Constant(vec3_af, {1._a, 2._a, 3._a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 1._a, 3._a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 3._a, 3._a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_af, {1._a, 1._a, 1._a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_af, {2._a, 2._a, 2._a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_af, {2._a, 2._a, 3._a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {1._a, 0._a, 0._a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_af, {0._a, 1._a, 0._a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {0._a, 0._a, 1._a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {0._a, 0._a, 0._a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_af, {1._a, -0._a, 0._a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {0._a, -1._a, 0._a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {0._a, -0._a, 1._a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {0._a, -0._a, 0._a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_af, {0._a, -0._a, -0._a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_af, {-0._a, -0._a, -0._a}).AllEqual(1, 3), true);
}
} // namespace
} // namespace tint::sem

View File

@ -25,7 +25,7 @@ namespace tint::sem {
Expression::Expression(const ast::Expression* declaration, Expression::Expression(const ast::Expression* declaration,
const sem::Type* type, const sem::Type* type,
const Statement* statement, const Statement* statement,
Constant constant, const Constant* constant,
bool has_side_effects, bool has_side_effects,
const Variable* source_var /* = nullptr */) const Variable* source_var /* = nullptr */)
: declaration_(declaration), : declaration_(declaration),

View File

@ -41,7 +41,7 @@ class Expression : public Castable<Expression, Node> {
Expression(const ast::Expression* declaration, Expression(const ast::Expression* declaration,
const sem::Type* type, const sem::Type* type,
const Statement* statement, const Statement* statement,
Constant constant, const Constant* constant,
bool has_side_effects, bool has_side_effects,
const Variable* source_var = nullptr); const Variable* source_var = nullptr);
@ -58,7 +58,7 @@ class Expression : public Castable<Expression, Node> {
const Statement* Stmt() const { return statement_; } const Statement* Stmt() const { return statement_; }
/// @return the constant value of this expression /// @return the constant value of this expression
const Constant& ConstantValue() const { return constant_; } const Constant* ConstantValue() const { return constant_; }
/// Returns the variable or parameter that this expression derives from. /// Returns the variable or parameter that this expression derives from.
/// For reference and pointer expressions, this will either be the originating /// For reference and pointer expressions, this will either be the originating
@ -88,7 +88,7 @@ class Expression : public Castable<Expression, Node> {
private: private:
const sem::Type* const type_; const sem::Type* const type_;
const Statement* const statement_; const Statement* const statement_;
const Constant constant_; const Constant* const constant_;
sem::Behaviors behaviors_{sem::Behavior::kNext}; sem::Behaviors behaviors_{sem::Behavior::kNext};
const bool has_side_effects_; const bool has_side_effects_;
}; };

View File

@ -23,13 +23,30 @@ using namespace tint::number_suffixes; // NOLINT
namespace tint::sem { namespace tint::sem {
namespace { namespace {
class MockConstant : public sem::Constant {
public:
explicit MockConstant(const sem::Type* ty) : type(ty) {}
~MockConstant() override {}
const sem::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
const Constant* Index(size_t) const override { return {}; }
bool AllZero() const override { return {}; }
bool AnyZero() const override { return {}; }
bool AllEqual() const override { return {}; }
size_t Hash() const override { return 0; }
private:
const sem::Type* type;
};
using ExpressionTest = TestHelper; using ExpressionTest = TestHelper;
TEST_F(ExpressionTest, UnwrapMaterialize) { TEST_F(ExpressionTest, UnwrapMaterialize) {
MockConstant c(create<I32>());
auto* a = create<Expression>(/* declaration */ nullptr, create<I32>(), /* statement */ nullptr, auto* a = create<Expression>(/* declaration */ nullptr, create<I32>(), /* statement */ nullptr,
Constant{}, /* constant_value */ nullptr,
/* has_side_effects */ false, /* source_var */ nullptr); /* has_side_effects */ false, /* source_var */ nullptr);
auto* b = create<Materialize>(a, /* statement */ nullptr, Constant{create<I32>(), {1_a}}); auto* b = create<Materialize>(a, /* statement */ nullptr, &c);
EXPECT_EQ(a, a->UnwrapMaterialize()); EXPECT_EQ(a, a->UnwrapMaterialize());
EXPECT_EQ(a, b->UnwrapMaterialize()); EXPECT_EQ(a, b->UnwrapMaterialize());

View File

@ -27,7 +27,7 @@ IndexAccessorExpression::IndexAccessorExpression(const ast::IndexAccessorExpress
const Expression* object, const Expression* object,
const Expression* index, const Expression* index,
const Statement* statement, const Statement* statement,
Constant constant, const Constant* constant,
bool has_side_effects, bool has_side_effects,
const Variable* source_var /* = nullptr */) const Variable* source_var /* = nullptr */)
: Base(declaration, type, statement, constant, has_side_effects, source_var), : Base(declaration, type, statement, constant, has_side_effects, source_var),

View File

@ -43,7 +43,7 @@ class IndexAccessorExpression final : public Castable<IndexAccessorExpression, E
const Expression* object, const Expression* object,
const Expression* index, const Expression* index,
const Statement* statement, const Statement* statement,
Constant constant, const Constant* constant,
bool has_side_effects, bool has_side_effects,
const Variable* source_var = nullptr); const Variable* source_var = nullptr);

View File

@ -17,19 +17,16 @@
TINT_INSTANTIATE_TYPEINFO(tint::sem::Materialize); TINT_INSTANTIATE_TYPEINFO(tint::sem::Materialize);
namespace tint::sem { namespace tint::sem {
Materialize::Materialize(const Expression* expr,
Materialize::Materialize(const Expression* expr, const Statement* statement, Constant constant) const Statement* statement,
const Constant* constant)
: Base(/* declaration */ expr->Declaration(), : Base(/* declaration */ expr->Declaration(),
/* type */ constant.Type(), /* type */ constant->Type(),
/* statement */ statement, /* statement */ statement,
/* constant */ constant, /* constant */ constant,
/* has_side_effects */ false, /* has_side_effects */ false,
/* source_var */ expr->SourceVariable()), /* source_var */ expr->SourceVariable()),
expr_(expr) { expr_(expr) {}
// Materialize nodes only wrap compile-time expressions, and so the Materialize expression must
// have a constant value.
TINT_ASSERT(Semantic, constant.IsValid());
}
Materialize::~Materialize() = default; Materialize::~Materialize() = default;

View File

@ -31,7 +31,7 @@ class Materialize final : public Castable<Materialize, Expression> {
/// @param expr the inner expression, being materialized /// @param expr the inner expression, being materialized
/// @param statement the statement that owns this expression /// @param statement the statement that owns this expression
/// @param constant the constant value of this expression /// @param constant the constant value of this expression
Materialize(const Expression* expr, const Statement* statement, Constant constant); Materialize(const Expression* expr, const Statement* statement, const Constant* constant);
/// Destructor /// Destructor
~Materialize() override; ~Materialize() override;

View File

@ -29,8 +29,7 @@ MemberAccessorExpression::MemberAccessorExpression(const ast::MemberAccessorExpr
const Expression* object, const Expression* object,
bool has_side_effects, bool has_side_effects,
const Variable* source_var /* = nullptr */) const Variable* source_var /* = nullptr */)
: Base(declaration, type, statement, Constant{}, has_side_effects, source_var), : Base(declaration, type, statement, nullptr, has_side_effects, source_var), object_(object) {}
object_(object) {}
MemberAccessorExpression::~MemberAccessorExpression() = default; MemberAccessorExpression::~MemberAccessorExpression() = default;

View File

@ -33,7 +33,7 @@ Variable::Variable(const ast::Variable* declaration,
const sem::Type* type, const sem::Type* type,
ast::StorageClass storage_class, ast::StorageClass storage_class,
ast::Access access, ast::Access access,
Constant constant_value) const Constant* constant_value)
: declaration_(declaration), : declaration_(declaration),
type_(type), type_(type),
storage_class_(storage_class), storage_class_(storage_class),
@ -47,9 +47,8 @@ LocalVariable::LocalVariable(const ast::Variable* declaration,
ast::StorageClass storage_class, ast::StorageClass storage_class,
ast::Access access, ast::Access access,
const sem::Statement* statement, const sem::Statement* statement,
Constant constant_value) const Constant* constant_value)
: Base(declaration, type, storage_class, access, std::move(constant_value)), : Base(declaration, type, storage_class, access, constant_value), statement_(statement) {}
statement_(statement) {}
LocalVariable::~LocalVariable() = default; LocalVariable::~LocalVariable() = default;
@ -57,9 +56,9 @@ GlobalVariable::GlobalVariable(const ast::Variable* declaration,
const sem::Type* type, const sem::Type* type,
ast::StorageClass storage_class, ast::StorageClass storage_class,
ast::Access access, ast::Access access,
Constant constant_value, const Constant* constant_value,
sem::BindingPoint binding_point) sem::BindingPoint binding_point)
: Base(declaration, type, storage_class, access, std::move(constant_value)), : Base(declaration, type, storage_class, access, constant_value),
binding_point_(binding_point) {} binding_point_(binding_point) {}
GlobalVariable::~GlobalVariable() = default; GlobalVariable::~GlobalVariable() = default;
@ -70,7 +69,7 @@ Parameter::Parameter(const ast::Parameter* declaration,
ast::StorageClass storage_class, ast::StorageClass storage_class,
ast::Access access, ast::Access access,
const ParameterUsage usage /* = ParameterUsage::kNone */) const ParameterUsage usage /* = ParameterUsage::kNone */)
: Base(declaration, type, storage_class, access, Constant{}), index_(index), usage_(usage) {} : Base(declaration, type, storage_class, access, nullptr), index_(index), usage_(usage) {}
Parameter::~Parameter() = default; Parameter::~Parameter() = default;

View File

@ -52,7 +52,7 @@ class Variable : public Castable<Variable, Node> {
const sem::Type* type, const sem::Type* type,
ast::StorageClass storage_class, ast::StorageClass storage_class,
ast::Access access, ast::Access access,
Constant constant_value); const Constant* constant_value);
/// Destructor /// Destructor
~Variable() override; ~Variable() override;
@ -70,7 +70,7 @@ class Variable : public Castable<Variable, Node> {
ast::Access Access() const { return access_; } ast::Access Access() const { return access_; }
/// @return the constant value of this expression /// @return the constant value of this expression
const Constant& ConstantValue() const { return constant_value_; } const Constant* ConstantValue() const { return constant_value_; }
/// @returns the variable constructor expression, or nullptr if the variable /// @returns the variable constructor expression, or nullptr if the variable
/// does not have one. /// does not have one.
@ -91,7 +91,7 @@ class Variable : public Castable<Variable, Node> {
const sem::Type* const type_; const sem::Type* const type_;
const ast::StorageClass storage_class_; const ast::StorageClass storage_class_;
const ast::Access access_; const ast::Access access_;
const Constant constant_value_; const Constant* constant_value_;
const Expression* constructor_ = nullptr; const Expression* constructor_ = nullptr;
std::vector<const VariableUser*> users_; std::vector<const VariableUser*> users_;
}; };
@ -111,7 +111,7 @@ class LocalVariable final : public Castable<LocalVariable, Variable> {
ast::StorageClass storage_class, ast::StorageClass storage_class,
ast::Access access, ast::Access access,
const sem::Statement* statement, const sem::Statement* statement,
Constant constant_value); const Constant* constant_value);
/// Destructor /// Destructor
~LocalVariable() override; ~LocalVariable() override;
@ -145,7 +145,7 @@ class GlobalVariable final : public Castable<GlobalVariable, Variable> {
const sem::Type* type, const sem::Type* type,
ast::StorageClass storage_class, ast::StorageClass storage_class,
ast::Access access, ast::Access access,
Constant constant_value, const Constant* constant_value,
sem::BindingPoint binding_point = {}); sem::BindingPoint binding_point = {});
/// Destructor /// Destructor

View File

@ -46,7 +46,7 @@ class LocalizeStructArrayAssignment::State {
expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) { expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
// Indexing using a runtime value? // Indexing using a runtime value?
auto* idx_sem = ctx.src->Sem().Get(ia->index); auto* idx_sem = ctx.src->Sem().Get(ia->index);
if (!idx_sem->ConstantValue().IsValid()) { if (!idx_sem->ConstantValue()) {
// Indexing a member access expr? // Indexing a member access expr?
if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) { if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
// That accesses an array? // That accesses an array?

View File

@ -275,7 +275,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
if (auto* sem_e = sem.Get(e)) { if (auto* sem_e = sem.Get(e)) {
if (auto* var_user = sem_e->As<sem::VariableUser>()) { if (auto* var_user = sem_e->As<sem::VariableUser>()) {
// Don't hoist constants. // Don't hoist constants.
if (var_user->ConstantValue().IsValid()) { if (var_user->ConstantValue()) {
return false; return false;
} }
// Don't hoist read-only variables as they cannot receive // Don't hoist read-only variables as they cannot receive

View File

@ -120,17 +120,18 @@ struct Robustness::State {
return nullptr; return nullptr;
} }
if (auto idx_constant = idx_sem->ConstantValue()) { if (auto* idx_constant = idx_sem->ConstantValue()) {
// Constant value index // Constant value index
if (idx_constant.Type()->Is<sem::I32>()) { auto val = std::get<AInt>(idx_constant->Value());
idx.i32 = static_cast<int32_t>(idx_constant.Element<AInt>(0).value); if (idx_constant->Type()->Is<sem::I32>()) {
idx.i32 = static_cast<int32_t>(val);
idx.is_signed = true; idx.is_signed = true;
} else if (idx_constant.Type()->Is<sem::U32>()) { } else if (idx_constant->Type()->Is<sem::U32>()) {
idx.u32 = static_cast<uint32_t>(idx_constant.Element<AInt>(0).value); idx.u32 = static_cast<uint32_t>(val);
idx.is_signed = false; idx.is_signed = false;
} else { } else {
TINT_ICE(Transform, b.Diagnostics()) << "unsupported constant value for accessor " TINT_ICE(Transform, b.Diagnostics()) << "unsupported constant value for accessor "
<< idx_constant.Type()->TypeInfo().name; << idx_constant->Type()->TypeInfo().name;
return nullptr; return nullptr;
} }
} else { } else {

View File

@ -358,8 +358,8 @@ struct ZeroInitWorkgroupMemory::State {
continue; continue;
} }
auto* sem = ctx.src->Sem().Get(expr); auto* sem = ctx.src->Sem().Get(expr);
if (auto c = sem->ConstantValue()) { if (auto* c = sem->ConstantValue()) {
workgroup_size_const *= c.Element<AInt>(0).value; workgroup_size_const *= c->As<AInt>();
continue; continue;
} }
// Constant value could not be found. Build expression instead. // Constant value could not be found. Build expression instead.

View File

@ -59,7 +59,7 @@ const sem::Expression* Zero(ProgramBuilder& b, const sem::Type* ty, const sem::S
<< "unsupported vector element type: " << ty->TypeInfo().name; << "unsupported vector element type: " << ty->TypeInfo().name;
return nullptr; return nullptr;
} }
auto* sem = b.create<sem::Expression>(expr, ty, stmt, sem::Constant{}, auto* sem = b.create<sem::Expression>(expr, ty, stmt, /* constant_value */ nullptr,
/* has_side_effects */ false); /* has_side_effects */ false);
b.Sem().Add(expr, sem); b.Sem().Add(expr, sem);
return sem; return sem;
@ -139,7 +139,7 @@ const sem::Call* AppendVector(ProgramBuilder* b,
ast::StorageClass::kNone, ast::Access::kUndefined)); ast::StorageClass::kNone, ast::Access::kUndefined));
auto* scalar_cast_sem = b->create<sem::Call>( auto* scalar_cast_sem = b->create<sem::Call>(
scalar_cast_ast, scalar_cast_target, std::vector<const sem::Expression*>{scalar_sem}, scalar_cast_ast, scalar_cast_target, std::vector<const sem::Expression*>{scalar_sem},
statement, sem::Constant{}, /* has_side_effects */ false); statement, /* constant_value */ nullptr, /* has_side_effects */ false);
b->Sem().Add(scalar_cast_ast, scalar_cast_sem); b->Sem().Add(scalar_cast_ast, scalar_cast_sem);
packed.emplace_back(scalar_cast_sem); packed.emplace_back(scalar_cast_sem);
} else { } else {
@ -158,7 +158,7 @@ const sem::Call* AppendVector(ProgramBuilder* b,
ast::Access::kUndefined); ast::Access::kUndefined);
})); }));
auto* constructor_sem = b->create<sem::Call>(constructor_ast, constructor_target, packed, auto* constructor_sem = b->create<sem::Call>(constructor_ast, constructor_target, packed,
statement, sem::Constant{}, statement, /* constant_value */ nullptr,
/* has_side_effects */ false); /* has_side_effects */ false);
b->Sem().Add(constructor_ast, constructor_sem); b->Sem().Add(constructor_ast, constructor_sem);
return constructor_sem; return constructor_sem;

View File

@ -1332,7 +1332,7 @@ bool GeneratorImpl::EmitBarrierCall(std::ostream& out, const sem::Builtin* built
const ast::Expression* GeneratorImpl::CreateF32Zero(const sem::Statement* stmt) { const ast::Expression* GeneratorImpl::CreateF32Zero(const sem::Statement* stmt) {
auto* zero = builder_.Expr(0_f); auto* zero = builder_.Expr(0_f);
auto* f32 = builder_.create<sem::F32>(); auto* f32 = builder_.create<sem::F32>();
auto* sem_zero = builder_.create<sem::Expression>(zero, f32, stmt, sem::Constant{}, auto* sem_zero = builder_.create<sem::Expression>(zero, f32, stmt, /* constant_value */ nullptr,
/* has_side_effects */ false); /* has_side_effects */ false);
builder_.Sem().Add(zero, sem_zero); builder_.Sem().Add(zero, sem_zero);
return zero; return zero;
@ -1771,7 +1771,7 @@ bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
bool GeneratorImpl::EmitExpression(std::ostream& out, const ast::Expression* expr) { bool GeneratorImpl::EmitExpression(std::ostream& out, const ast::Expression* expr) {
if (auto* sem = builder_.Sem().Get(expr)) { if (auto* sem = builder_.Sem().Get(expr)) {
if (auto constant = sem->ConstantValue()) { if (auto* constant = sem->ConstantValue()) {
return EmitConstant(out, constant); return EmitConstant(out, constant);
} }
} }
@ -2214,31 +2214,23 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
return true; return true;
} }
bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant& constant) { bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constant) {
return EmitConstantRange(out, constant, constant.Type(), 0, constant.ElementCount());
}
bool GeneratorImpl::EmitConstantRange(std::ostream& out,
const sem::Constant& constant,
const sem::Type* range_ty,
size_t start,
size_t end) {
return Switch( return Switch(
range_ty, // constant->Type(), //
[&](const sem::Bool*) { [&](const sem::Bool*) {
out << (constant.Element<AInt>(start) ? "true" : "false"); out << (constant->As<AInt>() ? "true" : "false");
return true; return true;
}, },
[&](const sem::F32*) { [&](const sem::F32*) {
PrintF32(out, static_cast<float>(constant.Element<AFloat>(start))); PrintF32(out, constant->As<float>());
return true; return true;
}, },
[&](const sem::I32*) { [&](const sem::I32*) {
out << constant.Element<AInt>(start).value; out << constant->As<AInt>();
return true; return true;
}, },
[&](const sem::U32*) { [&](const sem::U32*) {
out << constant.Element<AInt>(start).value << "u"; out << constant->As<AInt>() << "u";
return true; return true;
}, },
[&](const sem::Vector* v) { [&](const sem::Vector* v) {
@ -2248,15 +2240,15 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
ScopedParen sp(out); ScopedParen sp(out);
if (constant.AllEqual(start, end)) { if (constant->AllEqual()) {
return EmitConstantRange(out, constant, v->type(), start, start + 1); return EmitConstant(out, constant->Index(0));
} }
for (size_t i = start; i < end; i++) { for (size_t i = 0; i < v->Width(); i++) {
if (i > start) { if (i > 0) {
out << ", "; out << ", ";
} }
if (!EmitConstantRange(out, constant, v->type(), i, i + 1u)) { if (!EmitConstant(out, constant->Index(i))) {
return false; return false;
} }
} }
@ -2273,9 +2265,7 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
if (column_idx > 0) { if (column_idx > 0) {
out << ", "; out << ", ";
} }
size_t col_start = m->rows() * column_idx; if (!EmitConstant(out, constant->Index(column_idx))) {
size_t col_end = col_start + m->rows();
if (!EmitConstantRange(out, constant, m->ColumnType(), col_start, col_end)) {
return false; return false;
} }
} }
@ -2288,15 +2278,11 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
ScopedParen sp(out); ScopedParen sp(out);
auto* el_ty = a->ElemType(); for (size_t i = 0; i < a->Count(); i++) {
if (i > 0) {
uint32_t step = 0;
sem::Type::DeepestElementOf(el_ty, &step);
for (size_t i = start; i < end; i += step) {
if (i > start) {
out << ", "; out << ", ";
} }
if (!EmitConstantRange(out, constant, el_ty, i, i + step)) { if (!EmitConstant(out, constant->Index(i))) {
return false; return false;
} }
} }
@ -2306,7 +2292,7 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
[&](Default) { [&](Default) {
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"unhandled constant type: " + builder_.FriendlyName(constant.Type())); "unhandled constant type: " + builder_.FriendlyName(constant->Type()));
return false; return false;
}); });
} }

View File

@ -346,19 +346,7 @@ class GeneratorImpl : public TextGenerator {
/// @param out the output stream /// @param out the output stream
/// @param constant the constant value to emit /// @param constant the constant value to emit
/// @returns true if the constant value was successfully emitted /// @returns true if the constant value was successfully emitted
bool EmitConstant(std::ostream& out, const sem::Constant& constant); bool EmitConstant(std::ostream& out, const sem::Constant* constant);
/// Handles emitting a sub-range of a constant value
/// @param out the output stream
/// @param constant the constant value to emit
/// @param range_ty the sub-range type
/// @param start the element index for the first element
/// @param end the element index for one past the last element
/// @returns true if the constant value was successfully emitted
bool EmitConstantRange(std::ostream& out,
const sem::Constant& constant,
const sem::Type* range_ty,
size_t start,
size_t end);
/// Handles a literal /// Handles a literal
/// @param out the output stream /// @param out the output stream
/// @param lit the literal to emit /// @param lit the literal to emit

View File

@ -615,8 +615,7 @@ bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
if (auto* mat = TypeOf(lhs_sub_access->object)->UnwrapRef()->As<sem::Matrix>()) { if (auto* mat = TypeOf(lhs_sub_access->object)->UnwrapRef()->As<sem::Matrix>()) {
auto* rhs_col_idx_sem = builder_.Sem().Get(lhs_access->index); auto* rhs_col_idx_sem = builder_.Sem().Get(lhs_access->index);
auto* rhs_row_idx_sem = builder_.Sem().Get(lhs_sub_access->index); auto* rhs_row_idx_sem = builder_.Sem().Get(lhs_sub_access->index);
if (!rhs_col_idx_sem->ConstantValue().IsValid() || if (!rhs_col_idx_sem->ConstantValue() || !rhs_row_idx_sem->ConstantValue()) {
!rhs_row_idx_sem->ConstantValue().IsValid()) {
return EmitDynamicMatrixScalarAssignment(stmt, mat); return EmitDynamicMatrixScalarAssignment(stmt, mat);
} }
} }
@ -626,7 +625,7 @@ bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
const auto* lhs_access_type = TypeOf(lhs_access->object)->UnwrapRef(); const auto* lhs_access_type = TypeOf(lhs_access->object)->UnwrapRef();
if (auto* mat = lhs_access_type->As<sem::Matrix>()) { if (auto* mat = lhs_access_type->As<sem::Matrix>()) {
auto* lhs_index_sem = builder_.Sem().Get(lhs_access->index); auto* lhs_index_sem = builder_.Sem().Get(lhs_access->index);
if (!lhs_index_sem->ConstantValue().IsValid()) { if (!lhs_index_sem->ConstantValue()) {
return EmitDynamicMatrixVectorAssignment(stmt, mat); return EmitDynamicMatrixVectorAssignment(stmt, mat);
} }
} }
@ -634,7 +633,7 @@ bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
// indices // indices
if (auto* vec = lhs_access_type->As<sem::Vector>()) { if (auto* vec = lhs_access_type->As<sem::Vector>()) {
auto* rhs_sem = builder_.Sem().Get(lhs_access->index); auto* rhs_sem = builder_.Sem().Get(lhs_access->index);
if (!rhs_sem->ConstantValue().IsValid()) { if (!rhs_sem->ConstantValue()) {
return EmitDynamicVectorAssignment(stmt, vec); return EmitDynamicVectorAssignment(stmt, vec);
} }
} }
@ -654,28 +653,30 @@ bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
bool GeneratorImpl::EmitExpressionOrOneIfZero(std::ostream& out, const ast::Expression* expr) { bool GeneratorImpl::EmitExpressionOrOneIfZero(std::ostream& out, const ast::Expression* expr) {
// For constants, replace literal 0 with 1. // For constants, replace literal 0 with 1.
if (const auto& val = builder_.Sem().Get(expr)->ConstantValue()) { if (const auto* val = builder_.Sem().Get(expr)->ConstantValue()) {
if (!val.AnyZero()) { if (!val->AnyZero()) {
return EmitExpression(out, expr); return EmitExpression(out, expr);
} }
if (val.Type()->IsAnyOf<sem::I32, sem::U32>()) { auto* ty = val->Type();
return EmitValue(out, val.Type(), 1);
if (ty->IsAnyOf<sem::I32, sem::U32>()) {
return EmitValue(out, ty, 1);
} }
if (auto* vec = val.Type()->As<sem::Vector>()) { if (auto* vec = ty->As<sem::Vector>()) {
auto* elem_ty = vec->type(); auto* elem_ty = vec->type();
if (!EmitType(out, val.Type(), ast::StorageClass::kNone, ast::Access::kUndefined, "")) { if (!EmitType(out, ty, ast::StorageClass::kNone, ast::Access::kUndefined, "")) {
return false; return false;
} }
out << "("; out << "(";
for (size_t i = 0; i < val.ElementCount(); ++i) { for (size_t i = 0; i < vec->Width(); ++i) {
if (i != 0) { if (i != 0) {
out << ", "; out << ", ";
} }
auto s = val.Element<AInt>(i).value; auto s = val->Index(i)->As<AInt>();
if (!EmitValue(out, elem_ty, (s == 0) ? 1 : static_cast<int>(s))) { if (!EmitValue(out, elem_ty, (s == 0) ? 1 : static_cast<int>(s))) {
return false; return false;
} }
@ -1181,9 +1182,9 @@ bool GeneratorImpl::EmitUniformBufferAccess(
// If true, use scalar_offset_value, otherwise use scalar_offset_expr // If true, use scalar_offset_value, otherwise use scalar_offset_expr
bool scalar_offset_constant = false; bool scalar_offset_constant = false;
if (auto val = offset_arg->ConstantValue()) { if (auto* val = offset_arg->ConstantValue()) {
TINT_ASSERT(Writer, val.Type()->Is<sem::U32>()); TINT_ASSERT(Writer, val->Type()->Is<sem::U32>());
scalar_offset_value = static_cast<uint32_t>(val.Element<AInt>(0).value); scalar_offset_value = static_cast<uint32_t>(std::get<AInt>(val->Value()));
scalar_offset_value /= 4; // bytes -> scalar index scalar_offset_value /= 4; // bytes -> scalar index
scalar_offset_constant = true; scalar_offset_constant = true;
} }
@ -2337,7 +2338,7 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
case sem::BuiltinType::kTextureGather: case sem::BuiltinType::kTextureGather:
out << ".Gather"; out << ".Gather";
if (builtin->Parameters()[0]->Usage() == sem::ParameterUsage::kComponent) { if (builtin->Parameters()[0]->Usage() == sem::ParameterUsage::kComponent) {
switch (call->Arguments()[0]->ConstantValue().Element<AInt>(0).value) { switch (call->Arguments()[0]->ConstantValue()->As<AInt>()) {
case 0: case 0:
out << "Red"; out << "Red";
break; break;
@ -2384,7 +2385,8 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
auto* i32 = builder_.create<sem::I32>(); auto* i32 = builder_.create<sem::I32>();
auto* zero = builder_.Expr(0_i); auto* zero = builder_.Expr(0_i);
auto* stmt = builder_.Sem().Get(vector)->Stmt(); auto* stmt = builder_.Sem().Get(vector)->Stmt();
builder_.Sem().Add(zero, builder_.create<sem::Expression>(zero, i32, stmt, sem::Constant{}, builder_.Sem().Add(
zero, builder_.create<sem::Expression>(zero, i32, stmt, /* constant_value */ nullptr,
/* has_side_effects */ false)); /* has_side_effects */ false));
auto* packed = AppendVector(&builder_, vector, zero); auto* packed = AppendVector(&builder_, vector, zero);
return EmitExpression(out, packed->Declaration()); return EmitExpression(out, packed->Declaration());
@ -2614,7 +2616,7 @@ bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
bool GeneratorImpl::EmitExpression(std::ostream& out, const ast::Expression* expr) { bool GeneratorImpl::EmitExpression(std::ostream& out, const ast::Expression* expr) {
if (auto* sem = builder_.Sem().Get(expr)) { if (auto* sem = builder_.Sem().Get(expr)) {
if (auto constant = sem->ConstantValue()) { if (auto* constant = sem->ConstantValue()) {
return EmitConstant(out, constant); return EmitConstant(out, constant);
} }
} }
@ -3109,43 +3111,35 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
return true; return true;
} }
bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant& constant) { bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constant) {
return EmitConstantRange(out, constant, constant.Type(), 0, constant.ElementCount());
}
bool GeneratorImpl::EmitConstantRange(std::ostream& out,
const sem::Constant& constant,
const sem::Type* range_ty,
size_t start,
size_t end) {
return Switch( return Switch(
range_ty, // constant->Type(), //
[&](const sem::Bool*) { [&](const sem::Bool*) {
out << (constant.Element<AInt>(start) ? "true" : "false"); out << (constant->As<AInt>() ? "true" : "false");
return true; return true;
}, },
[&](const sem::F32*) { [&](const sem::F32*) {
PrintF32(out, static_cast<float>(constant.Element<AFloat>(start))); PrintF32(out, constant->As<float>());
return true; return true;
}, },
[&](const sem::I32*) { [&](const sem::I32*) {
out << constant.Element<AInt>(start).value; out << constant->As<AInt>();
return true; return true;
}, },
[&](const sem::U32*) { [&](const sem::U32*) {
out << constant.Element<AInt>(start).value << "u"; out << constant->As<AInt>() << "u";
return true; return true;
}, },
[&](const sem::Vector* v) { [&](const sem::Vector* v) {
if (constant.AllEqual(start, end)) { if (constant->AllEqual()) {
{ {
ScopedParen sp(out); ScopedParen sp(out);
if (!EmitConstantRange(out, constant, v->type(), start, start + 1)) { if (!EmitConstant(out, constant->Index(0))) {
return false; return false;
} }
} }
out << "."; out << ".";
for (size_t i = start; i < end; i++) { for (size_t i = 0; i < v->Width(); i++) {
out << "x"; out << "x";
} }
return true; return true;
@ -3157,11 +3151,11 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
ScopedParen sp(out); ScopedParen sp(out);
for (size_t i = start; i < end; i++) { for (size_t i = 0; i < v->Width(); i++) {
if (i > start) { if (i > 0) {
out << ", "; out << ", ";
} }
if (!EmitConstantRange(out, constant, v->type(), i, i + 1u)) { if (!EmitConstant(out, constant->Index(i))) {
return false; return false;
} }
} }
@ -3174,20 +3168,18 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
ScopedParen sp(out); ScopedParen sp(out);
for (size_t column_idx = 0; column_idx < m->columns(); column_idx++) { for (size_t i = 0; i < m->columns(); i++) {
if (column_idx > 0) { if (i > 0) {
out << ", "; out << ", ";
} }
size_t col_start = m->rows() * column_idx; if (!EmitConstant(out, constant->Index(i))) {
size_t col_end = col_start + m->rows();
if (!EmitConstantRange(out, constant, m->ColumnType(), col_start, col_end)) {
return false; return false;
} }
} }
return true; return true;
}, },
[&](const sem::Array* a) { [&](const sem::Array* a) {
if (constant.AllZero(start, end)) { if (constant->AllZero()) {
out << "("; out << "(";
if (!EmitType(out, a, ast::StorageClass::kNone, ast::Access::kUndefined, "")) { if (!EmitType(out, a, ast::StorageClass::kNone, ast::Access::kUndefined, "")) {
return false; return false;
@ -3199,15 +3191,11 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
out << "{"; out << "{";
TINT_DEFER(out << "}"); TINT_DEFER(out << "}");
auto* el_ty = a->ElemType(); for (size_t i = 0; i < a->Count(); i++) {
if (i > 0) {
uint32_t step = 0;
sem::Type::DeepestElementOf(el_ty, &step);
for (size_t i = start; i < end; i += step) {
if (i > start) {
out << ", "; out << ", ";
} }
if (!EmitConstantRange(out, constant, el_ty, i, i + step)) { if (!EmitConstant(out, constant->Index(i))) {
return false; return false;
} }
} }
@ -3217,7 +3205,7 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
[&](Default) { [&](Default) {
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"unhandled constant type: " + builder_.FriendlyName(constant.Type())); "unhandled constant type: " + builder_.FriendlyName(constant->Type()));
return false; return false;
}); });
} }

View File

@ -93,7 +93,8 @@ class GeneratorImpl : public TextGenerator {
/// @param stmt the statement to emit /// @param stmt the statement to emit
/// @returns true if the statement was emitted successfully /// @returns true if the statement was emitted successfully
bool EmitAssign(const ast::AssignmentStatement* stmt); bool EmitAssign(const ast::AssignmentStatement* stmt);
/// Emits code such that if `expr` is zero, it emits one, else `expr` /// Emits code such that if `expr` is zero, it emits one, else `expr`.
/// Used to avoid divide-by-zeros by substituting constant zeros with ones.
/// @param out the output of the expression stream /// @param out the output of the expression stream
/// @param expr the expression /// @param expr the expression
/// @returns true if the expression was emitted, false otherwise /// @returns true if the expression was emitted, false otherwise
@ -342,19 +343,7 @@ class GeneratorImpl : public TextGenerator {
/// @param out the output stream /// @param out the output stream
/// @param constant the constant value to emit /// @param constant the constant value to emit
/// @returns true if the constant value was successfully emitted /// @returns true if the constant value was successfully emitted
bool EmitConstant(std::ostream& out, const sem::Constant& constant); bool EmitConstant(std::ostream& out, const sem::Constant* constant);
/// Handles emitting a sub-range of a constant value
/// @param out the output stream
/// @param constant the constant value to emit
/// @param range_ty the sub-range type
/// @param start the element index for the first element
/// @param end the element index for one past the last element
/// @returns true if the constant value was successfully emitted
bool EmitConstantRange(std::ostream& out,
const sem::Constant& constant,
const sem::Type* range_ty,
size_t start,
size_t end);
/// Handles a literal /// Handles a literal
/// @param out the output stream /// @param out the output stream
/// @param lit the literal to emit /// @param lit the literal to emit

View File

@ -1201,7 +1201,7 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
break; // Other texture dimensions don't have an offset break; // Other texture dimensions don't have an offset
} }
} }
auto c = component->ConstantValue().Element<AInt>(0); auto c = component->ConstantValue()->As<AInt>();
switch (c.value) { switch (c.value) {
case 0: case 0:
out << "component::x"; out << "component::x";
@ -1594,31 +1594,23 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
}); });
} }
bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant& constant) { bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constant) {
return EmitConstantRange(out, constant, constant.Type(), 0, constant.ElementCount());
}
bool GeneratorImpl::EmitConstantRange(std::ostream& out,
const sem::Constant& constant,
const sem::Type* range_ty,
size_t start,
size_t end) {
return Switch( return Switch(
range_ty, // constant->Type(), //
[&](const sem::Bool*) { [&](const sem::Bool*) {
out << (constant.Element<AInt>(start) ? "true" : "false"); out << (constant->As<AInt>() ? "true" : "false");
return true; return true;
}, },
[&](const sem::F32*) { [&](const sem::F32*) {
PrintF32(out, static_cast<float>(constant.Element<AFloat>(start))); PrintF32(out, constant->As<float>());
return true; return true;
}, },
[&](const sem::I32*) { [&](const sem::I32*) {
PrintI32(out, static_cast<int32_t>(constant.Element<AInt>(start).value)); PrintI32(out, constant->As<int32_t>());
return true; return true;
}, },
[&](const sem::U32*) { [&](const sem::U32*) {
out << constant.Element<AInt>(start).value << "u"; out << constant->As<AInt>() << "u";
return true; return true;
}, },
[&](const sem::Vector* v) { [&](const sem::Vector* v) {
@ -1628,18 +1620,18 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
ScopedParen sp(out); ScopedParen sp(out);
if (constant.AllEqual(start, end)) { if (constant->AllEqual()) {
if (!EmitConstantRange(out, constant, v->type(), start, start + 1)) { if (!EmitConstant(out, constant->Index(0))) {
return false; return false;
} }
return true; return true;
} }
for (size_t i = start; i < end; i++) { for (size_t i = 0; i < v->Width(); i++) {
if (i > start) { if (i > 0) {
out << ", "; out << ", ";
} }
if (!EmitConstantRange(out, constant, v->type(), i, i + 1u)) { if (!EmitConstant(out, constant->Index(i))) {
return false; return false;
} }
} }
@ -1652,13 +1644,11 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
ScopedParen sp(out); ScopedParen sp(out);
for (size_t column_idx = 0; column_idx < m->columns(); column_idx++) { for (size_t i = 0; i < m->columns(); i++) {
if (column_idx > 0) { if (i > 0) {
out << ", "; out << ", ";
} }
size_t col_start = m->rows() * column_idx; if (!EmitConstant(out, constant->Index(i))) {
size_t col_end = col_start + m->rows();
if (!EmitConstantRange(out, constant, m->ColumnType(), col_start, col_end)) {
return false; return false;
} }
} }
@ -1669,7 +1659,7 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
return false; return false;
} }
if (constant.AllZero(start, end)) { if (constant->AllZero()) {
out << "{}"; out << "{}";
return true; return true;
} }
@ -1677,15 +1667,11 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
out << "{"; out << "{";
TINT_DEFER(out << "}"); TINT_DEFER(out << "}");
auto* el_ty = a->ElemType(); for (size_t i = 0; i < a->Count(); i++) {
if (i > 0) {
uint32_t step = 0;
sem::Type::DeepestElementOf(el_ty, &step);
for (size_t i = start; i < end; i += step) {
if (i > start) {
out << ", "; out << ", ";
} }
if (!EmitConstantRange(out, constant, el_ty, i, i + step)) { if (!EmitConstant(out, constant->Index(i))) {
return false; return false;
} }
} }
@ -1695,7 +1681,7 @@ bool GeneratorImpl::EmitConstantRange(std::ostream& out,
[&](Default) { [&](Default) {
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"unhandled constant type: " + builder_.FriendlyName(constant.Type())); "unhandled constant type: " + builder_.FriendlyName(constant->Type()));
return false; return false;
}); });
} }

View File

@ -256,19 +256,7 @@ class GeneratorImpl : public TextGenerator {
/// @param out the output stream /// @param out the output stream
/// @param constant the constant value to emit /// @param constant the constant value to emit
/// @returns true if the constant value was successfully emitted /// @returns true if the constant value was successfully emitted
bool EmitConstant(std::ostream& out, const sem::Constant& constant); bool EmitConstant(std::ostream& out, const sem::Constant* constant);
/// Handles emitting a sub-range of a constant value
/// @param out the output stream
/// @param constant the constant value to emit
/// @param range_ty the sub-range type
/// @param start the element index for the first element
/// @param end the element index for one past the last element
/// @returns true if the constant value was successfully emitted
bool EmitConstantRange(std::ostream& out,
const sem::Constant& constant,
const sem::Type* range_ty,
size_t start,
size_t end);
/// Handles a literal /// Handles a literal
/// @param out the output of the expression stream /// @param out the output of the expression stream
/// @param lit the literal to emit /// @param lit the literal to emit

View File

@ -959,7 +959,7 @@ bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr, Ac
Operand(result_type_id), Operand(result_type_id),
extract, extract,
Operand(info->source_id), Operand(info->source_id),
Operand(idx_constval.Element<uint32_t>(0)), Operand(idx_constval->As<uint32_t>()),
})) { })) {
return false; return false;
} }
@ -1703,20 +1703,14 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
return GenerateConstantIfNeeded(constant); return GenerateConstantIfNeeded(constant);
} }
uint32_t Builder::GenerateConstantIfNeeded(const sem::Constant& constant) { uint32_t Builder::GenerateConstantIfNeeded(const sem::Constant* constant) {
return GenerateConstantRangeIfNeeded(constant, constant.Type(), 0, constant.ElementCount()); if (constant->AllZero()) {
} return GenerateConstantNullIfNeeded(constant->Type());
uint32_t Builder::GenerateConstantRangeIfNeeded(const sem::Constant& constant,
const sem::Type* range_ty,
size_t start,
size_t end) {
if (constant.AllZero(start, end)) {
return GenerateConstantNullIfNeeded(range_ty);
} }
auto* ty = constant->Type();
auto composite = [&](const sem::Type* el_ty) -> uint32_t { auto composite = [&](size_t el_count) -> uint32_t {
auto type_id = GenerateTypeIfNeeded(range_ty); auto type_id = GenerateTypeIfNeeded(ty);
if (!type_id) { if (!type_id) {
return 0; return 0;
} }
@ -1724,14 +1718,12 @@ uint32_t Builder::GenerateConstantRangeIfNeeded(const sem::Constant& constant,
static constexpr size_t kOpsResultIdx = 1; // operand index of the result static constexpr size_t kOpsResultIdx = 1; // operand index of the result
std::vector<Operand> ops; std::vector<Operand> ops;
ops.reserve(end - start + 2); ops.reserve(el_count + 2);
ops.emplace_back(type_id); ops.emplace_back(type_id);
ops.push_back(Operand(0u)); // Placeholder for the result ID ops.push_back(Operand(0u)); // Placeholder for the result ID
uint32_t step = 0; for (size_t i = 0; i < el_count; i++) {
sem::Type::DeepestElementOf(el_ty, &step); auto id = GenerateConstantIfNeeded(constant->Index(i));
for (size_t i = start; i < end; i += step) {
auto id = GenerateConstantRangeIfNeeded(constant, el_ty, i, i + step);
if (!id) { if (!id) {
return 0; return 0;
} }
@ -1749,28 +1741,28 @@ uint32_t Builder::GenerateConstantRangeIfNeeded(const sem::Constant& constant,
}; };
return Switch( return Switch(
range_ty, // ty, //
[&](const sem::Bool*) { [&](const sem::Bool*) {
bool val = constant.Element<AInt>(start); bool val = constant->As<bool>();
return GenerateConstantIfNeeded(ScalarConstant::Bool(val)); return GenerateConstantIfNeeded(ScalarConstant::Bool(val));
}, },
[&](const sem::F32*) { [&](const sem::F32*) {
auto val = f32(constant.Element<AFloat>(start)); auto val = constant->As<f32>();
return GenerateConstantIfNeeded(ScalarConstant::F32(val.value)); return GenerateConstantIfNeeded(ScalarConstant::F32(val.value));
}, },
[&](const sem::I32*) { [&](const sem::I32*) {
auto val = i32(constant.Element<AInt>(start)); auto val = constant->As<i32>();
return GenerateConstantIfNeeded(ScalarConstant::I32(val.value)); return GenerateConstantIfNeeded(ScalarConstant::I32(val.value));
}, },
[&](const sem::U32*) { [&](const sem::U32*) {
auto val = u32(constant.Element<AInt>(start)); auto val = constant->As<u32>();
return GenerateConstantIfNeeded(ScalarConstant::U32(val.value)); return GenerateConstantIfNeeded(ScalarConstant::U32(val.value));
}, },
[&](const sem::Vector* v) { return composite(v->type()); }, [&](const sem::Vector* v) { return composite(v->Width()); },
[&](const sem::Matrix* m) { return composite(m->ColumnType()); }, [&](const sem::Matrix* m) { return composite(m->columns()); },
[&](const sem::Array* a) { return composite(a->ElemType()); }, [&](const sem::Array* a) { return composite(a->Count()); },
[&](Default) { [&](Default) {
error_ = "unhandled constant type: " + builder_.FriendlyName(constant.Type()); error_ = "unhandled constant type: " + builder_.FriendlyName(ty);
return false; return false;
}); });
} }

View File

@ -554,18 +554,7 @@ class Builder {
/// Generates a constant value if needed /// Generates a constant value if needed
/// @param constant the constant to generate. /// @param constant the constant to generate.
/// @returns the ID on success or 0 on failure /// @returns the ID on success or 0 on failure
uint32_t GenerateConstantIfNeeded(const sem::Constant& constant); uint32_t GenerateConstantIfNeeded(const sem::Constant* constant);
/// Handles emitting a sub-range of a constant value
/// @param constant the constant value to emit
/// @param range_ty the sub-range type
/// @param start the element index for the first element
/// @param end the element index for one past the last element
/// @returns true if the constant value was successfully emitted
uint32_t GenerateConstantRangeIfNeeded(const sem::Constant& constant,
const sem::Type* range_ty,
size_t start,
size_t end);
/// Generates a scalar constant if needed /// Generates a scalar constant if needed
/// @param constant the constant to generate. /// @param constant the constant to generate.