diff --git a/BUILD.gn b/BUILD.gn index 42f91ecfe1..0eaa74c3be 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -800,6 +800,7 @@ source_set("tint_unittests_spv_writer_src") { "src/writer/spirv/builder_as_expression_test.cc", "src/writer/spirv/builder_assign_test.cc", "src/writer/spirv/builder_binary_expression_test.cc", + "src/writer/spirv/builder_block_test.cc", "src/writer/spirv/builder_call_test.cc", "src/writer/spirv/builder_cast_expression_test.cc", "src/writer/spirv/builder_constructor_expression_test.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5181674ee1..efa3aeb146 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -454,6 +454,7 @@ if(${TINT_BUILD_SPV_WRITER}) writer/spirv/builder_as_expression_test.cc writer/spirv/builder_assign_test.cc writer/spirv/builder_binary_expression_test.cc + writer/spirv/builder_block_test.cc writer/spirv/builder_call_test.cc writer/spirv/builder_cast_expression_test.cc writer/spirv/builder_constructor_expression_test.cc diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 0de445d265..fbbc5d5e41 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -22,6 +22,7 @@ #include "src/ast/as_expression.h" #include "src/ast/assignment_statement.h" #include "src/ast/binary_expression.h" +#include "src/ast/block_statement.h" #include "src/ast/break_statement.h" #include "src/ast/call_expression.h" #include "src/ast/call_statement.h" @@ -239,6 +240,19 @@ bool TypeDeterminer::DetermineFunction(ast::Function* func) { return true; } +bool TypeDeterminer::DetermineStatements(const ast::BlockStatement* stmts) { + for (const auto& stmt : *stmts) { + if (!DetermineVariableStorageClass(stmt.get())) { + return false; + } + + if (!DetermineResultType(stmt.get())) { + return false; + } + } + return true; +} + bool TypeDeterminer::DetermineStatements(const ast::StatementList& stmts) { for (const auto& stmt : stmts) { if (!DetermineVariableStorageClass(stmt.get())) { @@ -282,6 +296,9 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) { auto* a = stmt->AsAssign(); return DetermineResultType(a->lhs()) && DetermineResultType(a->rhs()); } + if (stmt->IsBlock()) { + return DetermineStatements(stmt->AsBlock()); + } if (stmt->IsBreak()) { return true; } @@ -347,7 +364,8 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) { return DetermineResultType(v->variable()->constructor()); } - set_error(stmt->source(), "unknown statement type for type determination"); + set_error(stmt->source(), + "unknown statement type for type determination: " + stmt->str()); return false; } diff --git a/src/type_determiner.h b/src/type_determiner.h index ea9b9fcff8..d7dc0d1c75 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -64,6 +64,10 @@ class TypeDeterminer { /// Determines type information for a set of statements /// @param stmts the statements to check /// @returns true if the determination was successful + bool DetermineStatements(const ast::BlockStatement* stmts); + /// Determines type information for a set of statements + /// @param stmts the statements to check + /// @returns true if the determination was successful bool DetermineStatements(const ast::StatementList& stmts); /// Determines type information for a statement /// @param stmt the statement to check diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index d238ab6f74..075eed5e5a 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -24,6 +24,7 @@ #include "src/ast/as_expression.h" #include "src/ast/assignment_statement.h" #include "src/ast/binary_expression.h" +#include "src/ast/block_statement.h" #include "src/ast/break_statement.h" #include "src/ast/call_expression.h" #include "src/ast/call_statement.h" @@ -63,7 +64,7 @@ namespace { class FakeStmt : public ast::Statement { public: bool IsValid() const override { return true; } - void to_str(std::ostream&, size_t) const override {} + void to_str(std::ostream& out, size_t) const override { out << "Fake"; } }; class FakeExpr : public ast::Expression { @@ -97,7 +98,8 @@ TEST_F(TypeDeterminerTest, Error_WithEmptySource) { s.set_source(Source{0, 0}); EXPECT_FALSE(td()->DetermineResultType(&s)); - EXPECT_EQ(td()->error(), "unknown statement type for type determination"); + EXPECT_EQ(td()->error(), + "unknown statement type for type determination: Fake"); } TEST_F(TypeDeterminerTest, Stmt_Error_Unknown) { @@ -106,7 +108,7 @@ TEST_F(TypeDeterminerTest, Stmt_Error_Unknown) { EXPECT_FALSE(td()->DetermineResultType(&s)); EXPECT_EQ(td()->error(), - "2:30: unknown statement type for type determination"); + "2:30: unknown statement type for type determination: Fake"); } TEST_F(TypeDeterminerTest, Stmt_Assign) { @@ -158,6 +160,29 @@ TEST_F(TypeDeterminerTest, Stmt_Case) { EXPECT_TRUE(rhs_ptr->result_type()->IsF32()); } +TEST_F(TypeDeterminerTest, Stmt_Block) { + ast::type::I32Type i32; + ast::type::F32Type f32; + + auto lhs = std::make_unique( + std::make_unique(&i32, 2)); + auto* lhs_ptr = lhs.get(); + + auto rhs = std::make_unique( + std::make_unique(&f32, 2.3f)); + auto* rhs_ptr = rhs.get(); + + ast::BlockStatement block; + block.append(std::make_unique(std::move(lhs), + std::move(rhs))); + + EXPECT_TRUE(td()->DetermineResultType(&block)); + ASSERT_NE(lhs_ptr->result_type(), nullptr); + ASSERT_NE(rhs_ptr->result_type(), nullptr); + EXPECT_TRUE(lhs_ptr->result_type()->IsI32()); + EXPECT_TRUE(rhs_ptr->result_type()->IsF32()); +} + TEST_F(TypeDeterminerTest, Stmt_Else) { ast::type::I32Type i32; ast::type::F32Type f32; diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 6f18c5c0dc..94914ce357 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -23,6 +23,7 @@ #include "src/ast/assignment_statement.h" #include "src/ast/binary_expression.h" #include "src/ast/binding_decoration.h" +#include "src/ast/block_statement.h" #include "src/ast/bool_literal.h" #include "src/ast/builtin_decoration.h" #include "src/ast/call_expression.h" @@ -1338,6 +1339,18 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { return result_id; } +bool Builder::GenerateBlockStatement(ast::BlockStatement* stmt) { + scope_stack_.push_scope(); + for (const auto& block_stmt : *stmt) { + if (!GenerateStatement(block_stmt.get())) { + return false; + } + } + scope_stack_.pop_scope(); + + return true; +} + uint32_t Builder::GenerateCallExpression(ast::CallExpression* expr) { if (!expr->func()->IsIdentifier()) { error_ = "invalid function name"; @@ -1807,6 +1820,9 @@ bool Builder::GenerateStatement(ast::Statement* stmt) { if (stmt->IsAssign()) { return GenerateAssignStatement(stmt->AsAssign()); } + if (stmt->IsBlock()) { + return GenerateBlockStatement(stmt->AsBlock()); + } if (stmt->IsBreak()) { return GenerateBreakStatement(stmt->AsBreak()); } @@ -1839,7 +1855,7 @@ bool Builder::GenerateStatement(ast::Statement* stmt) { return GenerateVariableDeclStatement(stmt->AsVariableDecl()); } - error_ = "Unknown statement"; + error_ = "Unknown statement: " + stmt->str(); return false; } diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 50e63fabb3..1e62b3a91e 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -167,6 +167,10 @@ class Builder { /// @param assign the statement to generate /// @returns true if the statement was successfully generated bool GenerateAssignStatement(ast::AssignmentStatement* assign); + /// Generates a block statement + /// @param stmt the statement to generate + /// @returns true if the statement was successfully generated + bool GenerateBlockStatement(ast::BlockStatement* stmt); /// Generates a break statement /// @param stmt the statement to generate /// @returns true if the statement was successfully generated diff --git a/src/writer/spirv/builder_block_test.cc b/src/writer/spirv/builder_block_test.cc new file mode 100644 index 0000000000..b13ca1da0e --- /dev/null +++ b/src/writer/spirv/builder_block_test.cc @@ -0,0 +1,101 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "src/ast/assignment_statement.h" +#include "src/ast/block_statement.h" +#include "src/ast/float_literal.h" +#include "src/ast/identifier_expression.h" +#include "src/ast/scalar_constructor_expression.h" +#include "src/ast/type/f32_type.h" +#include "src/ast/variable_decl_statement.h" +#include "src/context.h" +#include "src/type_determiner.h" +#include "src/writer/spirv/builder.h" +#include "src/writer/spirv/spv_dump.h" + +namespace tint { +namespace writer { +namespace spirv { +namespace { + +using BuilderTest = testing::Test; + +TEST_F(BuilderTest, Block) { + ast::type::F32Type f32; + + // Note, this test uses shadow variables which aren't allowed in WGSL but + // serves to prove the block code is pushing new scopes as needed. + ast::BlockStatement outer; + + outer.append(std::make_unique( + std::make_unique("var", ast::StorageClass::kFunction, + &f32))); + outer.append(std::make_unique( + std::make_unique("var"), + std::make_unique( + std::make_unique(&f32, 1.0f)))); + + auto inner = std::make_unique(); + inner->append(std::make_unique( + std::make_unique("var", ast::StorageClass::kFunction, + &f32))); + inner->append(std::make_unique( + std::make_unique("var"), + std::make_unique( + std::make_unique(&f32, 2.0f)))); + + outer.append(std::move(inner)); + outer.append(std::make_unique( + std::make_unique("var"), + std::make_unique( + std::make_unique(&f32, 3.0f)))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&outer)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_TRUE(b.GenerateStatement(&outer)) << b.error(); + EXPECT_FALSE(b.has_error()); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32 +%2 = OpTypePointer Function %3 +%4 = OpConstantNull %3 +%5 = OpConstant %3 1 +%7 = OpConstant %3 2 +%8 = OpConstant %3 3 +)"); + + EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), + R"(%1 = OpVariable %2 Function %4 +%6 = OpVariable %2 Function %4 +)"); + + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(OpStore %1 %5 +OpStore %6 %7 +OpStore %1 %8 +)"); +} + +} // namespace +} // namespace spirv +} // namespace writer +} // namespace tint