ast: Remove TypeConstructorExpression

Add a new 'Target' to the ast::CallExpression, which can be either an
Identifier or Type. The Identifier may resolve to a Type, if the Type is
a structure or alias.

The Resolver now resolves the CallExpression target to one of the
following sem::CallTargets:
* sem::Function
* sem::Intrinsic
* sem::TypeConstructor
* sem::TypeCast

This change will allow us to remove the type tracking logic from the WGSL
parser, which is required for out-of-order module scope declarations.

Bug: tint:888
Bug: tint:1266
Change-Id: I696f117115a50981fd5c102a0d7764641bb755dd
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68525
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2021-11-15 20:45:50 +00:00
parent d12f48828b
commit 735dca8393
48 changed files with 2275 additions and 1446 deletions

View File

@ -314,8 +314,6 @@ libtint_source_set("libtint_core_all_src") {
"ast/texture.h", "ast/texture.h",
"ast/traverse_expressions.h", "ast/traverse_expressions.h",
"ast/type.h", "ast/type.h",
"ast/type_constructor_expression.cc",
"ast/type_constructor_expression.h",
"ast/type_decl.cc", "ast/type_decl.cc",
"ast/type_decl.h", "ast/type_decl.h",
"ast/type_name.cc", "ast/type_name.cc",
@ -408,8 +406,8 @@ libtint_source_set("libtint_core_all_src") {
"sem/storage_texture_type.h", "sem/storage_texture_type.h",
"sem/switch_statement.h", "sem/switch_statement.h",
"sem/texture_type.h", "sem/texture_type.h",
"sem/type_cast.h",
"sem/type_constructor.h", "sem/type_constructor.h",
"sem/type_conversion.h",
"sem/type.h", "sem/type.h",
"sem/type_manager.h", "sem/type_manager.h",
"sem/type_mappings.h", "sem/type_mappings.h",
@ -576,10 +574,10 @@ libtint_source_set("libtint_sem_src") {
"sem/switch_statement.h", "sem/switch_statement.h",
"sem/texture_type.cc", "sem/texture_type.cc",
"sem/texture_type.h", "sem/texture_type.h",
"sem/type_cast.cc",
"sem/type_cast.h",
"sem/type_constructor.cc", "sem/type_constructor.cc",
"sem/type_constructor.h", "sem/type_constructor.h",
"sem/type_conversion.cc",
"sem/type_conversion.h",
"sem/type.cc", "sem/type.cc",
"sem/type.h", "sem/type.h",
"sem/type_manager.cc", "sem/type_manager.cc",

View File

@ -177,8 +177,6 @@ set(TINT_LIB_SRCS
ast/texture.cc ast/texture.cc
ast/texture.h ast/texture.h
ast/traverse_expressions.h ast/traverse_expressions.h
ast/type_constructor_expression.cc
ast/type_constructor_expression.h
ast/type_name.cc ast/type_name.cc
ast/type_name.h ast/type_name.h
ast/ast_type.cc # TODO(bclayton) - rename to type.cc ast/ast_type.cc # TODO(bclayton) - rename to type.cc
@ -379,10 +377,10 @@ set(TINT_LIB_SRCS
sem/switch_statement.h sem/switch_statement.h
sem/texture_type.cc sem/texture_type.cc
sem/texture_type.h sem/texture_type.h
sem/type_cast.cc
sem/type_cast.h
sem/type_constructor.cc sem/type_constructor.cc
sem/type_constructor.h sem/type_constructor.h
sem/type_conversion.cc
sem/type_conversion.h
sem/type.cc sem/type.cc
sem/type.h sem/type.h
sem/type_manager.cc sem/type_manager.cc
@ -644,7 +642,6 @@ if(${TINT_BUILD_TESTS})
ast/test_helper.h ast/test_helper.h
ast/texture_test.cc ast/texture_test.cc
ast/traverse_expressions_test.cc ast/traverse_expressions_test.cc
ast/type_constructor_expression_test.cc
ast/u32_test.cc ast/u32_test.cc
ast/uint_literal_expression_test.cc ast/uint_literal_expression_test.cc
ast/unary_op_expression_test.cc ast/unary_op_expression_test.cc

View File

@ -21,13 +21,39 @@ TINT_INSTANTIATE_TYPEINFO(tint::ast::CallExpression);
namespace tint { namespace tint {
namespace ast { namespace ast {
namespace {
CallExpression::Target ToTarget(const IdentifierExpression* name) {
CallExpression::Target target;
target.name = name;
return target;
}
CallExpression::Target ToTarget(const Type* type) {
CallExpression::Target target;
target.type = type;
return target;
}
} // namespace
CallExpression::CallExpression(ProgramID pid, CallExpression::CallExpression(ProgramID pid,
const Source& src, const Source& src,
const IdentifierExpression* fn, const IdentifierExpression* name,
ExpressionList a) ExpressionList a)
: Base(pid, src), func(fn), args(a) { : Base(pid, src), target(ToTarget(name)), args(a) {
TINT_ASSERT(AST, func); TINT_ASSERT(AST, name);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, name, program_id);
for (auto* arg : args) {
TINT_ASSERT(AST, arg);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, arg, program_id);
}
}
CallExpression::CallExpression(ProgramID pid,
const Source& src,
const Type* type,
ExpressionList a)
: Base(pid, src), target(ToTarget(type)), args(a) {
TINT_ASSERT(AST, type);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
for (auto* arg : args) { for (auto* arg : args) {
TINT_ASSERT(AST, arg); TINT_ASSERT(AST, arg);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, arg, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, arg, program_id);
@ -41,9 +67,11 @@ CallExpression::~CallExpression() = default;
const CallExpression* CallExpression::Clone(CloneContext* ctx) const { const CallExpression* CallExpression::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering // Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source); auto src = ctx->Clone(source);
auto* fn = ctx->Clone(func);
auto p = ctx->Clone(args); auto p = ctx->Clone(args);
return ctx->dst->create<CallExpression>(src, fn, p); return target.name
? ctx->dst->create<CallExpression>(src, ctx->Clone(target.name), p)
: ctx->dst->create<CallExpression>(src, ctx->Clone(target.type),
p);
} }
} // namespace ast } // namespace ast

View File

@ -21,20 +21,36 @@ namespace tint {
namespace ast { namespace ast {
// Forward declarations. // Forward declarations.
class Type;
class IdentifierExpression; class IdentifierExpression;
/// A call expression /// A call expression - represents either a:
/// * sem::Function
/// * sem::Intrinsic
/// * sem::TypeConstructor
/// * sem::TypeConversion
class CallExpression : public Castable<CallExpression, Expression> { class CallExpression : public Castable<CallExpression, Expression> {
public: public:
/// Constructor /// Constructor
/// @param program_id the identifier of the program that owns this node /// @param program_id the identifier of the program that owns this node
/// @param source the call expression source /// @param source the call expression source
/// @param func the function /// @param name the function or type name
/// @param args the arguments /// @param args the arguments
CallExpression(ProgramID program_id, CallExpression(ProgramID program_id,
const Source& source, const Source& source,
const IdentifierExpression* func, const IdentifierExpression* name,
ExpressionList args); ExpressionList args);
/// Constructor
/// @param program_id the identifier of the program that owns this node
/// @param source the call expression source
/// @param type the type
/// @param args the arguments
CallExpression(ProgramID program_id,
const Source& source,
const Type* type,
ExpressionList args);
/// Move constructor /// Move constructor
CallExpression(CallExpression&&); CallExpression(CallExpression&&);
~CallExpression() override; ~CallExpression() override;
@ -45,8 +61,19 @@ class CallExpression : public Castable<CallExpression, Expression> {
/// @return the newly cloned node /// @return the newly cloned node
const CallExpression* Clone(CloneContext* ctx) const override; const CallExpression* Clone(CloneContext* ctx) const override;
/// Target is either an identifier, or a Type.
/// One of these must be nullptr and the other a non-nullptr.
struct Target {
/// name is a function or intrinsic to call, or type name to construct or
/// cast-to
const IdentifierExpression* name = nullptr;
/// type to construct or cast-to
const Type* type = nullptr;
};
/// The target function /// The target function
const IdentifierExpression* const func; const Target target;
/// The arguments /// The arguments
const ExpressionList args; const ExpressionList args;
}; };

View File

@ -21,14 +21,15 @@ namespace {
using CallExpressionTest = TestHelper; using CallExpressionTest = TestHelper;
TEST_F(CallExpressionTest, Creation) { TEST_F(CallExpressionTest, CreationIdentifier) {
auto* func = Expr("func"); auto* func = Expr("func");
ExpressionList params; ExpressionList params;
params.push_back(Expr("param1")); params.push_back(Expr("param1"));
params.push_back(Expr("param2")); params.push_back(Expr("param2"));
auto* stmt = create<CallExpression>(func, params); auto* stmt = create<CallExpression>(func, params);
EXPECT_EQ(stmt->func, func); EXPECT_EQ(stmt->target.name, func);
EXPECT_EQ(stmt->target.type, nullptr);
const auto& vec = stmt->args; const auto& vec = stmt->args;
ASSERT_EQ(vec.size(), 2u); ASSERT_EQ(vec.size(), 2u);
@ -36,10 +37,39 @@ TEST_F(CallExpressionTest, Creation) {
EXPECT_EQ(vec[1], params[1]); EXPECT_EQ(vec[1], params[1]);
} }
TEST_F(CallExpressionTest, Creation_WithSource) { TEST_F(CallExpressionTest, CreationIdentifier_WithSource) {
auto* func = Expr("func"); auto* func = Expr("func");
auto* stmt = create<CallExpression>(Source{Source::Location{20, 2}}, func, auto* stmt = create<CallExpression>(Source{{20, 2}}, func, ExpressionList{});
ExpressionList{}); EXPECT_EQ(stmt->target.name, func);
EXPECT_EQ(stmt->target.type, nullptr);
auto src = stmt->source;
EXPECT_EQ(src.range.begin.line, 20u);
EXPECT_EQ(src.range.begin.column, 2u);
}
TEST_F(CallExpressionTest, CreationType) {
auto* type = ty.f32();
ExpressionList params;
params.push_back(Expr("param1"));
params.push_back(Expr("param2"));
auto* stmt = create<CallExpression>(type, params);
EXPECT_EQ(stmt->target.name, nullptr);
EXPECT_EQ(stmt->target.type, type);
const auto& vec = stmt->args;
ASSERT_EQ(vec.size(), 2u);
EXPECT_EQ(vec[0], params[0]);
EXPECT_EQ(vec[1], params[1]);
}
TEST_F(CallExpressionTest, CreationType_WithSource) {
auto* type = ty.f32();
auto* stmt = create<CallExpression>(Source{{20, 2}}, type, ExpressionList{});
EXPECT_EQ(stmt->target.name, nullptr);
EXPECT_EQ(stmt->target.type, type);
auto src = stmt->source; auto src = stmt->source;
EXPECT_EQ(src.range.begin.line, 20u); EXPECT_EQ(src.range.begin.line, 20u);
EXPECT_EQ(src.range.begin.column, 2u); EXPECT_EQ(src.range.begin.column, 2u);
@ -51,11 +81,21 @@ TEST_F(CallExpressionTest, IsCall) {
EXPECT_TRUE(stmt->Is<CallExpression>()); EXPECT_TRUE(stmt->Is<CallExpression>());
} }
TEST_F(CallExpressionTest, Assert_Null_Function) { TEST_F(CallExpressionTest, Assert_Null_Identifier) {
EXPECT_FATAL_FAILURE( EXPECT_FATAL_FAILURE(
{ {
ProgramBuilder b; ProgramBuilder b;
b.create<CallExpression>(nullptr, ExpressionList{}); b.create<CallExpression>(static_cast<IdentifierExpression*>(nullptr),
ExpressionList{});
},
"internal compiler error");
}
TEST_F(CallExpressionTest, Assert_Null_Type) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
b.create<CallExpression>(static_cast<Type*>(nullptr), ExpressionList{});
}, },
"internal compiler error"); "internal compiler error");
} }
@ -73,7 +113,7 @@ TEST_F(CallExpressionTest, Assert_Null_Param) {
"internal compiler error"); "internal compiler error");
} }
TEST_F(CallExpressionTest, Assert_DifferentProgramID_Function) { TEST_F(CallExpressionTest, Assert_DifferentProgramID_Identifier) {
EXPECT_FATAL_FAILURE( EXPECT_FATAL_FAILURE(
{ {
ProgramBuilder b1; ProgramBuilder b1;
@ -83,6 +123,16 @@ TEST_F(CallExpressionTest, Assert_DifferentProgramID_Function) {
"internal compiler error"); "internal compiler error");
} }
TEST_F(CallExpressionTest, Assert_DifferentProgramID_Type) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b1;
ProgramBuilder b2;
b1.create<CallExpression>(b2.ty.f32(), ExpressionList{});
},
"internal compiler error");
}
TEST_F(CallExpressionTest, Assert_DifferentProgramID_Param) { TEST_F(CallExpressionTest, Assert_DifferentProgramID_Param) {
EXPECT_FATAL_FAILURE( EXPECT_FATAL_FAILURE(
{ {

View File

@ -24,7 +24,6 @@
#include "src/ast/literal_expression.h" #include "src/ast/literal_expression.h"
#include "src/ast/member_accessor_expression.h" #include "src/ast/member_accessor_expression.h"
#include "src/ast/phony_expression.h" #include "src/ast/phony_expression.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_op_expression.h" #include "src/ast/unary_op_expression.h"
#include "src/utils/reverse.h" #include "src/utils/reverse.h"
@ -113,8 +112,6 @@ bool TraverseExpressions(const ast::Expression* root,
// function name in the traversal. // function name in the traversal.
// to_visit.push_back(call->func); // to_visit.push_back(call->func);
push_list(call->args); push_list(call->args);
} else if (auto* type_ctor = expr->As<TypeConstructorExpression>()) {
push_list(type_ctor->values);
} else if (auto* member = expr->As<MemberAccessorExpression>()) { } else if (auto* member = expr->As<MemberAccessorExpression>()) {
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
// member name in the traversal. // member name in the traversal.

View File

@ -124,31 +124,6 @@ TEST_F(TraverseExpressionsTest, DescendCallExpression) {
} }
} }
TEST_F(TraverseExpressionsTest, DescendTypeConstructorExpression) {
std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
std::vector<const ast::Expression*> c = {vec2<i32>(e[0], e[1]),
vec2<i32>(e[2], e[3])};
auto* root = vec2<i32>(c[0], c[1]);
{
std::vector<const ast::Expression*> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(
root, Diagnostics(), [&](const ast::Expression* expr) {
l2r.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, c[0], e[0], e[1], c[1], e[2], e[3]));
}
{
std::vector<const ast::Expression*> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(
root, Diagnostics(), [&](const ast::Expression* expr) {
r2l.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, c[1], e[3], e[2], c[0], e[1], e[0]));
}
}
// TODO(crbug.com/tint/1257): Test ignores member accessor 'member' field. // TODO(crbug.com/tint/1257): Test ignores member accessor 'member' field.
// Replace with the test below when fixed. // Replace with the test below when fixed.
TEST_F(TraverseExpressionsTest, DescendMemberIndexExpression) { TEST_F(TraverseExpressionsTest, DescendMemberIndexExpression) {

View File

@ -1,51 +0,0 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/ast/type_constructor_expression.h"
#include "src/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::TypeConstructorExpression);
namespace tint {
namespace ast {
TypeConstructorExpression::TypeConstructorExpression(ProgramID pid,
const Source& src,
const ast::Type* ty,
ExpressionList vals)
: Base(pid, src), type(ty), values(std::move(vals)) {
TINT_ASSERT(AST, type);
for (auto* val : values) {
TINT_ASSERT(AST, val);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, val, program_id);
}
}
TypeConstructorExpression::TypeConstructorExpression(
TypeConstructorExpression&&) = default;
TypeConstructorExpression::~TypeConstructorExpression() = default;
const TypeConstructorExpression* TypeConstructorExpression::Clone(
CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source);
auto* ty = ctx->Clone(type);
auto vals = ctx->Clone(values);
return ctx->dst->create<TypeConstructorExpression>(src, ty, vals);
}
} // namespace ast
} // namespace tint

View File

@ -1,61 +0,0 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_AST_TYPE_CONSTRUCTOR_EXPRESSION_H_
#define SRC_AST_TYPE_CONSTRUCTOR_EXPRESSION_H_
#include <utility>
#include "src/ast/expression.h"
namespace tint {
namespace ast {
// Forward declaration
class Type;
/// A type specific constructor
class TypeConstructorExpression
: public Castable<TypeConstructorExpression, Expression> {
public:
/// Constructor
/// @param pid the identifier of the program that owns this node
/// @param src the source of this node
/// @param type the type
/// @param values the constructor values
TypeConstructorExpression(ProgramID pid,
const Source& src,
const ast::Type* type,
ExpressionList values);
/// Move constructor
TypeConstructorExpression(TypeConstructorExpression&&);
~TypeConstructorExpression() override;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const TypeConstructorExpression* Clone(CloneContext* ctx) const override;
/// The type
const ast::Type* const type;
/// The values
const ExpressionList values;
};
} // namespace ast
} // namespace tint
#endif // SRC_AST_TYPE_CONSTRUCTOR_EXPRESSION_H_

View File

@ -1,85 +0,0 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest-spi.h"
#include "src/ast/test_helper.h"
namespace tint {
namespace ast {
namespace {
using TypeConstructorExpressionTest = TestHelper;
TEST_F(TypeConstructorExpressionTest, Creation) {
ExpressionList expr;
expr.push_back(Expr("expr"));
auto* t = create<TypeConstructorExpression>(ty.f32(), expr);
EXPECT_TRUE(t->type->Is<ast::F32>());
ASSERT_EQ(t->values.size(), 1u);
EXPECT_EQ(t->values[0], expr[0]);
}
TEST_F(TypeConstructorExpressionTest, Creation_WithSource) {
ExpressionList expr;
expr.push_back(Expr("expr"));
auto* t = create<TypeConstructorExpression>(Source{Source::Location{20, 2}},
ty.f32(), expr);
auto src = t->source;
EXPECT_EQ(src.range.begin.line, 20u);
EXPECT_EQ(src.range.begin.column, 2u);
}
TEST_F(TypeConstructorExpressionTest, IsTypeConstructor) {
ExpressionList expr;
expr.push_back(Expr("expr"));
auto* t = create<TypeConstructorExpression>(ty.f32(), expr);
EXPECT_TRUE(t->Is<TypeConstructorExpression>());
}
TEST_F(TypeConstructorExpressionTest, Assert_Null_Type) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
b.create<TypeConstructorExpression>(nullptr, ExpressionList{b.Expr(1)});
},
"internal compiler error");
}
TEST_F(TypeConstructorExpressionTest, Assert_Null_Value) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
b.create<TypeConstructorExpression>(b.ty.i32(),
ExpressionList{nullptr});
},
"internal compiler error");
}
TEST_F(TypeConstructorExpressionTest, Assert_DifferentProgramID_Value) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b1;
ProgramBuilder b2;
b1.create<TypeConstructorExpression>(b1.ty.i32(),
ExpressionList{b2.Expr(1)});
},
"internal compiler error");
}
} // namespace
} // namespace ast
} // namespace tint

View File

@ -64,7 +64,6 @@
#include "src/ast/struct_member_offset_decoration.h" #include "src/ast/struct_member_offset_decoration.h"
#include "src/ast/struct_member_size_decoration.h" #include "src/ast/struct_member_size_decoration.h"
#include "src/ast/switch_statement.h" #include "src/ast/switch_statement.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/type_name.h" #include "src/ast/type_name.h"
#include "src/ast/u32.h" #include "src/ast/u32.h"
#include "src/ast/uint_literal_expression.h" #include "src/ast/uint_literal_expression.h"
@ -1125,35 +1124,33 @@ class ProgramBuilder {
ast::ExpressionList ExprList(ast::ExpressionList list) { return list; } ast::ExpressionList ExprList(ast::ExpressionList list) { return list; }
/// @param args the arguments for the type constructor /// @param args the arguments for the type constructor
/// @return an `ast::TypeConstructorExpression` of type `ty`, with the values /// @return an `ast::CallExpression` of type `ty`, with the values
/// of `args` converted to `ast::Expression`s using `Expr()` /// of `args` converted to `ast::Expression`s using `Expr()`
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* Construct(ARGS&&... args) { const ast::CallExpression* Construct(ARGS&&... args) {
return Construct(ty.Of<T>(), std::forward<ARGS>(args)...); return Construct(ty.Of<T>(), std::forward<ARGS>(args)...);
} }
/// @param type the type to construct /// @param type the type to construct
/// @param args the arguments for the constructor /// @param args the arguments for the constructor
/// @return an `ast::TypeConstructorExpression` of `type` constructed with the /// @return an `ast::CallExpression` of `type` constructed with the
/// values `args`. /// values `args`.
template <typename... ARGS> template <typename... ARGS>
const ast::TypeConstructorExpression* Construct(const ast::Type* type, const ast::CallExpression* Construct(const ast::Type* type, ARGS&&... args) {
ARGS&&... args) { return Construct(source_, type, std::forward<ARGS>(args)...);
return create<ast::TypeConstructorExpression>(
type, ExprList(std::forward<ARGS>(args)...));
} }
/// @param source the source information /// @param source the source information
/// @param type the type to construct /// @param type the type to construct
/// @param args the arguments for the constructor /// @param args the arguments for the constructor
/// @return an `ast::TypeConstructorExpression` of `type` constructed with the /// @return an `ast::CallExpression` of `type` constructed with the
/// values `args`. /// values `args`.
template <typename... ARGS> template <typename... ARGS>
const ast::TypeConstructorExpression* Construct(const Source& source, const ast::CallExpression* Construct(const Source& source,
const ast::Type* type, const ast::Type* type,
ARGS&&... args) { ARGS&&... args) {
return create<ast::TypeConstructorExpression>( return create<ast::CallExpression>(source, type,
source, type, ExprList(std::forward<ARGS>(args)...)); ExprList(std::forward<ARGS>(args)...));
} }
/// @param expr the expression for the bitcast /// @param expr the expression for the bitcast
@ -1189,128 +1186,128 @@ class ProgramBuilder {
/// @param args the arguments for the vector constructor /// @param args the arguments for the vector constructor
/// @param type the vector type /// @param type the vector type
/// @param size the vector size /// @param size the vector size
/// @return an `ast::TypeConstructorExpression` of a `size`-element vector of /// @return an `ast::CallExpression` of a `size`-element vector of
/// type `type`, constructed with the values `args`. /// type `type`, constructed with the values `args`.
template <typename... ARGS> template <typename... ARGS>
const ast::TypeConstructorExpression* vec(const ast::Type* type, const ast::CallExpression* vec(const ast::Type* type,
uint32_t size, uint32_t size,
ARGS&&... args) { ARGS&&... args) {
return Construct(ty.vec(type, size), std::forward<ARGS>(args)...); return Construct(ty.vec(type, size), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the vector constructor /// @param args the arguments for the vector constructor
/// @return an `ast::TypeConstructorExpression` of a 2-element vector of type /// @return an `ast::CallExpression` of a 2-element vector of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* vec2(ARGS&&... args) { const ast::CallExpression* vec2(ARGS&&... args) {
return Construct(ty.vec2<T>(), std::forward<ARGS>(args)...); return Construct(ty.vec2<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the vector constructor /// @param args the arguments for the vector constructor
/// @return an `ast::TypeConstructorExpression` of a 3-element vector of type /// @return an `ast::CallExpression` of a 3-element vector of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* vec3(ARGS&&... args) { const ast::CallExpression* vec3(ARGS&&... args) {
return Construct(ty.vec3<T>(), std::forward<ARGS>(args)...); return Construct(ty.vec3<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the vector constructor /// @param args the arguments for the vector constructor
/// @return an `ast::TypeConstructorExpression` of a 4-element vector of type /// @return an `ast::CallExpression` of a 4-element vector of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* vec4(ARGS&&... args) { const ast::CallExpression* vec4(ARGS&&... args) {
return Construct(ty.vec4<T>(), std::forward<ARGS>(args)...); return Construct(ty.vec4<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the matrix constructor /// @param args the arguments for the matrix constructor
/// @return an `ast::TypeConstructorExpression` of a 2x2 matrix of type /// @return an `ast::CallExpression` of a 2x2 matrix of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* mat2x2(ARGS&&... args) { const ast::CallExpression* mat2x2(ARGS&&... args) {
return Construct(ty.mat2x2<T>(), std::forward<ARGS>(args)...); return Construct(ty.mat2x2<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the matrix constructor /// @param args the arguments for the matrix constructor
/// @return an `ast::TypeConstructorExpression` of a 2x3 matrix of type /// @return an `ast::CallExpression` of a 2x3 matrix of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* mat2x3(ARGS&&... args) { const ast::CallExpression* mat2x3(ARGS&&... args) {
return Construct(ty.mat2x3<T>(), std::forward<ARGS>(args)...); return Construct(ty.mat2x3<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the matrix constructor /// @param args the arguments for the matrix constructor
/// @return an `ast::TypeConstructorExpression` of a 2x4 matrix of type /// @return an `ast::CallExpression` of a 2x4 matrix of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* mat2x4(ARGS&&... args) { const ast::CallExpression* mat2x4(ARGS&&... args) {
return Construct(ty.mat2x4<T>(), std::forward<ARGS>(args)...); return Construct(ty.mat2x4<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the matrix constructor /// @param args the arguments for the matrix constructor
/// @return an `ast::TypeConstructorExpression` of a 3x2 matrix of type /// @return an `ast::CallExpression` of a 3x2 matrix of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* mat3x2(ARGS&&... args) { const ast::CallExpression* mat3x2(ARGS&&... args) {
return Construct(ty.mat3x2<T>(), std::forward<ARGS>(args)...); return Construct(ty.mat3x2<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the matrix constructor /// @param args the arguments for the matrix constructor
/// @return an `ast::TypeConstructorExpression` of a 3x3 matrix of type /// @return an `ast::CallExpression` of a 3x3 matrix of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* mat3x3(ARGS&&... args) { const ast::CallExpression* mat3x3(ARGS&&... args) {
return Construct(ty.mat3x3<T>(), std::forward<ARGS>(args)...); return Construct(ty.mat3x3<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the matrix constructor /// @param args the arguments for the matrix constructor
/// @return an `ast::TypeConstructorExpression` of a 3x4 matrix of type /// @return an `ast::CallExpression` of a 3x4 matrix of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* mat3x4(ARGS&&... args) { const ast::CallExpression* mat3x4(ARGS&&... args) {
return Construct(ty.mat3x4<T>(), std::forward<ARGS>(args)...); return Construct(ty.mat3x4<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the matrix constructor /// @param args the arguments for the matrix constructor
/// @return an `ast::TypeConstructorExpression` of a 4x2 matrix of type /// @return an `ast::CallExpression` of a 4x2 matrix of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* mat4x2(ARGS&&... args) { const ast::CallExpression* mat4x2(ARGS&&... args) {
return Construct(ty.mat4x2<T>(), std::forward<ARGS>(args)...); return Construct(ty.mat4x2<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the matrix constructor /// @param args the arguments for the matrix constructor
/// @return an `ast::TypeConstructorExpression` of a 4x3 matrix of type /// @return an `ast::CallExpression` of a 4x3 matrix of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* mat4x3(ARGS&&... args) { const ast::CallExpression* mat4x3(ARGS&&... args) {
return Construct(ty.mat4x3<T>(), std::forward<ARGS>(args)...); return Construct(ty.mat4x3<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the matrix constructor /// @param args the arguments for the matrix constructor
/// @return an `ast::TypeConstructorExpression` of a 4x4 matrix of type /// @return an `ast::CallExpression` of a 4x4 matrix of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
const ast::TypeConstructorExpression* mat4x4(ARGS&&... args) { const ast::CallExpression* mat4x4(ARGS&&... args) {
return Construct(ty.mat4x4<T>(), std::forward<ARGS>(args)...); return Construct(ty.mat4x4<T>(), std::forward<ARGS>(args)...);
} }
/// @param args the arguments for the array constructor /// @param args the arguments for the array constructor
/// @return an `ast::TypeConstructorExpression` of an array with element type /// @return an `ast::CallExpression` of an array with element type
/// `T` and size `N`, constructed with the values `args`. /// `T` and size `N`, constructed with the values `args`.
template <typename T, int N, typename... ARGS> template <typename T, int N, typename... ARGS>
const ast::TypeConstructorExpression* array(ARGS&&... args) { const ast::CallExpression* array(ARGS&&... args) {
return Construct(ty.array<T, N>(), std::forward<ARGS>(args)...); return Construct(ty.array<T, N>(), std::forward<ARGS>(args)...);
} }
/// @param subtype the array element type /// @param subtype the array element type
/// @param n the array size. nullptr represents a runtime-array. /// @param n the array size. nullptr represents a runtime-array.
/// @param args the arguments for the array constructor /// @param args the arguments for the array constructor
/// @return an `ast::TypeConstructorExpression` of an array with element type /// @return an `ast::CallExpression` of an array with element type
/// `subtype`, constructed with the values `args`. /// `subtype`, constructed with the values `args`.
template <typename EXPR, typename... ARGS> template <typename EXPR, typename... ARGS>
const ast::TypeConstructorExpression* array(const ast::Type* subtype, const ast::CallExpression* array(const ast::Type* subtype,
EXPR&& n, EXPR&& n,
ARGS&&... args) { ARGS&&... args) {
return Construct(ty.array(subtype, std::forward<EXPR>(n)), return Construct(ty.array(subtype, std::forward<EXPR>(n)),
std::forward<ARGS>(args)...); std::forward<ARGS>(args)...);
} }

View File

@ -36,7 +36,7 @@ TEST_F(ParserImplTest, Statement_Call) {
ASSERT_TRUE(e->Is<ast::CallStatement>()); ASSERT_TRUE(e->Is<ast::CallStatement>());
auto* c = e->As<ast::CallStatement>()->expr; auto* c = e->As<ast::CallStatement>()->expr;
EXPECT_EQ(c->func->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(c->target.name->symbol, p->builder().Symbols().Get("a"));
EXPECT_EQ(c->args.size(), 0u); EXPECT_EQ(c->args.size(), 0u);
} }
@ -52,7 +52,7 @@ TEST_F(ParserImplTest, Statement_Call_WithParams) {
ASSERT_TRUE(e->Is<ast::CallStatement>()); ASSERT_TRUE(e->Is<ast::CallStatement>());
auto* c = e->As<ast::CallStatement>()->expr; auto* c = e->As<ast::CallStatement>()->expr;
EXPECT_EQ(c->func->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(c->target.name->symbol, p->builder().Symbols().Get("a"));
EXPECT_EQ(c->args.size(), 3u); EXPECT_EQ(c->args.size(), 3u);
EXPECT_TRUE(c->args[0]->Is<ast::IntLiteralExpression>()); EXPECT_TRUE(c->args[0]->Is<ast::IntLiteralExpression>());
@ -71,7 +71,7 @@ TEST_F(ParserImplTest, Statement_Call_WithParams_TrailingComma) {
ASSERT_TRUE(e->Is<ast::CallStatement>()); ASSERT_TRUE(e->Is<ast::CallStatement>());
auto* c = e->As<ast::CallStatement>()->expr; auto* c = e->As<ast::CallStatement>()->expr;
EXPECT_EQ(c->func->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(c->target.name->symbol, p->builder().Symbols().Get("a"));
EXPECT_EQ(c->args.size(), 2u); EXPECT_EQ(c->args.size(), 2u);
EXPECT_TRUE(c->args[0]->Is<ast::IntLiteralExpression>()); EXPECT_TRUE(c->args[0]->Is<ast::IntLiteralExpression>());

View File

@ -24,20 +24,19 @@ TEST_F(ParserImplTest, ConstExpr_TypeDecl) {
auto e = p->expect_const_expr(); auto e = p->expect_const_expr();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored); ASSERT_FALSE(e.errored);
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>()); ASSERT_TRUE(e->Is<ast::CallExpression>());
auto* t = e->As<ast::TypeConstructorExpression>(); auto* t = e->As<ast::CallExpression>();
ASSERT_TRUE(t->type->Is<ast::Vector>()); ASSERT_TRUE(t->target.type->Is<ast::Vector>());
EXPECT_EQ(t->type->As<ast::Vector>()->width, 2u); EXPECT_EQ(t->target.type->As<ast::Vector>()->width, 2u);
ASSERT_EQ(t->values.size(), 2u); ASSERT_EQ(t->args.size(), 2u);
auto& v = t->values;
ASSERT_TRUE(v[0]->Is<ast::FloatLiteralExpression>()); ASSERT_TRUE(t->args[0]->Is<ast::FloatLiteralExpression>());
EXPECT_FLOAT_EQ(v[0]->As<ast::FloatLiteralExpression>()->value, 1.); EXPECT_FLOAT_EQ(t->args[0]->As<ast::FloatLiteralExpression>()->value, 1.);
ASSERT_TRUE(v[1]->Is<ast::FloatLiteralExpression>()); ASSERT_TRUE(t->args[1]->Is<ast::FloatLiteralExpression>());
EXPECT_FLOAT_EQ(v[1]->As<ast::FloatLiteralExpression>()->value, 2.); EXPECT_FLOAT_EQ(t->args[1]->As<ast::FloatLiteralExpression>()->value, 2.);
} }
TEST_F(ParserImplTest, ConstExpr_TypeDecl_Empty) { TEST_F(ParserImplTest, ConstExpr_TypeDecl_Empty) {
@ -45,13 +44,13 @@ TEST_F(ParserImplTest, ConstExpr_TypeDecl_Empty) {
auto e = p->expect_const_expr(); auto e = p->expect_const_expr();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored); ASSERT_FALSE(e.errored);
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>()); ASSERT_TRUE(e->Is<ast::CallExpression>());
auto* t = e->As<ast::TypeConstructorExpression>(); auto* t = e->As<ast::CallExpression>();
ASSERT_TRUE(t->type->Is<ast::Vector>()); ASSERT_TRUE(t->target.type->Is<ast::Vector>());
EXPECT_EQ(t->type->As<ast::Vector>()->width, 2u); EXPECT_EQ(t->target.type->As<ast::Vector>()->width, 2u);
ASSERT_EQ(t->values.size(), 0u); ASSERT_EQ(t->args.size(), 0u);
} }
TEST_F(ParserImplTest, ConstExpr_TypeDecl_TrailingComma) { TEST_F(ParserImplTest, ConstExpr_TypeDecl_TrailingComma) {
@ -59,15 +58,15 @@ TEST_F(ParserImplTest, ConstExpr_TypeDecl_TrailingComma) {
auto e = p->expect_const_expr(); auto e = p->expect_const_expr();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored); ASSERT_FALSE(e.errored);
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>()); ASSERT_TRUE(e->Is<ast::CallExpression>());
auto* t = e->As<ast::TypeConstructorExpression>(); auto* t = e->As<ast::CallExpression>();
ASSERT_TRUE(t->type->Is<ast::Vector>()); ASSERT_TRUE(t->target.type->Is<ast::Vector>());
EXPECT_EQ(t->type->As<ast::Vector>()->width, 2u); EXPECT_EQ(t->target.type->As<ast::Vector>()->width, 2u);
ASSERT_EQ(t->values.size(), 2u); ASSERT_EQ(t->args.size(), 2u);
ASSERT_TRUE(t->values[0]->Is<ast::LiteralExpression>()); ASSERT_TRUE(t->args[0]->Is<ast::LiteralExpression>());
ASSERT_TRUE(t->values[1]->Is<ast::LiteralExpression>()); ASSERT_TRUE(t->args[1]->Is<ast::LiteralExpression>());
} }
TEST_F(ParserImplTest, ConstExpr_TypeDecl_MissingRightParen) { TEST_F(ParserImplTest, ConstExpr_TypeDecl_MissingRightParen) {
@ -134,7 +133,7 @@ TEST_F(ParserImplTest, ConstExpr_RegisteredType) {
auto e = p->expect_const_expr(); auto e = p->expect_const_expr();
ASSERT_FALSE(e.errored); ASSERT_FALSE(e.errored);
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>()); ASSERT_TRUE(e->Is<ast::CallExpression>());
} }
TEST_F(ParserImplTest, ConstExpr_NotRegisteredType) { TEST_F(ParserImplTest, ConstExpr_NotRegisteredType) {

View File

@ -39,11 +39,13 @@ TEST_F(ParserImplTest, PrimaryExpression_TypeDecl) {
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr); ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>()); ASSERT_TRUE(e->Is<ast::CallExpression>());
auto* ty = e->As<ast::TypeConstructorExpression>(); auto* call = e->As<ast::CallExpression>();
ASSERT_EQ(ty->values.size(), 4u); EXPECT_NE(call->target.type, nullptr);
const auto& val = ty->values;
ASSERT_EQ(call->args.size(), 4u);
const auto& val = call->args;
ASSERT_TRUE(val[0]->Is<ast::SintLiteralExpression>()); ASSERT_TRUE(val[0]->Is<ast::SintLiteralExpression>());
EXPECT_EQ(val[0]->As<ast::SintLiteralExpression>()->value, 1); EXPECT_EQ(val[0]->As<ast::SintLiteralExpression>()->value, 1);
@ -64,10 +66,11 @@ TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_ZeroConstructor) {
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr); ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
auto* ty = e->As<ast::TypeConstructorExpression>();
ASSERT_EQ(ty->values.size(), 0u); ASSERT_TRUE(e->Is<ast::CallExpression>());
auto* call = e->As<ast::CallExpression>();
ASSERT_EQ(call->args.size(), 0u);
} }
TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_InvalidTypeDecl) { TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_InvalidTypeDecl) {
@ -124,15 +127,15 @@ TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_StructConstructor_Empty) {
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr); ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
auto* constructor = e->As<ast::TypeConstructorExpression>(); ASSERT_TRUE(e->Is<ast::CallExpression>());
ASSERT_TRUE(constructor->type->Is<ast::TypeName>()); auto* call = e->As<ast::CallExpression>();
EXPECT_EQ(constructor->type->As<ast::TypeName>()->name,
ASSERT_TRUE(call->target.type->Is<ast::TypeName>());
EXPECT_EQ(call->target.type->As<ast::TypeName>()->name,
p->builder().Symbols().Get("S")); p->builder().Symbols().Get("S"));
auto values = constructor->values; ASSERT_EQ(call->args.size(), 0u);
ASSERT_EQ(values.size(), 0u);
} }
TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_StructConstructor_NotEmpty) { TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_StructConstructor_NotEmpty) {
@ -149,21 +152,21 @@ TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_StructConstructor_NotEmpty) {
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr); ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
auto* constructor = e->As<ast::TypeConstructorExpression>(); ASSERT_TRUE(e->Is<ast::CallExpression>());
ASSERT_TRUE(constructor->type->Is<ast::TypeName>()); auto* call = e->As<ast::CallExpression>();
EXPECT_EQ(constructor->type->As<ast::TypeName>()->name,
ASSERT_TRUE(call->target.type->Is<ast::TypeName>());
EXPECT_EQ(call->target.type->As<ast::TypeName>()->name,
p->builder().Symbols().Get("S")); p->builder().Symbols().Get("S"));
auto values = constructor->values; ASSERT_EQ(call->args.size(), 2u);
ASSERT_EQ(values.size(), 2u);
ASSERT_TRUE(values[0]->Is<ast::UintLiteralExpression>()); ASSERT_TRUE(call->args[0]->Is<ast::UintLiteralExpression>());
EXPECT_EQ(values[0]->As<ast::UintLiteralExpression>()->value, 1u); EXPECT_EQ(call->args[0]->As<ast::UintLiteralExpression>()->value, 1u);
ASSERT_TRUE(values[1]->Is<ast::FloatLiteralExpression>()); ASSERT_TRUE(call->args[1]->Is<ast::FloatLiteralExpression>());
EXPECT_EQ(values[1]->As<ast::FloatLiteralExpression>()->value, 2.f); EXPECT_EQ(call->args[1]->As<ast::FloatLiteralExpression>()->value, 2.f);
} }
TEST_F(ParserImplTest, PrimaryExpression_ConstLiteral_True) { TEST_F(ParserImplTest, PrimaryExpression_ConstLiteral_True) {
@ -225,13 +228,14 @@ TEST_F(ParserImplTest, PrimaryExpression_Cast) {
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr); ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
auto* c = e->As<ast::TypeConstructorExpression>(); ASSERT_TRUE(e->Is<ast::CallExpression>());
ASSERT_TRUE(c->type->Is<ast::F32>()); auto* call = e->As<ast::CallExpression>();
ASSERT_EQ(c->values.size(), 1u);
ASSERT_TRUE(c->values[0]->Is<ast::IntLiteralExpression>()); ASSERT_TRUE(call->target.type->Is<ast::F32>());
ASSERT_EQ(call->args.size(), 1u);
ASSERT_TRUE(call->args[0]->Is<ast::IntLiteralExpression>());
} }
TEST_F(ParserImplTest, PrimaryExpression_Bitcast) { TEST_F(ParserImplTest, PrimaryExpression_Bitcast) {

View File

@ -97,7 +97,7 @@ TEST_F(ParserImplTest, SingularExpression_Call_Empty) {
ASSERT_TRUE(e->Is<ast::CallExpression>()); ASSERT_TRUE(e->Is<ast::CallExpression>());
auto* c = e->As<ast::CallExpression>(); auto* c = e->As<ast::CallExpression>();
EXPECT_EQ(c->func->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(c->target.name->symbol, p->builder().Symbols().Get("a"));
EXPECT_EQ(c->args.size(), 0u); EXPECT_EQ(c->args.size(), 0u);
} }
@ -113,7 +113,7 @@ TEST_F(ParserImplTest, SingularExpression_Call_WithArgs) {
ASSERT_TRUE(e->Is<ast::CallExpression>()); ASSERT_TRUE(e->Is<ast::CallExpression>());
auto* c = e->As<ast::CallExpression>(); auto* c = e->As<ast::CallExpression>();
EXPECT_EQ(c->func->symbol, p->builder().Symbols().Get("test")); EXPECT_EQ(c->target.name->symbol, p->builder().Symbols().Get("test"));
EXPECT_EQ(c->args.size(), 3u); EXPECT_EQ(c->args.size(), 3u);
EXPECT_TRUE(c->args[0]->Is<ast::IntLiteralExpression>()); EXPECT_TRUE(c->args[0]->Is<ast::IntLiteralExpression>());

View File

@ -90,7 +90,10 @@ TEST_F(ParserImplTest, VariableStmt_VariableDecl_ArrayInit) {
EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->variable->constructor, nullptr); ASSERT_NE(e->variable->constructor, nullptr);
EXPECT_TRUE(e->variable->constructor->Is<ast::TypeConstructorExpression>()); auto* call = e->variable->constructor->As<ast::CallExpression>();
ASSERT_NE(call, nullptr);
EXPECT_EQ(call->target.name, nullptr);
EXPECT_NE(call->target.type, nullptr);
} }
TEST_F(ParserImplTest, VariableStmt_VariableDecl_ArrayInit_NoSpace) { TEST_F(ParserImplTest, VariableStmt_VariableDecl_ArrayInit_NoSpace) {
@ -105,7 +108,10 @@ TEST_F(ParserImplTest, VariableStmt_VariableDecl_ArrayInit_NoSpace) {
EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->variable->constructor, nullptr); ASSERT_NE(e->variable->constructor, nullptr);
EXPECT_TRUE(e->variable->constructor->Is<ast::TypeConstructorExpression>()); auto* call = e->variable->constructor->As<ast::CallExpression>();
ASSERT_NE(call, nullptr);
EXPECT_EQ(call->target.name, nullptr);
EXPECT_NE(call->target.type, nullptr);
} }
TEST_F(ParserImplTest, VariableStmt_VariableDecl_VecInit) { TEST_F(ParserImplTest, VariableStmt_VariableDecl_VecInit) {
@ -120,7 +126,10 @@ TEST_F(ParserImplTest, VariableStmt_VariableDecl_VecInit) {
EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->variable->constructor, nullptr); ASSERT_NE(e->variable->constructor, nullptr);
EXPECT_TRUE(e->variable->constructor->Is<ast::TypeConstructorExpression>()); auto* call = e->variable->constructor->As<ast::CallExpression>();
ASSERT_NE(call, nullptr);
EXPECT_EQ(call->target.name, nullptr);
EXPECT_NE(call->target.type, nullptr);
} }
TEST_F(ParserImplTest, VariableStmt_VariableDecl_VecInit_NoSpace) { TEST_F(ParserImplTest, VariableStmt_VariableDecl_VecInit_NoSpace) {
@ -135,7 +144,10 @@ TEST_F(ParserImplTest, VariableStmt_VariableDecl_VecInit_NoSpace) {
EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->variable->constructor, nullptr); ASSERT_NE(e->variable->constructor, nullptr);
EXPECT_TRUE(e->variable->constructor->Is<ast::TypeConstructorExpression>()); auto* call = e->variable->constructor->As<ast::CallExpression>();
ASSERT_NE(call, nullptr);
EXPECT_EQ(call->target.name, nullptr);
EXPECT_NE(call->target.type, nullptr);
} }
TEST_F(ParserImplTest, VariableStmt_Let) { TEST_F(ParserImplTest, VariableStmt_Let) {

View File

@ -90,11 +90,15 @@ TEST_F(ResolverCallTest, Valid) {
args.push_back(p.create_value(*this, 0)); args.push_back(p.create_value(*this, 0));
} }
Func("foo", std::move(params), ty.f32(), {Return(1.23f)}); auto* func = Func("foo", std::move(params), ty.f32(), {Return(1.23f)});
auto* call = Call("foo", std::move(args)); auto* call_expr = Call("foo", std::move(args));
WrapInFunction(call); WrapInFunction(call_expr);
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* call = Sem().Get(call_expr);
EXPECT_NE(call, nullptr);
EXPECT_EQ(call->Target(), Sem().Get(func));
} }
} // namespace } // namespace

View File

@ -70,12 +70,15 @@
#include "src/sem/storage_texture_type.h" #include "src/sem/storage_texture_type.h"
#include "src/sem/struct.h" #include "src/sem/struct.h"
#include "src/sem/switch_statement.h" #include "src/sem/switch_statement.h"
#include "src/sem/type_constructor.h"
#include "src/sem/type_conversion.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/utils/defer.h" #include "src/utils/defer.h"
#include "src/utils/get_or_create.h" #include "src/utils/get_or_create.h"
#include "src/utils/math.h" #include "src/utils/math.h"
#include "src/utils/reverse.h" #include "src/utils/reverse.h"
#include "src/utils/scoped_assignment.h" #include "src/utils/scoped_assignment.h"
#include "src/utils/transform.h"
namespace tint { namespace tint {
namespace resolver { namespace resolver {
@ -510,8 +513,8 @@ sem::Variable* Resolver::Variable(const ast::Variable* var,
builder_->create<sem::Reference>(storage_ty, storage_class, access); builder_->create<sem::Reference>(storage_ty, storage_class, access);
} }
if (rhs && !ValidateVariableConstructor(var, storage_class, storage_ty, if (rhs && !ValidateVariableConstructorOrCast(var, storage_class, storage_ty,
rhs->Type())) { rhs->Type())) {
return nullptr; return nullptr;
} }
@ -641,10 +644,11 @@ void Resolver::AllocateOverridableConstantIds() {
} }
} }
bool Resolver::ValidateVariableConstructor(const ast::Variable* var, bool Resolver::ValidateVariableConstructorOrCast(
ast::StorageClass storage_class, const ast::Variable* var,
const sem::Type* storage_ty, ast::StorageClass storage_class,
const sem::Type* rhs_ty) { const sem::Type* storage_ty,
const sem::Type* rhs_ty) {
auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS
// Value type has to match storage type // Value type has to match storage type
@ -2369,8 +2373,6 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) {
sem_expr = Bitcast(bitcast); sem_expr = Bitcast(bitcast);
} else if (auto* call = expr->As<ast::CallExpression>()) { } else if (auto* call = expr->As<ast::CallExpression>()) {
sem_expr = Call(call); sem_expr = Call(call);
} else if (auto* ctor = expr->As<ast::TypeConstructorExpression>()) {
sem_expr = TypeConstructor(ctor);
} else if (auto* ident = expr->As<ast::IdentifierExpression>()) { } else if (auto* ident = expr->As<ast::IdentifierExpression>()) {
sem_expr = Identifier(ident); sem_expr = Identifier(ident);
} else if (auto* literal = expr->As<ast::LiteralExpression>()) { } else if (auto* literal = expr->As<ast::LiteralExpression>()) {
@ -2462,33 +2464,72 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
return builder_->create<sem::Expression>(expr, ty, current_statement_, val); return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
} }
sem::Expression* Resolver::Call(const ast::CallExpression* expr) { sem::Call* Resolver::Call(const ast::CallExpression* expr) {
auto* ident = expr->func;
Mark(ident);
auto name = builder_->Symbols().NameFor(ident->symbol);
auto intrinsic_type = sem::ParseIntrinsicType(name);
auto* call = (intrinsic_type != IntrinsicType::kNone)
? IntrinsicCall(expr, intrinsic_type)
: FunctionCall(expr);
current_function_->AddDirectCall(call);
return call;
}
sem::Call* Resolver::IntrinsicCall(const ast::CallExpression* expr,
sem::IntrinsicType intrinsic_type) {
std::vector<const sem::Expression*> args(expr->args.size()); std::vector<const sem::Expression*> args(expr->args.size());
std::vector<const sem::Type*> arg_tys(expr->args.size()); std::vector<const sem::Type*> arg_tys(args.size());
for (size_t i = 0; i < expr->args.size(); i++) { for (size_t i = 0; i < expr->args.size(); i++) {
auto* arg = Sem(expr->args[i]); auto* arg = Sem(expr->args[i]);
if (!arg) { if (!arg) {
return nullptr; return nullptr;
} }
args[i] = arg; args[i] = arg;
arg_tys[i] = arg->Type(); arg_tys[i] = args[i]->Type();
} }
auto type_ctor_or_conv = [&](const sem::Type* ty) -> sem::Call* {
// The call has resolved to a type constructor or cast.
if (args.size() == 1) {
auto* target = ty;
auto* source = args[0]->Type()->UnwrapRef();
if ((source != target) && //
((source->is_scalar() && target->is_scalar()) ||
(source->Is<sem::Vector>() && target->Is<sem::Vector>()) ||
(source->Is<sem::Matrix>() && target->Is<sem::Matrix>()))) {
// Note: Matrix types currently cannot be converted (the element type
// must only be f32). We implement this for the day we support other
// matrix element types.
return TypeConversion(expr, ty, args[0], arg_tys[0]);
}
}
return TypeConstructor(expr, ty, std::move(args), std::move(arg_tys));
};
// Resolve the target of the CallExpression to determine whether this is a
// function call, cast or type constructor expression.
if (expr->target.type) {
auto* ty = Type(expr->target.type);
if (!ty) {
return nullptr;
}
return type_ctor_or_conv(ty);
}
auto* ident = expr->target.name;
Mark(ident);
auto it = named_type_info_.find(ident->symbol);
if (it != named_type_info_.end()) {
// We have a type.
return type_ctor_or_conv(it->second.sem);
}
// Not a type, treat as a intrinsic / function call.
auto name = builder_->Symbols().NameFor(ident->symbol);
auto intrinsic_type = sem::ParseIntrinsicType(name);
auto* call = (intrinsic_type != IntrinsicType::kNone)
? IntrinsicCall(expr, intrinsic_type, std::move(args),
std::move(arg_tys))
: FunctionCall(expr, std::move(args));
current_function_->AddDirectCall(call);
return call;
}
sem::Call* Resolver::IntrinsicCall(
const ast::CallExpression* expr,
sem::IntrinsicType intrinsic_type,
const std::vector<const sem::Expression*> args,
const std::vector<const sem::Type*> arg_tys) {
auto* intrinsic = intrinsic_table_->Lookup(intrinsic_type, std::move(arg_tys), auto* intrinsic = intrinsic_table_->Lookup(intrinsic_type, std::move(arg_tys),
expr->source); expr->source);
if (!intrinsic) { if (!intrinsic) {
@ -2509,21 +2550,45 @@ sem::Call* Resolver::IntrinsicCall(const ast::CallExpression* expr,
return nullptr; return nullptr;
} }
if (!ValidateCall(call)) { if (!ValidateIntrinsicCall(call)) {
return nullptr; return nullptr;
} }
return call; return call;
} }
sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr) { bool Resolver::ValidateIntrinsicCall(const sem::Call* call) {
auto* ident = expr->func; if (call->Type()->Is<sem::Void>()) {
auto name = builder_->Symbols().NameFor(ident->symbol); bool is_call_statement = false;
if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
if (call_stmt->expr == call->Declaration()) {
is_call_statement = true;
}
}
if (!is_call_statement) {
// https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
// If the called function does not return a value, a function call
// statement should be used instead.
auto* ident = call->Declaration()->target.name;
auto name = builder_->Symbols().NameFor(ident->symbol);
AddError("intrinsic '" + name + "' does not return a value",
call->Declaration()->source);
return false;
}
}
auto target_it = symbol_to_function_.find(ident->symbol); return true;
}
sem::Call* Resolver::FunctionCall(
const ast::CallExpression* expr,
const std::vector<const sem::Expression*> args) {
auto sym = expr->target.name->symbol;
auto name = builder_->Symbols().NameFor(sym);
auto target_it = symbol_to_function_.find(sym);
if (target_it == symbol_to_function_.end()) { if (target_it == symbol_to_function_.end()) {
if (current_function_ && if (current_function_ && current_function_->Declaration()->symbol == sym) {
current_function_->Declaration()->symbol == ident->symbol) {
AddError("recursion is not permitted. '" + name + AddError("recursion is not permitted. '" + name +
"' attempted to call itself.", "' attempted to call itself.",
expr->source); expr->source);
@ -2533,16 +2598,6 @@ sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr) {
return nullptr; return nullptr;
} }
auto* target = target_it->second; auto* target = target_it->second;
std::vector<const sem::Expression*> args(expr->args.size());
for (size_t i = 0; i < expr->args.size(); i++) {
auto* arg = Sem(expr->args[i]);
if (!arg) {
return nullptr;
}
args[i] = arg;
}
auto* call = builder_->create<sem::Call>(expr, target, std::move(args), auto* call = builder_->create<sem::Call>(expr, target, std::move(args),
current_statement_, sem::Constant{}); current_statement_, sem::Constant{});
@ -2567,38 +2622,9 @@ sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr) {
return nullptr; return nullptr;
} }
if (!ValidateCall(call)) {
return nullptr;
}
return call; return call;
} }
bool Resolver::ValidateCall(const sem::Call* call) {
if (call->Type()->Is<sem::Void>()) {
bool is_call_statement = false;
if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
if (call_stmt->expr == call->Declaration()) {
is_call_statement = true;
}
}
if (!is_call_statement) {
// https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
// If the called function does not return a value, a function call
// statement should be used instead.
auto* ident = call->Declaration()->func;
auto name = builder_->Symbols().NameFor(ident->symbol);
bool is_function = call->Target()->Is<sem::Function>();
AddError((is_function ? "function" : "intrinsic") + std::string(" '") +
name + "' does not return a value",
call->Declaration()->source);
return false;
}
}
return true;
}
bool Resolver::ValidateTextureIntrinsicFunction(const sem::Call* call) { bool Resolver::ValidateTextureIntrinsicFunction(const sem::Call* call) {
auto* intrinsic = call->Target()->As<sem::Intrinsic>(); auto* intrinsic = call->Target()->As<sem::Intrinsic>();
if (!intrinsic) { if (!intrinsic) {
@ -2623,8 +2649,7 @@ bool Resolver::ValidateTextureIntrinsicFunction(const sem::Call* call) {
bool is_const_expr = true; bool is_const_expr = true;
ast::TraverseExpressions( ast::TraverseExpressions(
arg->Declaration(), diagnostics_, [&](const ast::Expression* e) { arg->Declaration(), diagnostics_, [&](const ast::Expression* e) {
if (e->IsAnyOf<ast::LiteralExpression, if (e->IsAnyOf<ast::LiteralExpression, ast::CallExpression>()) {
ast::TypeConstructorExpression>()) {
return ast::TraverseAction::Descend; return ast::TraverseAction::Descend;
} }
is_const_expr = false; is_const_expr = false;
@ -2654,9 +2679,9 @@ bool Resolver::ValidateTextureIntrinsicFunction(const sem::Call* call) {
bool Resolver::ValidateFunctionCall(const sem::Call* call) { bool Resolver::ValidateFunctionCall(const sem::Call* call) {
auto* decl = call->Declaration(); auto* decl = call->Declaration();
auto* ident = decl->func;
auto* target = call->Target()->As<sem::Function>(); auto* target = call->Target()->As<sem::Function>();
auto name = builder_->Symbols().NameFor(ident->symbol); auto sym = decl->target.name->symbol;
auto name = builder_->Symbols().NameFor(sym);
if (target->Declaration()->IsEntryPoint()) { if (target->Declaration()->IsEntryPoint()) {
// https://www.w3.org/TR/WGSL/#function-restriction // https://www.w3.org/TR/WGSL/#function-restriction
@ -2735,40 +2760,150 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) {
} }
} }
} }
if (call->Type()->Is<sem::Void>()) {
bool is_call_statement = false;
if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
if (call_stmt->expr == call->Declaration()) {
is_call_statement = true;
}
}
if (!is_call_statement) {
// https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
// If the called function does not return a value, a function call
// statement should be used instead.
AddError("function '" + name + "' does not return a value", decl->source);
return false;
}
}
return true; return true;
} }
sem::Expression* Resolver::TypeConstructor( sem::Call* Resolver::TypeConversion(const ast::CallExpression* expr,
const ast::TypeConstructorExpression* expr) { const sem::Type* target,
auto* ty = Type(expr->type); const sem::Expression* arg,
if (!ty) { const sem::Type* source) {
// It is not valid to have a type-cast call expression inside a call
// statement.
if (current_statement_) {
if (auto* stmt =
current_statement_->Declaration()->As<ast::CallStatement>()) {
if (stmt->expr == expr) {
AddError("type cast evaluated but not used", expr->source);
return nullptr;
}
}
}
auto* call_target = utils::GetOrCreate(
type_conversions_, TypeConversionSig{target, source},
[&]() -> sem::TypeConversion* {
// Now that the argument types have been determined, make sure that they
// obey the conversion rules laid out in
// https://gpuweb.github.io/gpuweb/wgsl/#conversion-expr.
bool ok = true;
if (auto* vec_type = target->As<sem::Vector>()) {
ok = ValidateVectorConstructorOrCast(expr, vec_type);
} else if (auto* mat_type = target->As<sem::Matrix>()) {
// Note: Matrix types currently cannot be converted (the element type
// must only be f32). We implement this for the day we support other
// matrix element types.
ok = ValidateMatrixConstructorOrCast(expr, mat_type);
} else if (target->is_scalar()) {
ok = ValidateScalarConstructorOrCast(expr, target);
} else if (auto* arr_type = target->As<sem::Array>()) {
ok = ValidateArrayConstructorOrCast(expr, arr_type);
} else if (auto* struct_type = target->As<sem::Struct>()) {
ok = ValidateStructureConstructorOrCast(expr, struct_type);
} else {
AddError("type is not constructible", expr->source);
return nullptr;
}
if (!ok) {
return nullptr;
}
auto* param = builder_->create<sem::Parameter>(
nullptr, // declaration
0, // index
source->UnwrapRef(), // type
ast::StorageClass::kNone, // storage_class
ast::Access::kUndefined); // access
return builder_->create<sem::TypeConversion>(target, param);
});
if (!call_target) {
return nullptr; return nullptr;
} }
// Now that the argument types have been determined, make sure that they auto val = EvaluateConstantValue(expr, target);
// obey the constructor type rules laid out in return builder_->create<sem::Call>(expr, call_target,
// https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr. std::vector<const sem::Expression*>{arg},
bool ok = true; current_statement_, val);
if (auto* vec_type = ty->As<sem::Vector>()) { }
ok = ValidateVectorConstructor(expr, vec_type);
} else if (auto* mat_type = ty->As<sem::Matrix>()) { sem::Call* Resolver::TypeConstructor(
ok = ValidateMatrixConstructor(expr, mat_type); const ast::CallExpression* expr,
} else if (ty->is_scalar()) { const sem::Type* ty,
ok = ValidateScalarConstructor(expr, ty); const std::vector<const sem::Expression*> args,
} else if (auto* arr_type = ty->As<sem::Array>()) { const std::vector<const sem::Type*> arg_tys) {
ok = ValidateArrayConstructor(expr, arr_type); // It is not valid to have a type-constructor call expression as a call
} else if (auto* struct_type = ty->As<sem::Struct>()) { // statement.
ok = ValidateStructureConstructor(expr, struct_type); if (current_statement_) {
} else { if (auto* stmt =
AddError("type is not constructible", expr->source); current_statement_->Declaration()->As<ast::CallStatement>()) {
return nullptr; if (stmt->expr == expr) {
AddError("type constructor evaluated but not used", expr->source);
return nullptr;
}
}
} }
if (!ok) {
auto* call_target = utils::GetOrCreate(
type_ctors_, TypeConstructorSig{ty, arg_tys},
[&]() -> sem::TypeConstructor* {
// Now that the argument types have been determined, make sure that they
// obey the constructor type rules laid out in
// https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr.
bool ok = true;
if (auto* vec_type = ty->As<sem::Vector>()) {
ok = ValidateVectorConstructorOrCast(expr, vec_type);
} else if (auto* mat_type = ty->As<sem::Matrix>()) {
ok = ValidateMatrixConstructorOrCast(expr, mat_type);
} else if (ty->is_scalar()) {
ok = ValidateScalarConstructorOrCast(expr, ty);
} else if (auto* arr_type = ty->As<sem::Array>()) {
ok = ValidateArrayConstructorOrCast(expr, arr_type);
} else if (auto* struct_type = ty->As<sem::Struct>()) {
ok = ValidateStructureConstructorOrCast(expr, struct_type);
} else {
AddError("type is not constructible", expr->source);
return nullptr;
}
if (!ok) {
return nullptr;
}
return builder_->create<sem::TypeConstructor>(
ty, utils::Transform(
arg_tys,
[&](const sem::Type* t, size_t i) -> const sem::Parameter* {
return builder_->create<sem::Parameter>(
nullptr, // declaration
i, // index
t->UnwrapRef(), // type
ast::StorageClass::kNone, // storage_class
ast::Access::kUndefined); // access
}));
});
if (!call_target) {
return nullptr; return nullptr;
} }
auto val = EvaluateConstantValue(expr, ty); auto val = EvaluateConstantValue(expr, ty);
return builder_->create<sem::Expression>(expr, ty, current_statement_, val); return builder_->create<sem::Call>(expr, call_target, std::move(args),
current_statement_, val);
} }
sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) { sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
@ -2782,26 +2917,26 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
val); val);
} }
bool Resolver::ValidateStructureConstructor( bool Resolver::ValidateStructureConstructorOrCast(
const ast::TypeConstructorExpression* ctor, const ast::CallExpression* ctor,
const sem::Struct* struct_type) { const sem::Struct* struct_type) {
if (!struct_type->IsConstructible()) { if (!struct_type->IsConstructible()) {
AddError("struct constructor has non-constructible type", ctor->source); AddError("struct constructor has non-constructible type", ctor->source);
return false; return false;
} }
if (ctor->values.size() > 0) { if (ctor->args.size() > 0) {
if (ctor->values.size() != struct_type->Members().size()) { if (ctor->args.size() != struct_type->Members().size()) {
std::string fm = std::string fm =
ctor->values.size() < struct_type->Members().size() ? "few" : "many"; ctor->args.size() < struct_type->Members().size() ? "few" : "many";
AddError("struct constructor has too " + fm + " inputs: expected " + AddError("struct constructor has too " + fm + " inputs: expected " +
std::to_string(struct_type->Members().size()) + ", found " + std::to_string(struct_type->Members().size()) + ", found " +
std::to_string(ctor->values.size()), std::to_string(ctor->args.size()),
ctor->source); ctor->source);
return false; return false;
} }
for (auto* member : struct_type->Members()) { for (auto* member : struct_type->Members()) {
auto* value = ctor->values[member->Index()]; auto* value = ctor->args[member->Index()];
auto* value_ty = TypeOf(value); auto* value_ty = TypeOf(value);
if (member->Type() != value_ty->UnwrapRef()) { if (member->Type() != value_ty->UnwrapRef()) {
AddError( AddError(
@ -2817,10 +2952,9 @@ bool Resolver::ValidateStructureConstructor(
return true; return true;
} }
bool Resolver::ValidateArrayConstructor( bool Resolver::ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
const ast::TypeConstructorExpression* ctor, const sem::Array* array_type) {
const sem::Array* array_type) { auto& values = ctor->args;
auto& values = ctor->values;
auto* elem_ty = array_type->ElemType(); auto* elem_ty = array_type->ElemType();
for (auto* value : values) { for (auto* value : values) {
auto* value_ty = TypeOf(value)->UnwrapRef(); auto* value_ty = TypeOf(value)->UnwrapRef();
@ -2839,7 +2973,7 @@ bool Resolver::ValidateArrayConstructor(
return false; return false;
} else if (!elem_ty->IsConstructible()) { } else if (!elem_ty->IsConstructible()) {
AddError("array constructor has non-constructible element type", AddError("array constructor has non-constructible element type",
ctor->type->As<ast::Array>()->type->source); ctor->source);
return false; return false;
} else if (!values.empty() && (values.size() != array_type->Count())) { } else if (!values.empty() && (values.size() != array_type->Count())) {
std::string fm = values.size() < array_type->Count() ? "few" : "many"; std::string fm = values.size() < array_type->Count() ? "few" : "many";
@ -2858,10 +2992,9 @@ bool Resolver::ValidateArrayConstructor(
return true; return true;
} }
bool Resolver::ValidateVectorConstructor( bool Resolver::ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
const ast::TypeConstructorExpression* ctor, const sem::Vector* vec_type) {
const sem::Vector* vec_type) { auto& values = ctor->args;
auto& values = ctor->values;
auto* elem_ty = vec_type->type(); auto* elem_ty = vec_type->type();
size_t value_cardinality_sum = 0; size_t value_cardinality_sum = 0;
for (auto* value : values) { for (auto* value : values) {
@ -2937,10 +3070,9 @@ bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) {
return true; return true;
} }
bool Resolver::ValidateMatrixConstructor( bool Resolver::ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
const ast::TypeConstructorExpression* ctor, const sem::Matrix* matrix_ty) {
const sem::Matrix* matrix_ty) { auto& values = ctor->args;
auto& values = ctor->values;
// Zero Value expression // Zero Value expression
if (values.empty()) { if (values.empty()) {
return true; return true;
@ -3000,21 +3132,20 @@ bool Resolver::ValidateMatrixConstructor(
return true; return true;
} }
bool Resolver::ValidateScalarConstructor( bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
const ast::TypeConstructorExpression* ctor, const sem::Type* ty) {
const sem::Type* ty) { if (ctor->args.size() == 0) {
if (ctor->values.size() == 0) {
return true; return true;
} }
if (ctor->values.size() > 1) { if (ctor->args.size() > 1) {
AddError("expected zero or one value in constructor, got " + AddError("expected zero or one value in constructor, got " +
std::to_string(ctor->values.size()), std::to_string(ctor->args.size()),
ctor->source); ctor->source);
return false; return false;
} }
// Validate constructor // Validate constructor
auto* value = ctor->values[0]; auto* value = ctor->args[0];
auto* value_ty = TypeOf(value)->UnwrapRef(); auto* value_ty = TypeOf(value)->UnwrapRef();
using Bool = sem::Bool; using Bool = sem::Bool;
@ -4547,5 +4678,37 @@ const sem::Info::GetResultType<SEM, AST_OR_TYPE>* Resolver::Sem(
return sem; return sem;
} }
////////////////////////////////////////////////////////////////////////////////
// Resolver::TypeConversionSig
////////////////////////////////////////////////////////////////////////////////
bool Resolver::TypeConversionSig::operator==(
const TypeConversionSig& rhs) const {
return target == rhs.target && source == rhs.source;
}
std::size_t Resolver::TypeConversionSig::Hasher::operator()(
const TypeConversionSig& sig) const {
return utils::Hash(sig.target, sig.source);
}
////////////////////////////////////////////////////////////////////////////////
// Resolver::TypeConstructorSig
////////////////////////////////////////////////////////////////////////////////
Resolver::TypeConstructorSig::TypeConstructorSig(
const sem::Type* ty,
const std::vector<const sem::Type*> params)
: type(ty), parameters(params) {}
Resolver::TypeConstructorSig::TypeConstructorSig(const TypeConstructorSig&) =
default;
Resolver::TypeConstructorSig::~TypeConstructorSig() = default;
bool Resolver::TypeConstructorSig::operator==(
const TypeConstructorSig& rhs) const {
return type == rhs.type && parameters == rhs.parameters;
}
std::size_t Resolver::TypeConstructorSig::Hasher::operator()(
const TypeConstructorSig& sig) const {
return utils::Hash(sig.type, sig.parameters);
}
} // namespace resolver } // namespace resolver
} // namespace tint } // namespace tint

View File

@ -59,6 +59,7 @@ class Array;
class Atomic; class Atomic;
class Intrinsic; class Intrinsic;
class Statement; class Statement;
class TypeConstructor;
} // namespace sem } // namespace sem
namespace resolver { namespace resolver {
@ -170,15 +171,26 @@ class Resolver {
sem::Expression* IndexAccessor(const ast::IndexAccessorExpression*); sem::Expression* IndexAccessor(const ast::IndexAccessorExpression*);
sem::Expression* Binary(const ast::BinaryExpression*); sem::Expression* Binary(const ast::BinaryExpression*);
sem::Expression* Bitcast(const ast::BitcastExpression*); sem::Expression* Bitcast(const ast::BitcastExpression*);
sem::Expression* Call(const ast::CallExpression*); sem::Call* Call(const ast::CallExpression*);
sem::Expression* Expression(const ast::Expression*); sem::Expression* Expression(const ast::Expression*);
sem::Function* Function(const ast::Function*); sem::Function* Function(const ast::Function*);
sem::Call* FunctionCall(const ast::CallExpression*); sem::Call* FunctionCall(const ast::CallExpression*,
const std::vector<const sem::Expression*> args);
sem::Expression* Identifier(const ast::IdentifierExpression*); sem::Expression* Identifier(const ast::IdentifierExpression*);
sem::Call* IntrinsicCall(const ast::CallExpression*, sem::IntrinsicType); sem::Call* IntrinsicCall(const ast::CallExpression*,
sem::IntrinsicType,
const std::vector<const sem::Expression*> args,
const std::vector<const sem::Type*> arg_tys);
sem::Expression* Literal(const ast::LiteralExpression*); sem::Expression* Literal(const ast::LiteralExpression*);
sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*); sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*);
sem::Expression* TypeConstructor(const ast::TypeConstructorExpression*); sem::Call* TypeConversion(const ast::CallExpression* expr,
const sem::Type* ty,
const sem::Expression* arg,
const sem::Type* arg_ty);
sem::Call* TypeConstructor(const ast::CallExpression* expr,
const sem::Type* ty,
const std::vector<const sem::Expression*> args,
const std::vector<const sem::Type*> arg_tys);
sem::Expression* UnaryOp(const ast::UnaryOpExpression*); sem::Expression* UnaryOp(const ast::UnaryOpExpression*);
// Statement resolving methods // Statement resolving methods
@ -211,13 +223,13 @@ class Resolver {
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type, const sem::Type* storage_type,
const bool is_input); const bool is_input);
bool ValidateCall(const sem::Call* call);
bool ValidateEntryPoint(const sem::Function* func); bool ValidateEntryPoint(const sem::Function* func);
bool ValidateFunction(const sem::Function* func); bool ValidateFunction(const sem::Function* func);
bool ValidateFunctionCall(const sem::Call* call); bool ValidateFunctionCall(const sem::Call* call);
bool ValidateGlobalVariable(const sem::Variable* var); bool ValidateGlobalVariable(const sem::Variable* var);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco, bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type); const sem::Type* storage_type);
bool ValidateIntrinsicCall(const sem::Call* call);
bool ValidateLocationDecoration(const ast::LocationDecoration* location, bool ValidateLocationDecoration(const ast::LocationDecoration* location,
const sem::Type* type, const sem::Type* type,
std::unordered_set<uint32_t>& locations, std::unordered_set<uint32_t>& locations,
@ -234,23 +246,23 @@ class Resolver {
bool ValidateStatements(const ast::StatementList& stmts); bool ValidateStatements(const ast::StatementList& stmts);
bool ValidateStorageTexture(const ast::StorageTexture* t); bool ValidateStorageTexture(const ast::StorageTexture* t);
bool ValidateStructure(const sem::Struct* str); bool ValidateStructure(const sem::Struct* str);
bool ValidateStructureConstructor(const ast::TypeConstructorExpression* ctor, bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor,
const sem::Struct* struct_type); const sem::Struct* struct_type);
bool ValidateSwitch(const ast::SwitchStatement* s); bool ValidateSwitch(const ast::SwitchStatement* s);
bool ValidateVariable(const sem::Variable* var); bool ValidateVariable(const sem::Variable* var);
bool ValidateVariableConstructor(const ast::Variable* var, bool ValidateVariableConstructorOrCast(const ast::Variable* var,
ast::StorageClass storage_class, ast::StorageClass storage_class,
const sem::Type* storage_type, const sem::Type* storage_type,
const sem::Type* rhs_type); const sem::Type* rhs_type);
bool ValidateVector(const sem::Vector* ty, const Source& source); bool ValidateVector(const sem::Vector* ty, const Source& source);
bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor, bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
const sem::Vector* vec_type); const sem::Vector* vec_type);
bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor, bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
const sem::Matrix* matrix_type); const sem::Matrix* matrix_type);
bool ValidateScalarConstructor(const ast::TypeConstructorExpression* ctor, bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
const sem::Type* type); const sem::Type* type);
bool ValidateArrayConstructor(const ast::TypeConstructorExpression* ctor, bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
const sem::Array* arr_type); const sem::Array* arr_type);
bool ValidateTypeDecl(const ast::TypeDecl* named_type) const; bool ValidateTypeDecl(const ast::TypeDecl* named_type) const;
bool ValidateTextureIntrinsicFunction(const sem::Call* call); bool ValidateTextureIntrinsicFunction(const sem::Call* call);
bool ValidateNoDuplicateDecorations(const ast::DecorationList& decorations); bool ValidateNoDuplicateDecorations(const ast::DecorationList& decorations);
@ -378,15 +390,46 @@ class Resolver {
const sem::Type* type); const sem::Type* type);
sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal, sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,
const sem::Type* type); const sem::Type* type);
sem::Constant EvaluateConstantValue( sem::Constant EvaluateConstantValue(const ast::CallExpression* call,
const ast::TypeConstructorExpression* type_ctor, const sem::Type* type);
const sem::Type* type);
/// Sem is a helper for obtaining the semantic node for the given AST node. /// Sem is a helper for obtaining the semantic node for the given AST node.
template <typename SEM = sem::Info::InferFromAST, template <typename SEM = sem::Info::InferFromAST,
typename AST_OR_TYPE = CastableBase> typename AST_OR_TYPE = CastableBase>
const sem::Info::GetResultType<SEM, AST_OR_TYPE>* Sem(const AST_OR_TYPE* ast); const sem::Info::GetResultType<SEM, AST_OR_TYPE>* Sem(const AST_OR_TYPE* ast);
struct TypeConversionSig {
const sem::Type* target;
const sem::Type* source;
bool operator==(const TypeConversionSig&) const;
/// Hasher provides a hash function for the TypeConversionSig
struct Hasher {
/// @param sig the TypeConversionSig to create a hash for
/// @return the hash value
std::size_t operator()(const TypeConversionSig& sig) const;
};
};
struct TypeConstructorSig {
const sem::Type* type;
const std::vector<const sem::Type*> parameters;
TypeConstructorSig(const sem::Type* ty,
const std::vector<const sem::Type*> params);
TypeConstructorSig(const TypeConstructorSig&);
~TypeConstructorSig();
bool operator==(const TypeConstructorSig&) const;
/// Hasher provides a hash function for the TypeConstructorSig
struct Hasher {
/// @param sig the TypeConstructorSig to create a hash for
/// @return the hash value
std::size_t operator()(const TypeConstructorSig& sig) const;
};
};
ProgramBuilder* const builder_; ProgramBuilder* const builder_;
diag::List& diagnostics_; diag::List& diagnostics_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_; std::unique_ptr<IntrinsicTable> const intrinsic_table_;
@ -398,6 +441,14 @@ class Resolver {
std::unordered_set<const ast::Node*> marked_; std::unordered_set<const ast::Node*> marked_;
std::unordered_map<uint32_t, const sem::Variable*> constant_ids_; std::unordered_map<uint32_t, const sem::Variable*> constant_ids_;
std::unordered_map<TypeConversionSig,
sem::CallTarget*,
TypeConversionSig::Hasher>
type_conversions_;
std::unordered_map<TypeConstructorSig,
sem::CallTarget*,
TypeConstructorSig::Hasher>
type_ctors_;
sem::Function* current_function_ = nullptr; sem::Function* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr; sem::Statement* current_statement_ = nullptr;

View File

@ -15,6 +15,7 @@
#include "src/resolver/resolver.h" #include "src/resolver/resolver.h"
#include "src/sem/constant.h" #include "src/sem/constant.h"
#include "src/sem/type_constructor.h"
#include "src/utils/get_or_create.h" #include "src/utils/get_or_create.h"
namespace tint { namespace tint {
@ -32,7 +33,7 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr,
if (auto* e = expr->As<ast::LiteralExpression>()) { if (auto* e = expr->As<ast::LiteralExpression>()) {
return EvaluateConstantValue(e, type); return EvaluateConstantValue(e, type);
} }
if (auto* e = expr->As<ast::TypeConstructorExpression>()) { if (auto* e = expr->As<ast::CallExpression>()) {
return EvaluateConstantValue(e, type); return EvaluateConstantValue(e, type);
} }
return {}; return {};
@ -57,10 +58,8 @@ sem::Constant Resolver::EvaluateConstantValue(
return {}; return {};
} }
sem::Constant Resolver::EvaluateConstantValue( sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
const ast::TypeConstructorExpression* type_ctor, const sem::Type* type) {
const sem::Type* type) {
auto& ctor_values = type_ctor->values;
auto* vec = type->As<sem::Vector>(); auto* vec = type->As<sem::Vector>();
// For now, only fold scalars and vectors // For now, only fold scalars and vectors
@ -72,7 +71,7 @@ sem::Constant Resolver::EvaluateConstantValue(
int result_size = vec ? static_cast<int>(vec->Width()) : 1; int result_size = vec ? static_cast<int>(vec->Width()) : 1;
// For zero value init, return 0s // For zero value init, return 0s
if (ctor_values.empty()) { if (call->args.empty()) {
if (elem_type->Is<sem::I32>()) { if (elem_type->Is<sem::I32>()) {
return sem::Constant(type, sem::Constant::Scalars(result_size, 0)); return sem::Constant(type, sem::Constant::Scalars(result_size, 0));
} }
@ -90,12 +89,12 @@ sem::Constant Resolver::EvaluateConstantValue(
// Build value for type_ctor from each child value by casting to // Build value for type_ctor from each child value by casting to
// type_ctor's type. // type_ctor's type.
sem::Constant::Scalars elems; sem::Constant::Scalars elems;
for (auto* cv : ctor_values) { for (auto* expr : call->args) {
auto* expr = builder_->Sem().Get(cv); auto* arg = builder_->Sem().Get(expr);
if (!expr || !expr->ConstantValue()) { if (!arg || !arg->ConstantValue()) {
return {}; return {};
} }
auto cast = ConstantCast(expr->ConstantValue(), elem_type); auto cast = ConstantCast(arg->ConstantValue(), elem_type);
elems.insert(elems.end(), cast.Elements().begin(), cast.Elements().end()); elems.insert(elems.end(), cast.Elements().begin(), cast.Elements().end());
} }

View File

@ -15,6 +15,8 @@
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "src/resolver/resolver_test_helper.h" #include "src/resolver/resolver_test_helper.h"
#include "src/sem/reference_type.h" #include "src/sem/reference_type.h"
#include "src/sem/type_constructor.h"
#include "src/sem/type_conversion.h"
namespace tint { namespace tint {
namespace resolver { namespace resolver {
@ -223,68 +225,74 @@ INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
} // namespace InferTypeTest } // namespace InferTypeTest
namespace ConversionConstructorTest { namespace ConversionConstructTest {
enum class Kind {
Construct,
Conversion,
};
struct Params { struct Params {
Kind kind;
builder::ast_type_func_ptr lhs_type; builder::ast_type_func_ptr lhs_type;
builder::ast_type_func_ptr rhs_type; builder::ast_type_func_ptr rhs_type;
builder::ast_expr_func_ptr rhs_value_expr; builder::ast_expr_func_ptr rhs_value_expr;
}; };
template <typename LhsType, typename RhsType> template <typename LhsType, typename RhsType>
constexpr Params ParamsFor() { constexpr Params ParamsFor(Kind kind) {
return Params{DataType<LhsType>::AST, DataType<RhsType>::AST, return Params{kind, DataType<LhsType>::AST, DataType<RhsType>::AST,
DataType<RhsType>::Expr}; DataType<RhsType>::Expr};
} }
static constexpr Params valid_cases[] = { static constexpr Params valid_cases[] = {
// Direct init (non-conversions) // Direct init (non-conversions)
ParamsFor<bool, bool>(), // ParamsFor<bool, bool>(Kind::Construct), //
ParamsFor<i32, i32>(), // ParamsFor<i32, i32>(Kind::Construct), //
ParamsFor<u32, u32>(), // ParamsFor<u32, u32>(Kind::Construct), //
ParamsFor<f32, f32>(), // ParamsFor<f32, f32>(Kind::Construct), //
ParamsFor<vec3<bool>, vec3<bool>>(), // ParamsFor<vec3<bool>, vec3<bool>>(Kind::Construct), //
ParamsFor<vec3<i32>, vec3<i32>>(), // ParamsFor<vec3<i32>, vec3<i32>>(Kind::Construct), //
ParamsFor<vec3<u32>, vec3<u32>>(), // ParamsFor<vec3<u32>, vec3<u32>>(Kind::Construct), //
ParamsFor<vec3<f32>, vec3<f32>>(), // ParamsFor<vec3<f32>, vec3<f32>>(Kind::Construct), //
// Splat // Splat
ParamsFor<vec3<bool>, bool>(), // ParamsFor<vec3<bool>, bool>(Kind::Construct), //
ParamsFor<vec3<i32>, i32>(), // ParamsFor<vec3<i32>, i32>(Kind::Construct), //
ParamsFor<vec3<u32>, u32>(), // ParamsFor<vec3<u32>, u32>(Kind::Construct), //
ParamsFor<vec3<f32>, f32>(), // ParamsFor<vec3<f32>, f32>(Kind::Construct), //
// Conversion // Conversion
ParamsFor<bool, u32>(), // ParamsFor<bool, u32>(Kind::Conversion), //
ParamsFor<bool, i32>(), // ParamsFor<bool, i32>(Kind::Conversion), //
ParamsFor<bool, f32>(), // ParamsFor<bool, f32>(Kind::Conversion), //
ParamsFor<i32, bool>(), // ParamsFor<i32, bool>(Kind::Conversion), //
ParamsFor<i32, u32>(), // ParamsFor<i32, u32>(Kind::Conversion), //
ParamsFor<i32, f32>(), // ParamsFor<i32, f32>(Kind::Conversion), //
ParamsFor<u32, bool>(), // ParamsFor<u32, bool>(Kind::Conversion), //
ParamsFor<u32, i32>(), // ParamsFor<u32, i32>(Kind::Conversion), //
ParamsFor<u32, f32>(), // ParamsFor<u32, f32>(Kind::Conversion), //
ParamsFor<f32, bool>(), // ParamsFor<f32, bool>(Kind::Conversion), //
ParamsFor<f32, u32>(), // ParamsFor<f32, u32>(Kind::Conversion), //
ParamsFor<f32, i32>(), // ParamsFor<f32, i32>(Kind::Conversion), //
ParamsFor<vec3<bool>, vec3<u32>>(), // ParamsFor<vec3<bool>, vec3<u32>>(Kind::Conversion), //
ParamsFor<vec3<bool>, vec3<i32>>(), // ParamsFor<vec3<bool>, vec3<i32>>(Kind::Conversion), //
ParamsFor<vec3<bool>, vec3<f32>>(), // ParamsFor<vec3<bool>, vec3<f32>>(Kind::Conversion), //
ParamsFor<vec3<i32>, vec3<bool>>(), // ParamsFor<vec3<i32>, vec3<bool>>(Kind::Conversion), //
ParamsFor<vec3<i32>, vec3<u32>>(), // ParamsFor<vec3<i32>, vec3<u32>>(Kind::Conversion), //
ParamsFor<vec3<i32>, vec3<f32>>(), // ParamsFor<vec3<i32>, vec3<f32>>(Kind::Conversion), //
ParamsFor<vec3<u32>, vec3<bool>>(), // ParamsFor<vec3<u32>, vec3<bool>>(Kind::Conversion), //
ParamsFor<vec3<u32>, vec3<i32>>(), // ParamsFor<vec3<u32>, vec3<i32>>(Kind::Conversion), //
ParamsFor<vec3<u32>, vec3<f32>>(), // ParamsFor<vec3<u32>, vec3<f32>>(Kind::Conversion), //
ParamsFor<vec3<f32>, vec3<bool>>(), // ParamsFor<vec3<f32>, vec3<bool>>(Kind::Conversion), //
ParamsFor<vec3<f32>, vec3<u32>>(), // ParamsFor<vec3<f32>, vec3<u32>>(Kind::Conversion), //
ParamsFor<vec3<f32>, vec3<i32>>(), // ParamsFor<vec3<f32>, vec3<i32>>(Kind::Conversion), //
}; };
using ConversionConstructorValidTest = ResolverTestWithParam<Params>; using ConversionConstructorValidTest = ResolverTestWithParam<Params>;
@ -302,8 +310,9 @@ TEST_P(ConversionConstructorValidTest, All) {
<< FriendlyName(rhs_type) << "(<rhs value expr>))"; << FriendlyName(rhs_type) << "(<rhs value expr>))";
SCOPED_TRACE(ss.str()); SCOPED_TRACE(ss.str());
auto* a = Var("a", lhs_type1, ast::StorageClass::kNone, auto* arg = Construct(rhs_type, rhs_value_expr);
Construct(lhs_type2, Construct(rhs_type, rhs_value_expr))); auto* tc = Construct(lhs_type2, arg);
auto* a = Var("a", lhs_type1, ast::StorageClass::kNone, tc);
// Self-assign 'a' to force the expression to be resolved so we can test its // Self-assign 'a' to force the expression to be resolved so we can test its
// type below // type below
@ -311,6 +320,27 @@ TEST_P(ConversionConstructorValidTest, All) {
WrapInFunction(Decl(a), Assign(a_ident, "a")); WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
switch (params.kind) {
case Kind::Construct: {
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 1u);
EXPECT_EQ(ctor->Parameters()[0]->Type(), TypeOf(arg));
break;
}
case Kind::Conversion: {
auto* conv = call->Target()->As<sem::TypeConversion>();
ASSERT_NE(conv, nullptr);
EXPECT_EQ(call->Type(), conv->ReturnType());
ASSERT_EQ(conv->Parameters().size(), 1u);
EXPECT_EQ(conv->Parameters()[0]->Type(), TypeOf(arg));
break;
}
}
} }
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
ConversionConstructorValidTest, ConversionConstructorValidTest,
@ -408,7 +438,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
"'array<f32, 4>'"); "'array<f32, 4>'");
} }
} // namespace ConversionConstructorTest } // namespace ConversionConstructTest
namespace ArrayConstructor { namespace ArrayConstructor {
@ -418,7 +448,15 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = array<u32, 10>(); auto* tc = array<u32, 10>();
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()); ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
EXPECT_TRUE(call->Type()->Is<sem::Array>());
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 0u);
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -427,7 +465,18 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = array<u32, 3>(Expr(0u), Expr(10u), Expr(20u)); auto* tc = array<u32, 3>(Expr(0u), Expr(10u), Expr(20u));
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()); ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
EXPECT_TRUE(call->Type()->Is<sem::Array>());
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 3u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::U32>());
EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::U32>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -587,6 +636,118 @@ TEST_F(ResolverTypeConstructorValidationTest,
} // namespace ArrayConstructor } // namespace ArrayConstructor
namespace ScalarConstructor {
TEST_F(ResolverTypeConstructorValidationTest, Expr_Construct_i32_Success) {
auto* expr = Construct<i32>(Expr(123));
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<sem::I32>());
auto* call = Sem().Get(expr);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 1u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
}
TEST_F(ResolverTypeConstructorValidationTest, Expr_Construct_u32_Success) {
auto* expr = Construct<u32>(Expr(123u));
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<sem::U32>());
auto* call = Sem().Get(expr);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 1u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
}
TEST_F(ResolverTypeConstructorValidationTest, Expr_Construct_f32_Success) {
auto* expr = Construct<f32>(Expr(1.23f));
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<sem::F32>());
auto* call = Sem().Get(expr);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 1u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
}
TEST_F(ResolverTypeConstructorValidationTest, Expr_Convert_f32_to_i32_Success) {
auto* expr = Construct<i32>(1.23f);
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<sem::I32>());
auto* call = Sem().Get(expr);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConversion>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 1u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
}
TEST_F(ResolverTypeConstructorValidationTest, Expr_Convert_i32_to_u32_Success) {
auto* expr = Construct<u32>(123);
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<sem::U32>());
auto* call = Sem().Get(expr);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConversion>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 1u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
}
TEST_F(ResolverTypeConstructorValidationTest, Expr_Convert_u32_to_f32_Success) {
auto* expr = Construct<f32>(123u);
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<sem::F32>());
auto* call = Sem().Get(expr);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConversion>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 1u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
}
} // namespace ScalarConstructor
namespace VectorConstructor { namespace VectorConstructor {
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -708,12 +869,19 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec2<f32>(); auto* tc = vec2<f32>();
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 0u);
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -721,12 +889,21 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec2<f32>(1.0f, 1.0f); auto* tc = vec2<f32>(1.0f, 1.0f);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 2u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::F32>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -734,12 +911,21 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec2<u32>(1u, 1u); auto* tc = vec2<u32>(1u, 1u);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::U32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::U32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 2u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::U32>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -747,12 +933,21 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec2<i32>(1, 1); auto* tc = vec2<i32>(1, 1);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::I32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 2u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -760,12 +955,21 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec2<bool>(true, false); auto* tc = vec2<bool>(true, false);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::Bool>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::Bool>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 2u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Bool>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::Bool>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -773,12 +977,20 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec2<f32>(vec2<f32>()); auto* tc = vec2<f32>(vec2<f32>());
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 1u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -786,12 +998,20 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec2<f32>(vec2<i32>()); auto* tc = vec2<f32>(vec2<i32>());
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConversion>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 1u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -938,12 +1158,19 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec3<f32>(); auto* tc = vec3<f32>();
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 0u);
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -951,12 +1178,22 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec3<f32>(1.0f, 1.0f, 1.0f); auto* tc = vec3<f32>(1.0f, 1.0f, 1.0f);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 3u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::F32>());
EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::F32>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -964,12 +1201,22 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec3<u32>(1u, 1u, 1u); auto* tc = vec3<u32>(1u, 1u, 1u);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::U32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::U32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 3u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::U32>());
EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::U32>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -977,12 +1224,22 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec3<i32>(1, 1, 1); auto* tc = vec3<i32>(1, 1, 1);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::I32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 3u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -990,12 +1247,22 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec3<bool>(true, false, true); auto* tc = vec3<bool>(true, false, true);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::Bool>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::Bool>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 3u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Bool>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::Bool>());
EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::Bool>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -1003,12 +1270,21 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec3<f32>(vec2<f32>(), 1.0f); auto* tc = vec3<f32>(vec2<f32>(), 1.0f);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 2u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::F32>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -1016,12 +1292,21 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec3<f32>(1.0f, vec2<f32>()); auto* tc = vec3<f32>(1.0f, vec2<f32>());
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 2u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::Vector>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -1029,12 +1314,20 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec3<f32>(vec3<f32>()); auto* tc = vec3<f32>(vec3<f32>());
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 1u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -1042,12 +1335,20 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec3<f32>(vec3<i32>()); auto* tc = vec3<f32>(vec3<i32>());
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u); EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
auto* call = Sem().Get(tc);
ASSERT_NE(call, nullptr);
auto* ctor = call->Target()->As<sem::TypeConversion>();
ASSERT_NE(ctor, nullptr);
EXPECT_EQ(call->Type(), ctor->ReturnType());
ASSERT_EQ(ctor->Parameters().size(), 1u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -1248,7 +1549,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<f32>(); auto* tc = vec4<f32>();
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1261,7 +1562,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f); auto* tc = vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1274,7 +1575,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<u32>(1u, 1u, 1u, 1u); auto* tc = vec4<u32>(1u, 1u, 1u, 1u);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1287,7 +1588,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<i32>(1, 1, 1, 1); auto* tc = vec4<i32>(1, 1, 1, 1);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1300,7 +1601,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<bool>(true, false, true, false); auto* tc = vec4<bool>(true, false, true, false);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1313,7 +1614,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<f32>(vec2<f32>(), 1.0f, 1.0f); auto* tc = vec4<f32>(vec2<f32>(), 1.0f, 1.0f);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1326,7 +1627,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<f32>(1.0f, vec2<f32>(), 1.0f); auto* tc = vec4<f32>(1.0f, vec2<f32>(), 1.0f);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1339,7 +1640,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<f32>(1.0f, 1.0f, vec2<f32>()); auto* tc = vec4<f32>(1.0f, 1.0f, vec2<f32>());
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1352,7 +1653,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<f32>(vec2<f32>(), vec2<f32>()); auto* tc = vec4<f32>(vec2<f32>(), vec2<f32>());
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1365,7 +1666,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<f32>(vec3<f32>(), 1.0f); auto* tc = vec4<f32>(vec3<f32>(), 1.0f);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1378,7 +1679,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<f32>(1.0f, vec3<f32>()); auto* tc = vec4<f32>(1.0f, vec3<f32>());
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1391,7 +1692,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<f32>(vec4<f32>()); auto* tc = vec4<f32>(vec4<f32>());
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1404,7 +1705,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<f32>(vec4<i32>()); auto* tc = vec4<f32>(vec4<i32>());
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1431,7 +1732,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec4<f32>(vec3<f32>(vec2<f32>(1.0f, 1.0f), 1.0f), 1.0f); auto* tc = vec4<f32>(vec3<f32>(vec2<f32>(1.0f, 1.0f), 1.0f), 1.0f);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr); ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>()); ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@ -1462,7 +1763,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec3<f32>("my_vec2", "my_f32"); auto* tc = vec3<f32>("my_vec2", "my_f32");
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -1490,7 +1791,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = Construct(Source{{12, 34}}, vec_type, 1.0f, 1.0f); auto* tc = Construct(Source{{12, 34}}, vec_type, 1.0f, 1.0f);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -1517,7 +1818,7 @@ TEST_F(ResolverTypeConstructorValidationTest,
auto* tc = vec3<f32>(Construct(Source{{12, 34}}, vec_type), 1.0f); auto* tc = vec3<f32>(Construct(Source{{12, 34}}, vec_type), 1.0f);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
} // namespace VectorConstructor } // namespace VectorConstructor
@ -1728,7 +2029,7 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ZeroValue_Success) {
auto* tc = Construct(Source{{12, 40}}, matrix_type); auto* tc = Construct(Source{{12, 40}}, matrix_type);
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_P(MatrixConstructorTest, Expr_Constructor_WithColumns_Success) { TEST_P(MatrixConstructorTest, Expr_Constructor_WithColumns_Success) {
@ -1746,7 +2047,7 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_WithColumns_Success) {
auto* tc = Construct(Source{}, matrix_type, std::move(args)); auto* tc = Construct(Source{}, matrix_type, std::move(args));
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_P(MatrixConstructorTest, Expr_Constructor_WithElements_Success) { TEST_P(MatrixConstructorTest, Expr_Constructor_WithElements_Success) {
@ -1763,7 +2064,7 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_WithElements_Success) {
auto* tc = Construct(Source{}, matrix_type, std::move(args)); auto* tc = Construct(Source{}, matrix_type, std::move(args));
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) { TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
@ -1804,7 +2105,7 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {
auto* tc = Construct(Source{}, matrix_type, std::move(args)); auto* tc = Construct(Source{}, matrix_type, std::move(args));
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
@ -1839,7 +2140,7 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentTypeAlias_Success) {
auto* tc = Construct(Source{}, matrix_type, std::move(args)); auto* tc = Construct(Source{}, matrix_type, std::move(args));
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentElementTypeAlias_Error) { TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentElementTypeAlias_Error) {
@ -1877,7 +2178,7 @@ TEST_P(MatrixConstructorTest,
auto* tc = Construct(Source{}, matrix_type, std::move(args)); auto* tc = Construct(Source{}, matrix_type, std::move(args));
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
@ -2044,7 +2345,7 @@ TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Struct) {
auto* s = Structure("MyInputs", {m}); auto* s = Structure("MyInputs", {m});
auto* tc = Construct(Source{{12, 34}}, ty.Of(s)); auto* tc = Construct(Source{{12, 34}}, ty.Of(s));
WrapInFunction(tc); WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Struct_Empty) { TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Struct_Empty) {
@ -2055,7 +2356,7 @@ TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Struct_Empty) {
}); });
WrapInFunction(Construct(ty.Of(str))); WrapInFunction(Construct(ty.Of(str)));
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
} // namespace StructConstructor } // namespace StructConstructor
@ -2070,7 +2371,7 @@ TEST_F(ResolverTypeConstructorValidationTest, NonConstructibleType_Atomic) {
TEST_F(ResolverTypeConstructorValidationTest, TEST_F(ResolverTypeConstructorValidationTest,
NonConstructibleType_AtomicArray) { NonConstructibleType_AtomicArray) {
WrapInFunction(Call( WrapInFunction(Call(
"ignore", Construct(ty.array(ty.atomic(Source{{12, 34}}, ty.i32()), 4)))); "ignore", Construct(Source{{12, 34}}, ty.array(ty.atomic(ty.i32()), 4))));
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ( EXPECT_EQ(
@ -2097,6 +2398,22 @@ TEST_F(ResolverTypeConstructorValidationTest, NonConstructibleType_Sampler) {
EXPECT_EQ(r()->error(), "12:34 error: type is not constructible"); EXPECT_EQ(r()->error(), "12:34 error: type is not constructible");
} }
TEST_F(ResolverTypeConstructorValidationTest, TypeConstructorAsStatement) {
WrapInFunction(
CallStmt(Construct(Source{{12, 34}}, ty.vec2<f32>(), 1.f, 2.f)));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: type constructor evaluated but not used");
}
TEST_F(ResolverTypeConstructorValidationTest, TypeConversionAsStatement) {
WrapInFunction(CallStmt(Construct(Source{{12, 34}}, ty.f32(), 1)));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: type cast evaluated but not used");
}
} // namespace } // namespace
} // namespace resolver } // namespace resolver
} // namespace tint } // namespace tint

View File

@ -12,17 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "src/sem/type_cast.h" #include "src/sem/type_conversion.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::TypeCast); TINT_INSTANTIATE_TYPEINFO(tint::sem::TypeConversion);
namespace tint { namespace tint {
namespace sem { namespace sem {
TypeCast::TypeCast(const sem::Type* type, const sem::Parameter* parameter) TypeConversion::TypeConversion(const sem::Type* type,
const sem::Parameter* parameter)
: Base(type, ParameterList{parameter}) {} : Base(type, ParameterList{parameter}) {}
TypeCast::~TypeCast() = default; TypeConversion::~TypeConversion() = default;
} // namespace sem } // namespace sem
} // namespace tint } // namespace tint

View File

@ -12,24 +12,24 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef SRC_SEM_TYPE_CAST_H_ #ifndef SRC_SEM_TYPE_CONVERSION_H_
#define SRC_SEM_TYPE_CAST_H_ #define SRC_SEM_TYPE_CONVERSION_H_
#include "src/sem/call_target.h" #include "src/sem/call_target.h"
namespace tint { namespace tint {
namespace sem { namespace sem {
/// TypeCast is the CallTarget for a type cast. /// TypeConversion is the CallTarget for a type conversion (cast).
class TypeCast : public Castable<TypeCast, CallTarget> { class TypeConversion : public Castable<TypeConversion, CallTarget> {
public: public:
/// Constructor /// Constructor
/// @param type the target type of the cast /// @param type the target type of the cast
/// @param parameter the type cast parameter /// @param parameter the type cast parameter
TypeCast(const sem::Type* type, const sem::Parameter* parameter); TypeConversion(const sem::Type* type, const sem::Parameter* parameter);
/// Destructor /// Destructor
~TypeCast() override; ~TypeConversion() override;
/// @returns the cast source type /// @returns the cast source type
const sem::Type* Source() const { return Parameters()[0]->Type(); } const sem::Type* Source() const { return Parameters()[0]->Type(); }
@ -41,4 +41,4 @@ class TypeCast : public Castable<TypeCast, CallTarget> {
} // namespace sem } // namespace sem
} // namespace tint } // namespace tint
#endif // SRC_SEM_TYPE_CAST_H_ #endif // SRC_SEM_TYPE_CONVERSION_H_

View File

@ -78,7 +78,7 @@ void ExternalTextureTransform::Run(CloneContext& ctx,
// Replace the call with another that has the same parameters in // Replace the call with another that has the same parameters in
// addition to a level parameter (always zero for external // addition to a level parameter (always zero for external
// textures). // textures).
auto* exp = ctx.Clone(call_expr->func); auto* exp = ctx.Clone(call_expr->target.name);
auto* externalTextureParam = ctx.Clone(call_expr->args[0]); auto* externalTextureParam = ctx.Clone(call_expr->args[0]);
ast::ExpressionList params; ast::ExpressionList params;

View File

@ -19,7 +19,10 @@
#include <vector> #include <vector>
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/sem/call.h"
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/sem/type_constructor.h"
#include "src/sem/type_conversion.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::FoldConstants); TINT_INSTANTIATE_TYPEINFO(tint::transform::FoldConstants);
@ -32,26 +35,25 @@ FoldConstants::~FoldConstants() = default;
void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) { void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* { ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
auto* sem = ctx.src->Sem().Get(expr); auto* call = ctx.src->Sem().Get<sem::Call>(expr);
if (!sem) { if (!call) {
return nullptr; return nullptr;
} }
auto value = sem->ConstantValue(); auto value = call->ConstantValue();
if (!value.IsValid()) { if (!value.IsValid()) {
return nullptr; return nullptr;
} }
auto* ty = sem->Type(); auto* ty = call->Type();
auto* ctor = expr->As<ast::TypeConstructorExpression>(); if (!call->Target()->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
if (!ctor) {
return nullptr; return nullptr;
} }
// If original ctor expression had no init values, don't replace the // If original ctor expression had no init values, don't replace the
// expression // expression
if (ctor->values.size() == 0) { if (call->Arguments().empty()) {
return nullptr; return nullptr;
} }
@ -68,7 +70,7 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) {
// create it with 3. So what we do is construct with vec_size args, // create it with 3. So what we do is construct with vec_size args,
// except if the original vector was single-value initialized, in // except if the original vector was single-value initialized, in
// which case, we only construct with one arg again. // which case, we only construct with one arg again.
uint32_t ctor_size = (ctor->values.size() == 1) ? 1 : vec_size; uint32_t ctor_size = (call->Arguments().size() == 1) ? 1 : vec_size;
ast::ExpressionList ctors; ast::ExpressionList ctors;
for (uint32_t i = 0; i < ctor_size; ++i) { for (uint32_t i = 0; i < ctor_size; ++i) {

View File

@ -307,7 +307,8 @@ struct ModuleScopeVarToEntryPointParam::State {
// Pass the variables as pointers to any functions that need them. // Pass the variables as pointers to any functions that need them.
for (auto* call : calls_to_replace[func_ast]) { for (auto* call : calls_to_replace[func_ast]) {
auto* target = ctx.src->AST().Functions().Find(call->func->symbol); auto* target =
ctx.src->AST().Functions().Find(call->target.name->symbol);
auto* target_sem = ctx.src->Sem().Get(target); auto* target_sem = ctx.src->Sem().Get(target);
// Add new arguments for any variables that are needed by the callee. // Add new arguments for any variables that are needed by the callee.

View File

@ -19,7 +19,9 @@
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/sem/array.h" #include "src/sem/array.h"
#include "src/sem/call.h"
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/sem/type_constructor.h"
#include "src/utils/get_or_create.h" #include "src/utils/get_or_create.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::PadArrayElements); TINT_INSTANTIATE_TYPEINFO(tint::transform::PadArrayElements);
@ -131,26 +133,29 @@ void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) {
// Fix up array constructors so `A(1,2)` becomes // Fix up array constructors so `A(1,2)` becomes
// `A(padded(1), padded(2))` // `A(padded(1), padded(2))`
ctx.ReplaceAll([&](const ast::TypeConstructorExpression* ctor) ctx.ReplaceAll(
-> const ast::Expression* { [&](const ast::CallExpression* expr) -> const ast::Expression* {
if (auto* array = auto* call = sem.Get(expr);
tint::As<sem::Array>(sem.Get(ctor)->Type()->UnwrapRef())) { if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
if (auto p = pad(array)) { if (auto* array = ctor->ReturnType()->As<sem::Array>()) {
auto* arr_ty = p(); if (auto p = pad(array)) {
auto el_typename = arr_ty->type->As<ast::TypeName>()->name; auto* arr_ty = p();
auto el_typename = arr_ty->type->As<ast::TypeName>()->name;
ast::ExpressionList args; ast::ExpressionList args;
args.reserve(ctor->values.size()); args.reserve(call->Arguments().size());
for (auto* arg : ctor->values) { for (auto* arg : call->Arguments()) {
args.emplace_back(ctx.dst->Construct( auto* val = ctx.Clone(arg->Declaration());
ctx.dst->create<ast::TypeName>(el_typename), ctx.Clone(arg))); args.emplace_back(ctx.dst->Construct(
ctx.dst->create<ast::TypeName>(el_typename), val));
}
return ctx.dst->Construct(arr_ty, args);
}
}
} }
return nullptr;
return ctx.dst->Construct(arr_ty, args); });
}
}
return nullptr;
});
ctx.Clone(); ctx.Clone();
} }

View File

@ -18,8 +18,10 @@
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/sem/block_statement.h" #include "src/sem/block_statement.h"
#include "src/sem/call.h"
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/sem/statement.h" #include "src/sem/statement.h"
#include "src/sem/type_constructor.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteInitializersToConstVar); TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteInitializersToConstVar);
@ -50,14 +52,12 @@ void PromoteInitializersToConstVar::Run(CloneContext& ctx,
// pointer can be passed to the parent's constructor. // pointer can be passed to the parent's constructor.
for (auto* src_node : ctx.src->ASTNodes().Objects()) { for (auto* src_node : ctx.src->ASTNodes().Objects()) {
if (auto* src_init = src_node->As<ast::TypeConstructorExpression>()) { if (auto* src_init = src_node->As<ast::CallExpression>()) {
auto* src_sem_expr = ctx.src->Sem().Get(src_init); auto* call = ctx.src->Sem().Get(src_init);
if (!src_sem_expr) { if (!call->Target()->Is<sem::TypeConstructor>()) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "ast::TypeConstructorExpression has no semantic expression node";
continue; continue;
} }
auto* src_sem_stmt = src_sem_expr->Stmt(); auto* src_sem_stmt = call->Stmt();
if (!src_sem_stmt) { if (!src_sem_stmt) {
// Expression is outside of a statement. This usually means the // Expression is outside of a statement. This usually means the
// expression is part of a global (module-scope) constant declaration. // expression is part of a global (module-scope) constant declaration.
@ -76,12 +76,12 @@ void PromoteInitializersToConstVar::Run(CloneContext& ctx,
} }
} }
auto* src_ty = src_sem_expr->Type(); auto* src_ty = call->Type();
if (src_ty->IsAnyOf<sem::Array, sem::Struct>()) { if (src_ty->IsAnyOf<sem::Array, sem::Struct>()) {
// Create a new symbol for the constant // Create a new symbol for the constant
auto dst_symbol = ctx.dst->Sym(); auto dst_symbol = ctx.dst->Sym();
// Clone the type // Clone the type
auto* dst_ty = ctx.Clone(src_init->type); auto* dst_ty = CreateASTTypeFor(ctx, call->Type());
// Clone the initializer // Clone the initializer
auto* dst_init = ctx.Clone(src_init); auto* dst_init = ctx.Clone(src_init);
// Construct the constant that holds the hoisted initializer // Construct the constant that holds the hoisted initializer

View File

@ -30,7 +30,7 @@ fn main() {
var f1 : f32 = 2.0; var f1 : f32 = 2.0;
var f2 : f32 = 3.0; var f2 : f32 = 3.0;
var f3 : f32 = 4.0; var f3 : f32 = 4.0;
var i : f32 = array<f32, 4>(f0, f1, f2, f3)[2]; var i : f32 = array<f32, 4u>(f0, f1, f2, f3)[2];
} }
)"; )";
@ -41,7 +41,7 @@ fn main() {
var f1 : f32 = 2.0; var f1 : f32 = 2.0;
var f2 : f32 = 3.0; var f2 : f32 = 3.0;
var f3 : f32 = 4.0; var f3 : f32 = 4.0;
let tint_symbol : array<f32, 4> = array<f32, 4>(f0, f1, f2, f3); let tint_symbol : array<f32, 4u> = array<f32, 4u>(f0, f1, f2, f3);
var i : f32 = tint_symbol[2]; var i : f32 = tint_symbol[2];
} }
)"; )";
@ -88,16 +88,16 @@ TEST_F(PromoteInitializersToConstVarTest, ArrayInArrayArray) {
auto* src = R"( auto* src = R"(
[[stage(compute), workgroup_size(1)]] [[stage(compute), workgroup_size(1)]]
fn main() { fn main() {
var i : f32 = array<array<f32, 2>, 2>(array<f32, 2>(1.0, 2.0), array<f32, 2>(3.0, 4.0))[0][1]; var i : f32 = array<array<f32, 2u>, 2u>(array<f32, 2u>(1.0, 2.0), array<f32, 2u>(3.0, 4.0))[0][1];
} }
)"; )";
auto* expect = R"( auto* expect = R"(
[[stage(compute), workgroup_size(1)]] [[stage(compute), workgroup_size(1)]]
fn main() { fn main() {
let tint_symbol : array<f32, 2> = array<f32, 2>(1.0, 2.0); let tint_symbol : array<f32, 2u> = array<f32, 2u>(1.0, 2.0);
let tint_symbol_1 : array<f32, 2> = array<f32, 2>(3.0, 4.0); let tint_symbol_1 : array<f32, 2u> = array<f32, 2u>(3.0, 4.0);
let tint_symbol_2 : array<array<f32, 2>, 2> = array<array<f32, 2>, 2>(tint_symbol, tint_symbol_1); let tint_symbol_2 : array<array<f32, 2u>, 2u> = array<array<f32, 2u>, 2u>(tint_symbol, tint_symbol_1);
var i : f32 = tint_symbol_2[0][1]; var i : f32 = tint_symbol_2[0][1];
} }
)"; )";
@ -165,12 +165,12 @@ struct S1 {
}; };
struct S2 { struct S2 {
a : array<S1, 3>; a : array<S1, 3u>;
}; };
[[stage(compute), workgroup_size(1)]] [[stage(compute), workgroup_size(1)]]
fn main() { fn main() {
var x : i32 = S2(array<S1, 3>(S1(1), S1(2), S1(3))).a[1].a; var x : i32 = S2(array<S1, 3u>(S1(1), S1(2), S1(3))).a[1].a;
} }
)"; )";
@ -180,7 +180,7 @@ struct S1 {
}; };
struct S2 { struct S2 {
a : array<S1, 3>; a : array<S1, 3u>;
}; };
[[stage(compute), workgroup_size(1)]] [[stage(compute), workgroup_size(1)]]
@ -188,7 +188,7 @@ fn main() {
let tint_symbol : S1 = S1(1); let tint_symbol : S1 = S1(1);
let tint_symbol_1 : S1 = S1(2); let tint_symbol_1 : S1 = S1(2);
let tint_symbol_2 : S1 = S1(3); let tint_symbol_2 : S1 = S1(3);
let tint_symbol_3 : array<S1, 3> = array<S1, 3>(tint_symbol, tint_symbol_1, tint_symbol_2); let tint_symbol_3 : array<S1, 3u> = array<S1, 3u>(tint_symbol, tint_symbol_1, tint_symbol_2);
let tint_symbol_4 : S2 = S2(tint_symbol_3); let tint_symbol_4 : S2 = S2(tint_symbol_3);
var x : i32 = tint_symbol_4.a[1].a; var x : i32 = tint_symbol_4.a[1].a;
} }
@ -209,11 +209,11 @@ struct S {
[[stage(compute), workgroup_size(1)]] [[stage(compute), workgroup_size(1)]]
fn main() { fn main() {
var local_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0); var local_arr : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
var local_str : S = S(1, 2.0, 3); var local_str : S = S(1, 2.0, 3);
} }
let module_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0); let module_arr : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
let module_str : S = S(1, 2.0, 3); let module_str : S = S(1, 2.0, 3);
)"; )";

View File

@ -1285,7 +1285,7 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) {
continue; continue;
} }
if (sem->Target()->Is<sem::Intrinsic>()) { if (sem->Target()->Is<sem::Intrinsic>()) {
preserve.emplace(call->func); preserve.emplace(call->target.name);
} }
} }
} }

View File

@ -17,7 +17,9 @@
#include <utility> #include <utility>
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/sem/call.h"
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/sem/type_constructor.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixConstructors); TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixConstructors);
@ -33,38 +35,44 @@ VectorizeScalarMatrixConstructors::~VectorizeScalarMatrixConstructors() =
void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) { DataMap&) {
ctx.ReplaceAll([&](const ast::TypeConstructorExpression* constructor) ctx.ReplaceAll(
-> const ast::TypeConstructorExpression* { [&](const ast::CallExpression* expr) -> const ast::CallExpression* {
// Check if this is a matrix constructor with scalar arguments. auto* call = ctx.src->Sem().Get(expr);
auto* mat_type = ctx.src->Sem().Get(constructor->type)->As<sem::Matrix>(); auto* ty_ctor = call->Target()->As<sem::TypeConstructor>();
if (!mat_type) { if (!ty_ctor) {
return nullptr; return nullptr;
} }
if (constructor->values.size() == 0) { // Check if this is a matrix constructor with scalar arguments.
return nullptr; auto* mat_type = call->Type()->As<sem::Matrix>();
} if (!mat_type) {
if (!ctx.src->Sem().Get(constructor->values[0])->Type()->is_scalar()) { return nullptr;
return nullptr; }
}
// Build a list of vector expressions for each column. auto& args = call->Arguments();
ast::ExpressionList columns; if (args.size() == 0) {
for (uint32_t c = 0; c < mat_type->columns(); c++) { return nullptr;
// Build a list of scalar expressions for each value in the column. }
ast::ExpressionList row_values; if (!args[0]->Type()->is_scalar()) {
for (uint32_t r = 0; r < mat_type->rows(); r++) { return nullptr;
row_values.push_back( }
ctx.Clone(constructor->values[c * mat_type->rows() + r]));
}
// Construct the column vector. // Build a list of vector expressions for each column.
auto* col = ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), ast::ExpressionList columns;
mat_type->rows(), row_values); for (uint32_t c = 0; c < mat_type->columns(); c++) {
columns.push_back(col); // Build a list of scalar expressions for each value in the column.
} ast::ExpressionList row_values;
for (uint32_t r = 0; r < mat_type->rows(); r++) {
row_values.push_back(
ctx.Clone(args[c * mat_type->rows() + r]->Declaration()));
}
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns); // Construct the column vector.
}); auto* col = ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()),
mat_type->rows(), row_values);
columns.push_back(col);
}
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
});
ctx.Clone(); ctx.Clone();
} }

