diff --git a/BUILD.gn b/BUILD.gn index a1ed133638..ea3b9799a1 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -243,8 +243,6 @@ source_set("libtint_core_src") { "src/ast/call_statement.h", "src/ast/case_statement.cc", "src/ast/case_statement.h", - "src/ast/cast_expression.cc", - "src/ast/cast_expression.h", "src/ast/constructor_expression.cc", "src/ast/constructor_expression.h", "src/ast/continue_statement.cc", @@ -697,7 +695,6 @@ source_set("tint_unittests_core_src") { "src/ast/call_expression_test.cc", "src/ast/call_statement_test.cc", "src/ast/case_statement_test.cc", - "src/ast/cast_expression_test.cc", "src/ast/continue_statement_test.cc", "src/ast/decorated_variable_test.cc", "src/ast/discard_statement_test.cc", @@ -831,7 +828,6 @@ source_set("tint_unittests_spv_writer_src") { "src/writer/spirv/builder_bitcast_expression_test.cc", "src/writer/spirv/builder_block_test.cc", "src/writer/spirv/builder_call_test.cc", - "src/writer/spirv/builder_cast_expression_test.cc", "src/writer/spirv/builder_constructor_expression_test.cc", "src/writer/spirv/builder_discard_test.cc", "src/writer/spirv/builder_format_conversion_test.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7c16a03ca9..b2d4885eab 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -64,8 +64,6 @@ set(TINT_LIB_SRCS ast/call_statement.h ast/case_statement.cc ast/case_statement.h - ast/cast_expression.cc - ast/cast_expression.h ast/constructor_expression.cc ast/constructor_expression.h ast/continue_statement.cc @@ -306,7 +304,6 @@ set(TINT_TEST_SRCS ast/call_expression_test.cc ast/call_statement_test.cc ast/case_statement_test.cc - ast/cast_expression_test.cc ast/continue_statement_test.cc ast/discard_statement_test.cc ast/decorated_variable_test.cc @@ -489,7 +486,6 @@ if(${TINT_BUILD_SPV_WRITER}) writer/spirv/builder_bitcast_expression_test.cc writer/spirv/builder_block_test.cc writer/spirv/builder_call_test.cc - writer/spirv/builder_cast_expression_test.cc writer/spirv/builder_constructor_expression_test.cc writer/spirv/builder_discard_test.cc writer/spirv/builder_format_conversion_test.cc diff --git a/src/ast/cast_expression.cc b/src/ast/cast_expression.cc deleted file mode 100644 index 327c37d0e7..0000000000 --- a/src/ast/cast_expression.cc +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2020 The Tint Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "src/ast/cast_expression.h" - -namespace tint { -namespace ast { - -CastExpression::CastExpression() : Expression() {} - -CastExpression::CastExpression(type::Type* type, - std::unique_ptr expr) - : Expression(), type_(type), expr_(std::move(expr)) {} - -CastExpression::CastExpression(const Source& source, - type::Type* type, - std::unique_ptr expr) - : Expression(source), type_(type), expr_(std::move(expr)) {} - -CastExpression::CastExpression(CastExpression&&) = default; - -CastExpression::~CastExpression() = default; - -bool CastExpression::IsCast() const { - return true; -} - -bool CastExpression::IsValid() const { - if (expr_ == nullptr || !expr_->IsValid()) - return false; - return type_ != nullptr; -} - -void CastExpression::to_str(std::ostream& out, size_t indent) const { - make_indent(out, indent); - out << "Cast<" << type_->type_name() << ">(" << std::endl; - expr_->to_str(out, indent + 2); - make_indent(out, indent); - out << ")" << std::endl; -} - -} // namespace ast -} // namespace tint diff --git a/src/ast/cast_expression.h b/src/ast/cast_expression.h deleted file mode 100644 index b1db7fed04..0000000000 --- a/src/ast/cast_expression.h +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2020 The Tint Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef SRC_AST_CAST_EXPRESSION_H_ -#define SRC_AST_CAST_EXPRESSION_H_ - -#include -#include - -#include "src/ast/expression.h" -#include "src/ast/literal.h" -#include "src/ast/type/type.h" - -namespace tint { -namespace ast { - -/// A cast expression -class CastExpression : public Expression { - public: - /// Constructor - CastExpression(); - /// Constructor - /// @param type the type - /// @param expr the expr - CastExpression(type::Type* type, std::unique_ptr expr); - /// Constructor - /// @param source the cast expression source - /// @param type the type - /// @param expr the expr - CastExpression(const Source& source, - type::Type* type, - std::unique_ptr expr); - /// Move constructor - CastExpression(CastExpression&&); - ~CastExpression() override; - - /// Sets the type - /// @param type the type - void set_type(type::Type* type) { type_ = std::move(type); } - /// @returns the left side expression - type::Type* type() const { return type_; } - - /// Sets the expr - /// @param expr the expression - void set_expr(std::unique_ptr expr) { expr_ = std::move(expr); } - /// @returns the expression - Expression* expr() const { return expr_.get(); } - - /// @returns true if this is a cast expression - bool IsCast() const override; - - /// @returns true if the node is valid - bool IsValid() const override; - - /// Writes a representation of the node to the output stream - /// @param out the stream to write to - /// @param indent number of spaces to indent the node when writing - void to_str(std::ostream& out, size_t indent) const override; - - private: - CastExpression(const CastExpression&) = delete; - - type::Type* type_ = nullptr; - std::unique_ptr expr_; -}; - -} // namespace ast -} // namespace tint - -#endif // SRC_AST_CAST_EXPRESSION_H_ diff --git a/src/ast/cast_expression_test.cc b/src/ast/cast_expression_test.cc deleted file mode 100644 index 4fe7566cfa..0000000000 --- a/src/ast/cast_expression_test.cc +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2020 The Tint Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "src/ast/cast_expression.h" - -#include "gtest/gtest.h" -#include "src/ast/identifier_expression.h" -#include "src/ast/type/f32_type.h" - -namespace tint { -namespace ast { -namespace { - -using CastExpressionTest = testing::Test; - -TEST_F(CastExpressionTest, Creation) { - type::F32Type f32; - auto expr = std::make_unique("expr"); - auto* expr_ptr = expr.get(); - - CastExpression c(&f32, std::move(expr)); - EXPECT_EQ(c.type(), &f32); - EXPECT_EQ(c.expr(), expr_ptr); -} - -TEST_F(CastExpressionTest, Creation_withSource) { - type::F32Type f32; - auto expr = std::make_unique("expr"); - - CastExpression c(Source{20, 2}, &f32, std::move(expr)); - auto src = c.source(); - EXPECT_EQ(src.line, 20u); - EXPECT_EQ(src.column, 2u); -} - -TEST_F(CastExpressionTest, IsCast) { - CastExpression c; - EXPECT_TRUE(c.IsCast()); -} - -TEST_F(CastExpressionTest, IsValid) { - type::F32Type f32; - auto expr = std::make_unique("expr"); - - CastExpression c(&f32, std::move(expr)); - EXPECT_TRUE(c.IsValid()); -} - -TEST_F(CastExpressionTest, IsValid_MissingType) { - auto expr = std::make_unique("expr"); - - CastExpression c; - c.set_expr(std::move(expr)); - EXPECT_FALSE(c.IsValid()); -} - -TEST_F(CastExpressionTest, IsValid_MissingExpression) { - type::F32Type f32; - - CastExpression c; - c.set_type(&f32); - EXPECT_FALSE(c.IsValid()); -} - -TEST_F(CastExpressionTest, IsValid_InvalidExpression) { - type::F32Type f32; - auto expr = std::make_unique(""); - CastExpression c(&f32, std::move(expr)); - EXPECT_FALSE(c.IsValid()); -} - -TEST_F(CastExpressionTest, ToStr) { - type::F32Type f32; - auto expr = std::make_unique("expr"); - - CastExpression c(Source{20, 2}, &f32, std::move(expr)); - std::ostringstream out; - c.to_str(out, 2); - EXPECT_EQ(out.str(), R"( Cast<__f32>( - Identifier{expr} - ) -)"); -} - -} // namespace -} // namespace ast -} // namespace tint diff --git a/src/ast/expression.cc b/src/ast/expression.cc index 2e1ad978a2..f58c138a8b 100644 --- a/src/ast/expression.cc +++ b/src/ast/expression.cc @@ -20,7 +20,6 @@ #include "src/ast/binary_expression.h" #include "src/ast/bitcast_expression.h" #include "src/ast/call_expression.h" -#include "src/ast/cast_expression.h" #include "src/ast/constructor_expression.h" #include "src/ast/identifier_expression.h" #include "src/ast/member_accessor_expression.h" @@ -96,11 +95,6 @@ const CallExpression* Expression::AsCall() const { return static_cast(this); } -const CastExpression* Expression::AsCast() const { - assert(IsCast()); - return static_cast(this); -} - const ConstructorExpression* Expression::AsConstructor() const { assert(IsConstructor()); return static_cast(this); @@ -141,11 +135,6 @@ CallExpression* Expression::AsCall() { return static_cast(this); } -CastExpression* Expression::AsCast() { - assert(IsCast()); - return static_cast(this); -} - ConstructorExpression* Expression::AsConstructor() { assert(IsConstructor()); return static_cast(this); diff --git a/src/ast/expression.h b/src/ast/expression.h index b69e9da767..dc8ce58674 100644 --- a/src/ast/expression.h +++ b/src/ast/expression.h @@ -28,7 +28,6 @@ class ArrayAccessorExpression; class BinaryExpression; class BitcastExpression; class CallExpression; -class CastExpression; class IdentifierExpression; class ConstructorExpression; class MemberAccessorExpression; @@ -70,8 +69,6 @@ class Expression : public Node { const BitcastExpression* AsBitcast() const; /// @returns the expression as a call const CallExpression* AsCall() const; - /// @returns the expression as a cast - const CastExpression* AsCast() const; /// @returns the expression as an identifier const IdentifierExpression* AsIdentifier() const; /// @returns the expression as an constructor @@ -89,8 +86,6 @@ class Expression : public Node { BitcastExpression* AsBitcast(); /// @returns the expression as a call CallExpression* AsCall(); - /// @returns the expression as a cast - CastExpression* AsCast(); /// @returns the expression as an identifier IdentifierExpression* AsIdentifier(); /// @returns the expression as an constructor diff --git a/src/ast/type/type.cc b/src/ast/type/type.cc index a16301480f..7565d737ce 100644 --- a/src/ast/type/type.cc +++ b/src/ast/type/type.cc @@ -109,6 +109,10 @@ bool Type::IsVoid() const { return false; } +bool Type::is_scalar() { + return is_float_scalar() || is_integer_scalar() || IsBool(); +} + bool Type::is_float_scalar() { return IsF32(); } diff --git a/src/ast/type/type.h b/src/ast/type/type.h index e506c4d101..5d0dbe7076 100644 --- a/src/ast/type/type.h +++ b/src/ast/type/type.h @@ -90,6 +90,8 @@ class Type { /// @returns the unwrapped type Type* UnwrapAliasPtrAlias(); + /// @returns true if this type is a scalar + bool is_scalar(); /// @returns true if this type is a float scalar bool is_float_scalar(); /// @returns true if this type is a float matrix diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index bc59df1aa9..963a33d274 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -36,7 +36,6 @@ #include "src/ast/call_expression.h" #include "src/ast/call_statement.h" #include "src/ast/case_statement.h" -#include "src/ast/cast_expression.h" #include "src/ast/continue_statement.h" #include "src/ast/discard_statement.h" #include "src/ast/else_statement.h" @@ -3463,8 +3462,11 @@ TypedExpression FunctionEmitter::MakeNumericConversion( return {}; } - TypedExpression result(expr_type, std::make_unique( - expr_type, std::move(arg_expr.expr))); + ast::ExpressionList params; + params.push_back(std::move(arg_expr.expr)); + TypedExpression result(expr_type, + std::make_unique( + expr_type, std::move(params))); if (requested_type == expr_type) { return result; diff --git a/src/reader/spirv/function_conversion_test.cc b/src/reader/spirv/function_conversion_test.cc index 37525b9a36..24a6a7a869 100644 --- a/src/reader/spirv/function_conversion_test.cc +++ b/src/reader/spirv/function_conversion_test.cc @@ -243,9 +243,10 @@ TEST_F(SpvUnaryConversionTest, ConvertSToF_Scalar_FromSigned) { none __f32 { - Cast<__f32>( + TypeConstructor{ + __f32 Identifier{x_30} - ) + } } })")) << ToString(fe.ast_body()); @@ -269,11 +270,12 @@ TEST_F(SpvUnaryConversionTest, ConvertSToF_Scalar_FromUnsigned) { none __f32 { - Cast<__f32>( + TypeConstructor{ + __f32 Bitcast<__i32>{ Identifier{x_30} } - ) + } } })")) << ToString(fe.ast_body()); @@ -297,9 +299,10 @@ TEST_F(SpvUnaryConversionTest, ConvertSToF_Vector_FromSigned) { none __vec_2__f32 { - Cast<__vec_2__f32>( + TypeConstructor{ + __vec_2__f32 Identifier{x_30} - ) + } } })")) << ToString(fe.ast_body()); @@ -323,11 +326,12 @@ TEST_F(SpvUnaryConversionTest, ConvertSToF_Vector_FromUnsigned) { none __vec_2__f32 { - Cast<__vec_2__f32>( + TypeConstructor{ + __vec_2__f32 Bitcast<__vec_2__i32>{ Identifier{x_30} } - ) + } } })")) << ToString(fe.ast_body()); @@ -384,11 +388,12 @@ TEST_F(SpvUnaryConversionTest, ConvertUToF_Scalar_FromSigned) { none __f32 { - Cast<__f32>( + TypeConstructor{ + __f32 Bitcast<__u32>{ Identifier{x_30} } - ) + } } })")) << ToString(fe.ast_body()); @@ -412,9 +417,10 @@ TEST_F(SpvUnaryConversionTest, ConvertUToF_Scalar_FromUnsigned) { none __f32 { - Cast<__f32>( + TypeConstructor{ + __f32 Identifier{x_30} - ) + } } })")) << ToString(fe.ast_body()); @@ -438,11 +444,12 @@ TEST_F(SpvUnaryConversionTest, ConvertUToF_Vector_FromSigned) { none __vec_2__f32 { - Cast<__vec_2__f32>( + TypeConstructor{ + __vec_2__f32 Bitcast<__vec_2__u32>{ Identifier{x_30} } - ) + } } })")) << ToString(fe.ast_body()); @@ -466,9 +473,10 @@ TEST_F(SpvUnaryConversionTest, ConvertUToF_Vector_FromUnsigned) { none __vec_2__f32 { - Cast<__vec_2__f32>( + TypeConstructor{ + __vec_2__f32 Identifier{x_30} - ) + } } })")) << ToString(fe.ast_body()); @@ -526,9 +534,10 @@ TEST_F(SpvUnaryConversionTest, ConvertFToS_Scalar_ToSigned) { none __i32 { - Cast<__i32>( + TypeConstructor{ + __i32 Identifier{x_30} - ) + } } })")) << ToString(fe.ast_body()); @@ -553,9 +562,10 @@ TEST_F(SpvUnaryConversionTest, ConvertFToS_Scalar_ToUnsigned) { __u32 { Bitcast<__u32>{ - Cast<__i32>( + TypeConstructor{ + __i32 Identifier{x_30} - ) + } } } })")) @@ -580,9 +590,10 @@ TEST_F(SpvUnaryConversionTest, ConvertFToS_Vector_ToSigned) { none __vec_2__i32 { - Cast<__vec_2__i32>( + TypeConstructor{ + __vec_2__i32 Identifier{x_30} - ) + } } })")) << ToString(fe.ast_body()); @@ -607,9 +618,10 @@ TEST_F(SpvUnaryConversionTest, ConvertFToS_Vector_ToUnsigned) { __vec_2__u32 { Bitcast<__vec_2__u32>{ - Cast<__vec_2__i32>( + TypeConstructor{ + __vec_2__i32 Identifier{x_30} - ) + } } } })")) @@ -669,9 +681,10 @@ TEST_F(SpvUnaryConversionTest, ConvertFToU_Scalar_ToSigned) { __i32 { Bitcast<__i32>{ - Cast<__u32>( + TypeConstructor{ + __u32 Identifier{x_30} - ) + } } } })")) @@ -696,9 +709,10 @@ TEST_F(SpvUnaryConversionTest, ConvertFToU_Scalar_ToUnsigned) { none __u32 { - Cast<__u32>( + TypeConstructor{ + __u32 Identifier{x_30} - ) + } } })")) << ToString(fe.ast_body()); @@ -723,9 +737,10 @@ TEST_F(SpvUnaryConversionTest, ConvertFToU_Vector_ToSigned) { __vec_2__i32 { Bitcast<__vec_2__i32>{ - Cast<__vec_2__u32>( + TypeConstructor{ + __vec_2__u32 Identifier{x_30} - ) + } } } })")) @@ -750,9 +765,10 @@ TEST_F(SpvUnaryConversionTest, ConvertFToU_Vector_ToUnsigned) { none __vec_2__u32 { - Cast<__vec_2__u32>( + TypeConstructor{ + __vec_2__u32 Identifier{x_30} - ) + } } })")) << ToString(fe.ast_body()); diff --git a/src/reader/wgsl/lexer.cc b/src/reader/wgsl/lexer.cc index 87b678f9a5..42019fa2a3 100644 --- a/src/reader/wgsl/lexer.cc +++ b/src/reader/wgsl/lexer.cc @@ -485,8 +485,6 @@ Token Lexer::check_keyword(const Source& source, const std::string& str) { return {Token::Type::kBuiltin, source, "builtin"}; if (str == "case") return {Token::Type::kCase, source, "case"}; - if (str == "cast") - return {Token::Type::kCast, source, "cast"}; if (str == "compute") return {Token::Type::kCompute, source, "compute"}; if (str == "const") diff --git a/src/reader/wgsl/lexer_test.cc b/src/reader/wgsl/lexer_test.cc index c3b456ef51..81e27e4e7b 100644 --- a/src/reader/wgsl/lexer_test.cc +++ b/src/reader/wgsl/lexer_test.cc @@ -420,7 +420,6 @@ INSTANTIATE_TEST_SUITE_P( TokenData{"break", Token::Type::kBreak}, TokenData{"builtin", Token::Type::kBuiltin}, TokenData{"case", Token::Type::kCase}, - TokenData{"cast", Token::Type::kCast}, TokenData{"compute", Token::Type::kCompute}, TokenData{"const", Token::Type::kConst}, TokenData{"continue", Token::Type::kContinue}, diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc index 1616ea8344..26044d41a3 100644 --- a/src/reader/wgsl/parser_impl.cc +++ b/src/reader/wgsl/parser_impl.cc @@ -26,7 +26,6 @@ #include "src/ast/builtin_decoration.h" #include "src/ast/call_expression.h" #include "src/ast/case_statement.h" -#include "src/ast/cast_expression.h" #include "src/ast/continue_statement.h" #include "src/ast/decorated_variable.h" #include "src/ast/discard_statement.h" @@ -2767,8 +2766,7 @@ std::unique_ptr ParserImpl::continuing_stmt() { // | type_decl PAREN_LEFT argument_expression_list* PAREN_RIGHT // | const_literal // | paren_rhs_stmt -// | CAST LESS_THAN type_decl GREATER_THAN paren_rhs_stmt -// | AS LESS_THAN type_decl GREATER_THAN paren_rhs_stmt +// | BITCAST LESS_THAN type_decl GREATER_THAN paren_rhs_stmt std::unique_ptr ParserImpl::primary_expression() { auto t = peek(); auto source = t.source(); @@ -2790,14 +2788,14 @@ std::unique_ptr ParserImpl::primary_expression() { return paren; } - if (t.IsCast() || t.IsBitcast()) { + if (t.IsBitcast()) { auto src = t; next(); // Consume the peek t = next(); if (!t.IsLessThan()) { - set_error(t, "missing < for " + src.to_name() + " expression"); + set_error(t, "missing < for bitcast expression"); return nullptr; } @@ -2805,13 +2803,13 @@ std::unique_ptr ParserImpl::primary_expression() { if (has_error()) return nullptr; if (type == nullptr) { - set_error(peek(), "missing type for " + src.to_name() + " expression"); + set_error(peek(), "missing type for bitcast expression"); return nullptr; } t = next(); if (!t.IsGreaterThan()) { - set_error(t, "missing > for " + src.to_name() + " expression"); + set_error(t, "missing > for bitcast expression"); return nullptr; } @@ -2823,14 +2821,8 @@ std::unique_ptr ParserImpl::primary_expression() { return nullptr; } - if (src.IsCast()) { - return std::make_unique(source, type, - std::move(params)); - } else { - return std::make_unique(source, type, - std::move(params)); - } - + return std::make_unique(source, type, + std::move(params)); } else if (t.IsIdentifier()) { next(); // Consume the peek diff --git a/src/reader/wgsl/parser_impl_primary_expression_test.cc b/src/reader/wgsl/parser_impl_primary_expression_test.cc index bd459e130e..49737e4d9a 100644 --- a/src/reader/wgsl/parser_impl_primary_expression_test.cc +++ b/src/reader/wgsl/parser_impl_primary_expression_test.cc @@ -16,7 +16,6 @@ #include "src/ast/array_accessor_expression.h" #include "src/ast/bitcast_expression.h" #include "src/ast/bool_literal.h" -#include "src/ast/cast_expression.h" #include "src/ast/identifier_expression.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/sint_literal.h" @@ -170,73 +169,19 @@ TEST_F(ParserImplTest, PrimaryExpression_ParenExpr_InvalidExpr) { TEST_F(ParserImplTest, PrimaryExpression_Cast) { auto* f32_type = tm()->Get(std::make_unique()); - auto* p = parser("cast(1)"); + auto* p = parser("f32(1)"); auto e = p->primary_expression(); ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_NE(e, nullptr); - ASSERT_TRUE(e->IsCast()); + ASSERT_TRUE(e->IsConstructor()); + ASSERT_TRUE(e->AsConstructor()->IsTypeConstructor()); - auto* c = e->AsCast(); + auto* c = e->AsConstructor()->AsTypeConstructor(); ASSERT_EQ(c->type(), f32_type); + ASSERT_EQ(c->values().size(), 1u); - ASSERT_TRUE(c->expr()->IsConstructor()); - ASSERT_TRUE(c->expr()->AsConstructor()->IsScalarConstructor()); -} - -TEST_F(ParserImplTest, PrimaryExpression_Cast_MissingGreaterThan) { - auto* p = parser("castprimary_expression(); - ASSERT_TRUE(p->has_error()); - ASSERT_EQ(e, nullptr); - EXPECT_EQ(p->error(), "1:9: missing > for cast expression"); -} - -TEST_F(ParserImplTest, PrimaryExpression_Cast_MissingType) { - auto* p = parser("cast<>(1)"); - auto e = p->primary_expression(); - ASSERT_TRUE(p->has_error()); - ASSERT_EQ(e, nullptr); - EXPECT_EQ(p->error(), "1:6: missing type for cast expression"); -} - -TEST_F(ParserImplTest, PrimaryExpression_Cast_InvalidType) { - auto* p = parser("cast(1)"); - auto e = p->primary_expression(); - ASSERT_TRUE(p->has_error()); - ASSERT_EQ(e, nullptr); - EXPECT_EQ(p->error(), "1:6: unknown type alias 'invalid'"); -} - -TEST_F(ParserImplTest, PrimaryExpression_Cast_MissingLeftParen) { - auto* p = parser("cast1)"); - auto e = p->primary_expression(); - ASSERT_TRUE(p->has_error()); - ASSERT_EQ(e, nullptr); - EXPECT_EQ(p->error(), "1:10: expected ("); -} - -TEST_F(ParserImplTest, PrimaryExpression_Cast_MissingRightParen) { - auto* p = parser("cast(1"); - auto e = p->primary_expression(); - ASSERT_TRUE(p->has_error()); - ASSERT_EQ(e, nullptr); - EXPECT_EQ(p->error(), "1:12: expected )"); -} - -TEST_F(ParserImplTest, PrimaryExpression_Cast_MissingExpression) { - auto* p = parser("cast()"); - auto e = p->primary_expression(); - ASSERT_TRUE(p->has_error()); - ASSERT_EQ(e, nullptr); - EXPECT_EQ(p->error(), "1:11: unable to parse expression"); -} - -TEST_F(ParserImplTest, PrimaryExpression_Cast_InvalidExpression) { - auto* p = parser("cast(if (a) {})"); - auto e = p->primary_expression(); - ASSERT_TRUE(p->has_error()); - ASSERT_EQ(e, nullptr); - EXPECT_EQ(p->error(), "1:11: unable to parse expression"); + ASSERT_TRUE(c->values()[0]->IsConstructor()); + ASSERT_TRUE(c->values()[0]->AsConstructor()->IsScalarConstructor()); } TEST_F(ParserImplTest, PrimaryExpression_Bitcast) { diff --git a/src/reader/wgsl/token.cc b/src/reader/wgsl/token.cc index f836c0cae0..76d88e9e30 100644 --- a/src/reader/wgsl/token.cc +++ b/src/reader/wgsl/token.cc @@ -121,8 +121,6 @@ std::string Token::TypeToName(Type type) { return "builtin"; case Token::Type::kCase: return "case"; - case Token::Type::kCast: - return "cast"; case Token::Type::kCompute: return "compute"; case Token::Type::kConst: diff --git a/src/reader/wgsl/token.h b/src/reader/wgsl/token.h index ebd39e6929..184d1d94af 100644 --- a/src/reader/wgsl/token.h +++ b/src/reader/wgsl/token.h @@ -132,8 +132,6 @@ class Token { kBuiltin, /// A 'case' kCase, - /// A 'cast' - kCast, /// A 'compute' kCompute, /// A 'const' @@ -509,8 +507,6 @@ class Token { bool IsBuiltin() const { return type_ == Type::kBuiltin; } /// @returns true if token is a 'case' bool IsCase() const { return type_ == Type::kCase; } - /// @returns true if token is a 'cast' - bool IsCast() const { return type_ == Type::kCast; } /// @returns true if token is a 'sampler_comparison' bool IsComparisonSampler() const { return type_ == Type::kComparisonSampler; } /// @returns true if token is a 'compute' diff --git a/src/type_determiner.cc b/src/type_determiner.cc index dedf1f04cb..5a412a4d1d 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -26,7 +26,6 @@ #include "src/ast/call_expression.h" #include "src/ast/call_statement.h" #include "src/ast/case_statement.h" -#include "src/ast/cast_expression.h" #include "src/ast/continue_statement.h" #include "src/ast/else_statement.h" #include "src/ast/identifier_expression.h" @@ -305,9 +304,6 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) { if (expr->IsCall()) { return DetermineCall(expr->AsCall()); } - if (expr->IsCast()) { - return DetermineCast(expr->AsCast()); - } if (expr->IsConstructor()) { return DetermineConstructor(expr->AsConstructor()); } @@ -737,15 +733,6 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident, return true; } -bool TypeDeterminer::DetermineCast(ast::CastExpression* expr) { - if (!DetermineResultType(expr->expr())) { - return false; - } - - expr->set_result_type(expr->type()); - return true; -} - bool TypeDeterminer::DetermineConstructor(ast::ConstructorExpression* expr) { if (expr->IsTypeConstructor()) { auto* ty = expr->AsTypeConstructor(); diff --git a/src/type_determiner.h b/src/type_determiner.h index 6ba38339cb..482a439b8e 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -30,7 +30,6 @@ class ArrayAccessorExpression; class BinaryExpression; class BitcastExpression; class CallExpression; -class CastExpression; class ConstructorExpression; class Function; class IdentifierExpression; @@ -120,7 +119,6 @@ class TypeDeterminer { bool DetermineBinary(ast::BinaryExpression* expr); bool DetermineBitcast(ast::BitcastExpression* expr); bool DetermineCall(ast::CallExpression* expr); - bool DetermineCast(ast::CastExpression* expr); bool DetermineConstructor(ast::ConstructorExpression* expr); bool DetermineIdentifier(ast::IdentifierExpression* expr); bool DetermineIntrinsic(ast::IdentifierExpression* name, diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 9a6bd20d9f..d60b376f2f 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -29,7 +29,6 @@ #include "src/ast/call_expression.h" #include "src/ast/call_statement.h" #include "src/ast/case_statement.h" -#include "src/ast/cast_expression.h" #include "src/ast/continue_statement.h" #include "src/ast/else_statement.h" #include "src/ast/float_literal.h" @@ -686,8 +685,11 @@ TEST_F(TypeDeterminerTest, Expr_Call_Intrinsic) { TEST_F(TypeDeterminerTest, Expr_Cast) { ast::type::F32Type f32; - ast::CastExpression cast(&f32, - std::make_unique("name")); + + ast::ExpressionList params; + params.push_back(std::make_unique("name")); + + ast::TypeConstructorExpression cast(&f32, std::move(params)); EXPECT_TRUE(td()->DetermineResultType(&cast)); ASSERT_NE(cast.result_type(), nullptr); diff --git a/src/validator_test.cc b/src/validator_test.cc index b97b58ade4..62ada225af 100644 --- a/src/validator_test.cc +++ b/src/validator_test.cc @@ -24,7 +24,6 @@ #include "src/ast/call_expression.h" #include "src/ast/call_statement.h" #include "src/ast/case_statement.h" -#include "src/ast/cast_expression.h" #include "src/ast/continue_statement.h" #include "src/ast/else_statement.h" #include "src/ast/float_literal.h" diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 68a33cac19..6e9189457c 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -24,7 +24,6 @@ #include "src/ast/call_expression.h" #include "src/ast/call_statement.h" #include "src/ast/case_statement.h" -#include "src/ast/cast_expression.h" #include "src/ast/decorated_variable.h" #include "src/ast/else_statement.h" #include "src/ast/float_literal.h" @@ -724,21 +723,6 @@ bool GeneratorImpl::EmitBuiltinName(std::ostream&, return true; } -bool GeneratorImpl::EmitCast(std::ostream& pre, - std::ostream& out, - ast::CastExpression* expr) { - if (!EmitType(out, expr->type(), "")) { - return false; - } - - out << "("; - if (!EmitExpression(pre, out, expr->expr())) { - return false; - } - out << ")"; - return true; -} - bool GeneratorImpl::EmitCase(std::ostream& out, ast::CaseStatement* stmt) { make_indent(out); @@ -868,9 +852,6 @@ bool GeneratorImpl::EmitExpression(std::ostream& pre, if (expr->IsCall()) { return EmitCall(pre, out, expr->AsCall()); } - if (expr->IsCast()) { - return EmitCast(pre, out, expr->AsCast()); - } if (expr->IsConstructor()) { return EmitConstructor(pre, out, expr->AsConstructor()); } diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index e30c363256..1e231a6af0 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -128,14 +128,6 @@ class GeneratorImpl { /// @param stmt the statement /// @returns true if the statment was emitted successfully bool EmitCase(std::ostream& out, ast::CaseStatement* stmt); - /// Handles generating a cast expression - /// @param pre the preamble for the expression stream - /// @param out the output of the expression stream - /// @param expr the cast expression - /// @returns true if the cast was emitted - bool EmitCast(std::ostream& pre, - std::ostream& out, - ast::CastExpression* expr); /// Handles generating constructor expressions /// @param pre the preamble for the expression stream /// @param out the output of the expression stream diff --git a/src/writer/hlsl/generator_impl_cast_test.cc b/src/writer/hlsl/generator_impl_cast_test.cc index e646e07f8e..b92b1045d1 100644 --- a/src/writer/hlsl/generator_impl_cast_test.cc +++ b/src/writer/hlsl/generator_impl_cast_test.cc @@ -14,11 +14,11 @@ #include -#include "src/ast/cast_expression.h" #include "src/ast/identifier_expression.h" #include "src/ast/module.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/vector_type.h" +#include "src/ast/type_constructor_expression.h" #include "src/writer/hlsl/test_helper.h" namespace tint { @@ -30,8 +30,11 @@ using HlslGeneratorImplTest_Cast = TestHelper; TEST_F(HlslGeneratorImplTest_Cast, EmitExpression_Cast_Scalar) { ast::type::F32Type f32; - auto id = std::make_unique("id"); - ast::CastExpression cast(&f32, std::move(id)); + + ast::ExpressionList params; + params.push_back(std::make_unique("id")); + + ast::TypeConstructorExpression cast(&f32, std::move(params)); ASSERT_TRUE(gen().EmitExpression(pre(), out(), &cast)) << gen().error(); EXPECT_EQ(result(), "float(id)"); @@ -41,8 +44,10 @@ TEST_F(HlslGeneratorImplTest_Cast, EmitExpression_Cast_Vector) { ast::type::F32Type f32; ast::type::VectorType vec3(&f32, 3); - auto id = std::make_unique("id"); - ast::CastExpression cast(&vec3, std::move(id)); + ast::ExpressionList params; + params.push_back(std::make_unique("id")); + + ast::TypeConstructorExpression cast(&vec3, std::move(params)); ASSERT_TRUE(gen().EmitExpression(pre(), out(), &cast)) << gen().error(); EXPECT_EQ(result(), "vector(id)"); diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index 45773b9557..6f26ffa2c0 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -24,7 +24,6 @@ #include "src/ast/call_expression.h" #include "src/ast/call_statement.h" #include "src/ast/case_statement.h" -#include "src/ast/cast_expression.h" #include "src/ast/continue_statement.h" #include "src/ast/decorated_variable.h" #include "src/ast/else_statement.h" @@ -744,19 +743,6 @@ bool GeneratorImpl::EmitCase(ast::CaseStatement* stmt) { return true; } -bool GeneratorImpl::EmitCast(ast::CastExpression* expr) { - if (!EmitType(expr->type(), "")) { - return false; - } - - out_ << "("; - if (!EmitExpression(expr->expr())) { - return false; - } - out_ << ")"; - return true; -} - bool GeneratorImpl::EmitConstructor(ast::ConstructorExpression* expr) { if (expr->IsScalarConstructor()) { return EmitScalarConstructor(expr->AsScalarConstructor()); @@ -992,9 +978,6 @@ bool GeneratorImpl::EmitExpression(ast::Expression* expr) { if (expr->IsCall()) { return EmitCall(expr->AsCall()); } - if (expr->IsCast()) { - return EmitCast(expr->AsCast()); - } if (expr->IsConstructor()) { return EmitConstructor(expr->AsConstructor()); } diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index cc458d2625..f33a143425 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h @@ -98,10 +98,6 @@ class GeneratorImpl : public TextGenerator { /// @param stmt the statement /// @returns true if the statement was emitted successfully bool EmitCase(ast::CaseStatement* stmt); - /// Handles generating a cast expression - /// @param expr the cast expression - /// @returns true if the cast was emitted - bool EmitCast(ast::CastExpression* expr); /// Handles generating constructor expressions /// @param expr the constructor expression /// @returns true if the expression was emitted diff --git a/src/writer/msl/generator_impl_cast_test.cc b/src/writer/msl/generator_impl_cast_test.cc index f7603352cb..aea7840768 100644 --- a/src/writer/msl/generator_impl_cast_test.cc +++ b/src/writer/msl/generator_impl_cast_test.cc @@ -15,11 +15,11 @@ #include #include "gtest/gtest.h" -#include "src/ast/cast_expression.h" #include "src/ast/identifier_expression.h" #include "src/ast/module.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/vector_type.h" +#include "src/ast/type_constructor_expression.h" #include "src/writer/msl/generator_impl.h" namespace tint { @@ -31,8 +31,11 @@ using MslGeneratorImplTest = testing::Test; TEST_F(MslGeneratorImplTest, EmitExpression_Cast_Scalar) { ast::type::F32Type f32; - auto id = std::make_unique("id"); - ast::CastExpression cast(&f32, std::move(id)); + + ast::ExpressionList params; + params.push_back(std::make_unique("id")); + + ast::TypeConstructorExpression cast(&f32, std::move(params)); ast::Module m; GeneratorImpl g(&m); @@ -44,8 +47,10 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Cast_Vector) { ast::type::F32Type f32; ast::type::VectorType vec3(&f32, 3); - auto id = std::make_unique("id"); - ast::CastExpression cast(&vec3, std::move(id)); + ast::ExpressionList params; + params.push_back(std::make_unique("id")); + + ast::TypeConstructorExpression cast(&vec3, std::move(params)); ast::Module m; GeneratorImpl g(&m); diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index b2fcdf11e0..f879befeaa 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -30,7 +30,6 @@ #include "src/ast/call_expression.h" #include "src/ast/call_statement.h" #include "src/ast/case_statement.h" -#include "src/ast/cast_expression.h" #include "src/ast/constructor_expression.h" #include "src/ast/decorated_variable.h" #include "src/ast/else_statement.h" @@ -465,9 +464,6 @@ uint32_t Builder::GenerateExpression(ast::Expression* expr) { if (expr->IsCall()) { return GenerateCallExpression(expr->AsCall()); } - if (expr->IsCast()) { - return GenerateCastExpression(expr->AsCast()); - } if (expr->IsConstructor()) { return GenerateConstructorExpression(expr->AsConstructor(), false); } @@ -1054,32 +1050,20 @@ uint32_t Builder::GenerateConstructorExpression( uint32_t Builder::GenerateTypeConstructorExpression( ast::TypeConstructorExpression* init, bool is_global_init) { - auto type_id = GenerateTypeIfNeeded(init->type()); - if (type_id == 0) { - return 0; - } + auto& values = init->values(); // Generate the zero initializer if there are no values provided. - if (init->values().empty()) { + if (values.empty()) { ast::NullLiteral nl(init->type()->UnwrapPtrIfNeeded()); return GenerateLiteralIfNeeded(&nl); } - auto* result_type = init->type()->UnwrapPtrIfNeeded(); - if (result_type->IsVector()) { - result_type = result_type->AsVector()->type(); - } else if (result_type->IsArray()) { - result_type = result_type->AsArray()->type(); - } else if (result_type->IsMatrix()) { - result_type = result_type->AsMatrix()->type(); - } - std::ostringstream out; out << "__const"; OperandList ops; bool constructor_is_const = true; - for (const auto& e : init->values()) { + for (const auto& e : values) { if (!e->IsConstructor()) { if (is_global_init) { error_ = "constructor must be a constant expression"; @@ -1089,9 +1073,37 @@ uint32_t Builder::GenerateTypeConstructorExpression( } } + auto* result_type = init->type()->UnwrapAliasPtrAlias(); + + bool can_cast_or_copy = result_type->is_scalar(); + if (result_type->IsVector() && result_type->AsVector()->type()->is_scalar()) { + auto* value_type = values[0]->result_type()->UnwrapAliasPtrAlias(); + can_cast_or_copy = + (value_type->IsVector() && + value_type->AsVector()->type()->is_scalar() && + result_type->AsVector()->size() == value_type->AsVector()->size()); + } + if (can_cast_or_copy) { + return GenerateCastOrCopy(result_type, values[0].get()); + } + + auto type_id = GenerateTypeIfNeeded(init->type()); + if (type_id == 0) { + return 0; + } + bool result_is_constant_composite = constructor_is_const; bool result_is_spec_composite = false; - for (const auto& e : init->values()) { + + if (result_type->IsVector()) { + result_type = result_type->AsVector()->type(); + } else if (result_type->IsArray()) { + result_type = result_type->AsArray()->type(); + } else if (result_type->IsMatrix()) { + result_type = result_type->AsMatrix()->type(); + } + + for (const auto& e : values) { uint32_t id = 0; if (constructor_is_const) { id = GenerateConstructorExpression(e->AsConstructor(), is_global_init); @@ -1104,54 +1116,54 @@ uint32_t Builder::GenerateTypeConstructorExpression( } auto* value_type = e->result_type()->UnwrapPtrIfNeeded(); + if (result_type == value_type) { + out << "_" << id; + ops.push_back(Operand::Int(id)); + continue; + } + + // Both scalars, but not the same type so we need to generate a conversion + // of the value. + if (value_type->is_scalar() && result_type->is_scalar()) { + id = GenerateCastOrCopy(result_type, values[0].get()); + out << "_" << id; + ops.push_back(Operand::Int(id)); + continue; + } // When handling vectors as the values there a few cases to take into // consideration: // 1. Module scoped vec3(vec2(1, 2), 3) -> OpSpecConstantOp - // 2. Function scoped vec3(vec2(1, 2), 3) -> OpCompositeExtract + // 2. Function scoped vec3(vec2(1, 2), 3) -> OpCompositeExtract // 3. Either array, 1>(vec3(1, 2, 3)) -> use the ID. + // -> handled above + // + // For cases 1 and 2, if the type is different we also may need to insert + // a type cast. if (value_type->IsVector()) { auto* vec = value_type->AsVector(); auto* vec_type = vec->type(); - // If the value we want is the same as what we have, use it directly. - // This maps to case 3. - if (result_type == value_type) { - out << "_" << id; - ops.push_back(Operand::Int(id)); - } else if (!is_global_init) { - // A non-global initializer. Case 2. - auto value_type_id = GenerateTypeIfNeeded(vec_type); - if (value_type_id == 0) { - return 0; - } + auto value_type_id = GenerateTypeIfNeeded(vec_type); + if (value_type_id == 0) { + return 0; + } - for (uint32_t i = 0; i < vec->size(); ++i) { - auto extract = result_op(); - auto extract_id = extract.to_i(); + for (uint32_t i = 0; i < vec->size(); ++i) { + auto extract = result_op(); + auto extract_id = extract.to_i(); + if (!is_global_init) { + // A non-global initializer. Case 2. push_function_inst(spv::Op::OpCompositeExtract, {Operand::Int(value_type_id), extract, Operand::Int(id), Operand::Int(i)}); - out << "_" << extract_id; - ops.push_back(Operand::Int(extract_id)); - // We no longer have a constant composite, but have to do a // composite construction as these calls are inside a function. result_is_constant_composite = false; - } - } else { - // A global initializer, must use OpSpecConstantOp. Case 1. - auto value_type_id = GenerateTypeIfNeeded(vec_type); - if (value_type_id == 0) { - return 0; - } - - for (uint32_t i = 0; i < vec->size(); ++i) { - auto extract = result_op(); - auto extract_id = extract.to_i(); - + } else { + // A global initializer, must use OpSpecConstantOp. Case 1. auto idx_id = GenerateU32Literal(i); if (idx_id == 0) { return 0; @@ -1161,15 +1173,15 @@ uint32_t Builder::GenerateTypeConstructorExpression( Operand::Int(SpvOpCompositeExtract), Operand::Int(id), Operand::Int(idx_id)}); - out << "_" << extract_id; - ops.push_back(Operand::Int(extract_id)); - result_is_spec_composite = true; } + + out << "_" << extract_id; + ops.push_back(Operand::Int(extract_id)); } } else { - out << "_" << id; - ops.push_back(Operand::Int(id)); + error_ = "Unhandled type cast value type"; + return 0; } } @@ -1192,9 +1204,69 @@ uint32_t Builder::GenerateTypeConstructorExpression( } else { push_function_inst(spv::Op::OpCompositeConstruct, ops); } + return result.to_i(); } +uint32_t Builder::GenerateCastOrCopy(ast::type::Type* to_type, + ast::Expression* from_expr) { + auto result = result_op(); + auto result_id = result.to_i(); + + auto result_type_id = GenerateTypeIfNeeded(to_type); + if (result_type_id == 0) { + return 0; + } + + auto val_id = GenerateExpression(from_expr); + if (val_id == 0) { + return 0; + } + val_id = GenerateLoadIfNeeded(from_expr->result_type(), val_id); + + auto* from_type = from_expr->result_type()->UnwrapPtrIfNeeded(); + + spv::Op op = spv::Op::OpNop; + if ((from_type->IsI32() && to_type->IsF32()) || + (from_type->is_signed_integer_vector() && to_type->is_float_vector())) { + op = spv::Op::OpConvertSToF; + } else if ((from_type->IsU32() && to_type->IsF32()) || + (from_type->is_unsigned_integer_vector() && + to_type->is_float_vector())) { + op = spv::Op::OpConvertUToF; + } else if ((from_type->IsF32() && to_type->IsI32()) || + (from_type->is_float_vector() && + to_type->is_signed_integer_vector())) { + op = spv::Op::OpConvertFToS; + } else if ((from_type->IsF32() && to_type->IsU32()) || + (from_type->is_float_vector() && + to_type->is_unsigned_integer_vector())) { + op = spv::Op::OpConvertFToU; + } else if ((from_type->IsBool() && to_type->IsBool()) || + (from_type->IsU32() && to_type->IsU32()) || + (from_type->IsI32() && to_type->IsI32()) || + (from_type->IsF32() && to_type->IsF32())) { + op = spv::Op::OpCopyObject; + } else if ((from_type->IsI32() && to_type->IsU32()) || + (from_type->IsU32() && to_type->IsI32()) || + (from_type->is_signed_integer_vector() && + to_type->is_unsigned_integer_vector()) || + (from_type->is_unsigned_integer_vector() && + to_type->is_integer_scalar_or_vector())) { + op = spv::Op::OpBitcast; + } + if (op == spv::Op::OpNop) { + error_ = "unable to determine conversion type for cast, from: " + + from_type->type_name() + " to: " + to_type->type_name(); + return 0; + } + + push_function_inst( + op, {Operand::Int(result_type_id), result, Operand::Int(val_id)}); + + return result_id; +} + uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) { auto type_id = GenerateTypeIfNeeded(lit->type()); if (type_id == 0) { @@ -1726,70 +1798,6 @@ uint32_t Builder::GenerateBitcastExpression(ast::BitcastExpression* expr) { return result_id; } -uint32_t Builder::GenerateCastExpression(ast::CastExpression* cast) { - auto result = result_op(); - auto result_id = result.to_i(); - - auto result_type_id = GenerateTypeIfNeeded(cast->result_type()); - if (result_type_id == 0) { - return 0; - } - - auto val_id = GenerateExpression(cast->expr()); - if (val_id == 0) { - return 0; - } - val_id = GenerateLoadIfNeeded(cast->expr()->result_type(), val_id); - - auto* to_type = cast->result_type()->UnwrapPtrIfNeeded(); - auto* from_type = cast->expr()->result_type()->UnwrapPtrIfNeeded(); - - spv::Op op = spv::Op::OpNop; - if ((from_type->IsI32() && to_type->IsF32()) || - (from_type->is_signed_integer_vector() && to_type->is_float_vector())) { - op = spv::Op::OpConvertSToF; - } else if ((from_type->IsU32() && to_type->IsF32()) || - (from_type->is_unsigned_integer_vector() && - to_type->is_float_vector())) { - op = spv::Op::OpConvertUToF; - } else if ((from_type->IsF32() && to_type->IsI32()) || - (from_type->is_float_vector() && - to_type->is_signed_integer_vector())) { - op = spv::Op::OpConvertFToS; - } else if ((from_type->IsF32() && to_type->IsU32()) || - (from_type->is_float_vector() && - to_type->is_unsigned_integer_vector())) { - op = spv::Op::OpConvertFToU; - } else if ((from_type->IsU32() && to_type->IsU32()) || - (from_type->IsI32() && to_type->IsI32()) || - (from_type->IsF32() && to_type->IsF32()) || - (from_type->is_unsigned_integer_vector() && - to_type->is_unsigned_integer_vector()) || - (from_type->is_signed_integer_vector() && - to_type->is_signed_integer_vector()) || - (from_type->is_float_vector() && to_type->is_float_vector())) { - op = spv::Op::OpCopyObject; - } else if ((from_type->IsI32() && to_type->IsU32()) || - (from_type->IsU32() && to_type->IsI32()) || - (from_type->is_signed_integer_vector() && - to_type->is_unsigned_integer_vector()) || - (from_type->is_unsigned_integer_vector() && - to_type->is_integer_scalar_or_vector())) { - op = spv::Op::OpBitcast; - } - - if (op == spv::Op::OpNop) { - error_ = "unable to determine conversion type for cast, from: " + - from_type->type_name() + " to: " + to_type->type_name(); - return 0; - } - - push_function_inst( - op, {Operand::Int(result_type_id), result, Operand::Int(val_id)}); - - return result_id; -} - bool Builder::GenerateConditionalBlock( ast::Expression* cond, const ast::BlockStatement* true_body, diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index a0c33d64ab..ccfce3912a 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -306,10 +306,12 @@ class Builder { uint32_t GenerateSampledImage(ast::type::Type* texture_type, Operand texture_operand, Operand sampler_operand); - /// Generates a cast expression - /// @param expr the expression to generate + /// Generates a cast or object copy for the expression result + /// @param to_type the type we're casting too + /// @param from_expr the expression to cast /// @returns the expression ID on success or 0 otherwise - uint32_t GenerateCastExpression(ast::CastExpression* expr); + uint32_t GenerateCastOrCopy(ast::type::Type* to_type, + ast::Expression* from_expr); /// Generates a loop statement /// @param stmt the statement to generate /// @returns true on successful generation diff --git a/src/writer/spirv/builder_cast_expression_test.cc b/src/writer/spirv/builder_cast_expression_test.cc deleted file mode 100644 index 675bb203fd..0000000000 --- a/src/writer/spirv/builder_cast_expression_test.cc +++ /dev/null @@ -1,554 +0,0 @@ -// Copyright 2020 The Tint Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "gtest/gtest.h" -#include "src/ast/cast_expression.h" -#include "src/ast/float_literal.h" -#include "src/ast/identifier_expression.h" -#include "src/ast/module.h" -#include "src/ast/scalar_constructor_expression.h" -#include "src/ast/sint_literal.h" -#include "src/ast/type/f32_type.h" -#include "src/ast/type/i32_type.h" -#include "src/ast/type/u32_type.h" -#include "src/ast/type/vector_type.h" -#include "src/ast/uint_literal.h" -#include "src/context.h" -#include "src/type_determiner.h" -#include "src/writer/spirv/builder.h" -#include "src/writer/spirv/spv_dump.h" - -namespace tint { -namespace writer { -namespace spirv { -namespace { - -using BuilderTest = testing::Test; - -TEST_F(BuilderTest, Cast_FloatToU32) { - ast::type::U32Type u32; - ast::type::F32Type f32; - - ast::CastExpression cast(&u32, - std::make_unique( - std::make_unique(&f32, 2.4))); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 -%3 = OpTypeFloat 32 -%4 = OpConstant %3 2.4000001 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%1 = OpConvertFToU %2 %4 -)"); -} - -TEST_F(BuilderTest, Cast_FloatToI32) { - ast::type::I32Type i32; - ast::type::F32Type f32; - - ast::CastExpression cast(&i32, - std::make_unique( - std::make_unique(&f32, 2.4))); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 -%3 = OpTypeFloat 32 -%4 = OpConstant %3 2.4000001 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%1 = OpConvertFToS %2 %4 -)"); -} - -TEST_F(BuilderTest, Cast_I32ToFloat) { - ast::type::I32Type i32; - ast::type::F32Type f32; - - ast::CastExpression cast(&f32, - std::make_unique( - std::make_unique(&i32, 2))); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 -%3 = OpTypeInt 32 1 -%4 = OpConstant %3 2 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%1 = OpConvertSToF %2 %4 -)"); -} - -TEST_F(BuilderTest, Cast_U32ToFloat) { - ast::type::U32Type u32; - ast::type::F32Type f32; - - ast::CastExpression cast(&f32, - std::make_unique( - std::make_unique(&u32, 2))); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 -%3 = OpTypeInt 32 0 -%4 = OpConstant %3 2 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%1 = OpConvertUToF %2 %4 -)"); -} - -TEST_F(BuilderTest, Cast_WithLoad) { - ast::type::F32Type f32; - ast::type::I32Type i32; - - // var i : i32 = 1; - // cast(i); - auto var = - std::make_unique("i", ast::StorageClass::kPrivate, &i32); - - ast::CastExpression cast(&f32, - std::make_unique("i")); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - td.RegisterVariableForTesting(var.get()); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); - EXPECT_EQ(b.GenerateCastExpression(&cast), 5u) << b.error(); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1 -%2 = OpTypePointer Private %3 -%4 = OpConstantNull %3 -%1 = OpVariable %2 Private %4 -%6 = OpTypeFloat 32 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%7 = OpLoad %3 %1 -%5 = OpConvertSToF %6 %7 -)"); -} - -TEST_F(BuilderTest, Cast_WithAlias) { - ast::type::I32Type i32; - ast::type::F32Type f32; - - // type Int = i32 - // cast(1.f) - - ast::type::AliasType alias("Int", &i32); - - ast::CastExpression cast(&alias, - std::make_unique( - std::make_unique(&f32, 2.3))); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 -%3 = OpTypeFloat 32 -%4 = OpConstant %3 2.29999995 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%1 = OpConvertFToS %2 %4 -)"); -} - -TEST_F(BuilderTest, Cast_I32ToU32) { - ast::type::U32Type u32; - ast::type::I32Type i32; - - ast::CastExpression cast(&u32, - std::make_unique( - std::make_unique(&i32, 2))); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 -%3 = OpTypeInt 32 1 -%4 = OpConstant %3 2 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%1 = OpBitcast %2 %4 -)"); -} - -TEST_F(BuilderTest, Cast_U32ToI32) { - ast::type::U32Type u32; - ast::type::I32Type i32; - - ast::CastExpression cast(&i32, - std::make_unique( - std::make_unique(&u32, 2))); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 -%3 = OpTypeInt 32 0 -%4 = OpConstant %3 2 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%1 = OpBitcast %2 %4 -)"); -} - -TEST_F(BuilderTest, Cast_I32ToI32) { - ast::type::I32Type i32; - - ast::CastExpression cast(&i32, - std::make_unique( - std::make_unique(&i32, 2))); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 -%3 = OpConstant %2 2 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%1 = OpCopyObject %2 %3 -)"); -} - -TEST_F(BuilderTest, Cast_U32ToU32) { - ast::type::U32Type u32; - - ast::CastExpression cast(&u32, - std::make_unique( - std::make_unique(&u32, 2))); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 -%3 = OpConstant %2 2 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%1 = OpCopyObject %2 %3 -)"); -} - -TEST_F(BuilderTest, Cast_F32ToF32) { - ast::type::F32Type f32; - - ast::CastExpression cast(&f32, - std::make_unique( - std::make_unique(&f32, 2.0))); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 -%3 = OpConstant %2 2 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%1 = OpCopyObject %2 %3 -)"); -} - -TEST_F(BuilderTest, Cast_Vectors_I32_to_F32) { - ast::type::I32Type i32; - ast::type::VectorType ivec3(&i32, 3); - ast::type::F32Type f32; - ast::type::VectorType fvec3(&f32, 3); - - auto var = - std::make_unique("i", ast::StorageClass::kPrivate, &ivec3); - - ast::CastExpression cast(&fvec3, - std::make_unique("i")); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - td.RegisterVariableForTesting(var.get()); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); - EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1 -%3 = OpTypeVector %4 3 -%2 = OpTypePointer Private %3 -%5 = OpConstantNull %3 -%1 = OpVariable %2 Private %5 -%8 = OpTypeFloat 32 -%7 = OpTypeVector %8 3 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%9 = OpLoad %3 %1 -%6 = OpConvertSToF %7 %9 -)"); -} - -TEST_F(BuilderTest, Cast_Vectors_U32_to_F32) { - ast::type::U32Type u32; - ast::type::VectorType uvec3(&u32, 3); - ast::type::F32Type f32; - ast::type::VectorType fvec3(&f32, 3); - - auto var = - std::make_unique("i", ast::StorageClass::kPrivate, &uvec3); - - ast::CastExpression cast(&fvec3, - std::make_unique("i")); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - td.RegisterVariableForTesting(var.get()); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); - EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0 -%3 = OpTypeVector %4 3 -%2 = OpTypePointer Private %3 -%5 = OpConstantNull %3 -%1 = OpVariable %2 Private %5 -%8 = OpTypeFloat 32 -%7 = OpTypeVector %8 3 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%9 = OpLoad %3 %1 -%6 = OpConvertUToF %7 %9 -)"); -} - -TEST_F(BuilderTest, Cast_Vectors_F32_to_I32) { - ast::type::I32Type i32; - ast::type::VectorType ivec3(&i32, 3); - ast::type::F32Type f32; - ast::type::VectorType fvec3(&f32, 3); - - auto var = - std::make_unique("i", ast::StorageClass::kPrivate, &fvec3); - - ast::CastExpression cast(&ivec3, - std::make_unique("i")); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - td.RegisterVariableForTesting(var.get()); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); - EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 -%3 = OpTypeVector %4 3 -%2 = OpTypePointer Private %3 -%5 = OpConstantNull %3 -%1 = OpVariable %2 Private %5 -%8 = OpTypeInt 32 1 -%7 = OpTypeVector %8 3 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%9 = OpLoad %3 %1 -%6 = OpConvertFToS %7 %9 -)"); -} - -TEST_F(BuilderTest, Cast_Vectors_F32_to_U32) { - ast::type::U32Type u32; - ast::type::VectorType uvec3(&u32, 3); - ast::type::F32Type f32; - ast::type::VectorType fvec3(&f32, 3); - - auto var = - std::make_unique("i", ast::StorageClass::kPrivate, &fvec3); - - ast::CastExpression cast(&uvec3, - std::make_unique("i")); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - td.RegisterVariableForTesting(var.get()); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); - EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 -%3 = OpTypeVector %4 3 -%2 = OpTypePointer Private %3 -%5 = OpConstantNull %3 -%1 = OpVariable %2 Private %5 -%8 = OpTypeInt 32 0 -%7 = OpTypeVector %8 3 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%9 = OpLoad %3 %1 -%6 = OpConvertFToU %7 %9 -)"); -} - -TEST_F(BuilderTest, Cast_Vectors_U32_to_U32) { - ast::type::U32Type u32; - ast::type::VectorType uvec3(&u32, 3); - - auto var = - std::make_unique("i", ast::StorageClass::kPrivate, &uvec3); - - ast::CastExpression cast(&uvec3, - std::make_unique("i")); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - td.RegisterVariableForTesting(var.get()); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); - EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0 -%3 = OpTypeVector %4 3 -%2 = OpTypePointer Private %3 -%5 = OpConstantNull %3 -%1 = OpVariable %2 Private %5 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%7 = OpLoad %3 %1 -%6 = OpCopyObject %3 %7 -)"); -} - -TEST_F(BuilderTest, Cast_Vectors_I32_to_U32) { - ast::type::U32Type u32; - ast::type::VectorType uvec3(&u32, 3); - ast::type::I32Type i32; - ast::type::VectorType ivec3(&i32, 3); - - auto var = - std::make_unique("i", ast::StorageClass::kPrivate, &ivec3); - - ast::CastExpression cast(&uvec3, - std::make_unique("i")); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - td.RegisterVariableForTesting(var.get()); - ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); - EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1 -%3 = OpTypeVector %4 3 -%2 = OpTypePointer Private %3 -%5 = OpConstantNull %3 -%1 = OpVariable %2 Private %5 -%8 = OpTypeInt 32 0 -%7 = OpTypeVector %8 3 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%9 = OpLoad %3 %1 -%6 = OpBitcast %7 %9 -)"); -} - -} // namespace -} // namespace spirv -} // namespace writer -} // namespace tint diff --git a/src/writer/spirv/builder_constructor_expression_test.cc b/src/writer/spirv/builder_constructor_expression_test.cc index 3a16281fef..f1468db00e 100644 --- a/src/writer/spirv/builder_constructor_expression_test.cc +++ b/src/writer/spirv/builder_constructor_expression_test.cc @@ -18,12 +18,25 @@ #include "spirv/unified1/spirv.h" #include "spirv/unified1/spirv.hpp11" #include "src/ast/binary_expression.h" +#include "src/ast/bool_literal.h" #include "src/ast/float_literal.h" #include "src/ast/identifier_expression.h" +#include "src/ast/member_accessor_expression.h" #include "src/ast/scalar_constructor_expression.h" +#include "src/ast/sint_literal.h" +#include "src/ast/struct.h" +#include "src/ast/struct_decoration.h" +#include "src/ast/struct_member.h" +#include "src/ast/type/array_type.h" +#include "src/ast/type/bool_type.h" #include "src/ast/type/f32_type.h" +#include "src/ast/type/i32_type.h" +#include "src/ast/type/matrix_type.h" +#include "src/ast/type/struct_type.h" +#include "src/ast/type/u32_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" +#include "src/ast/uint_literal.h" #include "src/context.h" #include "src/type_determiner.h" #include "src/writer/spirv/builder.h" @@ -82,31 +95,40 @@ TEST_F(BuilderTest, Constructor_Type) { )"); } -TEST_F(BuilderTest, Constructor_Type_ZeroInit) { +TEST_F(BuilderTest, Constructor_Type_WithAlias) { + ast::type::I32Type i32; ast::type::F32Type f32; - ast::type::VectorType vec(&f32, 2); - ast::ExpressionList vals; - ast::TypeConstructorExpression t(&vec, std::move(vals)); + // type Int = i32 + // cast(1.f) + + ast::type::AliasType alias("Int", &i32); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.3))); + + ast::TypeConstructorExpression cast(&alias, std::move(params)); Context ctx; ast::Module mod; TypeDeterminer td(&ctx, &mod); - EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); Builder b(&mod); b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 1u); - EXPECT_EQ(b.GenerateConstructorExpression(&t, false), 3u); - ASSERT_FALSE(b.has_error()) << b.error(); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 -%1 = OpTypeVector %2 2 -%3 = OpConstantNull %1 + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpTypeFloat 32 +%4 = OpConstant %3 2.29999995 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpConvertFToS %2 %4 )"); } -TEST_F(BuilderTest, Constructor_Type_NonConstructorParam) { +TEST_F(BuilderTest, Constructor_Type_IdentifierExpression_Param) { ast::type::F32Type f32; ast::type::VectorType vec(&f32, 2); @@ -130,7 +152,7 @@ TEST_F(BuilderTest, Constructor_Type_NonConstructorParam) { b.push_function(Function{}); ASSERT_TRUE(b.GenerateFunctionVariable(var.get())) << b.error(); - EXPECT_EQ(b.GenerateConstructorExpression(&t, false), 8u); + EXPECT_EQ(b.GenerateExpression(&t), 8u); ASSERT_FALSE(b.has_error()) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32 @@ -149,81 +171,7 @@ TEST_F(BuilderTest, Constructor_Type_NonConstructorParam) { )"); } -TEST_F(BuilderTest, Constructor_Type_NonConstVector) { - ast::type::F32Type f32; - ast::type::VectorType vec2(&f32, 2); - ast::type::VectorType vec4(&f32, 4); - - auto var = std::make_unique( - "ident", ast::StorageClass::kFunction, &vec2); - - ast::ExpressionList vals; - vals.push_back(std::make_unique( - std::make_unique(&f32, 1.0f))); - vals.push_back(std::make_unique( - std::make_unique(&f32, 1.0f))); - vals.push_back(std::make_unique("ident")); - - ast::TypeConstructorExpression t(&vec4, std::move(vals)); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - td.RegisterVariableForTesting(var.get()); - EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); - - Builder b(&mod); - b.push_function(Function{}); - ASSERT_TRUE(b.GenerateFunctionVariable(var.get())) << b.error(); - - EXPECT_EQ(b.GenerateConstructorExpression(&t, false), 11u); - ASSERT_FALSE(b.has_error()) << b.error(); - - EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 -%3 = OpTypeVector %4 2 -%2 = OpTypePointer Function %3 -%5 = OpConstantNull %3 -%6 = OpTypeVector %4 4 -%7 = OpConstant %4 1 -)"); - EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), - R"(%1 = OpVariable %2 Function %5 -)"); - - EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%8 = OpLoad %3 %1 -%9 = OpCompositeExtract %4 %8 0 -%10 = OpCompositeExtract %4 %8 1 -%11 = OpCompositeConstruct %6 %7 %7 %9 %10 -)"); -} - -TEST_F(BuilderTest, Constructor_Type_Dedups) { - ast::type::F32Type f32; - ast::type::VectorType vec(&f32, 3); - - ast::ExpressionList vals; - vals.push_back(std::make_unique( - std::make_unique(&f32, 1.0f))); - vals.push_back(std::make_unique( - std::make_unique(&f32, 1.0f))); - vals.push_back(std::make_unique( - std::make_unique(&f32, 3.0f))); - - ast::TypeConstructorExpression t(&vec, std::move(vals)); - - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); - EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); - - Builder b(&mod); - EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 5u); - EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 5u); - ASSERT_FALSE(b.has_error()) << b.error(); -} - -TEST_F(BuilderTest, Constructor_NonConst_Type_Fails) { +TEST_F(BuilderTest, Constructor_Type_NonConst_Value_Fails) { ast::type::F32Type f32; ast::type::VectorType vec(&f32, 2); auto rel = std::make_unique( @@ -247,6 +195,1607 @@ TEST_F(BuilderTest, Constructor_NonConst_Type_Fails) { EXPECT_EQ(b.error(), R"(constructor must be a constant expression)"); } +TEST_F(BuilderTest, Constructor_Type_Bool_With_Bool) { + ast::type::BoolType bool_type; + + ast::ExpressionList vals; + vals.push_back(std::make_unique( + std::make_unique(&bool_type, true))); + + ast::TypeConstructorExpression t(&bool_type, std::move(vals)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&t)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateExpression(&t), 1u); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool +%3 = OpConstantTrue %2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpCopyObject %2 %3 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_I32_With_I32) { + ast::type::I32Type i32; + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&i32, 2))); + + ast::TypeConstructorExpression cast(&i32, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpConstant %2 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpCopyObject %2 %3 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_U32_With_U32) { + ast::type::U32Type u32; + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&u32, 2))); + + ast::TypeConstructorExpression cast(&u32, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 +%3 = OpConstant %2 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpCopyObject %2 %3 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_F32_With_F32) { + ast::type::F32Type f32; + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&f32, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%3 = OpConstant %2 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpCopyObject %2 %3 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Vec2_With_F32_F32) { + ast::type::F32Type f32; + ast::type::VectorType vec(&f32, 2); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&vec, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 4u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 2 +%3 = OpConstant %2 2 +%4 = OpConstantComposite %1 %3 %3 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Vec3_With_F32_F32_F32) { + ast::type::F32Type f32; + ast::type::VectorType vec(&f32, 3); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&vec, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 4u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 3 +%3 = OpConstant %2 2 +%4 = OpConstantComposite %1 %3 %3 %3 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Vec3_With_F32_Vec2) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec3(&f32, 3); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + + ast::TypeConstructorExpression cast(&vec3, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 8u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 3 +%3 = OpConstant %2 2 +%4 = OpTypeVector %2 2 +%5 = OpConstantComposite %4 %3 %3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%6 = OpCompositeExtract %2 %5 0 +%7 = OpCompositeExtract %2 %5 1 +%8 = OpCompositeConstruct %1 %3 %6 %7 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Vec3_With_Vec2_F32) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec3(&f32, 3); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&vec3, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 8u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 3 +%3 = OpTypeVector %2 2 +%4 = OpConstant %2 2 +%5 = OpConstantComposite %3 %4 %4 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%6 = OpCompositeExtract %2 %5 0 +%7 = OpCompositeExtract %2 %5 1 +%8 = OpCompositeConstruct %1 %6 %7 %4 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Vec4_With_F32_F32_F32_F32) { + ast::type::F32Type f32; + ast::type::VectorType vec(&f32, 4); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&vec, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 4u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpConstant %2 2 +%4 = OpConstantComposite %1 %3 %3 %3 %3 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Vec4_With_F32_F32_Vec2) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 8u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpConstant %2 2 +%4 = OpTypeVector %2 2 +%5 = OpConstantComposite %4 %3 %3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%6 = OpCompositeExtract %2 %5 0 +%7 = OpCompositeExtract %2 %5 1 +%8 = OpCompositeConstruct %1 %3 %3 %6 %7 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Vec4_With_F32_Vec2_F32) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 8u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpConstant %2 2 +%4 = OpTypeVector %2 2 +%5 = OpConstantComposite %4 %3 %3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%6 = OpCompositeExtract %2 %5 0 +%7 = OpCompositeExtract %2 %5 1 +%8 = OpCompositeConstruct %1 %3 %6 %7 %3 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Vec4_With_Vec2_F32_F32) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 8u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpTypeVector %2 2 +%4 = OpConstant %2 2 +%5 = OpConstantComposite %3 %4 %4 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%6 = OpCompositeExtract %2 %5 0 +%7 = OpCompositeExtract %2 %5 1 +%8 = OpCompositeConstruct %1 %6 %7 %4 %4 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Vec4_With_Vec2_Vec2) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList vec2_params; + vec2_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec2_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + params.push_back(std::make_unique( + &vec2, std::move(vec2_params))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 10u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpTypeVector %2 2 +%4 = OpConstant %2 2 +%5 = OpConstantComposite %3 %4 %4 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%6 = OpCompositeExtract %2 %5 0 +%7 = OpCompositeExtract %2 %5 1 +%8 = OpCompositeExtract %2 %5 0 +%9 = OpCompositeExtract %2 %5 1 +%10 = OpCompositeConstruct %1 %6 %7 %8 %9 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Vec4_With_F32_Vec3) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + &vec3, std::move(vec_params))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 9u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpConstant %2 2 +%4 = OpTypeVector %2 3 +%5 = OpConstantComposite %4 %3 %3 %3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%6 = OpCompositeExtract %2 %5 0 +%7 = OpCompositeExtract %2 %5 1 +%8 = OpCompositeExtract %2 %5 2 +%9 = OpCompositeConstruct %1 %3 %6 %7 %8 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Vec4_With_Vec3_F32) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + &vec3, std::move(vec_params))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 9u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpTypeVector %2 3 +%4 = OpConstant %2 2 +%5 = OpConstantComposite %3 %4 %4 %4 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%6 = OpCompositeExtract %2 %5 0 +%7 = OpCompositeExtract %2 %5 1 +%8 = OpCompositeExtract %2 %5 2 +%9 = OpCompositeConstruct %1 %6 %7 %8 %4 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec3_With_F32_Vec2) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec3(&f32, 3); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + + ast::TypeConstructorExpression cast(&vec3, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 3 +%3 = OpConstant %2 2 +%4 = OpTypeVector %2 2 +%5 = OpConstantComposite %4 %3 %3 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 0 +%6 = OpSpecConstantOp %2 CompositeExtract %5 8 +%10 = OpConstant %7 1 +%9 = OpSpecConstantOp %2 CompositeExtract %5 10 +%11 = OpSpecConstantComposite %1 %3 %6 %9 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec3_With_Vec2_F32) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec3(&f32, 3); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&vec3, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 3 +%3 = OpTypeVector %2 2 +%4 = OpConstant %2 2 +%5 = OpConstantComposite %3 %4 %4 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 0 +%6 = OpSpecConstantOp %2 CompositeExtract %5 8 +%10 = OpConstant %7 1 +%9 = OpSpecConstantOp %2 CompositeExtract %5 10 +%11 = OpSpecConstantComposite %1 %6 %9 %4 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_F32_F32_Vec2) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpConstant %2 2 +%4 = OpTypeVector %2 2 +%5 = OpConstantComposite %4 %3 %3 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 0 +%6 = OpSpecConstantOp %2 CompositeExtract %5 8 +%10 = OpConstant %7 1 +%9 = OpSpecConstantOp %2 CompositeExtract %5 10 +%11 = OpSpecConstantComposite %1 %3 %3 %6 %9 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_F32_Vec2_F32) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpConstant %2 2 +%4 = OpTypeVector %2 2 +%5 = OpConstantComposite %4 %3 %3 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 0 +%6 = OpSpecConstantOp %2 CompositeExtract %5 8 +%10 = OpConstant %7 1 +%9 = OpSpecConstantOp %2 CompositeExtract %5 10 +%11 = OpSpecConstantComposite %1 %3 %6 %9 %3 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_Vec2_F32_F32) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpTypeVector %2 2 +%4 = OpConstant %2 2 +%5 = OpConstantComposite %3 %4 %4 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 0 +%6 = OpSpecConstantOp %2 CompositeExtract %5 8 +%10 = OpConstant %7 1 +%9 = OpSpecConstantOp %2 CompositeExtract %5 10 +%11 = OpSpecConstantComposite %1 %6 %9 %4 %4 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_Vec2_Vec2) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList vec2_params; + vec2_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec2_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + &vec2, std::move(vec_params))); + params.push_back(std::make_unique( + &vec2, std::move(vec2_params))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpTypeVector %2 2 +%4 = OpConstant %2 2 +%5 = OpConstantComposite %3 %4 %4 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 0 +%6 = OpSpecConstantOp %2 CompositeExtract %5 8 +%10 = OpConstant %7 1 +%9 = OpSpecConstantOp %2 CompositeExtract %5 10 +%11 = OpSpecConstantOp %2 CompositeExtract %5 8 +%12 = OpSpecConstantOp %2 CompositeExtract %5 10 +%13 = OpSpecConstantComposite %1 %6 %9 %11 %12 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_F32_Vec3) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + params.push_back(std::make_unique( + &vec3, std::move(vec_params))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpConstant %2 2 +%4 = OpTypeVector %2 3 +%5 = OpConstantComposite %4 %3 %3 %3 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 0 +%6 = OpSpecConstantOp %2 CompositeExtract %5 8 +%10 = OpConstant %7 1 +%9 = OpSpecConstantOp %2 CompositeExtract %5 10 +%12 = OpConstant %7 2 +%11 = OpSpecConstantOp %2 CompositeExtract %5 12 +%13 = OpSpecConstantComposite %1 %3 %6 %9 %11 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ModuleScope_Vec4_With_Vec3_F32) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + ast::type::VectorType vec4(&f32, 4); + + ast::ExpressionList vec_params; + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + vec_params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::ExpressionList params; + params.push_back(std::make_unique( + &vec3, std::move(vec_params))); + params.push_back(std::make_unique( + std::make_unique(&f32, 2.0))); + + ast::TypeConstructorExpression cast(&vec4, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 4 +%3 = OpTypeVector %2 3 +%4 = OpConstant %2 2 +%5 = OpConstantComposite %3 %4 %4 %4 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 0 +%6 = OpSpecConstantOp %2 CompositeExtract %5 8 +%10 = OpConstant %7 1 +%9 = OpSpecConstantOp %2 CompositeExtract %5 10 +%12 = OpConstant %7 2 +%11 = OpSpecConstantOp %2 CompositeExtract %5 12 +%13 = OpSpecConstantComposite %1 %6 %9 %11 %4 +)"); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Mat2x2_With_Vec2_Vec2) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Mat3x2_With_Vec2_Vec2_Vec2) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Mat4x2_With_Vec2_Vec2_Vec2_Vec2) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Mat2x3_With_Vec3_Vec3) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Mat3x3_With_Vec3_Vec3_Vec3) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Mat4x3_With_Vec3_Vec3_Vec3_Vec3) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Mat2x4_With_Vec4_Vec4) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Mat3x4_With_Vec4_Vec4_Vec4) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Mat4x4_With_Vec4_Vec4_Vec4_Vec4) { + FAIL(); +} + +TEST_F(BuilderTest, + DISABLED_Constructor_Type_ModuleScope_Mat2x2_With_Vec2_Vec2) { + FAIL(); +} + +TEST_F(BuilderTest, + DISABLED_Constructor_Type_ModuleScope_Mat3x2_With_Vec2_Vec2_Vec2) { + FAIL(); +} + +TEST_F(BuilderTest, + DISABLED_Constructor_Type_ModuleScope_Mat4x2_With_Vec2_Vec2_Vec2_Vec2) { + FAIL(); +} + +TEST_F(BuilderTest, + DISABLED_Constructor_Type_ModuleScope_Mat2x3_With_Vec3_Vec3) { + FAIL(); +} + +TEST_F(BuilderTest, + DISABLED_Constructor_Type_ModuleScope_Mat3x3_With_Vec3_Vec3_Vec3) { + FAIL(); +} + +TEST_F(BuilderTest, + DISABLED_Constructor_Type_ModuleScope_Mat4x3_With_Vec3_Vec3_Vec3_Vec3) { + FAIL(); +} + +TEST_F(BuilderTest, + DISABLED_Constructor_Type_ModuleScope_Mat2x4_With_Vec4_Vec4) { + FAIL(); +} + +TEST_F(BuilderTest, + DISABLED_Constructor_Type_ModuleScope_Mat3x4_With_Vec4_Vec4_Vec4) { + FAIL(); +} + +TEST_F(BuilderTest, + DISABLED_Constructor_Type_ModuleScope_Mat4x4_With_Vec4_Vec4_Vec4_Vec4) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Array_5_F32) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Array_5_Vec3) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_ModuleScope_Array_5_Vec3) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_Struct) { + FAIL(); +} + +TEST_F(BuilderTest, DISABLED_Constructor_Type_ModuleScope_Struct_With_Vec2) { + FAIL(); +} + +TEST_F(BuilderTest, Constructor_Type_ZeroInit_F32) { + ast::type::F32Type f32; + + ast::ExpressionList vals; + ast::TypeConstructorExpression t(&f32, std::move(vals)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateExpression(&t), 2u); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32 +%2 = OpConstantNull %1 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ZeroInit_I32) { + ast::type::I32Type i32; + + ast::ExpressionList vals; + ast::TypeConstructorExpression t(&i32, std::move(vals)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateExpression(&t), 2u); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1 +%2 = OpConstantNull %1 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ZeroInit_U32) { + ast::type::U32Type u32; + + ast::ExpressionList vals; + ast::TypeConstructorExpression t(&u32, std::move(vals)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateExpression(&t), 2u); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0 +%2 = OpConstantNull %1 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ZeroInit_Bool) { + ast::type::BoolType bool_type; + + ast::ExpressionList vals; + ast::TypeConstructorExpression t(&bool_type, std::move(vals)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateExpression(&t), 2u); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool +%2 = OpConstantNull %1 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ZeroInit_Vector) { + ast::type::I32Type i32; + ast::type::VectorType vec(&i32, 2); + + ast::ExpressionList vals; + ast::TypeConstructorExpression t(&vec, std::move(vals)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateExpression(&t), 3u); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%1 = OpTypeVector %2 2 +%3 = OpConstantNull %1 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ZeroInit_Matrix) { + ast::type::F32Type f32; + ast::type::MatrixType mat(&f32, 2, 4); + + ast::ExpressionList vals; + ast::TypeConstructorExpression t(&mat, std::move(vals)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateExpression(&t), 4u); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32 +%2 = OpTypeVector %3 2 +%1 = OpTypeMatrix %2 4 +%4 = OpConstantNull %1 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ZeroInit_Array) { + ast::type::I32Type i32; + ast::type::ArrayType ary(&i32, 2); + + ast::ExpressionList vals; + ast::TypeConstructorExpression t(&ary, std::move(vals)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateExpression(&t), 5u); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 2 +%1 = OpTypeArray %2 %4 +%5 = OpConstantNull %1 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_ZeroInit_Struct) { + ast::type::F32Type f32; + + ast::StructMemberDecorationList decos; + ast::StructMemberList members; + members.push_back( + std::make_unique("a", &f32, std::move(decos))); + + auto s = std::make_unique(ast::StructDecoration::kNone, + std::move(members)); + ast::type::StructType s_type(std::move(s)); + s_type.set_name("my_struct"); + + ast::ExpressionList vals; + ast::TypeConstructorExpression t(&s_type, std::move(vals)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateExpression(&t), 3u); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeStruct %2 +%3 = OpConstantNull %1 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_U32_To_I32) { + ast::type::U32Type u32; + ast::type::I32Type i32; + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&u32, 2))); + + ast::TypeConstructorExpression cast(&i32, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpBitcast %2 %4 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_I32_To_U32) { + ast::type::U32Type u32; + ast::type::I32Type i32; + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&i32, 2))); + + ast::TypeConstructorExpression cast(&u32, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 +%3 = OpTypeInt 32 1 +%4 = OpConstant %3 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpBitcast %2 %4 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_F32_To_I32) { + ast::type::I32Type i32; + ast::type::F32Type f32; + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.4))); + + ast::TypeConstructorExpression cast(&i32, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpTypeFloat 32 +%4 = OpConstant %3 2.4000001 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpConvertFToS %2 %4 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_F32_To_U32) { + ast::type::U32Type u32; + ast::type::F32Type f32; + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 2.4))); + + ast::TypeConstructorExpression cast(&u32, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 +%3 = OpTypeFloat 32 +%4 = OpConstant %3 2.4000001 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpConvertFToU %2 %4 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_I32_To_F32) { + ast::type::I32Type i32; + ast::type::F32Type f32; + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&i32, 2))); + + ast::TypeConstructorExpression cast(&f32, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%3 = OpTypeInt 32 1 +%4 = OpConstant %3 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpConvertSToF %2 %4 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_U32_To_F32) { + ast::type::U32Type u32; + ast::type::F32Type f32; + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&u32, 2))); + + ast::TypeConstructorExpression cast(&f32, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpConvertUToF %2 %4 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_Vectors_U32_to_I32) { + ast::type::U32Type u32; + ast::type::VectorType uvec3(&u32, 3); + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &uvec3); + + ast::ExpressionList params; + params.push_back(std::make_unique("i")); + + ast::TypeConstructorExpression cast(&ivec3, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeInt 32 1 +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%6 = OpBitcast %7 %9 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_Vectors_F32_to_I32) { + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::F32Type f32; + ast::type::VectorType fvec3(&f32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &fvec3); + + ast::ExpressionList params; + params.push_back(std::make_unique("i")); + + ast::TypeConstructorExpression cast(&ivec3, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeInt 32 1 +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%6 = OpConvertFToS %7 %9 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_Vectors_I32_to_U32) { + ast::type::U32Type u32; + ast::type::VectorType uvec3(&u32, 3); + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &ivec3); + + ast::ExpressionList params; + params.push_back(std::make_unique("i")); + + ast::TypeConstructorExpression cast(&uvec3, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeInt 32 0 +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%6 = OpBitcast %7 %9 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_Vectors_F32_to_U32) { + ast::type::U32Type u32; + ast::type::VectorType uvec3(&u32, 3); + ast::type::F32Type f32; + ast::type::VectorType fvec3(&f32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &fvec3); + + ast::ExpressionList params; + params.push_back(std::make_unique("i")); + + ast::TypeConstructorExpression cast(&uvec3, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeInt 32 0 +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%6 = OpConvertFToU %7 %9 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_Vectors_I32_to_F32) { + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::F32Type f32; + ast::type::VectorType fvec3(&f32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &ivec3); + + ast::ExpressionList params; + params.push_back(std::make_unique("i")); + + ast::TypeConstructorExpression cast(&fvec3, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeFloat 32 +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%6 = OpConvertSToF %7 %9 +)"); +} + +TEST_F(BuilderTest, Constructor_Type_Convert_Vectors_U32_to_F32) { + ast::type::U32Type u32; + ast::type::VectorType uvec3(&u32, 3); + ast::type::F32Type f32; + ast::type::VectorType fvec3(&f32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &uvec3); + + ast::ExpressionList params; + params.push_back(std::make_unique("i")); + + ast::TypeConstructorExpression cast(&fvec3, std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeFloat 32 +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%6 = OpConvertUToF %7 %9 +)"); +} + } // namespace } // namespace spirv } // namespace writer diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index 4389ae1c93..e13fe6b179 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -29,7 +29,6 @@ #include "src/ast/call_expression.h" #include "src/ast/call_statement.h" #include "src/ast/case_statement.h" -#include "src/ast/cast_expression.h" #include "src/ast/constructor_expression.h" #include "src/ast/continue_statement.h" #include "src/ast/decorated_variable.h" @@ -186,9 +185,6 @@ bool GeneratorImpl::EmitExpression(ast::Expression* expr) { if (expr->IsCall()) { return EmitCall(expr->AsCall()); } - if (expr->IsCast()) { - return EmitCast(expr->AsCast()); - } if (expr->IsIdentifier()) { return EmitIdentifier(expr->AsIdentifier()); } @@ -269,21 +265,6 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { return true; } -bool GeneratorImpl::EmitCast(ast::CastExpression* expr) { - out_ << "cast<"; - if (!EmitType(expr->type())) { - return false; - } - - out_ << ">("; - if (!EmitExpression(expr->expr())) { - return false; - } - - out_ << ")"; - return true; -} - bool GeneratorImpl::EmitConstructor(ast::ConstructorExpression* expr) { if (expr->IsScalarConstructor()) { return EmitScalarConstructor(expr->AsScalarConstructor()); diff --git a/src/writer/wgsl/generator_impl.h b/src/writer/wgsl/generator_impl.h index 5805f09260..0a9e6f4056 100644 --- a/src/writer/wgsl/generator_impl.h +++ b/src/writer/wgsl/generator_impl.h @@ -98,10 +98,6 @@ class GeneratorImpl : public TextGenerator { /// @param stmt the statement /// @returns true if the statment was emitted successfully bool EmitCase(ast::CaseStatement* stmt); - /// Handles generating a cast expression - /// @param expr the cast expression - /// @returns true if the cast was emitted - bool EmitCast(ast::CastExpression* expr); /// Handles generating a scalar constructor /// @param expr the scalar constructor expression /// @returns true if the scalar constructor is emitted diff --git a/src/writer/wgsl/generator_impl_cast_test.cc b/src/writer/wgsl/generator_impl_cast_test.cc index e388337fe9..ee81a66af3 100644 --- a/src/writer/wgsl/generator_impl_cast_test.cc +++ b/src/writer/wgsl/generator_impl_cast_test.cc @@ -15,9 +15,9 @@ #include #include "gtest/gtest.h" -#include "src/ast/cast_expression.h" #include "src/ast/identifier_expression.h" #include "src/ast/type/f32_type.h" +#include "src/ast/type_constructor_expression.h" #include "src/writer/wgsl/generator_impl.h" namespace tint { @@ -29,12 +29,15 @@ using WgslGeneratorImplTest = testing::Test; TEST_F(WgslGeneratorImplTest, EmitExpression_Cast) { ast::type::F32Type f32; - auto id = std::make_unique("id"); - ast::CastExpression cast(&f32, std::move(id)); + + ast::ExpressionList params; + params.push_back(std::make_unique("id")); + + ast::TypeConstructorExpression cast(&f32, std::move(params)); GeneratorImpl g; ASSERT_TRUE(g.EmitExpression(&cast)) << g.error(); - EXPECT_EQ(g.result(), "cast(id)"); + EXPECT_EQ(g.result(), "f32(id)"); } } // namespace