diff --git a/src/ast/function.cc b/src/ast/function.cc
index c03694fba9..5348b6297c 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -64,8 +64,8 @@ std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
 
 ast::PipelineStage Function::pipeline_stage() const {
   for (auto* deco : decorations_) {
-    if (deco->IsStage()) {
-      return deco->AsStage()->value();
+    if (auto* stage = deco->As<StageDecoration>()) {
+      return stage->value();
     }
   }
   return ast::PipelineStage::kNone;
diff --git a/src/ast/function_decoration.cc b/src/ast/function_decoration.cc
index 2f8fe6fa77..0b746befa6 100644
--- a/src/ast/function_decoration.cc
+++ b/src/ast/function_decoration.cc
@@ -16,7 +16,6 @@
 
 #include <assert.h>
 
-#include "src/ast/stage_decoration.h"
 #include "src/ast/workgroup_decoration.h"
 
 namespace tint {
@@ -32,19 +31,10 @@ DecorationKind FunctionDecoration::GetKind() const {
   return Kind;
 }
 
-bool FunctionDecoration::IsStage() const {
-  return false;
-}
-
 bool FunctionDecoration::IsWorkgroup() const {
   return false;
 }
 
-const StageDecoration* FunctionDecoration::AsStage() const {
-  assert(IsStage());
-  return static_cast<const StageDecoration*>(this);
-}
-
 const WorkgroupDecoration* FunctionDecoration::AsWorkgroup() const {
   assert(IsWorkgroup());
   return static_cast<const WorkgroupDecoration*>(this);
diff --git a/src/ast/function_decoration.h b/src/ast/function_decoration.h
index 4fb50ab413..fe1a9dd3a9 100644
--- a/src/ast/function_decoration.h
+++ b/src/ast/function_decoration.h
@@ -24,7 +24,6 @@
 namespace tint {
 namespace ast {
 
-class StageDecoration;
 class WorkgroupDecoration;
 
 /// A decoration attached to a function
@@ -38,13 +37,9 @@ class FunctionDecoration : public Castable<FunctionDecoration, Decoration> {
   /// @return the decoration kind
   DecorationKind GetKind() const override;
 
-  /// @returns true if this is a stage decoration
-  virtual bool IsStage() const;
   /// @returns true if this is a workgroup decoration
   virtual bool IsWorkgroup() const;
 
-  /// @returns the decoration as a stage decoration
-  const StageDecoration* AsStage() const;
   /// @returns the decoration as a workgroup decoration
   const WorkgroupDecoration* AsWorkgroup() const;
 
diff --git a/src/ast/stage_decoration.cc b/src/ast/stage_decoration.cc
index ecb923f731..4469c01178 100644
--- a/src/ast/stage_decoration.cc
+++ b/src/ast/stage_decoration.cc
@@ -22,10 +22,6 @@ StageDecoration::StageDecoration(ast::PipelineStage stage, const Source& source)
 
 StageDecoration::~StageDecoration() = default;
 
-bool StageDecoration::IsStage() const {
-  return true;
-}
-
 void StageDecoration::to_str(std::ostream& out, size_t indent) const {
   make_indent(out, indent);
   out << "StageDecoration{" << stage_ << "}" << std::endl;
diff --git a/src/ast/stage_decoration.h b/src/ast/stage_decoration.h
index 0846b6df20..e70e94a898 100644
--- a/src/ast/stage_decoration.h
+++ b/src/ast/stage_decoration.h
@@ -30,9 +30,6 @@ class StageDecoration : public Castable<StageDecoration, FunctionDecoration> {
   StageDecoration(ast::PipelineStage stage, const Source& source);
   ~StageDecoration() override;
 
-  /// @returns true if this is a stage decoration
-  bool IsStage() const override;
-
   /// @returns the stage
   ast::PipelineStage value() const { return stage_; }
 
diff --git a/src/ast/stage_decoration_test.cc b/src/ast/stage_decoration_test.cc
index c60aff351c..5adb97ee20 100644
--- a/src/ast/stage_decoration_test.cc
+++ b/src/ast/stage_decoration_test.cc
@@ -30,9 +30,10 @@ TEST_F(StageDecorationTest, Creation_1param) {
 }
 
 TEST_F(StageDecorationTest, Is) {
-  StageDecoration d{ast::PipelineStage::kFragment, Source{}};
-  EXPECT_FALSE(d.IsWorkgroup());
-  EXPECT_TRUE(d.IsStage());
+  StageDecoration sd{ast::PipelineStage::kFragment, Source{}};
+  Decoration* d = &sd;
+  EXPECT_FALSE(sd.IsWorkgroup());
+  EXPECT_TRUE(d->Is<ast::StageDecoration>());
 }
 
 TEST_F(StageDecorationTest, ToStr) {
diff --git a/src/ast/workgroup_decoration_test.cc b/src/ast/workgroup_decoration_test.cc
index 72d9934715..1d722f37f4 100644
--- a/src/ast/workgroup_decoration_test.cc
+++ b/src/ast/workgroup_decoration_test.cc
@@ -16,6 +16,7 @@
 
 #include <sstream>
 
+#include "src/ast/stage_decoration.h"
 #include "src/ast/test_helper.h"
 
 namespace tint {
@@ -57,9 +58,10 @@ TEST_F(WorkgroupDecorationTest, Creation_3param) {
 }
 
 TEST_F(WorkgroupDecorationTest, Is) {
-  WorkgroupDecoration d{2, 4, 6, Source{}};
-  EXPECT_TRUE(d.IsWorkgroup());
-  EXPECT_FALSE(d.IsStage());
+  WorkgroupDecoration wd{2, 4, 6, Source{}};
+  Decoration* d = &wd;
+  EXPECT_TRUE(wd.IsWorkgroup());
+  EXPECT_FALSE(d->Is<StageDecoration>());
 }
 
 TEST_F(WorkgroupDecorationTest, ToStr) {
diff --git a/src/reader/wgsl/parser_impl_function_decoration_test.cc b/src/reader/wgsl/parser_impl_function_decoration_test.cc
index 966937bc97..0e39a34c95 100644
--- a/src/reader/wgsl/parser_impl_function_decoration_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decoration_test.cc
@@ -259,8 +259,9 @@ TEST_F(ParserImplTest, FunctionDecoration_Stage) {
   ASSERT_FALSE(p->has_error());
   auto* func_deco = deco.value->As<ast::FunctionDecoration>();
   ASSERT_NE(func_deco, nullptr);
-  ASSERT_TRUE(func_deco->IsStage());
-  EXPECT_EQ(func_deco->AsStage()->value(), ast::PipelineStage::kCompute);
+  ASSERT_TRUE(func_deco->Is<ast::StageDecoration>());
+  EXPECT_EQ(func_deco->As<ast::StageDecoration>()->value(),
+            ast::PipelineStage::kCompute);
 }
 
 TEST_F(ParserImplTest, FunctionDecoration_Stage_MissingValue) {
diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc
index 8c8d59304d..6567f372b1 100644
--- a/src/validator/validator_impl.cc
+++ b/src/validator/validator_impl.cc
@@ -23,6 +23,7 @@
 #include "src/ast/int_literal.h"
 #include "src/ast/intrinsic.h"
 #include "src/ast/sint_literal.h"
+#include "src/ast/stage_decoration.h"
 #include "src/ast/struct.h"
 #include "src/ast/switch_statement.h"
 #include "src/ast/type/array_type.h"
@@ -174,7 +175,7 @@ bool ValidatorImpl::ValidateEntryPoint(const ast::FunctionList& funcs) {
       }
       auto stage_deco_count = 0;
       for (auto* deco : func->decorations()) {
-        if (deco->IsStage()) {
+        if (deco->Is<ast::StageDecoration>()) {
           stage_deco_count++;
         }
       }
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index e713aa6580..ad77a7cd78 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -354,8 +354,8 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
       out_ << "workgroup_size(" << std::to_string(x) << ", "
            << std::to_string(y) << ", " << std::to_string(z) << ")";
     }
-    if (deco->IsStage()) {
-      out_ << "stage(" << deco->AsStage()->value() << ")";
+    if (auto* stage = deco->As<ast::StageDecoration>()) {
+      out_ << "stage(" << stage->value() << ")";
     }
     out_ << "]]" << std::endl;
   }