View File

@ -715,8 +715,8 @@ struct State {
LoadPrimitive(array_base, primitive_offset, buffer, base_format)); LoadPrimitive(array_base, primitive_offset, buffer, base_format));
} }
return ctx.dst->Construct( return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count),
ctx.dst->create<ast::Vector>(base_type, count), std::move(expr_list)); std::move(expr_list));
} }
/// Process a non-struct entry point parameter. /// Process a non-struct entry point parameter.

View File

@ -18,8 +18,11 @@
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/sem/array.h" #include "src/sem/array.h"
#include "src/sem/call.h"
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/sem/type_constructor.h"
#include "src/utils/get_or_create.h" #include "src/utils/get_or_create.h"
#include "src/utils/transform.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::WrapArraysInStructs); TINT_INSTANTIATE_TYPEINFO(tint::transform::WrapArraysInStructs);
@ -74,21 +77,28 @@ void WrapArraysInStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) {
}); });
// Fix up array constructors so `A(1,2)` becomes `tint_array_wrapper(A(1,2))` // Fix up array constructors so `A(1,2)` becomes `tint_array_wrapper(A(1,2))`
ctx.ReplaceAll([&](const ast::TypeConstructorExpression* ctor) ctx.ReplaceAll(
-> const ast::Expression* { [&](const ast::CallExpression* expr) -> const ast::Expression* {
if (auto* array = if (auto* call = sem.Get(expr)) {
::tint::As<sem::Array>(sem.Get(ctor)->Type()->UnwrapRef())) { if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
if (auto w = wrapper(array)) { if (auto* array = ctor->ReturnType()->As<sem::Array>()) {
// Wrap the array type constructor with another constructor for if (auto w = wrapper(array)) {
// the wrapper // Wrap the array type constructor with another constructor for
auto* wrapped_array_ty = ctx.Clone(ctor->type); // the wrapper
auto* array_ty = w.array_type(ctx); auto* wrapped_array_ty = ctx.dst->ty.type_name(w.wrapper_name);
auto* arr_ctor = ctx.dst->Construct(array_ty, ctx.Clone(ctor->values)); auto* array_ty = w.array_type(ctx);
return ctx.dst->Construct(wrapped_array_ty, arr_ctor); auto args = utils::Transform(
} call->Arguments(), [&](const tint::sem::Expression* s) {
} return ctx.Clone(s->Declaration());
return nullptr; });
}); auto* arr_ctor = ctx.dst->Construct(array_ty, args);
return ctx.dst->Construct(wrapped_array_ty, arr_ctor);
}
}
}
}
return nullptr;
});
ctx.Clone(); ctx.Clone();
} }

