From 916b40811144d66f89202ebe45e75072c795272c Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Mon, 1 Jun 2020 18:56:34 +0000 Subject: [PATCH] [spirv-writer] Add switch support This CL adds switch support to the SPIR-V writer. Bug: tint:5 Change-Id: I8a6ad40cb2d344c87abdf842194b60afb1b4c96e Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22165 Reviewed-by: David Neto --- BUILD.gn | 1 + src/CMakeLists.txt | 1 + src/ast/case_statement.cc | 3 + src/ast/case_statement.h | 4 + src/writer/spirv/binary_writer.cc | 1 - src/writer/spirv/builder.cc | 88 ++++- src/writer/spirv/builder.h | 4 + src/writer/spirv/builder_if_test.cc | 1 - src/writer/spirv/builder_switch_test.cc | 437 ++++++++++++++++++++++++ 9 files changed, 536 insertions(+), 4 deletions(-) create mode 100644 src/writer/spirv/builder_switch_test.cc diff --git a/BUILD.gn b/BUILD.gn index 73f3454c44..2adbf02b6e 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -692,6 +692,7 @@ source_set("tint_unittests_spv_writer_src") { "src/writer/spirv/builder_literal_test.cc", "src/writer/spirv/builder_loop_test.cc", "src/writer/spirv/builder_return_test.cc", + "src/writer/spirv/builder_switch_test.cc", "src/writer/spirv/builder_test.cc", "src/writer/spirv/builder_type_test.cc", "src/writer/spirv/builder_unary_op_expression_test.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2c4745a9c0..e731adb36d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -429,6 +429,7 @@ if(${TINT_BUILD_SPV_WRITER}) writer/spirv/builder_literal_test.cc writer/spirv/builder_loop_test.cc writer/spirv/builder_return_test.cc + writer/spirv/builder_switch_test.cc writer/spirv/builder_test.cc writer/spirv/builder_type_test.cc writer/spirv/builder_unary_op_expression_test.cc diff --git a/src/ast/case_statement.cc b/src/ast/case_statement.cc index 098921776d..7eb4878426 100644 --- a/src/ast/case_statement.cc +++ b/src/ast/case_statement.cc @@ -19,6 +19,9 @@ namespace ast { CaseStatement::CaseStatement() : Statement() {} +CaseStatement::CaseStatement(StatementList body) + : Statement(), body_(std::move(body)) {} + CaseStatement::CaseStatement(CaseSelectorList conditions, StatementList body) : Statement(), conditions_(std::move(conditions)), body_(std::move(body)) {} diff --git a/src/ast/case_statement.h b/src/ast/case_statement.h index 11dcd33989..b5b13d6b61 100644 --- a/src/ast/case_statement.h +++ b/src/ast/case_statement.h @@ -36,6 +36,10 @@ class CaseStatement : public Statement { /// Constructor CaseStatement(); /// Constructor + /// Creates a default case statement + /// @param body the case body + explicit CaseStatement(StatementList body); + /// Constructor /// @param conditions the case conditions /// @param body the case body CaseStatement(CaseSelectorList conditions, StatementList body); diff --git a/src/writer/spirv/binary_writer.cc b/src/writer/spirv/binary_writer.cc index 81d8fac2e3..294bae8fe5 100644 --- a/src/writer/spirv/binary_writer.cc +++ b/src/writer/spirv/binary_writer.cc @@ -52,7 +52,6 @@ void BinaryWriter::WriteHeader(uint32_t bound) { void BinaryWriter::process_instruction(const Instruction& inst) { out_.push_back(inst.word_length() << 16 | static_cast(inst.opcode())); - for (const auto& op : inst.operands()) { process_op(op); } diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 084a50423c..d633f9c0e9 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -25,6 +25,7 @@ #include "src/ast/bool_literal.h" #include "src/ast/builtin_decoration.h" #include "src/ast/call_expression.h" +#include "src/ast/case_statement.h" #include "src/ast/cast_expression.h" #include "src/ast/constructor_expression.h" #include "src/ast/decorated_variable.h" @@ -43,6 +44,7 @@ #include "src/ast/struct.h" #include "src/ast/struct_member.h" #include "src/ast/struct_member_offset_decoration.h" +#include "src/ast/switch_statement.h" #include "src/ast/type/array_type.h" #include "src/ast/type/matrix_type.h" #include "src/ast/type/pointer_type.h" @@ -235,7 +237,7 @@ bool Builder::GenerateAssignStatement(ast::AssignmentStatement* assign) { bool Builder::GenerateBreakStatement(ast::BreakStatement*) { if (merge_stack_.empty()) { - error_ = "Attempted to break with a merge block"; + error_ = "Attempted to break without a merge block"; return false; } push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_stack_.back())}); @@ -244,7 +246,7 @@ bool Builder::GenerateBreakStatement(ast::BreakStatement*) { bool Builder::GenerateContinueStatement(ast::ContinueStatement*) { if (continue_stack_.empty()) { - error_ = "Attempted to continue with a continue block"; + error_ = "Attempted to continue without a continue block"; return false; } push_function_inst(spv::Op::OpBranch, {Operand::Int(continue_stack_.back())}); @@ -1330,6 +1332,85 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) { return true; } +bool Builder::GenerateSwitchStatement(ast::SwitchStatement* stmt) { + auto merge_block = result_op(); + auto merge_block_id = merge_block.to_i(); + + merge_stack_.push_back(merge_block_id); + + auto cond_id = GenerateExpression(stmt->condition()); + if (cond_id == 0) { + return false; + } + cond_id = GenerateLoadIfNeeded(stmt->condition()->result_type(), cond_id); + + auto default_block = result_op(); + auto default_block_id = default_block.to_i(); + + std::vector params = {Operand::Int(cond_id), + Operand::Int(default_block_id)}; + + std::vector case_ids; + for (const auto& item : stmt->body()) { + if (item->IsDefault()) { + case_ids.push_back(default_block_id); + continue; + } + + auto block = result_op(); + auto block_id = block.to_i(); + + case_ids.push_back(block_id); + for (const auto& selector : item->conditions()) { + if (!selector->IsInt()) { + error_ = "expected integer literal for switch case label"; + return false; + } + + params.push_back(Operand::Int(selector->AsInt()->value())); + params.push_back(Operand::Int(block_id)); + } + } + + push_function_inst(spv::Op::OpSelectionMerge, + {Operand::Int(merge_block_id), + Operand::Int(SpvSelectionControlMaskNone)}); + push_function_inst(spv::Op::OpSwitch, params); + + bool generated_default = false; + auto& body = stmt->body(); + // We output the case statements in order they were entered in the original + // source. Each fallthrough goes to the next case entry, so is a forward + // branch, otherwise the branch is to the merge block which comes after + // the switch statement. + for (uint32_t i = 0; i < body.size(); i++) { + auto& item = body[i]; + + if (item->IsDefault()) { + generated_default = true; + } + + push_function_inst(spv::Op::OpLabel, {Operand::Int(case_ids[i])}); + if (!GenerateStatementList(item->body())) { + return false; + } + + if (!LastIsTerminator(item->body())) { + push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); + } + } + + if (!generated_default) { + push_function_inst(spv::Op::OpLabel, {Operand::Int(default_block_id)}); + push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); + } + + merge_stack_.pop_back(); + + push_function_inst(spv::Op::OpLabel, {Operand::Int(merge_block_id)}); + return true; +} + bool Builder::GenerateReturnStatement(ast::ReturnStatement* stmt) { if (stmt->has_value()) { auto val_id = GenerateExpression(stmt->value()); @@ -1422,6 +1503,9 @@ bool Builder::GenerateStatement(ast::Statement* stmt) { if (stmt->IsReturn()) { return GenerateReturnStatement(stmt->AsReturn()); } + if (stmt->IsSwitch()) { + return GenerateSwitchStatement(stmt->AsSwitch()); + } if (stmt->IsVariableDecl()) { return GenerateVariableDeclStatement(stmt->AsVariableDecl()); } diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index a8433e65fa..2a6fd65eac 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -298,6 +298,10 @@ class Builder { /// @param stmt the statement to generate /// @returns true on success, false otherwise bool GenerateReturnStatement(ast::ReturnStatement* stmt); + /// Generates a switch statement + /// @param stmt the statement to generate + /// @returns ture on success, false otherwise + bool GenerateSwitchStatement(ast::SwitchStatement* stmt); /// Generates a conditional section merge block /// @param cond the condition /// @param true_body the statements making up the true block diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc index ef123bb690..e7109ab07d 100644 --- a/src/writer/spirv/builder_if_test.cc +++ b/src/writer/spirv/builder_if_test.cc @@ -97,7 +97,6 @@ TEST_F(BuilderTest, If_WithStatements) { ast::Module mod; TypeDeterminer td(&ctx, &mod); td.RegisterVariableForTesting(var.get()); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); Builder b(&mod); diff --git a/src/writer/spirv/builder_switch_test.cc b/src/writer/spirv/builder_switch_test.cc new file mode 100644 index 0000000000..1284345765 --- /dev/null +++ b/src/writer/spirv/builder_switch_test.cc @@ -0,0 +1,437 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "src/ast/assignment_statement.h" +#include "src/ast/bool_literal.h" +#include "src/ast/break_statement.h" +#include "src/ast/case_statement.h" +#include "src/ast/identifier_expression.h" +#include "src/ast/if_statement.h" +#include "src/ast/int_literal.h" +#include "src/ast/scalar_constructor_expression.h" +#include "src/ast/switch_statement.h" +#include "src/ast/type/bool_type.h" +#include "src/ast/type/i32_type.h" +#include "src/context.h" +#include "src/type_determiner.h" +#include "src/writer/spirv/builder.h" +#include "src/writer/spirv/spv_dump.h" + +namespace tint { +namespace writer { +namespace spirv { +namespace { + +using BuilderTest = testing::Test; + +TEST_F(BuilderTest, Switch_Empty) { + ast::type::I32Type i32; + + // switch (1) { + // } + auto cond = std::make_unique( + std::make_unique(&i32, 1)); + + ast::SwitchStatement expr(std::move(cond), ast::CaseStatementList{}); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + + EXPECT_TRUE(b.GenerateSwitchStatement(&expr)) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpConstant %2 1 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(OpSelectionMerge %1 None +OpSwitch %3 %4 +%4 = OpLabel +OpBranch %1 +%1 = OpLabel +)"); +} + +TEST_F(BuilderTest, Switch_WithCase) { + ast::type::I32Type i32; + + // switch(a) { + // case 1: + // v = 1; + // case 2: + // v = 2; + // } + + auto v = + std::make_unique("v", ast::StorageClass::kPrivate, &i32); + auto a = + std::make_unique("a", ast::StorageClass::kPrivate, &i32); + + ast::StatementList case_1_body; + case_1_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 1)))); + + ast::StatementList case_2_body; + case_2_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 2)))); + + ast::CaseSelectorList selector_1; + selector_1.push_back(std::make_unique(&i32, 1)); + + ast::CaseSelectorList selector_2; + selector_2.push_back(std::make_unique(&i32, 2)); + + ast::CaseStatementList cases; + cases.push_back( + std::make_unique(std::move(selector_1), std::move(case_1_body))); + cases.push_back(std::make_unique( + std::move(selector_2), std::move(case_2_body))); + + ast::SwitchStatement expr(std::make_unique("a"), + std::move(cases)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(v.get()); + td.RegisterVariableForTesting(a.get()); + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + ast::Function func("a_func", {}, &i32); + + Builder b(&mod); + ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error(); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + + EXPECT_TRUE(b.GenerateSwitchStatement(&expr)) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v" +OpName %5 "a" +OpName %7 "a_func" +%3 = OpTypeInt 32 1 +%2 = OpTypePointer Private %3 +%4 = OpConstantNull %3 +%1 = OpVariable %2 Private %4 +%5 = OpVariable %2 Private %4 +%6 = OpTypeFunction %3 +%14 = OpConstant %3 1 +%15 = OpConstant %3 2 +%7 = OpFunction %3 None %6 +%8 = OpLabel +%10 = OpLoad %3 %5 +OpSelectionMerge %9 None +OpSwitch %10 %11 1 %12 2 %13 +%12 = OpLabel +OpStore %1 %14 +OpBranch %9 +%13 = OpLabel +OpStore %1 %15 +OpBranch %9 +%11 = OpLabel +OpBranch %9 +%9 = OpLabel +OpFunctionEnd +)"); +} + +TEST_F(BuilderTest, Switch_WithDefault) { + ast::type::I32Type i32; + + // switch(true) { + // default: + // v = 1; + // } + + auto v = + std::make_unique("v", ast::StorageClass::kPrivate, &i32); + auto a = + std::make_unique("a", ast::StorageClass::kPrivate, &i32); + + ast::StatementList default_body; + default_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 1)))); + + ast::CaseStatementList cases; + cases.push_back( + std::make_unique(std::move(default_body))); + + ast::SwitchStatement expr(std::make_unique("a"), + std::move(cases)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(v.get()); + td.RegisterVariableForTesting(a.get()); + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + ast::Function func("a_func", {}, &i32); + + Builder b(&mod); + ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error(); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + + EXPECT_TRUE(b.GenerateSwitchStatement(&expr)) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v" +OpName %5 "a" +OpName %7 "a_func" +%3 = OpTypeInt 32 1 +%2 = OpTypePointer Private %3 +%4 = OpConstantNull %3 +%1 = OpVariable %2 Private %4 +%5 = OpVariable %2 Private %4 +%6 = OpTypeFunction %3 +%12 = OpConstant %3 1 +%7 = OpFunction %3 None %6 +%8 = OpLabel +%10 = OpLoad %3 %5 +OpSelectionMerge %9 None +OpSwitch %10 %11 +%11 = OpLabel +OpStore %1 %12 +OpBranch %9 +%9 = OpLabel +OpFunctionEnd +)"); +} + +TEST_F(BuilderTest, Switch_WithCaseAndDefault) { + ast::type::I32Type i32; + + // switch(a) { + // case 1: + // v = 1; + // case 2, 3: + // v = 2; + // default: + // v = 3; + // } + + auto v = + std::make_unique("v", ast::StorageClass::kPrivate, &i32); + auto a = + std::make_unique("a", ast::StorageClass::kPrivate, &i32); + + ast::StatementList case_1_body; + case_1_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 1)))); + + ast::StatementList case_2_body; + case_2_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 2)))); + + ast::StatementList default_body; + default_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 3)))); + + ast::CaseSelectorList selector_1; + selector_1.push_back(std::make_unique(&i32, 1)); + + ast::CaseSelectorList selector_2; + selector_2.push_back(std::make_unique(&i32, 2)); + selector_2.push_back(std::make_unique(&i32, 3)); + + ast::CaseStatementList cases; + cases.push_back(std::make_unique( + std::move(selector_1), std::move(case_1_body))); + cases.push_back(std::make_unique( + std::move(selector_2), std::move(case_2_body))); + cases.push_back( + std::make_unique(std::move(default_body))); + + ast::SwitchStatement expr(std::make_unique("a"), + std::move(cases)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(v.get()); + td.RegisterVariableForTesting(a.get()); + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + ast::Function func("a_func", {}, &i32); + + Builder b(&mod); + ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error(); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + + EXPECT_TRUE(b.GenerateSwitchStatement(&expr)) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v" +OpName %5 "a" +OpName %7 "a_func" +%3 = OpTypeInt 32 1 +%2 = OpTypePointer Private %3 +%4 = OpConstantNull %3 +%1 = OpVariable %2 Private %4 +%5 = OpVariable %2 Private %4 +%6 = OpTypeFunction %3 +%14 = OpConstant %3 1 +%15 = OpConstant %3 2 +%16 = OpConstant %3 3 +%7 = OpFunction %3 None %6 +%8 = OpLabel +%10 = OpLoad %3 %5 +OpSelectionMerge %9 None +OpSwitch %10 %11 1 %12 2 %13 3 %13 +%12 = OpLabel +OpStore %1 %14 +OpBranch %9 +%13 = OpLabel +OpStore %1 %15 +OpBranch %9 +%11 = OpLabel +OpStore %1 %16 +OpBranch %9 +%9 = OpLabel +OpFunctionEnd +)"); +} + +TEST_F(BuilderTest, DISABLED_Switch_CaseWithFallthrough) { + // switch (a) { + // case 1: + // v = 1; + // fallthrough; + // case 2: + // v = 2; + // } + FAIL(); +} + +// TODO(dsinclair): Implement when parsing is handled for multi-value +// case labels. +TEST_F(BuilderTest, DISABLED_Switch_CaseWithMulitpleLabels) { + // switch (a) { + // case 1, 2, 3: + // v = 1; + // } + FAIL(); +} + +TEST_F(BuilderTest, Switch_WithNestedBreak) { + ast::type::I32Type i32; + ast::type::BoolType bool_type; + + // switch (a) { + // case 1: + // if (true) { + // break; + // } + // v = 1; + // } + + auto v = + std::make_unique("v", ast::StorageClass::kPrivate, &i32); + auto a = + std::make_unique("a", ast::StorageClass::kPrivate, &i32); + + ast::StatementList if_body; + if_body.push_back(std::make_unique()); + + ast::StatementList case_1_body; + case_1_body.push_back(std::make_unique( + std::make_unique( + std::make_unique(&bool_type, true)), + std::move(if_body))); + + case_1_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 1)))); + + ast::CaseSelectorList selector_1; + selector_1.push_back(std::make_unique(&i32, 1)); + + ast::CaseStatementList cases; + cases.push_back(std::make_unique( + std::move(selector_1), std::move(case_1_body))); + + ast::SwitchStatement expr(std::make_unique("a"), + std::move(cases)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(v.get()); + td.RegisterVariableForTesting(a.get()); + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + ast::Function func("a_func", {}, &i32); + + Builder b(&mod); + ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error(); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + + EXPECT_TRUE(b.GenerateSwitchStatement(&expr)) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v" +OpName %5 "a" +OpName %7 "a_func" +%3 = OpTypeInt 32 1 +%2 = OpTypePointer Private %3 +%4 = OpConstantNull %3 +%1 = OpVariable %2 Private %4 +%5 = OpVariable %2 Private %4 +%6 = OpTypeFunction %3 +%13 = OpTypeBool +%14 = OpConstantTrue %13 +%17 = OpConstant %3 1 +%7 = OpFunction %3 None %6 +%8 = OpLabel +%10 = OpLoad %3 %5 +OpSelectionMerge %9 None +OpSwitch %10 %11 1 %12 +%12 = OpLabel +OpSelectionMerge %15 None +OpBranchConditional %14 %16 %15 +%16 = OpLabel +OpBranch %9 +%15 = OpLabel +OpStore %1 %17 +OpBranch %9 +%11 = OpLabel +OpBranch %9 +%9 = OpLabel +OpFunctionEnd +)"); +} + +} // namespace +} // namespace spirv +} // namespace writer +} // namespace tint