diff --git a/src/ast/function.cc b/src/ast/function.cc index c03694fba9..5348b6297c 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -64,8 +64,8 @@ std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const { ast::PipelineStage Function::pipeline_stage() const { for (auto* deco : decorations_) { - if (deco->IsStage()) { - return deco->AsStage()->value(); + if (auto* stage = deco->As<StageDecoration>()) { + return stage->value(); } } return ast::PipelineStage::kNone; diff --git a/src/ast/function_decoration.cc b/src/ast/function_decoration.cc index 2f8fe6fa77..0b746befa6 100644 --- a/src/ast/function_decoration.cc +++ b/src/ast/function_decoration.cc @@ -16,7 +16,6 @@ #include <assert.h> -#include "src/ast/stage_decoration.h" #include "src/ast/workgroup_decoration.h" namespace tint { @@ -32,19 +31,10 @@ DecorationKind FunctionDecoration::GetKind() const { return Kind; } -bool FunctionDecoration::IsStage() const { - return false; -} - bool FunctionDecoration::IsWorkgroup() const { return false; } -const StageDecoration* FunctionDecoration::AsStage() const { - assert(IsStage()); - return static_cast<const StageDecoration*>(this); -} - const WorkgroupDecoration* FunctionDecoration::AsWorkgroup() const { assert(IsWorkgroup()); return static_cast<const WorkgroupDecoration*>(this); diff --git a/src/ast/function_decoration.h b/src/ast/function_decoration.h index 4fb50ab413..fe1a9dd3a9 100644 --- a/src/ast/function_decoration.h +++ b/src/ast/function_decoration.h @@ -24,7 +24,6 @@ namespace tint { namespace ast { -class StageDecoration; class WorkgroupDecoration; /// A decoration attached to a function @@ -38,13 +37,9 @@ class FunctionDecoration : public Castable<FunctionDecoration, Decoration> { /// @return the decoration kind DecorationKind GetKind() const override; - /// @returns true if this is a stage decoration - virtual bool IsStage() const; /// @returns true if this is a workgroup decoration virtual bool IsWorkgroup() const; - /// @returns the decoration as a stage decoration - const StageDecoration* AsStage() const; /// @returns the decoration as a workgroup decoration const WorkgroupDecoration* AsWorkgroup() const; diff --git a/src/ast/stage_decoration.cc b/src/ast/stage_decoration.cc index ecb923f731..4469c01178 100644 --- a/src/ast/stage_decoration.cc +++ b/src/ast/stage_decoration.cc @@ -22,10 +22,6 @@ StageDecoration::StageDecoration(ast::PipelineStage stage, const Source& source) StageDecoration::~StageDecoration() = default; -bool StageDecoration::IsStage() const { - return true; -} - void StageDecoration::to_str(std::ostream& out, size_t indent) const { make_indent(out, indent); out << "StageDecoration{" << stage_ << "}" << std::endl; diff --git a/src/ast/stage_decoration.h b/src/ast/stage_decoration.h index 0846b6df20..e70e94a898 100644 --- a/src/ast/stage_decoration.h +++ b/src/ast/stage_decoration.h @@ -30,9 +30,6 @@ class StageDecoration : public Castable<StageDecoration, FunctionDecoration> { StageDecoration(ast::PipelineStage stage, const Source& source); ~StageDecoration() override; - /// @returns true if this is a stage decoration - bool IsStage() const override; - /// @returns the stage ast::PipelineStage value() const { return stage_; } diff --git a/src/ast/stage_decoration_test.cc b/src/ast/stage_decoration_test.cc index c60aff351c..5adb97ee20 100644 --- a/src/ast/stage_decoration_test.cc +++ b/src/ast/stage_decoration_test.cc @@ -30,9 +30,10 @@ TEST_F(StageDecorationTest, Creation_1param) { } TEST_F(StageDecorationTest, Is) { - StageDecoration d{ast::PipelineStage::kFragment, Source{}}; - EXPECT_FALSE(d.IsWorkgroup()); - EXPECT_TRUE(d.IsStage()); + StageDecoration sd{ast::PipelineStage::kFragment, Source{}}; + Decoration* d = &sd; + EXPECT_FALSE(sd.IsWorkgroup()); + EXPECT_TRUE(d->Is<ast::StageDecoration>()); } TEST_F(StageDecorationTest, ToStr) { diff --git a/src/ast/workgroup_decoration_test.cc b/src/ast/workgroup_decoration_test.cc index 72d9934715..1d722f37f4 100644 --- a/src/ast/workgroup_decoration_test.cc +++ b/src/ast/workgroup_decoration_test.cc @@ -16,6 +16,7 @@ #include <sstream> +#include "src/ast/stage_decoration.h" #include "src/ast/test_helper.h" namespace tint { @@ -57,9 +58,10 @@ TEST_F(WorkgroupDecorationTest, Creation_3param) { } TEST_F(WorkgroupDecorationTest, Is) { - WorkgroupDecoration d{2, 4, 6, Source{}}; - EXPECT_TRUE(d.IsWorkgroup()); - EXPECT_FALSE(d.IsStage()); + WorkgroupDecoration wd{2, 4, 6, Source{}}; + Decoration* d = &wd; + EXPECT_TRUE(wd.IsWorkgroup()); + EXPECT_FALSE(d->Is<StageDecoration>()); } TEST_F(WorkgroupDecorationTest, ToStr) { diff --git a/src/reader/wgsl/parser_impl_function_decoration_test.cc b/src/reader/wgsl/parser_impl_function_decoration_test.cc index 966937bc97..0e39a34c95 100644 --- a/src/reader/wgsl/parser_impl_function_decoration_test.cc +++ b/src/reader/wgsl/parser_impl_function_decoration_test.cc @@ -259,8 +259,9 @@ TEST_F(ParserImplTest, FunctionDecoration_Stage) { ASSERT_FALSE(p->has_error()); auto* func_deco = deco.value->As<ast::FunctionDecoration>(); ASSERT_NE(func_deco, nullptr); - ASSERT_TRUE(func_deco->IsStage()); - EXPECT_EQ(func_deco->AsStage()->value(), ast::PipelineStage::kCompute); + ASSERT_TRUE(func_deco->Is<ast::StageDecoration>()); + EXPECT_EQ(func_deco->As<ast::StageDecoration>()->value(), + ast::PipelineStage::kCompute); } TEST_F(ParserImplTest, FunctionDecoration_Stage_MissingValue) { diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc index 8c8d59304d..6567f372b1 100644 --- a/src/validator/validator_impl.cc +++ b/src/validator/validator_impl.cc @@ -23,6 +23,7 @@ #include "src/ast/int_literal.h" #include "src/ast/intrinsic.h" #include "src/ast/sint_literal.h" +#include "src/ast/stage_decoration.h" #include "src/ast/struct.h" #include "src/ast/switch_statement.h" #include "src/ast/type/array_type.h" @@ -174,7 +175,7 @@ bool ValidatorImpl::ValidateEntryPoint(const ast::FunctionList& funcs) { } auto stage_deco_count = 0; for (auto* deco : func->decorations()) { - if (deco->IsStage()) { + if (deco->Is<ast::StageDecoration>()) { stage_deco_count++; } } diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index e713aa6580..ad77a7cd78 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -354,8 +354,8 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { out_ << "workgroup_size(" << std::to_string(x) << ", " << std::to_string(y) << ", " << std::to_string(z) << ")"; } - if (deco->IsStage()) { - out_ << "stage(" << deco->AsStage()->value() << ")"; + if (auto* stage = deco->As<ast::StageDecoration>()) { + out_ << "stage(" << stage->value() << ")"; } out_ << "]]" << std::endl; }