View File

@ -15,34 +15,66 @@
#include "src/writer/append_vector.h" #include "src/writer/append_vector.h"
#include <utility> #include <utility>
#include <vector>
#include "src/sem/call.h"
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/sem/type_constructor.h"
#include "src/sem/type_conversion.h"
#include "src/utils/transform.h"
namespace tint { namespace tint {
namespace writer { namespace writer {
namespace { namespace {
const ast::TypeConstructorExpression* AsVectorConstructor( struct VectorConstructorInfo {
ProgramBuilder* b, const sem::Call* call = nullptr;
const ast::Expression* expr) { const sem::TypeConstructor* ctor = nullptr;
if (auto* constructor = expr->As<ast::TypeConstructorExpression>()) { operator bool() const { return call != nullptr; }
if (b->TypeOf(constructor)->Is<sem::Vector>()) { };
return constructor; VectorConstructorInfo AsVectorConstructor(const sem::Expression* expr) {
if (auto* call = expr->As<sem::Call>()) {
if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
if (ctor->ReturnType()->Is<sem::Vector>()) {
return {call, ctor};
}
} }
} }
return nullptr; return {};
}
const sem::Expression* Zero(ProgramBuilder& b,
const sem::Type* ty,
const sem::Statement* stmt) {
const ast::Expression* expr = nullptr;
if (ty->Is<sem::I32>()) {
expr = b.Expr(0);
} else if (ty->Is<sem::U32>()) {
expr = b.Expr(0u);
} else if (ty->Is<sem::F32>()) {
expr = b.Expr(0.0f);
} else if (ty->Is<sem::Bool>()) {
expr = b.Expr(false);
} else {
TINT_UNREACHABLE(Writer, b.Diagnostics())
<< "unsupported vector element type: " << ty->TypeInfo().name;
return nullptr;
}
auto* sem = b.create<sem::Expression>(expr, ty, stmt, sem::Constant{});
b.Sem().Add(expr, sem);
return sem;
} }
} // namespace } // namespace
const ast::TypeConstructorExpression* AppendVector( const sem::Call* AppendVector(ProgramBuilder* b,
ProgramBuilder* b, const ast::Expression* vector_ast,
const ast::Expression* vector, const ast::Expression* scalar_ast) {
const ast::Expression* scalar) {
uint32_t packed_size; uint32_t packed_size;
const sem::Type* packed_el_sem_ty; const sem::Type* packed_el_sem_ty;
auto* vector_sem = b->Sem().Get(vector); auto* vector_sem = b->Sem().Get(vector_ast);
auto* scalar_sem = b->Sem().Get(scalar_ast);
auto* vector_ty = vector_sem->Type()->UnwrapRef(); auto* vector_ty = vector_sem->Type()->UnwrapRef();
if (auto* vec = vector_ty->As<sem::Vector>()) { if (auto* vec = vector_ty->As<sem::Vector>()) {
packed_size = vec->Width() + 1; packed_size = vec->Width() + 1;
@ -52,15 +84,15 @@ const ast::TypeConstructorExpression* AppendVector(
packed_el_sem_ty = vector_ty; packed_el_sem_ty = vector_ty;
} }
const ast::Type* packed_el_ty = nullptr; const ast::Type* packed_el_ast_ty = nullptr;
if (packed_el_sem_ty->Is<sem::I32>()) { if (packed_el_sem_ty->Is<sem::I32>()) {
packed_el_ty = b->create<ast::I32>(); packed_el_ast_ty = b->create<ast::I32>();
} else if (packed_el_sem_ty->Is<sem::U32>()) { } else if (packed_el_sem_ty->Is<sem::U32>()) {
packed_el_ty = b->create<ast::U32>(); packed_el_ast_ty = b->create<ast::U32>();
} else if (packed_el_sem_ty->Is<sem::F32>()) { } else if (packed_el_sem_ty->Is<sem::F32>()) {
packed_el_ty = b->create<ast::F32>(); packed_el_ast_ty = b->create<ast::F32>();
} else if (packed_el_sem_ty->Is<sem::Bool>()) { } else if (packed_el_sem_ty->Is<sem::Bool>()) {
packed_el_ty = b->create<ast::Bool>(); packed_el_ast_ty = b->create<ast::Bool>();
} else { } else {
TINT_UNREACHABLE(Writer, b->Diagnostics()) TINT_UNREACHABLE(Writer, b->Diagnostics())
<< "unsupported vector element type: " << "unsupported vector element type: "
@ -69,7 +101,7 @@ const ast::TypeConstructorExpression* AppendVector(
auto* statement = vector_sem->Stmt(); auto* statement = vector_sem->Stmt();
auto* packed_ty = b->create<ast::Vector>(packed_el_ty, packed_size); auto* packed_ast_ty = b->create<ast::Vector>(packed_el_ast_ty, packed_size);
auto* packed_sem_ty = b->create<sem::Vector>(packed_el_sem_ty, packed_size); auto* packed_sem_ty = b->create<sem::Vector>(packed_el_sem_ty, packed_size);
// If the coordinates are already passed in a vector constructor, with only // If the coordinates are already passed in a vector constructor, with only
@ -80,61 +112,61 @@ const ast::TypeConstructorExpression* AppendVector(
// The other cases for a nested vector constructor are when it is used // The other cases for a nested vector constructor are when it is used
// to convert a vector of a different type, e.g. vec2<i32>(vec2<u32>()). // to convert a vector of a different type, e.g. vec2<i32>(vec2<u32>()).
// In that case, preserve the original argument, or you'll get a type error. // In that case, preserve the original argument, or you'll get a type error.
ast::ExpressionList packed;
if (auto* vc = AsVectorConstructor(b, vector)) { std::vector<const sem::Expression*> packed;
const auto num_supplied = vc->values.size(); if (auto vc = AsVectorConstructor(vector_sem)) {
const auto num_supplied = vc.call->Arguments().size();
if (num_supplied == 0) { if (num_supplied == 0) {
// Zero-value vector constructor. Populate with zeros // Zero-value vector constructor. Populate with zeros
auto buildZero = [&]() -> const ast::LiteralExpression* {
if (packed_el_sem_ty->Is<sem::I32>()) {
return b->Expr(0);
} else if (packed_el_sem_ty->Is<sem::U32>()) {
return b->Expr(0u);
} else if (packed_el_sem_ty->Is<sem::F32>()) {
return b->Expr(0.0f);
} else if (packed_el_sem_ty->Is<sem::Bool>()) {
return b->Expr(false);
} else {
TINT_UNREACHABLE(Writer, b->Diagnostics())
<< "unsupported vector element type: "
<< packed_el_sem_ty->TypeInfo().name;
}
return nullptr;
};
for (uint32_t i = 0; i < packed_size - 1; i++) { for (uint32_t i = 0; i < packed_size - 1; i++) {
auto* zero = buildZero(); auto* zero = Zero(*b, packed_el_sem_ty, statement);
b->Sem().Add(
zero, b->create<sem::Expression>(zero, packed_el_sem_ty, statement,
sem::Constant{}));
packed.emplace_back(zero); packed.emplace_back(zero);
} }
} else if (num_supplied + 1 == packed_size) { } else if (num_supplied + 1 == packed_size) {
// All vector components were supplied as scalars. Pass them through. // All vector components were supplied as scalars. Pass them through.
packed = vc->values; packed = vc.call->Arguments();
} }
} }
if (packed.empty()) { if (packed.empty()) {
// The special cases didn't occur. Use the vector argument as-is. // The special cases didn't occur. Use the vector argument as-is.
packed.emplace_back(vector); packed.emplace_back(vector_sem);
} }
if (packed_el_sem_ty != b->TypeOf(scalar)->UnwrapRef()) {
if (packed_el_sem_ty != scalar_sem->Type()->UnwrapRef()) {
// Cast scalar to the vector element type // Cast scalar to the vector element type
auto* scalar_cast = b->Construct(packed_el_ty, scalar); auto* scalar_cast_ast = b->Construct(packed_el_ast_ty, scalar_ast);
b->Sem().Add(scalar_cast, auto* scalar_cast_target = b->create<sem::TypeConversion>(
b->create<sem::Expression>(scalar_cast, packed_el_sem_ty, packed_el_sem_ty,
statement, sem::Constant{})); b->create<sem::Parameter>(nullptr, 0, scalar_sem->Type()->UnwrapRef(),
packed.emplace_back(scalar_cast); ast::StorageClass::kNone,
ast::Access::kUndefined));
auto* scalar_cast_sem =
b->create<sem::Call>(scalar_cast_ast, scalar_cast_target,
std::vector<const sem::Expression*>{scalar_sem},
statement, sem::Constant{});
b->Sem().Add(scalar_cast_ast, scalar_cast_sem);
packed.emplace_back(scalar_cast_sem);
} else { } else {
packed.emplace_back(scalar); packed.emplace_back(scalar_sem);
} }
auto* constructor = b->Construct(packed_ty, std::move(packed)); auto* constructor_ast = b->Construct(
b->Sem().Add(constructor, packed_ast_ty, utils::Transform(packed, [&](const sem::Expression* expr) {
b->create<sem::Expression>(constructor, packed_sem_ty, statement, return expr->Declaration();
sem::Constant{})); }));
auto* constructor_target = b->create<sem::TypeConstructor>(
return constructor; packed_sem_ty,
utils::Transform(packed,
[&](const tint::sem::Expression* arg,
size_t i) -> const sem::Parameter* {
return b->create<sem::Parameter>(
nullptr, i, arg->Type()->UnwrapRef(),
ast::StorageClass::kNone, ast::Access::kUndefined);
}));
auto* constructor_sem = b->create<sem::Call>(
constructor_ast, constructor_target, packed, statement, sem::Constant{});
b->Sem().Add(constructor_ast, constructor_sem);
return constructor_sem;
} }
} // namespace writer } // namespace writer

View File

@ -20,8 +20,8 @@
namespace tint { namespace tint {
namespace ast { namespace ast {
class CallExpression;
class Expression; class Expression;
class TypeConstructorExpression;
} // namespace ast } // namespace ast
namespace writer { namespace writer {
@ -36,10 +36,9 @@ namespace writer {
/// @param scalar the scalar to append to the vector. Must be a scalar. /// @param scalar the scalar to append to the vector. Must be a scalar.
/// @returns a vector expression containing the elements of `vector` followed by /// @returns a vector expression containing the elements of `vector` followed by
/// the single element of `scalar` cast to the `vector` element type. /// the single element of `scalar` cast to the `vector` element type.
const ast::TypeConstructorExpression* AppendVector( const sem::Call* AppendVector(ProgramBuilder* builder,
ProgramBuilder* builder, const ast::Expression* vector,
const ast::Expression* vector, const ast::Expression* scalar);
const ast::Expression* scalar);
} // namespace writer } // namespace writer
} // namespace tint } // namespace tint

View File

@ -15,6 +15,7 @@
#include "src/writer/append_vector.h" #include "src/writer/append_vector.h"
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/resolver/resolver.h" #include "src/resolver/resolver.h"
#include "src/sem/type_constructor.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
@ -24,6 +25,7 @@ namespace {
class AppendVectorTest : public ::testing::Test, public ProgramBuilder {}; class AppendVectorTest : public ::testing::Test, public ProgramBuilder {};
// AppendVector(vec2<i32>(1, 2), 3) -> vec3<i32>(1, 2, 3)
TEST_F(AppendVectorTest, Vec2i32_i32) { TEST_F(AppendVectorTest, Vec2i32_i32) {
auto* scalar_1 = Expr(1); auto* scalar_1 = Expr(1);
auto* scalar_2 = Expr(2); auto* scalar_2 = Expr(2);
@ -34,15 +36,36 @@ TEST_F(AppendVectorTest, Vec2i32_i32) {
resolver::Resolver resolver(this); resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error(); ASSERT_TRUE(resolver.Resolve()) << resolver.error();
auto* vec_123 = AppendVector(this, vec_12, scalar_3) auto* append = AppendVector(this, vec_12, scalar_3);
->As<ast::TypeConstructorExpression>();
auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr); ASSERT_NE(vec_123, nullptr);
ASSERT_EQ(vec_123->values.size(), 3u); ASSERT_EQ(vec_123->args.size(), 3u);
EXPECT_EQ(vec_123->values[0], scalar_1); EXPECT_EQ(vec_123->args[0], scalar_1);
EXPECT_EQ(vec_123->values[1], scalar_2); EXPECT_EQ(vec_123->args[1], scalar_2);
EXPECT_EQ(vec_123->values[2], scalar_3); EXPECT_EQ(vec_123->args[2], scalar_3);
auto* call = Sem().Get(vec_123);
ASSERT_NE(call, nullptr);
ASSERT_EQ(call->Arguments().size(), 3u);
EXPECT_EQ(call->Arguments()[0], Sem().Get(scalar_1));
EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_2));
EXPECT_EQ(call->Arguments()[2], Sem().Get(scalar_3));
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(ctor->ReturnType(), call->Type());
ASSERT_EQ(ctor->Parameters().size(), 3u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
} }
// AppendVector(vec2<i32>(1, 2), 3u) -> vec3<i32>(1, 2, i32(3u))
TEST_F(AppendVectorTest, Vec2i32_u32) { TEST_F(AppendVectorTest, Vec2i32_u32) {
auto* scalar_1 = Expr(1); auto* scalar_1 = Expr(1);
auto* scalar_2 = Expr(2); auto* scalar_2 = Expr(2);
@ -53,19 +76,41 @@ TEST_F(AppendVectorTest, Vec2i32_u32) {
resolver::Resolver resolver(this); resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error(); ASSERT_TRUE(resolver.Resolve()) << resolver.error();
auto* vec_123 = AppendVector(this, vec_12, scalar_3) auto* append = AppendVector(this, vec_12, scalar_3);
->As<ast::TypeConstructorExpression>();
auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr); ASSERT_NE(vec_123, nullptr);
ASSERT_EQ(vec_123->values.size(), 3u); ASSERT_EQ(vec_123->args.size(), 3u);
EXPECT_EQ(vec_123->values[0], scalar_1); EXPECT_EQ(vec_123->args[0], scalar_1);
EXPECT_EQ(vec_123->values[1], scalar_2); EXPECT_EQ(vec_123->args[1], scalar_2);
auto* u32_to_i32 = vec_123->values[2]->As<ast::TypeConstructorExpression>(); auto* u32_to_i32 = vec_123->args[2]->As<ast::CallExpression>();
ASSERT_NE(u32_to_i32, nullptr); ASSERT_NE(u32_to_i32, nullptr);
EXPECT_TRUE(u32_to_i32->type->Is<ast::I32>()); EXPECT_TRUE(u32_to_i32->target.type->Is<ast::I32>());
ASSERT_EQ(u32_to_i32->values.size(), 1u); ASSERT_EQ(u32_to_i32->args.size(), 1u);
EXPECT_EQ(u32_to_i32->values[0], scalar_3); EXPECT_EQ(u32_to_i32->args[0], scalar_3);
auto* call = Sem().Get(vec_123);
ASSERT_NE(call, nullptr);
ASSERT_EQ(call->Arguments().size(), 3u);
EXPECT_EQ(call->Arguments()[0], Sem().Get(scalar_1));
EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_2));
EXPECT_EQ(call->Arguments()[2], Sem().Get(u32_to_i32));
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(ctor->ReturnType(), call->Type());
ASSERT_EQ(ctor->Parameters().size(), 3u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
} }
// AppendVector(vec2<i32>(vec2<u32>(1u, 2u)), 3u) ->
// vec3<i32>(vec2<i32>(vec2<u32>(1u, 2u)), i32(3u))
TEST_F(AppendVectorTest, Vec2i32FromVec2u32_u32) { TEST_F(AppendVectorTest, Vec2i32FromVec2u32_u32) {
auto* scalar_1 = Expr(1u); auto* scalar_1 = Expr(1u);
auto* scalar_2 = Expr(2u); auto* scalar_2 = Expr(2u);
@ -77,26 +122,45 @@ TEST_F(AppendVectorTest, Vec2i32FromVec2u32_u32) {
resolver::Resolver resolver(this); resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error(); ASSERT_TRUE(resolver.Resolve()) << resolver.error();
auto* vec_123 = AppendVector(this, vec_12, scalar_3) auto* append = AppendVector(this, vec_12, scalar_3);
->As<ast::TypeConstructorExpression>();
ASSERT_NE(vec_123, nullptr);
ASSERT_EQ(vec_123->values.size(), 2u);
auto* v2u32_to_v2i32 =
vec_123->values[0]->As<ast::TypeConstructorExpression>();
ASSERT_NE(v2u32_to_v2i32, nullptr);
ASSERT_TRUE(v2u32_to_v2i32->type->Is<ast::Vector>());
EXPECT_EQ(v2u32_to_v2i32->type->As<ast::Vector>()->width, 2u);
EXPECT_TRUE(v2u32_to_v2i32->type->As<ast::Vector>()->type->Is<ast::I32>());
EXPECT_EQ(v2u32_to_v2i32->values.size(), 1u);
EXPECT_EQ(v2u32_to_v2i32->values[0], uvec_12);
auto* u32_to_i32 = vec_123->values[1]->As<ast::TypeConstructorExpression>(); auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr);
ASSERT_EQ(vec_123->args.size(), 2u);
auto* v2u32_to_v2i32 = vec_123->args[0]->As<ast::CallExpression>();
ASSERT_NE(v2u32_to_v2i32, nullptr);
ASSERT_TRUE(v2u32_to_v2i32->target.type->Is<ast::Vector>());
EXPECT_EQ(v2u32_to_v2i32->target.type->As<ast::Vector>()->width, 2u);
EXPECT_TRUE(
v2u32_to_v2i32->target.type->As<ast::Vector>()->type->Is<ast::I32>());
EXPECT_EQ(v2u32_to_v2i32->args.size(), 1u);
EXPECT_EQ(v2u32_to_v2i32->args[0], uvec_12);
auto* u32_to_i32 = vec_123->args[1]->As<ast::CallExpression>();
ASSERT_NE(u32_to_i32, nullptr); ASSERT_NE(u32_to_i32, nullptr);
EXPECT_TRUE(u32_to_i32->type->Is<ast::I32>()); EXPECT_TRUE(u32_to_i32->target.type->Is<ast::I32>());
ASSERT_EQ(u32_to_i32->values.size(), 1u); ASSERT_EQ(u32_to_i32->args.size(), 1u);
EXPECT_EQ(u32_to_i32->values[0], scalar_3); EXPECT_EQ(u32_to_i32->args[0], scalar_3);
auto* call = Sem().Get(vec_123);
ASSERT_NE(call, nullptr);
ASSERT_EQ(call->Arguments().size(), 2u);
EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_12));
EXPECT_EQ(call->Arguments()[1], Sem().Get(u32_to_i32));
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(ctor->ReturnType(), call->Type());
ASSERT_EQ(ctor->Parameters().size(), 2u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
} }
// AppendVector(vec2<i32>(1, 2), 3.0f) -> vec3<i32>(1, 2, i32(3.0f))
TEST_F(AppendVectorTest, Vec2i32_f32) { TEST_F(AppendVectorTest, Vec2i32_f32) {
auto* scalar_1 = Expr(1); auto* scalar_1 = Expr(1);
auto* scalar_2 = Expr(2); auto* scalar_2 = Expr(2);
@ -107,40 +171,84 @@ TEST_F(AppendVectorTest, Vec2i32_f32) {
resolver::Resolver resolver(this); resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error(); ASSERT_TRUE(resolver.Resolve()) << resolver.error();
auto* vec_123 = AppendVector(this, vec_12, scalar_3) auto* append = AppendVector(this, vec_12, scalar_3);
->As<ast::TypeConstructorExpression>();
auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr); ASSERT_NE(vec_123, nullptr);
ASSERT_EQ(vec_123->values.size(), 3u); ASSERT_EQ(vec_123->args.size(), 3u);
EXPECT_EQ(vec_123->values[0], scalar_1); EXPECT_EQ(vec_123->args[0], scalar_1);
EXPECT_EQ(vec_123->values[1], scalar_2); EXPECT_EQ(vec_123->args[1], scalar_2);
auto* f32_to_i32 = vec_123->values[2]->As<ast::TypeConstructorExpression>(); auto* f32_to_i32 = vec_123->args[2]->As<ast::CallExpression>();
ASSERT_NE(f32_to_i32, nullptr); ASSERT_NE(f32_to_i32, nullptr);
EXPECT_TRUE(f32_to_i32->type->Is<ast::I32>()); EXPECT_TRUE(f32_to_i32->target.type->Is<ast::I32>());
ASSERT_EQ(f32_to_i32->values.size(), 1u); ASSERT_EQ(f32_to_i32->args.size(), 1u);
EXPECT_EQ(f32_to_i32->values[0], scalar_3); EXPECT_EQ(f32_to_i32->args[0], scalar_3);
auto* call = Sem().Get(vec_123);
ASSERT_NE(call, nullptr);
ASSERT_EQ(call->Arguments().size(), 3u);
EXPECT_EQ(call->Arguments()[0], Sem().Get(scalar_1));
EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_2));
EXPECT_EQ(call->Arguments()[2], Sem().Get(f32_to_i32));
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(ctor->ReturnType(), call->Type());
ASSERT_EQ(ctor->Parameters().size(), 3u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
} }
// AppendVector(vec3<i32>(1, 2, 3), 4) -> vec4<i32>(1, 2, 3, 4)
TEST_F(AppendVectorTest, Vec3i32_i32) { TEST_F(AppendVectorTest, Vec3i32_i32) {
auto* scalar_1 = Expr(1); auto* scalar_1 = Expr(1);
auto* scalar_2 = Expr(2); auto* scalar_2 = Expr(2);
auto* scalar_3 = Expr(3); auto* scalar_3 = Expr(3);
auto* scalar_4 = Expr(3); auto* scalar_4 = Expr(4);
auto* vec_123 = vec3<i32>(scalar_1, scalar_2, scalar_3); auto* vec_123 = vec3<i32>(scalar_1, scalar_2, scalar_3);
WrapInFunction(vec_123, scalar_4); WrapInFunction(vec_123, scalar_4);
resolver::Resolver resolver(this); resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error(); ASSERT_TRUE(resolver.Resolve()) << resolver.error();
auto* vec_1234 = AppendVector(this, vec_123, scalar_4) auto* append = AppendVector(this, vec_123, scalar_4);
->As<ast::TypeConstructorExpression>();
auto* vec_1234 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_1234, nullptr); ASSERT_NE(vec_1234, nullptr);
ASSERT_EQ(vec_1234->values.size(), 4u); ASSERT_EQ(vec_1234->args.size(), 4u);
EXPECT_EQ(vec_1234->values[0], scalar_1); EXPECT_EQ(vec_1234->args[0], scalar_1);
EXPECT_EQ(vec_1234->values[1], scalar_2); EXPECT_EQ(vec_1234->args[1], scalar_2);
EXPECT_EQ(vec_1234->values[2], scalar_3); EXPECT_EQ(vec_1234->args[2], scalar_3);
EXPECT_EQ(vec_1234->values[3], scalar_4); EXPECT_EQ(vec_1234->args[3], scalar_4);
auto* call = Sem().Get(vec_1234);
ASSERT_NE(call, nullptr);
ASSERT_EQ(call->Arguments().size(), 4u);
EXPECT_EQ(call->Arguments()[0], Sem().Get(scalar_1));
EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_2));
EXPECT_EQ(call->Arguments()[2], Sem().Get(scalar_3));
EXPECT_EQ(call->Arguments()[3], Sem().Get(scalar_4));
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 4u);
EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(ctor->ReturnType(), call->Type());
ASSERT_EQ(ctor->Parameters().size(), 4u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[3]->Type()->Is<sem::I32>());
} }
// AppendVector(vec_12, 3) -> vec3<i32>(vec_12, 3)
TEST_F(AppendVectorTest, Vec2i32Var_i32) { TEST_F(AppendVectorTest, Vec2i32Var_i32) {
Global("vec_12", ty.vec2<i32>(), ast::StorageClass::kPrivate); Global("vec_12", ty.vec2<i32>(), ast::StorageClass::kPrivate);
auto* vec_12 = Expr("vec_12"); auto* vec_12 = Expr("vec_12");
@ -150,14 +258,33 @@ TEST_F(AppendVectorTest, Vec2i32Var_i32) {
resolver::Resolver resolver(this); resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error(); ASSERT_TRUE(resolver.Resolve()) << resolver.error();
auto* vec_123 = AppendVector(this, vec_12, scalar_3) auto* append = AppendVector(this, vec_12, scalar_3);
->As<ast::TypeConstructorExpression>();
auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr); ASSERT_NE(vec_123, nullptr);
ASSERT_EQ(vec_123->values.size(), 2u); ASSERT_EQ(vec_123->args.size(), 2u);
EXPECT_EQ(vec_123->values[0], vec_12); EXPECT_EQ(vec_123->args[0], vec_12);
EXPECT_EQ(vec_123->values[1], scalar_3); EXPECT_EQ(vec_123->args[1], scalar_3);
auto* call = Sem().Get(vec_123);
ASSERT_NE(call, nullptr);
ASSERT_EQ(call->Arguments().size(), 2u);
EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_12));
EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_3));
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(ctor->ReturnType(), call->Type());
ASSERT_EQ(ctor->Parameters().size(), 2u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
} }
// AppendVector(1, 2, scalar_3) -> vec3<i32>(1, 2, scalar_3)
TEST_F(AppendVectorTest, Vec2i32_i32Var) { TEST_F(AppendVectorTest, Vec2i32_i32Var) {
Global("scalar_3", ty.i32(), ast::StorageClass::kPrivate); Global("scalar_3", ty.i32(), ast::StorageClass::kPrivate);
auto* scalar_1 = Expr(1); auto* scalar_1 = Expr(1);
@ -169,15 +296,36 @@ TEST_F(AppendVectorTest, Vec2i32_i32Var) {
resolver::Resolver resolver(this); resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error(); ASSERT_TRUE(resolver.Resolve()) << resolver.error();
auto* vec_123 = AppendVector(this, vec_12, scalar_3) auto* append = AppendVector(this, vec_12, scalar_3);
->As<ast::TypeConstructorExpression>();
auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr); ASSERT_NE(vec_123, nullptr);
ASSERT_EQ(vec_123->values.size(), 3u); ASSERT_EQ(vec_123->args.size(), 3u);
EXPECT_EQ(vec_123->values[0], scalar_1); EXPECT_EQ(vec_123->args[0], scalar_1);
EXPECT_EQ(vec_123->values[1], scalar_2); EXPECT_EQ(vec_123->args[1], scalar_2);
EXPECT_EQ(vec_123->values[2], scalar_3); EXPECT_EQ(vec_123->args[2], scalar_3);
auto* call = Sem().Get(vec_123);
ASSERT_NE(call, nullptr);
ASSERT_EQ(call->Arguments().size(), 3u);
EXPECT_EQ(call->Arguments()[0], Sem().Get(scalar_1));
EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_2));
EXPECT_EQ(call->Arguments()[2], Sem().Get(scalar_3));
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(ctor->ReturnType(), call->Type());
ASSERT_EQ(ctor->Parameters().size(), 3u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
} }
// AppendVector(vec_12, scalar_3) -> vec3<i32>(vec_12, scalar_3)
TEST_F(AppendVectorTest, Vec2i32Var_i32Var) { TEST_F(AppendVectorTest, Vec2i32Var_i32Var) {
Global("vec_12", ty.vec2<i32>(), ast::StorageClass::kPrivate); Global("vec_12", ty.vec2<i32>(), ast::StorageClass::kPrivate);
Global("scalar_3", ty.i32(), ast::StorageClass::kPrivate); Global("scalar_3", ty.i32(), ast::StorageClass::kPrivate);
@ -188,14 +336,33 @@ TEST_F(AppendVectorTest, Vec2i32Var_i32Var) {
resolver::Resolver resolver(this); resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error(); ASSERT_TRUE(resolver.Resolve()) << resolver.error();
auto* vec_123 = AppendVector(this, vec_12, scalar_3) auto* append = AppendVector(this, vec_12, scalar_3);
->As<ast::TypeConstructorExpression>();
auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr); ASSERT_NE(vec_123, nullptr);
ASSERT_EQ(vec_123->values.size(), 2u); ASSERT_EQ(vec_123->args.size(), 2u);
EXPECT_EQ(vec_123->values[0], vec_12); EXPECT_EQ(vec_123->args[0], vec_12);
EXPECT_EQ(vec_123->values[1], scalar_3); EXPECT_EQ(vec_123->args[1], scalar_3);
auto* call = Sem().Get(vec_123);
ASSERT_NE(call, nullptr);
ASSERT_EQ(call->Arguments().size(), 2u);
EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_12));
EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_3));
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(ctor->ReturnType(), call->Type());
ASSERT_EQ(ctor->Parameters().size(), 2u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
} }
// AppendVector(vec_12, scalar_3) -> vec3<i32>(vec_12, i32(scalar_3))
TEST_F(AppendVectorTest, Vec2i32Var_f32Var) { TEST_F(AppendVectorTest, Vec2i32Var_f32Var) {
Global("vec_12", ty.vec2<i32>(), ast::StorageClass::kPrivate); Global("vec_12", ty.vec2<i32>(), ast::StorageClass::kPrivate);
Global("scalar_3", ty.f32(), ast::StorageClass::kPrivate); Global("scalar_3", ty.f32(), ast::StorageClass::kPrivate);
@ -206,18 +373,37 @@ TEST_F(AppendVectorTest, Vec2i32Var_f32Var) {
resolver::Resolver resolver(this); resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error(); ASSERT_TRUE(resolver.Resolve()) << resolver.error();
auto* vec_123 = AppendVector(this, vec_12, scalar_3) auto* append = AppendVector(this, vec_12, scalar_3);
->As<ast::TypeConstructorExpression>();
auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr); ASSERT_NE(vec_123, nullptr);
ASSERT_EQ(vec_123->values.size(), 2u); ASSERT_EQ(vec_123->args.size(), 2u);
EXPECT_EQ(vec_123->values[0], vec_12); EXPECT_EQ(vec_123->args[0], vec_12);
auto* f32_to_i32 = vec_123->values[1]->As<ast::TypeConstructorExpression>(); auto* f32_to_i32 = vec_123->args[1]->As<ast::CallExpression>();
ASSERT_NE(f32_to_i32, nullptr); ASSERT_NE(f32_to_i32, nullptr);
EXPECT_TRUE(f32_to_i32->type->Is<ast::I32>()); EXPECT_TRUE(f32_to_i32->target.type->Is<ast::I32>());
ASSERT_EQ(f32_to_i32->values.size(), 1u); ASSERT_EQ(f32_to_i32->args.size(), 1u);
EXPECT_EQ(f32_to_i32->values[0], scalar_3); EXPECT_EQ(f32_to_i32->args[0], scalar_3);
auto* call = Sem().Get(vec_123);
ASSERT_NE(call, nullptr);
ASSERT_EQ(call->Arguments().size(), 2u);
EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_12));
EXPECT_EQ(call->Arguments()[1], Sem().Get(f32_to_i32));
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(ctor->ReturnType(), call->Type());
ASSERT_EQ(ctor->Parameters().size(), 2u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
} }
// AppendVector(vec_12, scalar_3) -> vec3<bool>(vec_12, scalar_3)
TEST_F(AppendVectorTest, Vec2boolVar_boolVar) { TEST_F(AppendVectorTest, Vec2boolVar_boolVar) {
Global("vec_12", ty.vec2<bool>(), ast::StorageClass::kPrivate); Global("vec_12", ty.vec2<bool>(), ast::StorageClass::kPrivate);
Global("scalar_3", ty.bool_(), ast::StorageClass::kPrivate); Global("scalar_3", ty.bool_(), ast::StorageClass::kPrivate);
@ -228,14 +414,33 @@ TEST_F(AppendVectorTest, Vec2boolVar_boolVar) {
resolver::Resolver resolver(this); resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error(); ASSERT_TRUE(resolver.Resolve()) << resolver.error();
auto* vec_123 = AppendVector(this, vec_12, scalar_3) auto* append = AppendVector(this, vec_12, scalar_3);
->As<ast::TypeConstructorExpression>();
auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr); ASSERT_NE(vec_123, nullptr);
ASSERT_EQ(vec_123->values.size(), 2u); ASSERT_EQ(vec_123->args.size(), 2u);
EXPECT_EQ(vec_123->values[0], vec_12); EXPECT_EQ(vec_123->args[0], vec_12);
EXPECT_EQ(vec_123->values[1], scalar_3); EXPECT_EQ(vec_123->args[1], scalar_3);
auto* call = Sem().Get(vec_123);
ASSERT_NE(call, nullptr);
ASSERT_EQ(call->Arguments().size(), 2u);
EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_12));
EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_3));
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::Bool>());
EXPECT_EQ(ctor->ReturnType(), call->Type());
ASSERT_EQ(ctor->Parameters().size(), 2u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::Bool>());
} }
// AppendVector(vec3<i32>(), 4) -> vec3<bool>(0, 0, 0, 4)
TEST_F(AppendVectorTest, ZeroVec3i32_i32) { TEST_F(AppendVectorTest, ZeroVec3i32_i32) {
auto* scalar = Expr(4); auto* scalar = Expr(4);
auto* vec000 = vec3<i32>(); auto* vec000 = vec3<i32>();
@ -244,16 +449,38 @@ TEST_F(AppendVectorTest, ZeroVec3i32_i32) {
resolver::Resolver resolver(this); resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error(); ASSERT_TRUE(resolver.Resolve()) << resolver.error();
auto* vec_0004 = auto* append = AppendVector(this, vec000, scalar);
AppendVector(this, vec000, scalar)->As<ast::TypeConstructorExpression>();
auto* vec_0004 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_0004, nullptr); ASSERT_NE(vec_0004, nullptr);
ASSERT_EQ(vec_0004->values.size(), 4u); ASSERT_EQ(vec_0004->args.size(), 4u);
for (size_t i = 0; i < 3; i++) { for (size_t i = 0; i < 3; i++) {
auto* literal = As<ast::SintLiteralExpression>(vec_0004->values[i]); auto* literal = As<ast::SintLiteralExpression>(vec_0004->args[i]);
ASSERT_NE(literal, nullptr); ASSERT_NE(literal, nullptr);
EXPECT_EQ(literal->value, 0); EXPECT_EQ(literal->value, 0);
} }
EXPECT_EQ(vec_0004->values[3], scalar); EXPECT_EQ(vec_0004->args[3], scalar);
auto* call = Sem().Get(vec_0004);
ASSERT_NE(call, nullptr);
ASSERT_EQ(call->Arguments().size(), 4u);
EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_0004->args[0]));
EXPECT_EQ(call->Arguments()[1], Sem().Get(vec_0004->args[1]));
EXPECT_EQ(call->Arguments()[2], Sem().Get(vec_0004->args[2]));
EXPECT_EQ(call->Arguments()[3], Sem().Get(scalar));
auto* ctor = call->Target()->As<sem::TypeConstructor>();
ASSERT_NE(ctor, nullptr);
ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 4u);
EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(ctor->ReturnType(), call->Type());
ASSERT_EQ(ctor->Parameters().size(), 4u);
EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
EXPECT_TRUE(ctor->Parameters()[3]->Type()->Is<sem::I32>());
} }
} // namespace } // namespace

