From 1a14f2093c61084776f662418f42e1cc932d85d4 Mon Sep 17 00:00:00 2001
From: David Neto <dneto@google.com>
Date: Mon, 14 Jun 2021 19:42:27 +0000
Subject: [PATCH] spirv-reader: pipeline IO: handle sample_mask input

It's an array in Vulkan SPIR-V, but a scalar u32 in WGSL.
Handle signedness change.

Note that input variables can't have an initializer, so that
doesn't need to be handled.

Bug: tint:508
Change-Id: I7cf4228b31f9c42e4e4436d78cbb1eb0c8196cd5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/54482
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
---
 src/reader/spirv/function.cc                  |  35 +++-
 .../spirv/parser_impl_module_var_test.cc      | 150 +++++++++++++++++-
 2 files changed, 174 insertions(+), 11 deletions(-)

diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index fb63dc7d78..5022c6f354 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -685,6 +685,15 @@ struct LoopStatementBuilder
   ast::BlockStatement* continuing = nullptr;
 };
 
+/// @param decos a list of parsed decorations
+/// @returns true if the decorations include a SampleMask builtin
+bool HasBuiltinSampleMask(const ast::DecorationList& decos) {
+  if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(decos)) {
+    return builtin->value() == ast::Builtin::kSampleMask;
+  }
+  return false;
+}
+
 }  // namespace
 
 BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb)
@@ -973,11 +982,17 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() {
     // variable.
     ast::Expression* param_value =
         create<ast::IdentifierExpression>(source, param_sym);
-    if (forced_store_type != store_type) {
-      // Insert a bitcast if needed.
-      const auto cast_name = namer_.MakeDerivedName(param_name + "_cast");
-      const auto cast_sym = builder_.Symbols().Register(cast_name);
-
+    if (HasBuiltinSampleMask(param_decos)) {
+      // In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a scalar.
+      // Use the first element only.
+      param_value = create<ast::ArrayAccessorExpression>(
+          source, param_value, parser_impl_.MakeNullValue(ty_.I32()));
+      if (store_type->As<Array>()->type->IsSignedScalarOrVector()) {
+        // sample_mask is unsigned in WGSL. Bitcast it.
+        param_value = create<ast::BitcastExpression>(
+            source, ty_.I32()->Build(builder_), param_value);
+      }
+    } else if (forced_store_type != store_type) {
       // The parameter will have the WGSL type, but we need to add
       // a bitcast to the variable store type.
       param_value = create<ast::BitcastExpression>(
@@ -1046,9 +1061,15 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() {
           std::move(out_decos));
       return_members.push_back(return_member);
 
+      ast::Expression* return_member_value =
+          create<ast::IdentifierExpression>(source, var_sym);
+      if (forced_store_type != store_type) {
+        // We need to cast from the variable store type to the member type.
+        return_member_value = create<ast::BitcastExpression>(
+            source, forced_store_type->Build(builder_), return_member_value);
+      }
       // Save the expression.
-      return_exprs.push_back(
-          create<ast::IdentifierExpression>(source, var_sym));
+      return_exprs.push_back(return_member_value);
     }
 
     // Create and register the result type.
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index f40b5f3f28..a755f13d2c 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -4470,13 +4470,155 @@ TEST_F(SpvModuleScopeVarParserTest,
 
 // SampleMask is an array in Vulkan SPIR-V, but a scalar in WGSL.
 TEST_F(SpvModuleScopeVarParserTest,
-       DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_U) {}
+       EntryPointWrapping_BuiltinVar_SampleMask_In_Unsigned) {
+  // SampleMask is u32 in WGSL.
+  // Use unsigned array element in Vulkan.
+  const auto assembly = CommonCapabilities() + R"(
+     OpEntryPoint Fragment %main "main" %1
+     OpExecutionMode %main OriginUpperLeft
+     OpDecorate %1 BuiltIn SampleMask
+)" + CommonTypes() +
+                        R"(
+     %arr = OpTypeArray %uint %uint_1
+     %ptr_ty = OpTypePointer Input %arr
+     %1 = OpVariable %ptr_ty Input
+
+     %main = OpFunction %void None %voidfn
+     %entry = OpLabel
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+
+  // TODO(crbug.com/tint/508): Remove this when everything is converted
+  // to HLSL style pipeline IO.
+  p->SetHLSLStylePipelineIO();
+
+  ASSERT_TRUE(p->Parse()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+  const auto got = p->program().to_str();
+  const std::string expected = R"(Module{
+  Variable{
+    x_1
+    private
+    undefined
+    __array__u32_1
+  }
+  Function main_1 -> __void
+  ()
+  {
+    Return{}
+  }
+  Function main -> __void
+  StageDecoration{fragment}
+  (
+    VariableConst{
+      Decorations{
+        BuiltinDecoration{sample_mask}
+      }
+      x_1_param
+      none
+      undefined
+      __u32
+    }
+  )
+  {
+    Assignment{
+      Identifier[not set]{x_1}
+      ArrayAccessor[not set]{
+        Identifier[not set]{x_1_param}
+        ScalarConstructor[not set]{0}
+      }
+    }
+    Call[not set]{
+      Identifier[not set]{main_1}
+      (
+      )
+    }
+  }
+}
+)";
+  EXPECT_EQ(got, expected) << got;
+}
+
 TEST_F(SpvModuleScopeVarParserTest,
-       DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_U_Initializer) {}
+       EntryPointWrapping_BuiltinVar_SampleMask_In_Signed) {
+  // SampleMask is u32 in WGSL.
+  // Use signed array element in Vulkan.
+  const auto assembly = CommonCapabilities() + R"(
+     OpEntryPoint Fragment %main "main" %1
+     OpExecutionMode %main OriginUpperLeft
+     OpDecorate %1 BuiltIn SampleMask
+)" + CommonTypes() +
+                        R"(
+     %arr = OpTypeArray %int %uint_1
+     %ptr_ty = OpTypePointer Input %arr
+     %1 = OpVariable %ptr_ty Input
+
+     %main = OpFunction %void None %voidfn
+     %entry = OpLabel
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+
+  // TODO(crbug.com/tint/508): Remove this when everything is converted
+  // to HLSL style pipeline IO.
+  p->SetHLSLStylePipelineIO();
+
+  ASSERT_TRUE(p->Parse()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+  const auto got = p->program().to_str();
+  const std::string expected = R"(Module{
+  Variable{
+    x_1
+    private
+    undefined
+    __array__i32_1
+  }
+  Function main_1 -> __void
+  ()
+  {
+    Return{}
+  }
+  Function main -> __void
+  StageDecoration{fragment}
+  (
+    VariableConst{
+      Decorations{
+        BuiltinDecoration{sample_mask}
+      }
+      x_1_param
+      none
+      undefined
+      __u32
+    }
+  )
+  {
+    Assignment{
+      Identifier[not set]{x_1}
+      Bitcast[not set]<__i32>{
+        ArrayAccessor[not set]{
+          Identifier[not set]{x_1_param}
+          ScalarConstructor[not set]{0}
+        }
+      }
+    }
+    Call[not set]{
+      Identifier[not set]{main_1}
+      (
+      )
+    }
+  }
+}
+)";
+  EXPECT_EQ(got, expected) << got;
+}
+
 TEST_F(SpvModuleScopeVarParserTest,
-       DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_S) {}
+       DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_Out_U) {}
 TEST_F(SpvModuleScopeVarParserTest,
-       DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_S_Initializer) {}
+       DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_Out_S) {}
 
 // TODO(dneto): pipeline IO: flatten structures, and distribute locations