View File

@ -41,6 +41,8 @@
#include "src/sem/statement.h" #include "src/sem/statement.h"
#include "src/sem/storage_texture_type.h" #include "src/sem/storage_texture_type.h"
#include "src/sem/struct.h" #include "src/sem/struct.h"
#include "src/sem/type_constructor.h"
#include "src/sem/type_conversion.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/transform/calculate_array_length.h" #include "src/transform/calculate_array_length.h"
#include "src/transform/glsl.h" #include "src/transform/glsl.h"
@ -358,85 +360,49 @@ bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
bool GeneratorImpl::EmitCall(std::ostream& out, bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) { const ast::CallExpression* expr) {
const auto& args = expr->args;
auto* ident = expr->func;
auto* call = builder_.Sem().Get(expr); auto* call = builder_.Sem().Get(expr);
auto* target = call->Target(); auto* target = call->Target();
if (auto* func = target->As<sem::Function>()) { if (auto* func = target->As<sem::Function>()) {
if (ast::HasDecoration< return EmitFunctionCall(out, call, func);
transform::CalculateArrayLength::BufferSizeIntrinsic>(
func->Declaration()->decorations)) {
// Special function generated by the CalculateArrayLength transform for
// calling X.GetDimensions(Y)
if (!EmitExpression(out, args[0])) {
return false;
}
out << ".GetDimensions(";
if (!EmitExpression(out, args[1])) {
return false;
}
out << ")";
return true;
}
} }
if (auto* intrinsic = target->As<sem::Intrinsic>()) {
if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) { return EmitIntrinsicCall(out, call, intrinsic);
if (intrinsic->IsTexture()) {
return EmitTextureCall(out, expr, intrinsic);
} else if (intrinsic->Type() == sem::IntrinsicType::kSelect) {
return EmitSelectCall(out, expr);
} else if (intrinsic->Type() == sem::IntrinsicType::kDot) {
return EmitDotCall(out, expr, intrinsic);
} else if (intrinsic->Type() == sem::IntrinsicType::kModf) {
return EmitModfCall(out, expr, intrinsic);
} else if (intrinsic->Type() == sem::IntrinsicType::kFrexp) {
return EmitFrexpCall(out, expr, intrinsic);
} else if (intrinsic->Type() == sem::IntrinsicType::kIsNormal) {
return EmitIsNormalCall(out, expr, intrinsic);
} else if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {
return EmitExpression(out, expr->args[0]); // [DEPRECATED]
} else if (intrinsic->IsDataPacking()) {
return EmitDataPackingCall(out, expr, intrinsic);
} else if (intrinsic->IsDataUnpacking()) {
return EmitDataUnpackingCall(out, expr, intrinsic);
} else if (intrinsic->IsBarrier()) {
return EmitBarrierCall(out, intrinsic);
} else if (intrinsic->IsAtomic()) {
return EmitWorkgroupAtomicCall(out, expr, intrinsic);
}
auto name = generate_builtin_name(intrinsic);
if (name.empty()) {
return false;
}
out << name << "(";
bool first = true;
for (auto* arg : args) {
if (!first) {
out << ", ";
}
first = false;
if (!EmitExpression(out, arg)) {
return false;
}
}
out << ")";
return true;
} }
if (auto* cast = target->As<sem::TypeConversion>()) {
return EmitTypeConversion(out, call, cast);
}
if (auto* ctor = target->As<sem::TypeConstructor>()) {
return EmitTypeConstructor(out, call, ctor);
}
TINT_ICE(Writer, diagnostics_)
<< "unhandled call target: " << target->TypeInfo().name;
return false;
}
bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
const sem::Call* call,
const sem::Function* func) {
const auto& args = call->Arguments();
auto* decl = call->Declaration();
auto* ident = decl->target.name;
auto name = builder_.Symbols().NameFor(ident->symbol); auto name = builder_.Symbols().NameFor(ident->symbol);
auto caller_sym = ident->symbol; auto caller_sym = ident->symbol;
auto* func = builder_.AST().Functions().Find(ident->symbol); if (ast::HasDecoration<transform::CalculateArrayLength::BufferSizeIntrinsic>(
if (func == nullptr) { func->Declaration()->decorations)) {
diagnostics_.add_error(diag::System::Writer, // Special function generated by the CalculateArrayLength transform for
"Unable to find function: " + // calling X.GetDimensions(Y)
builder_.Symbols().NameFor(ident->symbol)); if (!EmitExpression(out, args[0]->Declaration())) {
return false; return false;
}
out << ".GetDimensions(";
if (!EmitExpression(out, args[1]->Declaration())) {
return false;
}
out << ")";
return true;
} }
out << name << "("; out << name << "(";
@ -448,13 +414,141 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
} }
first = false; first = false;
if (!EmitExpression(out, arg)) { if (!EmitExpression(out, arg->Declaration())) {
return false; return false;
} }
} }
out << ")"; out << ")";
return true;
}
bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
const sem::Call* call,
const sem::Intrinsic* intrinsic) {
auto* expr = call->Declaration();
if (intrinsic->IsTexture()) {
return EmitTextureCall(out, expr, intrinsic);
}
if (intrinsic->Type() == sem::IntrinsicType::kSelect) {
return EmitSelectCall(out, expr);
}
if (intrinsic->Type() == sem::IntrinsicType::kDot) {
return EmitDotCall(out, expr, intrinsic);
}
if (intrinsic->Type() == sem::IntrinsicType::kModf) {
return EmitModfCall(out, expr, intrinsic);
}
if (intrinsic->Type() == sem::IntrinsicType::kFrexp) {
return EmitFrexpCall(out, expr, intrinsic);
}
if (intrinsic->Type() == sem::IntrinsicType::kIsNormal) {
return EmitIsNormalCall(out, expr, intrinsic);
}
if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {
return EmitExpression(out, expr->args[0]); // [DEPRECATED]
}
if (intrinsic->IsDataPacking()) {
return EmitDataPackingCall(out, expr, intrinsic);
}
if (intrinsic->IsDataUnpacking()) {
return EmitDataUnpackingCall(out, expr, intrinsic);
}
if (intrinsic->IsBarrier()) {
return EmitBarrierCall(out, intrinsic);
}
if (intrinsic->IsAtomic()) {
return EmitWorkgroupAtomicCall(out, expr, intrinsic);
}
auto name = generate_builtin_name(intrinsic);
if (name.empty()) {
return false;
}
out << name << "(";
bool first = true;
for (auto* arg : call->Arguments()) {
if (!first) {
out << ", ";
}
first = false;
if (!EmitExpression(out, arg->Declaration())) {
return false;
}
}
out << ")";
return true;
}
bool GeneratorImpl::EmitTypeConversion(std::ostream& out,
const sem::Call* call,
const sem::TypeConversion* conv) {
if (!EmitType(out, conv->Target(), ast::StorageClass::kNone,
ast::Access::kReadWrite, "")) {
return false;
}
out << "(";
if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
return false;
}
out << ")";
return true;
}
bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
const sem::Call* call,
const sem::TypeConstructor* ctor) {
auto* type = ctor->ReturnType();
// If the type constructor is empty then we need to construct with the zero
// value for all components.
if (call->Arguments().empty()) {
return EmitZeroValue(out, type);
}
// For single-value vector initializers, swizzle the scalar to the right
// vector dimension using .x
const bool is_single_value_vector_init =
type->is_scalar_vector() && call->Arguments().size() == 1 &&
call->Arguments()[0]->Type()->UnwrapRef()->is_scalar();
auto it = structure_builders_.find(As<sem::Struct>(type));
if (it != structure_builders_.end()) {
out << it->second << "(";
} else {
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
"")) {
return false;
}
out << "(";
}
if (is_single_value_vector_init) {
out << "(";
}
bool first = true;
for (auto* arg : call->Arguments()) {
if (!first) {
out << ", ";
}
first = false;
if (!EmitExpression(out, arg->Declaration())) {
return false;
}
}
if (is_single_value_vector_init) {
out << ")." << std::string(type->As<sem::Vector>()->Width(), 'x');
}
out << ")";
return true; return true;
} }
@ -1148,13 +1242,13 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
builder_.Sem().Add(zero, builder_.create<sem::Expression>(zero, i32, stmt, builder_.Sem().Add(zero, builder_.create<sem::Expression>(zero, i32, stmt,
sem::Constant{})); sem::Constant{}));
auto* packed = AppendVector(&builder_, vector, zero); auto* packed = AppendVector(&builder_, vector, zero);
return EmitExpression(out, packed); return EmitExpression(out, packed->Declaration());
}; };
auto emit_vector_appended_with_level = [&](const ast::Expression* vector) { auto emit_vector_appended_with_level = [&](const ast::Expression* vector) {
if (auto* level = arg(Usage::kLevel)) { if (auto* level = arg(Usage::kLevel)) {
auto* packed = AppendVector(&builder_, vector, level); auto* packed = AppendVector(&builder_, vector, level);
return EmitExpression(out, packed); return EmitExpression(out, packed->Declaration());
} }
return emit_vector_appended_with_i32_zero(vector); return emit_vector_appended_with_i32_zero(vector);
}; };
@ -1164,11 +1258,11 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
auto* packed = AppendVector(&builder_, param_coords, array_index); auto* packed = AppendVector(&builder_, param_coords, array_index);
if (pack_level_in_coords) { if (pack_level_in_coords) {
// Then mip level needs to be appended to the coordinates. // Then mip level needs to be appended to the coordinates.
if (!emit_vector_appended_with_level(packed)) { if (!emit_vector_appended_with_level(packed->Declaration())) {
return false; return false;
} }
} else { } else {
if (!EmitExpression(out, packed)) { if (!EmitExpression(out, packed->Declaration())) {
return false; return false;
} }
} }
@ -1347,58 +1441,6 @@ bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
return true; return true;
} }
bool GeneratorImpl::EmitTypeConstructor(
std::ostream& out,
const ast::TypeConstructorExpression* expr) {
auto* type = TypeOf(expr)->UnwrapRef();
// If the type constructor is empty then we need to construct with the zero
// value for all components.
if (expr->values.empty()) {
return EmitZeroValue(out, type);
}
// For single-value vector initializers, swizzle the scalar to the right
// vector dimension using .x
const bool is_single_value_vector_init =
type->is_scalar_vector() && expr->values.size() == 1 &&
TypeOf(expr->values[0])->UnwrapRef()->is_scalar();
auto it = structure_builders_.find(As<sem::Struct>(type));
if (it != structure_builders_.end()) {
out << it->second << "(";
} else {
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
"")) {
return false;
}
out << "(";
}
if (is_single_value_vector_init) {
out << "(";
}
bool first = true;
for (auto* e : expr->values) {
if (!first) {
out << ", ";
}
first = false;
if (!EmitExpression(out, e)) {
return false;
}
}
if (is_single_value_vector_init) {
out << ")." << std::string(type->As<sem::Vector>()->Width(), 'x');
}
out << ")";
return true;
}
bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) { bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
if (!emit_continuing_()) { if (!emit_continuing_()) {
return false; return false;
@ -1428,9 +1470,6 @@ bool GeneratorImpl::EmitExpression(std::ostream& out,
if (auto* c = expr->As<ast::CallExpression>()) { if (auto* c = expr->As<ast::CallExpression>()) {
return EmitCall(out, c); return EmitCall(out, c);
} }
if (auto* c = expr->As<ast::TypeConstructorExpression>()) {
return EmitTypeConstructor(out, c);
}
if (auto* i = expr->As<ast::IdentifierExpression>()) { if (auto* i = expr->As<ast::IdentifierExpression>()) {
return EmitIdentifier(out, i); return EmitIdentifier(out, i);
} }

View File

@ -43,6 +43,8 @@ namespace tint {
namespace sem { namespace sem {
class Call; class Call;
class Intrinsic; class Intrinsic;
class TypeConstructor;
class TypeConversion;
} // namespace sem } // namespace sem
namespace writer { namespace writer {
@ -100,6 +102,38 @@ class GeneratorImpl : public TextGenerator {
/// @param expr the call expression /// @param expr the call expression
/// @returns true if the call expression is emitted /// @returns true if the call expression is emitted
bool EmitCall(std::ostream& out, const ast::CallExpression* expr); bool EmitCall(std::ostream& out, const ast::CallExpression* expr);
/// Handles generating a function call expression
/// @param out the output of the expression stream
/// @param call the call expression
/// @param function the function being called
/// @returns true if the expression is emitted
bool EmitFunctionCall(std::ostream& out,
const sem::Call* call,
const sem::Function* function);
/// Handles generating an intrinsic call expression
/// @param out the output of the expression stream
/// @param call the call expression
/// @param intrinsic the intrinsic being called
/// @returns true if the expression is emitted
bool EmitIntrinsicCall(std::ostream& out,
const sem::Call* call,
const sem::Intrinsic* intrinsic);
/// Handles generating a type conversion expression
/// @param out the output of the expression stream
/// @param call the call expression
/// @param conv the type conversion
/// @returns true if the expression is emitted
bool EmitTypeConversion(std::ostream& out,
const sem::Call* call,
const sem::TypeConversion* conv);
/// Handles generating a type constructor expression
/// @param out the output of the expression stream
/// @param call the call expression
/// @param ctor the type constructor
/// @returns true if the expression is emitted
bool EmitTypeConstructor(std::ostream& out,
const sem::Call* call,
const sem::TypeConstructor* ctor);
/// Handles generating a barrier intrinsic call /// Handles generating a barrier intrinsic call
/// @param out the output of the expression stream /// @param out the output of the expression stream
/// @param intrinsic the semantic information for the barrier intrinsic /// @param intrinsic the semantic information for the barrier intrinsic
@ -192,12 +226,6 @@ class GeneratorImpl : public TextGenerator {
/// @param stmt the discard statement /// @param stmt the discard statement
/// @returns true if the statement was successfully emitted /// @returns true if the statement was successfully emitted
bool EmitDiscard(const ast::DiscardStatement* stmt); bool EmitDiscard(const ast::DiscardStatement* stmt);
/// Handles emitting a type constructor
/// @param out the output of the expression stream
/// @param expr the type constructor expression
/// @returns true if the constructor is emitted
bool EmitTypeConstructor(std::ostream& out,
const ast::TypeConstructorExpression* expr);
/// Handles a continue statement /// Handles a continue statement
/// @param stmt the statement to emit /// @param stmt the statement to emit
/// @returns true if the statement was emitted successfully /// @returns true if the statement was emitted successfully

View File

@ -41,6 +41,8 @@
#include "src/sem/statement.h" #include "src/sem/statement.h"
#include "src/sem/storage_texture_type.h" #include "src/sem/storage_texture_type.h"
#include "src/sem/struct.h" #include "src/sem/struct.h"
#include "src/sem/type_constructor.h"
#include "src/sem/type_conversion.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/transform/add_empty_entry_point.h" #include "src/transform/add_empty_entry_point.h"
#include "src/transform/calculate_array_length.h" #include "src/transform/calculate_array_length.h"
@ -499,7 +501,7 @@ bool GeneratorImpl::EmitBinary(std::ostream& out,
case ast::BinaryOp::kDivide: case ast::BinaryOp::kDivide:
out << "/"; out << "/";
if (auto val = program_->Sem().Get(expr->rhs)->ConstantValue()) { if (auto val = builder_.Sem().Get(expr->rhs)->ConstantValue()) {
// Integer divide by zero is a DXC compile error, and undefined behavior // Integer divide by zero is a DXC compile error, and undefined behavior
// in WGSL. Replace the 0 with 1. // in WGSL. Replace the 0 with 1.
if (val.Type()->Is<sem::I32>() && val.Elements()[0].i32 == 0) { if (val.Type()->Is<sem::I32>() && val.Elements()[0].i32 == 0) {
@ -559,117 +561,209 @@ bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
bool GeneratorImpl::EmitCall(std::ostream& out, bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) { const ast::CallExpression* expr) {
const auto& args = expr->args;
auto* ident = expr->func;
auto* call = builder_.Sem().Get(expr); auto* call = builder_.Sem().Get(expr);
auto* target = call->Target(); auto* target = call->Target();
if (auto* func = target->As<sem::Function>()) { if (auto* func = target->As<sem::Function>()) {
if (ast::HasDecoration< return EmitFunctionCall(out, call, func);
transform::CalculateArrayLength::BufferSizeIntrinsic>(
func->Declaration()->decorations)) {
// Special function generated by the CalculateArrayLength transform for
// calling X.GetDimensions(Y)
if (!EmitExpression(out, args[0])) {
return false;
}
out << ".GetDimensions(";
if (!EmitExpression(out, args[1])) {
return false;
}
out << ")";
return true;
}
if (auto* intrinsic =
ast::GetDecoration<transform::DecomposeMemoryAccess::Intrinsic>(
func->Declaration()->decorations)) {
switch (intrinsic->storage_class) {
case ast::StorageClass::kUniform:
return EmitUniformBufferAccess(out, expr, intrinsic);
case ast::StorageClass::kStorage:
return EmitStorageBufferAccess(out, expr, intrinsic);
default:
TINT_UNREACHABLE(Writer, diagnostics_)
<< "unsupported DecomposeMemoryAccess::Intrinsic storage class:"
<< intrinsic->storage_class;
return false;
}
}
} }
if (auto* intrinsic = target->As<sem::Intrinsic>()) {
return EmitIntrinsicCall(out, call, intrinsic);
}
if (auto* conv = target->As<sem::TypeConversion>()) {
return EmitTypeConversion(out, call, conv);
}
if (auto* ctor = target->As<sem::TypeConstructor>()) {
return EmitTypeConstructor(out, call, ctor);
}
TINT_ICE(Writer, diagnostics_)
<< "unhandled call target: " << target->TypeInfo().name;
return false;
}
if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) { bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
if (intrinsic->IsTexture()) { const sem::Call* call,
return EmitTextureCall(out, expr, intrinsic); const sem::Function* func) {
} else if (intrinsic->Type() == sem::IntrinsicType::kSelect) { auto* expr = call->Declaration();
return EmitSelectCall(out, expr);
} else if (intrinsic->Type() == sem::IntrinsicType::kModf) { if (ast::HasDecoration<transform::CalculateArrayLength::BufferSizeIntrinsic>(
return EmitModfCall(out, expr, intrinsic); func->Declaration()->decorations)) {
} else if (intrinsic->Type() == sem::IntrinsicType::kFrexp) { // Special function generated by the CalculateArrayLength transform for
return EmitFrexpCall(out, expr, intrinsic); // calling X.GetDimensions(Y)
} else if (intrinsic->Type() == sem::IntrinsicType::kIsNormal) { if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
return EmitIsNormalCall(out, expr, intrinsic);
} else if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {
return EmitExpression(out, expr->args[0]); // [DEPRECATED]
} else if (intrinsic->IsDataPacking()) {
return EmitDataPackingCall(out, expr, intrinsic);
} else if (intrinsic->IsDataUnpacking()) {
return EmitDataUnpackingCall(out, expr, intrinsic);
} else if (intrinsic->IsBarrier()) {
return EmitBarrierCall(out, intrinsic);
} else if (intrinsic->IsAtomic()) {
return EmitWorkgroupAtomicCall(out, expr, intrinsic);
}
auto name = generate_builtin_name(intrinsic);
if (name.empty()) {
return false; return false;
} }
out << ".GetDimensions(";
out << name << "("; if (!EmitExpression(out, call->Arguments()[1]->Declaration())) {
return false;
bool first = true;
for (auto* arg : args) {
if (!first) {
out << ", ";
}
first = false;
if (!EmitExpression(out, arg)) {
return false;
}
} }
out << ")"; out << ")";
return true; return true;
} }
auto name = builder_.Symbols().NameFor(ident->symbol); if (auto* intrinsic =
auto caller_sym = ident->symbol; ast::GetDecoration<transform::DecomposeMemoryAccess::Intrinsic>(
func->Declaration()->decorations)) {
switch (intrinsic->storage_class) {
case ast::StorageClass::kUniform:
return EmitUniformBufferAccess(out, expr, intrinsic);
case ast::StorageClass::kStorage:
return EmitStorageBufferAccess(out, expr, intrinsic);
default:
TINT_UNREACHABLE(Writer, diagnostics_)
<< "unsupported DecomposeMemoryAccess::Intrinsic storage class:"
<< intrinsic->storage_class;
return false;
}
}
auto* func = builder_.AST().Functions().Find(ident->symbol); out << builder_.Symbols().NameFor(func->Declaration()->symbol) << "(";
if (func == nullptr) {
diagnostics_.add_error(diag::System::Writer, bool first = true;
"Unable to find function: " + for (auto* arg : call->Arguments()) {
builder_.Symbols().NameFor(ident->symbol)); if (!first) {
out << ", ";
}
first = false;
if (!EmitExpression(out, arg->Declaration())) {
return false;
}
}
out << ")";
return true;
}
bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
const sem::Call* call,
const sem::Intrinsic* intrinsic) {
auto* expr = call->Declaration();
if (intrinsic->IsTexture()) {
return EmitTextureCall(out, expr, intrinsic);
}
if (intrinsic->Type() == sem::IntrinsicType::kSelect) {
return EmitSelectCall(out, expr);
}
if (intrinsic->Type() == sem::IntrinsicType::kModf) {
return EmitModfCall(out, expr, intrinsic);
}
if (intrinsic->Type() == sem::IntrinsicType::kFrexp) {
return EmitFrexpCall(out, expr, intrinsic);
}
if (intrinsic->Type() == sem::IntrinsicType::kIsNormal) {
return EmitIsNormalCall(out, expr, intrinsic);
}
if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {
return EmitExpression(out, expr->args[0]); // [DEPRECATED]
}
if (intrinsic->IsDataPacking()) {
return EmitDataPackingCall(out, expr, intrinsic);
}
if (intrinsic->IsDataUnpacking()) {
return EmitDataUnpackingCall(out, expr, intrinsic);
}
if (intrinsic->IsBarrier()) {
return EmitBarrierCall(out, intrinsic);
}
if (intrinsic->IsAtomic()) {
return EmitWorkgroupAtomicCall(out, expr, intrinsic);
}
auto name = generate_builtin_name(intrinsic);
if (name.empty()) {
return false; return false;
} }
out << name << "("; out << name << "(";
bool first = true; bool first = true;
for (auto* arg : args) { for (auto* arg : call->Arguments()) {
if (!first) { if (!first) {
out << ", "; out << ", ";
} }
first = false; first = false;
if (!EmitExpression(out, arg)) { if (!EmitExpression(out, arg->Declaration())) {
return false; return false;
} }
} }
out << ")"; out << ")";
return true;
}
bool GeneratorImpl::EmitTypeConversion(std::ostream& out,
const sem::Call* call,
const sem::TypeConversion* conv) {
if (!EmitType(out, conv->Target(), ast::StorageClass::kNone,
ast::Access::kReadWrite, "")) {
return false;
}
out << "(";
if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
return false;
}
out << ")";
return true;
}
bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
const sem::Call* call,
const sem::TypeConstructor* ctor) {
auto* type = call->Type();
// If the type constructor is empty then we need to construct with the zero
// value for all components.
if (call->Arguments().empty()) {
return EmitZeroValue(out, type);
}
bool brackets = type->IsAnyOf<sem::Array, sem::Struct>();
// For single-value vector initializers, swizzle the scalar to the right
// vector dimension using .x
const bool is_single_value_vector_init =
type->is_scalar_vector() && call->Arguments().size() == 1 &&
ctor->Parameters()[0]->Type()->is_scalar();
auto it = structure_builders_.find(As<sem::Struct>(type));
if (it != structure_builders_.end()) {
out << it->second << "(";
brackets = false;
} else if (brackets) {
out << "{";
} else {
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
"")) {
return false;
}
out << "(";
}
if (is_single_value_vector_init) {
out << "(";
}
bool first = true;
for (auto* e : call->Arguments()) {
if (!first) {
out << ", ";
}
first = false;
if (!EmitExpression(out, e->Declaration())) {
return false;
}
}
if (is_single_value_vector_init) {
out << ")." << std::string(type->As<sem::Vector>()->Width(), 'x');
}
out << (brackets ? "}" : ")");
return true; return true;
} }
@ -1892,13 +1986,13 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
builder_.Sem().Add(zero, builder_.create<sem::Expression>(zero, i32, stmt, builder_.Sem().Add(zero, builder_.create<sem::Expression>(zero, i32, stmt,
sem::Constant{})); sem::Constant{}));
auto* packed = AppendVector(&builder_, vector, zero); auto* packed = AppendVector(&builder_, vector, zero);
return EmitExpression(out, packed); return EmitExpression(out, packed->Declaration());
}; };
auto emit_vector_appended_with_level = [&](const ast::Expression* vector) { auto emit_vector_appended_with_level = [&](const ast::Expression* vector) {
if (auto* level = arg(Usage::kLevel)) { if (auto* level = arg(Usage::kLevel)) {
auto* packed = AppendVector(&builder_, vector, level); auto* packed = AppendVector(&builder_, vector, level);
return EmitExpression(out, packed); return EmitExpression(out, packed->Declaration());
} }
return emit_vector_appended_with_i32_zero(vector); return emit_vector_appended_with_i32_zero(vector);
}; };
@ -1908,11 +2002,11 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
auto* packed = AppendVector(&builder_, param_coords, array_index); auto* packed = AppendVector(&builder_, param_coords, array_index);
if (pack_level_in_coords) { if (pack_level_in_coords) {
// Then mip level needs to be appended to the coordinates. // Then mip level needs to be appended to the coordinates.
if (!emit_vector_appended_with_level(packed)) { if (!emit_vector_appended_with_level(packed->Declaration())) {
return false; return false;
} }
} else { } else {
if (!EmitExpression(out, packed)) { if (!EmitExpression(out, packed->Declaration())) {
return false; return false;
} }
} }
@ -2112,63 +2206,6 @@ bool GeneratorImpl::EmitCase(const ast::SwitchStatement* s, size_t case_idx) {
return true; return true;
} }
bool GeneratorImpl::EmitTypeConstructor(
std::ostream& out,
const ast::TypeConstructorExpression* expr) {
auto* type = TypeOf(expr)->UnwrapRef();
// If the type constructor is empty then we need to construct with the zero
// value for all components.
if (expr->values.empty()) {
return EmitZeroValue(out, type);
}
bool brackets = type->IsAnyOf<sem::Array, sem::Struct>();
// For single-value vector initializers, swizzle the scalar to the right
// vector dimension using .x
const bool is_single_value_vector_init =
type->is_scalar_vector() && expr->values.size() == 1 &&
TypeOf(expr->values[0])->UnwrapRef()->is_scalar();
auto it = structure_builders_.find(As<sem::Struct>(type));
if (it != structure_builders_.end()) {
out << it->second << "(";
brackets = false;
} else if (brackets) {
out << "{";
} else {
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
"")) {
return false;
}
out << "(";
}
if (is_single_value_vector_init) {
out << "(";
}
bool first = true;
for (auto* e : expr->values) {
if (!first) {
out << ", ";
}
first = false;
if (!EmitExpression(out, e)) {
return false;
}
}
if (is_single_value_vector_init) {
out << ")." << std::string(type->As<sem::Vector>()->Width(), 'x');
}
out << (brackets ? "}" : ")");
return true;
}
bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) { bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
if (!emit_continuing_()) { if (!emit_continuing_()) {
return false; return false;
@ -2198,9 +2235,6 @@ bool GeneratorImpl::EmitExpression(std::ostream& out,
if (auto* c = expr->As<ast::CallExpression>()) { if (auto* c = expr->As<ast::CallExpression>()) {
return EmitCall(out, c); return EmitCall(out, c);
} }
if (auto* c = expr->As<ast::TypeConstructorExpression>()) {
return EmitTypeConstructor(out, c);
}
if (auto* i = expr->As<ast::IdentifierExpression>()) { if (auto* i = expr->As<ast::IdentifierExpression>()) {
return EmitIdentifier(out, i); return EmitIdentifier(out, i);
} }

View File

@ -44,6 +44,8 @@ namespace tint {
namespace sem { namespace sem {
class Call; class Call;
class Intrinsic; class Intrinsic;
class TypeConstructor;
class TypeConversion;
} // namespace sem } // namespace sem
namespace writer { namespace writer {
@ -116,6 +118,38 @@ class GeneratorImpl : public TextGenerator {
/// @param expr the call expression /// @param expr the call expression
/// @returns true if the call expression is emitted /// @returns true if the call expression is emitted
bool EmitCall(std::ostream& out, const ast::CallExpression* expr); bool EmitCall(std::ostream& out, const ast::CallExpression* expr);
/// Handles generating a function call expression
/// @param out the output of the expression stream
/// @param call the call expression
/// @param function the function being called
/// @returns true if the expression is emitted
bool EmitFunctionCall(std::ostream& out,
const sem::Call* call,
const sem::Function* function);
/// Handles generating an intrinsic call expression
/// @param out the output of the expression stream
/// @param call the call expression
/// @param intrinsic the intrinsic being called
/// @returns true if the expression is emitted
bool EmitIntrinsicCall(std::ostream& out,
const sem::Call* call,
const sem::Intrinsic* intrinsic);
/// Handles generating a type conversion expression
/// @param out the output of the expression stream
/// @param call the call expression
/// @param conv the type conversion
/// @returns true if the expression is emitted
bool EmitTypeConversion(std::ostream& out,
const sem::Call* call,
const sem::TypeConversion* conv);
/// Handles generating a type constructor expression
/// @param out the output of the expression stream
/// @param call the call expression
/// @param ctor the type constructor
/// @returns true if the expression is emitted
bool EmitTypeConstructor(std::ostream& out,
const sem::Call* call,
const sem::TypeConstructor* ctor);
/// Handles generating a call expression to a /// Handles generating a call expression to a
/// transform::DecomposeMemoryAccess::Intrinsic for a uniform buffer /// transform::DecomposeMemoryAccess::Intrinsic for a uniform buffer
/// @param out the output of the expression stream /// @param out the output of the expression stream
@ -221,12 +255,6 @@ class GeneratorImpl : public TextGenerator {
/// @param stmt the discard statement /// @param stmt the discard statement
/// @returns true if the statement was successfully emitted /// @returns true if the statement was successfully emitted
bool EmitDiscard(const ast::DiscardStatement* stmt); bool EmitDiscard(const ast::DiscardStatement* stmt);
/// Handles emitting a type constructor
/// @param out the output of the expression stream
/// @param expr the type constructor expression
/// @returns true if the constructor is emitted
bool EmitTypeConstructor(std::ostream& out,
const ast::TypeConstructorExpression* expr);
/// Handles a continue statement /// Handles a continue statement
/// @param stmt the statement to emit /// @param stmt the statement to emit
/// @returns true if the statement was emitted successfully /// @returns true if the statement was emitted successfully

View File

@ -51,6 +51,8 @@
#include "src/sem/sampled_texture_type.h" #include "src/sem/sampled_texture_type.h"
#include "src/sem/storage_texture_type.h" #include "src/sem/storage_texture_type.h"
#include "src/sem/struct.h" #include "src/sem/struct.h"
#include "src/sem/type_constructor.h"
#include "src/sem/type_conversion.h"
#include "src/sem/u32_type.h" #include "src/sem/u32_type.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/sem/vector_type.h" #include "src/sem/vector_type.h"
@ -242,10 +244,9 @@ bool GeneratorImpl::EmitIndexAccessor(
std::ostream& out, std::ostream& out,
const ast::IndexAccessorExpression* expr) { const ast::IndexAccessorExpression* expr) {
bool paren_lhs = bool paren_lhs =
!expr->object !expr->object->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression, ast::IdentifierExpression,
ast::IdentifierExpression, ast::MemberAccessorExpression, ast::MemberAccessorExpression>();
ast::TypeConstructorExpression>();
if (paren_lhs) { if (paren_lhs) {
out << "("; out << "(";
@ -496,43 +497,53 @@ bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
bool GeneratorImpl::EmitCall(std::ostream& out, bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) { const ast::CallExpression* expr) {
auto* ident = expr->func;
auto* call = program_->Sem().Get(expr); auto* call = program_->Sem().Get(expr);
if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) { auto* target = call->Target();
return EmitIntrinsicCall(out, expr, intrinsic);
if (auto* func = target->As<sem::Function>()) {
return EmitFunctionCall(out, call, func);
}
if (auto* intrinsic = target->As<sem::Intrinsic>()) {
return EmitIntrinsicCall(out, call, intrinsic);
}
if (auto* conv = target->As<sem::TypeConversion>()) {
return EmitTypeConversion(out, call, conv);
}
if (auto* ctor = target->As<sem::TypeConstructor>()) {
return EmitTypeConstructor(out, call, ctor);
} }
auto* func = program_->AST().Functions().Find(ident->symbol); TINT_ICE(Writer, diagnostics_)
if (func == nullptr) { << "unhandled call target: " << target->TypeInfo().name;
diagnostics_.add_error(diag::System::Writer, return false;
"Unable to find function: " + }
program_->Symbols().NameFor(ident->symbol));
return false;
}
bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
const sem::Call* call,
const sem::Function*) {
auto* ident = call->Declaration()->target.name;
out << program_->Symbols().NameFor(ident->symbol) << "("; out << program_->Symbols().NameFor(ident->symbol) << "(";
bool first = true; bool first = true;
const auto& args = expr->args; for (auto* arg : call->Arguments()) {
for (auto* arg : args) {
if (!first) { if (!first) {
out << ", "; out << ", ";
} }
first = false; first = false;
if (!EmitExpression(out, arg)) { if (!EmitExpression(out, arg->Declaration())) {
return false; return false;
} }
} }
out << ")"; out << ")";
return true; return true;
} }
bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out, bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
const ast::CallExpression* expr, const sem::Call* call,
const sem::Intrinsic* intrinsic) { const sem::Intrinsic* intrinsic) {
auto* expr = call->Declaration();
if (intrinsic->IsAtomic()) { if (intrinsic->IsAtomic()) {
return EmitAtomicCall(out, expr, intrinsic); return EmitAtomicCall(out, expr, intrinsic);
} }
@ -634,6 +645,64 @@ bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
return true; return true;
} }
bool GeneratorImpl::EmitTypeConversion(std::ostream& out,
const sem::Call* call,
const sem::TypeConversion* conv) {
if (!EmitType(out, conv->Target(), "")) {
return false;
}
out << "(";
if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
return false;
}
out << ")";
return true;
}
bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
const sem::Call* call,
const sem::TypeConstructor* ctor) {
auto* type = ctor->ReturnType();
if (type->IsAnyOf<sem::Array, sem::Struct>()) {
out << "{";
} else {
if (!EmitType(out, type, "")) {
return false;
}
out << "(";
}
int i = 0;
for (auto* arg : call->Arguments()) {
if (i > 0) {
out << ", ";
}
if (auto* struct_ty = type->As<sem::Struct>()) {
// Emit field designators for structures to account for padding members.
auto* member = struct_ty->Members()[i]->Declaration();
auto name = program_->Symbols().NameFor(member->symbol);
out << "." << name << "=";
}
if (!EmitExpression(out, arg->Declaration())) {
return false;
}
i++;
}
if (type->IsAnyOf<sem::Array, sem::Struct>()) {
out << "}";
} else {
out << ")";
}
return true;
}
bool GeneratorImpl::EmitAtomicCall(std::ostream& out, bool GeneratorImpl::EmitAtomicCall(std::ostream& out,
const ast::CallExpression* expr, const ast::CallExpression* expr,
const sem::Intrinsic* intrinsic) { const sem::Intrinsic* intrinsic) {
@ -762,10 +831,9 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
// accessor used for the function calls. // accessor used for the function calls.
auto texture_expr = [&]() { auto texture_expr = [&]() {
bool paren_lhs = bool paren_lhs =
!texture !texture->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression, ast::IdentifierExpression,
ast::IdentifierExpression, ast::MemberAccessorExpression, ast::MemberAccessorExpression>();
ast::TypeConstructorExpression>();
if (paren_lhs) { if (paren_lhs) {
out << "("; out << "(";
} }
@ -1300,48 +1368,6 @@ bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
return true; return true;
} }
bool GeneratorImpl::EmitTypeConstructor(
std::ostream& out,
const ast::TypeConstructorExpression* expr) {
auto* type = TypeOf(expr)->UnwrapRef();
if (type->IsAnyOf<sem::Array, sem::Struct>()) {
out << "{";
} else {
if (!EmitType(out, type, "")) {
return false;
}
out << "(";
}
int i = 0;
for (auto* e : expr->values) {
if (i > 0) {
out << ", ";
}
if (auto* struct_ty = type->As<sem::Struct>()) {
// Emit field designators for structures to account for padding members.
auto* member = struct_ty->Members()[i]->Declaration();
auto name = program_->Symbols().NameFor(member->symbol);
out << "." << name << "=";
}
if (!EmitExpression(out, e)) {
return false;
}
i++;
}
if (type->IsAnyOf<sem::Array, sem::Struct>()) {
out << "}";
} else {
out << ")";
}
return true;
}
bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) { bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
if (type->Is<sem::Bool>()) { if (type->Is<sem::Bool>()) {
out << "false"; out << "false";
@ -1426,9 +1452,6 @@ bool GeneratorImpl::EmitExpression(std::ostream& out,
if (auto* c = expr->As<ast::CallExpression>()) { if (auto* c = expr->As<ast::CallExpression>()) {
return EmitCall(out, c); return EmitCall(out, c);
} }
if (auto* c = expr->As<ast::TypeConstructorExpression>()) {
return EmitTypeConstructor(out, c);
}
if (auto* i = expr->As<ast::IdentifierExpression>()) { if (auto* i = expr->As<ast::IdentifierExpression>()) {
return EmitIdentifier(out, i); return EmitIdentifier(out, i);
} }
@ -1899,11 +1922,9 @@ bool GeneratorImpl::EmitMemberAccessor(
std::ostream& out, std::ostream& out,
const ast::MemberAccessorExpression* expr) { const ast::MemberAccessorExpression* expr) {
auto write_lhs = [&] { auto write_lhs = [&] {
bool paren_lhs = bool paren_lhs = !expr->structure->IsAnyOf<
!expr->structure ast::IndexAccessorExpression, ast::CallExpression,
->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression, ast::IdentifierExpression, ast::MemberAccessorExpression>();
ast::IdentifierExpression, ast::MemberAccessorExpression,
ast::TypeConstructorExpression>();
if (paren_lhs) { if (paren_lhs) {
out << "("; out << "(";
} }

View File

@ -33,7 +33,6 @@
#include "src/ast/member_accessor_expression.h" #include "src/ast/member_accessor_expression.h"
#include "src/ast/return_statement.h" #include "src/ast/return_statement.h"
#include "src/ast/switch_statement.h" #include "src/ast/switch_statement.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_op_expression.h" #include "src/ast/unary_op_expression.h"
#include "src/program.h" #include "src/program.h"
#include "src/scope_stack.h" #include "src/scope_stack.h"
@ -46,6 +45,8 @@ namespace tint {
namespace sem { namespace sem {
class Call; class Call;
class Intrinsic; class Intrinsic;
class TypeConstructor;
class TypeConversion;
} // namespace sem } // namespace sem
namespace writer { namespace writer {
@ -130,12 +131,36 @@ class GeneratorImpl : public TextGenerator {
bool EmitCall(std::ostream& out, const ast::CallExpression* expr); bool EmitCall(std::ostream& out, const ast::CallExpression* expr);
/// Handles generating an intrinsic call expression /// Handles generating an intrinsic call expression
/// @param out the output of the expression stream /// @param out the output of the expression stream
/// @param expr the call expression /// @param call the call expression
/// @param intrinsic the intrinsic being called /// @param intrinsic the intrinsic being called
/// @returns true if the call expression is emitted /// @returns true if the call expression is emitted
bool EmitIntrinsicCall(std::ostream& out, bool EmitIntrinsicCall(std::ostream& out,
const ast::CallExpression* expr, const sem::Call* call,
const sem::Intrinsic* intrinsic); const sem::Intrinsic* intrinsic);
/// Handles generating a type conversion expression
/// @param out the output of the expression stream
/// @param call the call expression
/// @param conv the type conversion
/// @returns true if the expression is emitted
bool EmitTypeConversion(std::ostream& out,
const sem::Call* call,
const sem::TypeConversion* conv);
/// Handles generating a type constructor
/// @param out the output of the expression stream
/// @param call the call expression
/// @param ctor the type constructor
/// @returns true if the constructor is emitted
bool EmitTypeConstructor(std::ostream& out,
const sem::Call* call,
const sem::TypeConstructor* ctor);
/// Handles generating a function call
/// @param out the output of the expression stream
/// @param call the call expression
/// @param func the target function
/// @returns true if the call is emitted
bool EmitFunctionCall(std::ostream& out,
const sem::Call* call,
const sem::Function* func);
/// Handles generating a call to an atomic function (`atomicAdd`, /// Handles generating a call to an atomic function (`atomicAdd`,
/// `atomicMax`, etc) /// `atomicMax`, etc)
/// @param out the output of the expression stream /// @param out the output of the expression stream
@ -293,12 +318,6 @@ class GeneratorImpl : public TextGenerator {
/// @param str the struct to generate /// @param str the struct to generate
/// @returns true if the struct is emitted /// @returns true if the struct is emitted
bool EmitStructType(TextBuffer* buffer, const sem::Struct* str); bool EmitStructType(TextBuffer* buffer, const sem::Struct* str);
/// Handles emitting a type constructor
/// @param out the output of the expression stream
/// @param expr the type constructor expression
/// @returns true if the constructor is emitted
bool EmitTypeConstructor(std::ostream& out,
const ast::TypeConstructorExpression* expr);
/// Handles a unary op expression /// Handles a unary op expression
/// @param out the output of the expression stream /// @param out the output of the expression stream
/// @param expr the expression to emit /// @param expr the expression to emit

View File

@ -22,6 +22,7 @@
#include "src/ast/fallthrough_statement.h" #include "src/ast/fallthrough_statement.h"
#include "src/ast/internal_decoration.h" #include "src/ast/internal_decoration.h"
#include "src/ast/override_decoration.h" #include "src/ast/override_decoration.h"
#include "src/ast/traverse_expressions.h"
#include "src/sem/array.h" #include "src/sem/array.h"
#include "src/sem/atomic_type.h" #include "src/sem/atomic_type.h"
#include "src/sem/call.h" #include "src/sem/call.h"
@ -33,7 +34,10 @@
#include "src/sem/multisampled_texture_type.h" #include "src/sem/multisampled_texture_type.h"
#include "src/sem/reference_type.h" #include "src/sem/reference_type.h"
#include "src/sem/sampled_texture_type.h" #include "src/sem/sampled_texture_type.h"
#include "src/sem/statement.h"
#include "src/sem/struct.h" #include "src/sem/struct.h"
#include "src/sem/type_constructor.h"
#include "src/sem/type_conversion.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/sem/vector_type.h" #include "src/sem/vector_type.h"
#include "src/transform/add_empty_entry_point.h" #include "src/transform/add_empty_entry_point.h"
@ -577,9 +581,6 @@ uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
if (auto* c = expr->As<ast::CallExpression>()) { if (auto* c = expr->As<ast::CallExpression>()) {
return GenerateCallExpression(c); return GenerateCallExpression(c);
} }
if (auto* c = expr->As<ast::TypeConstructorExpression>()) {
return GenerateConstructorExpression(nullptr, c);
}
if (auto* i = expr->As<ast::IdentifierExpression>()) { if (auto* i = expr->As<ast::IdentifierExpression>()) {
return GenerateIdentifierExpression(i); return GenerateIdentifierExpression(i);
} }
@ -1259,80 +1260,44 @@ uint32_t Builder::GenerateConstructorExpression(const ast::Variable* var,
if (auto* literal = expr->As<ast::LiteralExpression>()) { if (auto* literal = expr->As<ast::LiteralExpression>()) {
return GenerateLiteralIfNeeded(var, literal); return GenerateLiteralIfNeeded(var, literal);
} }
if (auto* type = expr->As<ast::TypeConstructorExpression>()) { if (auto* call = builder_.Sem().Get<sem::Call>(expr)) {
return GenerateTypeConstructorExpression(var, type); if (call->Target()->IsAnyOf<sem::TypeConstructor, sem::TypeConversion>()) {
return GenerateTypeConstructorOrConversion(call, var);
}
} }
error_ = "unknown constructor expression"; error_ = "unknown constructor expression";
return 0; return 0;
} }
bool Builder::is_constructor_const(const ast::Expression* expr, bool Builder::IsConstructorConst(const ast::Expression* expr) {
bool is_global_init) { bool is_const = true;
if (expr->Is<ast::LiteralExpression>()) { ast::TraverseExpressions(expr, builder_.Diagnostics(),
return true; [&](const ast::Expression* e) {
} if (e->Is<ast::LiteralExpression>()) {
return ast::TraverseAction::Descend;
}
if (auto* ce = e->As<ast::CallExpression>()) {
auto* call = builder_.Sem().Get(ce);
if (call->Target()->Is<sem::TypeConstructor>()) {
return ast::TraverseAction::Descend;
}
}
auto* tc = expr->As<ast::TypeConstructorExpression>(); is_const = false;
if (!tc) { return ast::TraverseAction::Stop;
return false; });
} return is_const;
auto* result_type = TypeOf(tc)->UnwrapRef();
for (size_t i = 0; i < tc->values.size(); ++i) {
auto* e = tc->values[i];
if (!e->IsAnyOf<ast::TypeConstructorExpression, ast::LiteralExpression>()) {
if (is_global_init) {
error_ = "constructor must be a constant expression";
return false;
}
return false;
}
if (!is_constructor_const(e, is_global_init)) {
return false;
}
if (has_error()) {
return false;
}
auto* lit = e->As<ast::LiteralExpression>();
if (result_type->Is<sem::Vector>() && lit == nullptr) {
return false;
}
// This should all be handled by |is_constructor_const| call above
if (lit == nullptr) {
continue;
}
const sem::Type* subtype = result_type->UnwrapRef();
if (auto* vec = subtype->As<sem::Vector>()) {
subtype = vec->type();
} else if (auto* mat = subtype->As<sem::Matrix>()) {
subtype = mat->type();
} else if (auto* arr = subtype->As<sem::Array>()) {
subtype = arr->ElemType();
} else if (auto* str = subtype->As<sem::Struct>()) {
subtype = str->Members()[i]->Type();
}
if (subtype != TypeOf(lit)->UnwrapRef()) {
return false;
}
}
return true;
} }
uint32_t Builder::GenerateTypeConstructorExpression( uint32_t Builder::GenerateTypeConstructorOrConversion(
const ast::Variable* var, const sem::Call* call,
const ast::TypeConstructorExpression* init) { const ast::Variable* var) {
auto& args = call->Arguments();
auto* global_var = builder_.Sem().Get<sem::GlobalVariable>(var); auto* global_var = builder_.Sem().Get<sem::GlobalVariable>(var);
auto* result_type = call->Type();
auto& values = init->values;
auto* result_type = TypeOf(init);
// Generate the zero initializer if there are no values provided. // Generate the zero initializer if there are no values provided.
if (values.empty()) { if (args.empty()) {
if (global_var && global_var->IsOverridable()) { if (global_var && global_var->IsOverridable()) {
auto constant_id = global_var->ConstantId(); auto constant_id = global_var->ConstantId();
if (result_type->Is<sem::I32>()) { if (result_type->Is<sem::I32>()) {
@ -1356,10 +1321,10 @@ uint32_t Builder::GenerateTypeConstructorExpression(
} }
std::ostringstream out; std::ostringstream out;
out << "__const_" << init->type->FriendlyName(builder_.Symbols()) << "_"; out << "__const_" << result_type->FriendlyName(builder_.Symbols()) << "_";
result_type = result_type->UnwrapRef(); result_type = result_type->UnwrapRef();
bool constructor_is_const = is_constructor_const(init, global_var); bool constructor_is_const = IsConstructorConst(call->Declaration());
if (has_error()) { if (has_error()) {
return 0; return 0;
} }
@ -1368,7 +1333,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
if (auto* res_vec = result_type->As<sem::Vector>()) { if (auto* res_vec = result_type->As<sem::Vector>()) {
if (res_vec->type()->is_scalar()) { if (res_vec->type()->is_scalar()) {
auto* value_type = TypeOf(values[0])->UnwrapRef(); auto* value_type = args[0]->Type()->UnwrapRef();
if (auto* val_vec = value_type->As<sem::Vector>()) { if (auto* val_vec = value_type->As<sem::Vector>()) {
if (val_vec->type()->is_scalar()) { if (val_vec->type()->is_scalar()) {
can_cast_or_copy = res_vec->Width() == val_vec->Width(); can_cast_or_copy = res_vec->Width() == val_vec->Width();
@ -1378,7 +1343,8 @@ uint32_t Builder::GenerateTypeConstructorExpression(
} }
if (can_cast_or_copy) { if (can_cast_or_copy) {
return GenerateCastOrCopyOrPassthrough(result_type, values[0], global_var); return GenerateCastOrCopyOrPassthrough(result_type, args[0]->Declaration(),
global_var);
} }
auto type_id = GenerateTypeIfNeeded(result_type); auto type_id = GenerateTypeIfNeeded(result_type);
@ -1394,19 +1360,18 @@ uint32_t Builder::GenerateTypeConstructorExpression(
} }
OperandList ops; OperandList ops;
for (auto* e : values) { for (auto* e : args) {
uint32_t id = 0; uint32_t id = 0;
if (constructor_is_const) { id = GenerateExpression(e->Declaration());
id = GenerateConstructorExpression(nullptr, e); if (id == 0) {
} else { return 0;
id = GenerateExpression(e);
id = GenerateLoadIfNeeded(TypeOf(e), id);
} }
id = GenerateLoadIfNeeded(e->Type(), id);
if (id == 0) { if (id == 0) {
return 0; return 0;
} }
auto* value_type = TypeOf(e)->UnwrapRef(); auto* value_type = e->Type()->UnwrapRef();
// If the result and value types are the same we can just use the object. // If the result and value types are the same we can just use the object.
// If the result is not a vector then we should have validated that the // If the result is not a vector then we should have validated that the
// value type is a correctly sized vector so we can just use it directly. // value type is a correctly sized vector so we can just use it directly.
@ -1421,7 +1386,8 @@ uint32_t Builder::GenerateTypeConstructorExpression(
// Both scalars, but not the same type so we need to generate a conversion // Both scalars, but not the same type so we need to generate a conversion
// of the value. // of the value.
if (value_type->is_scalar() && result_type->is_scalar()) { if (value_type->is_scalar() && result_type->is_scalar()) {
id = GenerateCastOrCopyOrPassthrough(result_type, values[0], global_var); id = GenerateCastOrCopyOrPassthrough(result_type, args[0]->Declaration(),
global_var);
out << "_" << id; out << "_" << id;
ops.push_back(Operand::Int(id)); ops.push_back(Operand::Int(id));
continue; continue;
@ -1483,9 +1449,9 @@ uint32_t Builder::GenerateTypeConstructorExpression(
} }
// For a single-value vector initializer, splat the initializer value. // For a single-value vector initializer, splat the initializer value.
auto* const init_result_type = TypeOf(init)->UnwrapRef(); auto* const init_result_type = call->Type()->UnwrapRef();
if (values.size() == 1 && init_result_type->is_scalar_vector() && if (args.size() == 1 && init_result_type->is_scalar_vector() &&
TypeOf(values[0])->UnwrapRef()->is_scalar()) { args[0]->Type()->UnwrapRef()->is_scalar()) {
size_t vec_size = init_result_type->As<sem::Vector>()->Width(); size_t vec_size = init_result_type->As<sem::Vector>()->Width();
for (size_t i = 0; i < (vec_size - 1); ++i) { for (size_t i = 0; i < (vec_size - 1); ++i) {
ops.push_back(ops[0]); ops.push_back(ops[0]);
@ -2232,14 +2198,29 @@ bool Builder::GenerateBlockStatementWithoutScoping(
} }
uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) { uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
auto* ident = expr->func;
auto* call = builder_.Sem().Get(expr); auto* call = builder_.Sem().Get(expr);
auto* target = call->Target(); auto* target = call->Target();
if (auto* intrinsic = target->As<sem::Intrinsic>()) {
return GenerateIntrinsic(expr, intrinsic);
}
auto type_id = GenerateTypeIfNeeded(target->ReturnType()); if (auto* func = target->As<sem::Function>()) {
return GenerateFunctionCall(call, func);
}
if (auto* intrinsic = target->As<sem::Intrinsic>()) {
return GenerateIntrinsicCall(call, intrinsic);
}
if (target->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
return GenerateTypeConstructorOrConversion(call, nullptr);
}
TINT_ICE(Writer, builder_.Diagnostics())
<< "unhandled call target: " << target->TypeInfo().name;
return false;
}
uint32_t Builder::GenerateFunctionCall(const sem::Call* call,
const sem::Function*) {
auto* expr = call->Declaration();
auto* ident = expr->target.name;
auto type_id = GenerateTypeIfNeeded(call->Type());
if (type_id == 0) { if (type_id == 0) {
return 0; return 0;
} }
@ -2278,8 +2259,8 @@ uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
return result_id; return result_id;
} }
uint32_t Builder::GenerateIntrinsic(const ast::CallExpression* call, uint32_t Builder::GenerateIntrinsicCall(const sem::Call* call,
const sem::Intrinsic* intrinsic) { const sem::Intrinsic* intrinsic) {
auto result = result_op(); auto result = result_op();
auto result_id = result.to_i(); auto result_id = result.to_i();
@ -2323,15 +2304,15 @@ uint32_t Builder::GenerateIntrinsic(const ast::CallExpression* call,
// and loads it if necessary. Returns 0 on error. // and loads it if necessary. Returns 0 on error.
auto get_arg_as_value_id = [&](size_t i, auto get_arg_as_value_id = [&](size_t i,
bool generate_load = true) -> uint32_t { bool generate_load = true) -> uint32_t {
auto* arg = call->args[i]; auto* arg = call->Arguments()[i];
auto* param = intrinsic->Parameters()[i]; auto* param = intrinsic->Parameters()[i];
auto val_id = GenerateExpression(arg); auto val_id = GenerateExpression(arg->Declaration());
if (val_id == 0) { if (val_id == 0) {
return 0; return 0;
} }
if (generate_load && !param->Type()->Is<sem::Pointer>()) { if (generate_load && !param->Type()->Is<sem::Pointer>()) {
val_id = GenerateLoadIfNeeded(TypeOf(arg), val_id); val_id = GenerateLoadIfNeeded(arg->Type(), val_id);
} }
return val_id; return val_id;
}; };
@ -2364,13 +2345,8 @@ uint32_t Builder::GenerateIntrinsic(const ast::CallExpression* call,
op = spv::Op::OpAll; op = spv::Op::OpAll;
break; break;
case IntrinsicType::kArrayLength: { case IntrinsicType::kArrayLength: {
if (call->args.empty()) { auto* address_of =
error_ = "missing param for runtime array length"; call->Arguments()[0]->Declaration()->As<ast::UnaryOpExpression>();
return 0;
}
auto* arg = call->args[0];
auto* address_of = arg->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) { if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
error_ = "arrayLength() expected pointer to member access, got " + error_ = "arrayLength() expected pointer to member access, got " +
std::string(address_of->TypeInfo().name); std::string(address_of->TypeInfo().name);
@ -2695,7 +2671,7 @@ uint32_t Builder::GenerateIntrinsic(const ast::CallExpression* call,
return 0; return 0;
} }
for (size_t i = 0; i < call->args.size(); i++) { for (size_t i = 0; i < call->Arguments().size(); i++) {
if (auto val_id = get_arg_as_value_id(i)) { if (auto val_id = get_arg_as_value_id(i)) {
params.emplace_back(Operand::Int(val_id)); params.emplace_back(Operand::Int(val_id));
} else { } else {
@ -2710,22 +2686,22 @@ uint32_t Builder::GenerateIntrinsic(const ast::CallExpression* call,
return result_id; return result_id;
} }
bool Builder::GenerateTextureIntrinsic(const ast::CallExpression* call, bool Builder::GenerateTextureIntrinsic(const sem::Call* call,
const sem::Intrinsic* intrinsic, const sem::Intrinsic* intrinsic,
Operand result_type, Operand result_type,
Operand result_id) { Operand result_id) {
using Usage = sem::ParameterUsage; using Usage = sem::ParameterUsage;
auto& signature = intrinsic->Signature(); auto& signature = intrinsic->Signature();
auto arguments = call->args; auto& arguments = call->Arguments();
// Generates the given expression, returning the operand ID // Generates the given expression, returning the operand ID
auto gen = [&](const ast::Expression* expr) { auto gen = [&](const sem::Expression* expr) {
auto val_id = GenerateExpression(expr); auto val_id = GenerateExpression(expr->Declaration());
if (val_id == 0) { if (val_id == 0) {
return Operand::Int(0); return Operand::Int(0);
} }
val_id = GenerateLoadIfNeeded(TypeOf(expr), val_id); val_id = GenerateLoadIfNeeded(expr->Type(), val_id);
return Operand::Int(val_id); return Operand::Int(val_id);
}; };
@ -2751,7 +2727,7 @@ bool Builder::GenerateTextureIntrinsic(const ast::CallExpression* call,
TINT_ICE(Writer, builder_.Diagnostics()) << "missing texture argument"; TINT_ICE(Writer, builder_.Diagnostics()) << "missing texture argument";
} }
auto* texture_type = TypeOf(texture)->UnwrapRef()->As<sem::Texture>(); auto* texture_type = texture->Type()->UnwrapRef()->As<sem::Texture>();
auto op = spv::Op::OpNop; auto op = spv::Op::OpNop;
@ -2819,7 +2795,7 @@ bool Builder::GenerateTextureIntrinsic(const ast::CallExpression* call,
} else { } else {
// Assign post_emission to swizzle the result of the call to // Assign post_emission to swizzle the result of the call to
// OpImageQuerySize[Lod]. // OpImageQuerySize[Lod].
auto* element_type = ElementTypeOf(TypeOf(call)); auto* element_type = ElementTypeOf(call->Type());
auto spirv_result = result_op(); auto spirv_result = result_op();
auto* spirv_result_type = auto* spirv_result_type =
builder_.create<sem::Vector>(element_type, spirv_result_width); builder_.create<sem::Vector>(element_type, spirv_result_width);
@ -2856,8 +2832,9 @@ bool Builder::GenerateTextureIntrinsic(const ast::CallExpression* call,
auto append_coords_to_spirv_params = [&]() -> bool { auto append_coords_to_spirv_params = [&]() -> bool {
if (auto* array_index = arg(Usage::kArrayIndex)) { if (auto* array_index = arg(Usage::kArrayIndex)) {
// Array index needs to be appended to the coordinates. // Array index needs to be appended to the coordinates.
auto* packed = AppendVector(&builder_, arg(Usage::kCoords), array_index); auto* packed = AppendVector(&builder_, arg(Usage::kCoords)->Declaration(),
auto param = GenerateTypeConstructorExpression(nullptr, packed); array_index->Declaration());
auto param = GenerateExpression(packed->Declaration());
if (param == 0) { if (param == 0) {
return false; return false;
} }
@ -3026,7 +3003,7 @@ bool Builder::GenerateTextureIntrinsic(const ast::CallExpression* call,
return false; return false;
} }
auto level = Operand::Int(0); auto level = Operand::Int(0);
if (TypeOf(arg(Usage::kLevel))->Is<sem::I32>()) { if (arg(Usage::kLevel)->Type()->UnwrapRef()->Is<sem::I32>()) {
// Depth textures have i32 parameters for the level, but SPIR-V expects // Depth textures have i32 parameters for the level, but SPIR-V expects
// F32. Cast. // F32. Cast.
auto f32_type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>()); auto f32_type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
@ -3156,7 +3133,7 @@ bool Builder::GenerateControlBarrierIntrinsic(const sem::Intrinsic* intrinsic) {
}); });
} }
bool Builder::GenerateAtomicIntrinsic(const ast::CallExpression* call, bool Builder::GenerateAtomicIntrinsic(const sem::Call* call,
const sem::Intrinsic* intrinsic, const sem::Intrinsic* intrinsic,
Operand result_type, Operand result_type,
Operand result_id) { Operand result_id) {
@ -3193,18 +3170,18 @@ bool Builder::GenerateAtomicIntrinsic(const ast::CallExpression* call,
return false; return false;
} }
uint32_t pointer_id = GenerateExpression(call->args[0]); uint32_t pointer_id = GenerateExpression(call->Arguments()[0]->Declaration());
if (pointer_id == 0) { if (pointer_id == 0) {
return false; return false;
} }
uint32_t value_id = 0; uint32_t value_id = 0;
if (call->args.size() > 1) { if (call->Arguments().size() > 1) {
value_id = GenerateExpression(call->args.back()); value_id = GenerateExpression(call->Arguments().back()->Declaration());
if (value_id == 0) { if (value_id == 0) {
return false; return false;
} }
value_id = GenerateLoadIfNeeded(TypeOf(call->args.back()), value_id); value_id = GenerateLoadIfNeeded(call->Arguments().back()->Type(), value_id);
if (value_id == 0) { if (value_id == 0) {
return false; return false;
} }
@ -3308,12 +3285,12 @@ bool Builder::GenerateAtomicIntrinsic(const ast::CallExpression* call,
value, value,
}); });
case sem::IntrinsicType::kAtomicCompareExchangeWeak: { case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
auto comparator = GenerateExpression(call->args[1]); auto comparator = GenerateExpression(call->Arguments()[1]->Declaration());
if (comparator == 0) { if (comparator == 0) {
return false; return false;
} }
auto* value_sem_type = TypeOf(call->args[2]); auto* value_sem_type = TypeOf(call->Arguments()[2]->Declaration());
auto value_type = GenerateTypeIfNeeded(value_sem_type); auto value_type = GenerateTypeIfNeeded(value_sem_type);
if (value_type == 0) { if (value_type == 0) {

View File

@ -46,6 +46,8 @@ namespace tint {
namespace sem { namespace sem {
class Call; class Call;
class Reference; class Reference;
class TypeConstructor;
class TypeConversion;
} // namespace sem } // namespace sem
namespace writer { namespace writer {
@ -341,13 +343,6 @@ class Builder {
/// @returns the ID of the expression or 0 on failure. /// @returns the ID of the expression or 0 on failure.
uint32_t GenerateConstructorExpression(const ast::Variable* var, uint32_t GenerateConstructorExpression(const ast::Variable* var,
const ast::Expression* expr); const ast::Expression* expr);
/// Generates a type constructor expression
/// @param var the variable generated for, nullptr if no variable associated.
/// @param init the expression to generate
/// @returns the ID of the expression or 0 on failure.
uint32_t GenerateTypeConstructorExpression(
const ast::Variable* var,
const ast::TypeConstructorExpression* init);
/// Generates a literal constant if needed /// Generates a literal constant if needed
/// @param var the variable generated for, nullptr if no variable associated. /// @param var the variable generated for, nullptr if no variable associated.
/// @param lit the literal to generate /// @param lit the literal to generate
@ -371,12 +366,24 @@ class Builder {
/// @param expr the expression to generate /// @param expr the expression to generate
/// @returns the expression ID on success or 0 otherwise /// @returns the expression ID on success or 0 otherwise
uint32_t GenerateCallExpression(const ast::CallExpression* expr); uint32_t GenerateCallExpression(const ast::CallExpression* expr);
/// Generates an intrinsic call /// Handles generating a function call expression
/// @param call the call expression /// @param call the call expression
/// @param intrinsic the semantic information for the intrinsic /// @param function the function being called
/// @returns the expression ID on success or 0 otherwise /// @returns the expression ID on success or 0 otherwise
uint32_t GenerateIntrinsic(const ast::CallExpression* call, uint32_t GenerateFunctionCall(const sem::Call* call,
const sem::Intrinsic* intrinsic); const sem::Function* function);
/// Handles generating an intrinsic call expression
/// @param call the call expression
/// @param intrinsic the intrinsic being called
/// @returns the expression ID on success or 0 otherwise
uint32_t GenerateIntrinsicCall(const sem::Call* call,
const sem::Intrinsic* intrinsic);
/// Handles generating a type constructor or type conversion expression
/// @param call the call expression
/// @param var the variable that is being initialized. May be null.
/// @returns the expression ID on success or 0 otherwise
uint32_t GenerateTypeConstructorOrConversion(const sem::Call* call,
const ast::Variable* var);
/// Generates a texture intrinsic call. Emits an error and returns false if /// Generates a texture intrinsic call. Emits an error and returns false if
/// we're currently outside a function. /// we're currently outside a function.
/// @param call the call expression /// @param call the call expression
@ -385,7 +392,7 @@ class Builder {
/// @param result_id result identifier operand of the texture instruction /// @param result_id result identifier operand of the texture instruction
/// parameters /// parameters
/// @returns true on success /// @returns true on success
bool GenerateTextureIntrinsic(const ast::CallExpression* call, bool GenerateTextureIntrinsic(const sem::Call* call,
const sem::Intrinsic* intrinsic, const sem::Intrinsic* intrinsic,
spirv::Operand result_type, spirv::Operand result_type,
spirv::Operand result_id); spirv::Operand result_id);
@ -399,7 +406,7 @@ class Builder {
/// @param result_type result type operand of the texture instruction /// @param result_type result type operand of the texture instruction
/// @param result_id result identifier operand of the texture instruction /// @param result_id result identifier operand of the texture instruction
/// @returns true on success /// @returns true on success
bool GenerateAtomicIntrinsic(const ast::CallExpression* call, bool GenerateAtomicIntrinsic(const sem::Call* call,
const sem::Intrinsic* intrinsic, const sem::Intrinsic* intrinsic,
Operand result_type, Operand result_type,
Operand result_id); Operand result_id);
@ -536,9 +543,8 @@ class Builder {
/// Determines if the given type constructor is created from constant values /// Determines if the given type constructor is created from constant values
/// @param expr the expression to check /// @param expr the expression to check
/// @param is_global_init if this is a global initializer
/// @returns true if the constructor is constant /// @returns true if the constructor is constant
bool is_constructor_const(const ast::Expression* expr, bool is_global_init); bool IsConstructorConst(const ast::Expression* expr);
private: private:
/// @returns an Operand with a new result ID in it. Increments the next_id_ /// @returns an Operand with a new result ID in it. Increments the next_id_

View File

@ -165,20 +165,6 @@ TEST_F(SpvBuilderConstructorTest, Vector_Bitcast_Params) {
)"); )");
} }
TEST_F(SpvBuilderConstructorTest, Type_NonConst_Value_Fails) {
auto* rel = create<ast::BinaryExpression>(ast::BinaryOp::kAdd, Expr(3.0f),
Expr(3.0f));
auto* t = vec2<f32>(1.0f, rel);
auto* g = Global("g", ty.vec2<f32>(), t, ast::StorageClass::kPrivate);
spirv::Builder& b = Build();
EXPECT_EQ(b.GenerateConstructorExpression(g, t), 0u);
EXPECT_TRUE(b.has_error());
EXPECT_EQ(b.error(), R"(constructor must be a constant expression)");
}
TEST_F(SpvBuilderConstructorTest, Type_Bool_With_Bool) { TEST_F(SpvBuilderConstructorTest, Type_Bool_With_Bool) {
auto* cast = Construct<bool>(true); auto* cast = Construct<bool>(true);
WrapInFunction(cast); WrapInFunction(cast);
@ -668,6 +654,36 @@ TEST_F(SpvBuilderConstructorTest, Type_Vec4_With_Vec4) {
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"()"); EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"()");
} }
TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_F32_With_F32) {
auto* ctor = Construct<f32>(2.0f);
GlobalConst("g", ty.f32(), ctor);
spirv::Builder& b = SanitizeAndBuild();
ASSERT_TRUE(b.Build());
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
%2 = OpConstant %1 2
%4 = OpTypeVoid
%3 = OpTypeFunction %4
)");
Validate(b);
}
TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_U32_With_F32) {
auto* ctor = Construct<u32>(1.5f);
GlobalConst("g", ty.u32(), ctor);
spirv::Builder& b = SanitizeAndBuild();
ASSERT_TRUE(b.Build());
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
%2 = OpConstant %1 1
%4 = OpTypeVoid
%3 = OpTypeFunction %4
)");
Validate(b);
}
TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_Vec2_With_F32) { TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_Vec2_With_F32) {
auto* cast = vec2<f32>(2.0f); auto* cast = vec2<f32>(2.0f);
auto* g = Global("g", ty.vec2<f32>(), cast, ast::StorageClass::kPrivate); auto* g = Global("g", ty.vec2<f32>(), cast, ast::StorageClass::kPrivate);
@ -1689,27 +1705,10 @@ TEST_F(SpvBuilderConstructorTest,
spirv::Builder& b = Build(); spirv::Builder& b = Build();
EXPECT_TRUE(b.is_constructor_const(t, true)); EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());
} }
TEST_F(SpvBuilderConstructorTest, IsConstructorConst_GlobalVector_WithIdent) {
// vec3<f32>(a, b, c) -> false -- ERROR
Global("a", ty.f32(), ast::StorageClass::kPrivate);
Global("b", ty.f32(), ast::StorageClass::kPrivate);
Global("c", ty.f32(), ast::StorageClass::kPrivate);
auto* t = vec3<f32>("a", "b", "c");
WrapInFunction(t);
spirv::Builder& b = Build();
EXPECT_FALSE(b.is_constructor_const(t, true));
EXPECT_TRUE(b.has_error());
EXPECT_EQ(b.error(), "constructor must be a constant expression");
}
TEST_F(SpvBuilderConstructorTest, TEST_F(SpvBuilderConstructorTest,
IsConstructorConst_GlobalArrayWithAllConstConstructors) { IsConstructorConst_GlobalArrayWithAllConstConstructors) {
// array<vec3<f32>, 2>(vec3<f32>(1.0, 2.0, 3.0), vec3<f32>(1.0, 2.0, 3.0)) // array<vec3<f32>, 2>(vec3<f32>(1.0, 2.0, 3.0), vec3<f32>(1.0, 2.0, 3.0))
@ -1720,7 +1719,7 @@ TEST_F(SpvBuilderConstructorTest,
spirv::Builder& b = Build(); spirv::Builder& b = Build();
EXPECT_TRUE(b.is_constructor_const(t, true)); EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());
} }
@ -1733,12 +1732,12 @@ TEST_F(SpvBuilderConstructorTest,
spirv::Builder& b = Build(); spirv::Builder& b = Build();
EXPECT_FALSE(b.is_constructor_const(t, true)); EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());
} }
TEST_F(SpvBuilderConstructorTest, TEST_F(SpvBuilderConstructorTest,
IsConstructorConst_GlobalWithTypeCastConstructor) { IsConstructorConst_GlobalWithTypeConversionConstructor) {
// vec2<f32>(f32(1), f32(2)) -> false // vec2<f32>(f32(1), f32(2)) -> false
auto* t = vec2<f32>(Construct<f32>(1), Construct<f32>(2)); auto* t = vec2<f32>(Construct<f32>(1), Construct<f32>(2));
@ -1746,7 +1745,7 @@ TEST_F(SpvBuilderConstructorTest,
spirv::Builder& b = Build(); spirv::Builder& b = Build();
EXPECT_FALSE(b.is_constructor_const(t, true)); EXPECT_FALSE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());
} }
@ -1759,7 +1758,7 @@ TEST_F(SpvBuilderConstructorTest,
spirv::Builder& b = Build(); spirv::Builder& b = Build();
EXPECT_TRUE(b.is_constructor_const(t, false)); EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());
} }
@ -1775,7 +1774,7 @@ TEST_F(SpvBuilderConstructorTest, IsConstructorConst_Vector_WithIdent) {
spirv::Builder& b = Build(); spirv::Builder& b = Build();
EXPECT_FALSE(b.is_constructor_const(t, false)); EXPECT_FALSE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());
} }
@ -1792,12 +1791,12 @@ TEST_F(SpvBuilderConstructorTest,
spirv::Builder& b = Build(); spirv::Builder& b = Build();
EXPECT_TRUE(b.is_constructor_const(t, false)); EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());
} }
TEST_F(SpvBuilderConstructorTest, TEST_F(SpvBuilderConstructorTest,
IsConstructorConst_VectorWithTypeCastConstConstructors) { IsConstructorConst_VectorWithTypeConversionConstConstructors) {
// vec2<f32>(f32(1), f32(2)) -> false // vec2<f32>(f32(1), f32(2)) -> false
auto* t = vec2<f32>(Construct<f32>(1), Construct<f32>(2)); auto* t = vec2<f32>(Construct<f32>(1), Construct<f32>(2));
@ -1805,7 +1804,7 @@ TEST_F(SpvBuilderConstructorTest,
spirv::Builder& b = Build(); spirv::Builder& b = Build();
EXPECT_FALSE(b.is_constructor_const(t, false)); EXPECT_FALSE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());
} }
@ -1815,7 +1814,7 @@ TEST_F(SpvBuilderConstructorTest, IsConstructorConst_BitCastScalars) {
spirv::Builder& b = Build(); spirv::Builder& b = Build();
EXPECT_FALSE(b.is_constructor_const(t, false)); EXPECT_FALSE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());
} }
@ -1830,7 +1829,7 @@ TEST_F(SpvBuilderConstructorTest, IsConstructorConst_Struct) {
spirv::Builder& b = Build(); spirv::Builder& b = Build();
EXPECT_TRUE(b.is_constructor_const(t, false)); EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());
} }
@ -1849,7 +1848,7 @@ TEST_F(SpvBuilderConstructorTest,
spirv::Builder& b = Build(); spirv::Builder& b = Build();
EXPECT_FALSE(b.is_constructor_const(t, false)); EXPECT_FALSE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error()); EXPECT_FALSE(b.has_error());
} }

View File

@ -134,9 +134,6 @@ bool GeneratorImpl::EmitExpression(std::ostream& out,
if (auto* l = expr->As<ast::LiteralExpression>()) { if (auto* l = expr->As<ast::LiteralExpression>()) {
return EmitLiteral(out, l); return EmitLiteral(out, l);
} }
if (auto* c = expr->As<ast::TypeConstructorExpression>()) {
return EmitTypeConstructor(out, c);
}
if (auto* m = expr->As<ast::MemberAccessorExpression>()) { if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
return EmitMemberAccessor(out, m); return EmitMemberAccessor(out, m);
} }
@ -156,10 +153,9 @@ bool GeneratorImpl::EmitIndexAccessor(
std::ostream& out, std::ostream& out,
const ast::IndexAccessorExpression* expr) { const ast::IndexAccessorExpression* expr) {
bool paren_lhs = bool paren_lhs =
!expr->object !expr->object->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression, ast::IdentifierExpression,
ast::IdentifierExpression, ast::MemberAccessorExpression, ast::MemberAccessorExpression>();
ast::TypeConstructorExpression>();
if (paren_lhs) { if (paren_lhs) {
out << "("; out << "(";
} }
@ -183,10 +179,9 @@ bool GeneratorImpl::EmitMemberAccessor(
std::ostream& out, std::ostream& out,
const ast::MemberAccessorExpression* expr) { const ast::MemberAccessorExpression* expr) {
bool paren_lhs = bool paren_lhs =
!expr->structure !expr->structure->IsAnyOf<ast::IndexAccessorExpression,
->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression, ast::CallExpression, ast::IdentifierExpression,
ast::IdentifierExpression, ast::MemberAccessorExpression, ast::MemberAccessorExpression>();
ast::TypeConstructorExpression>();
if (paren_lhs) { if (paren_lhs) {
out << "("; out << "(";
} }
@ -220,7 +215,17 @@ bool GeneratorImpl::EmitBitcast(std::ostream& out,
bool GeneratorImpl::EmitCall(std::ostream& out, bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) { const ast::CallExpression* expr) {
if (!EmitExpression(out, expr->func)) { if (expr->target.name) {
if (!EmitExpression(out, expr->target.name)) {
return false;
}
} else if (expr->target.type) {
if (!EmitType(out, expr->target.type)) {
return false;
}
} else {
TINT_ICE(Writer, diagnostics_)
<< "CallExpression target had neither a name or type";
return false; return false;
} }
out << "("; out << "(";
@ -243,31 +248,6 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
return true; return true;
} }
bool GeneratorImpl::EmitTypeConstructor(
std::ostream& out,
const ast::TypeConstructorExpression* expr) {
if (!EmitType(out, expr->type)) {
return false;
}
out << "(";
bool first = true;
for (auto* e : expr->values) {
if (!first) {
out << ", ";
}
first = false;
if (!EmitExpression(out, e)) {
return false;
}
}
out << ")";
return true;
}
bool GeneratorImpl::EmitLiteral(std::ostream& out, bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) { const ast::LiteralExpression* lit) {
if (auto* bl = lit->As<ast::BoolLiteralExpression>()) { if (auto* bl = lit->As<ast::BoolLiteralExpression>()) {

View File

@ -31,7 +31,6 @@
#include "src/ast/member_accessor_expression.h" #include "src/ast/member_accessor_expression.h"
#include "src/ast/return_statement.h" #include "src/ast/return_statement.h"
#include "src/ast/switch_statement.h" #include "src/ast/switch_statement.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_op_expression.h" #include "src/ast/unary_op_expression.h"
#include "src/program.h" #include "src/program.h"
#include "src/sem/storage_texture_type.h" #include "src/sem/storage_texture_type.h"
@ -183,12 +182,6 @@ class GeneratorImpl : public TextGenerator {
/// @param access the access to generate /// @param access the access to generate
/// @returns true if the access is emitted /// @returns true if the access is emitted
bool EmitAccess(std::ostream& out, const ast::Access access); bool EmitAccess(std::ostream& out, const ast::Access access);
/// Handles emitting a type constructor
/// @param out the output of the expression stream
/// @param expr the type constructor expression
/// @returns true if the constructor is emitted
bool EmitTypeConstructor(std::ostream& out,
const ast::TypeConstructorExpression* expr);
/// Handles a unary op expression /// Handles a unary op expression
/// @param out the output of the expression stream /// @param out the output of the expression stream
/// @param expr the expression to emit /// @param expr the expression to emit

View File

@ -203,7 +203,6 @@ tint_unittests_source_set("tint_unittests_ast_src") {
"../src/ast/test_helper.h", "../src/ast/test_helper.h",
"../src/ast/texture_test.cc", "../src/ast/texture_test.cc",
"../src/ast/traverse_expressions_test.cc", "../src/ast/traverse_expressions_test.cc",
"../src/ast/type_constructor_expression_test.cc",
"../src/ast/u32_test.cc", "../src/ast/u32_test.cc",
"../src/ast/uint_literal_expression_test.cc", "../src/ast/uint_literal_expression_test.cc",
"../src/ast/unary_op_expression_test.cc", "../src/ast/unary_op_expression_test.cc",