diff --git a/BUILD.gn b/BUILD.gn index 8ec5e25746..3f2df43529 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -975,6 +975,7 @@ source_set("tint_unittests_spv_reader_src") { source_set("tint_unittests_spv_writer_src") { sources = [ + "src/transform/msl_test.cc", "src/transform/spirv_test.cc", "src/writer/spirv/binary_writer_test.cc", "src/writer/spirv/builder_accessor_expression_test.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b18564cbe9..009319436f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -638,6 +638,7 @@ if(${TINT_BUILD_TESTS}) if(${TINT_BUILD_SPV_WRITER}) list(APPEND TINT_TEST_SRCS + transform/msl_test.cc transform/spirv_test.cc writer/spirv/binary_writer_test.cc writer/spirv/builder_accessor_expression_test.cc diff --git a/src/transform/msl.cc b/src/transform/msl.cc index bb6d8e9a61..a089fda847 100644 --- a/src/transform/msl.cc +++ b/src/transform/msl.cc @@ -20,13 +20,254 @@ namespace tint { namespace transform { +namespace { +const char* kReservedKeywords[] = {"access", + "alignas", + "alignof", + "and", + "and_eq", + "array", + "array_ref", + "as_type", + "asm", + "atomic", + "atomic_bool", + "atomic_int", + "atomic_uint", + "auto", + "bitand", + "bitor", + "bool", + "bool2", + "bool3", + "bool4", + "break", + "buffer", + "case", + "catch", + "char", + "char16_t", + "char2", + "char3", + "char32_t", + "char4", + "class", + "compl", + "const", + "const_cast", + "const_reference", + "constant", + "constexpr", + "continue", + "decltype", + "default", + "delete", + "depth2d", + "depth2d_array", + "depth2d_ms", + "depth2d_ms_array", + "depthcube", + "depthcube_array", + "device", + "discard_fragment", + "do", + "double", + "dynamic_cast", + "else", + "enum", + "explicit", + "extern", + "false", + "final", + "float", + "float2", + "float2x2", + "float2x3", + "float2x4", + "float3", + "float3x2", + "float3x3", + "float3x4", + "float4", + "float4x2", + "float4x3", + "float4x4", + "for", + "fragment", + "friend", + "goto", + "half", + "half2", + "half2x2", + "half2x3", + "half2x4", + "half3", + "half3x2", + "half3x3", + "half3x4", + "half4", + "half4x2", + "half4x3", + "half4x4", + "if", + "imageblock", + "inline", + "int", + "int16_t", + "int2", + "int3", + "int32_t", + "int4", + "int64_t", + "int8_t", + "kernel", + "long", + "long2", + "long3", + "long4", + "main", + "metal", + "mutable", + "namespace", + "new", + "noexcept", + "not", + "not_eq", + "nullptr", + "operator", + "or", + "or_eq", + "override", + "packed_bool2", + "packed_bool3", + "packed_bool4", + "packed_char2", + "packed_char3", + "packed_char4", + "packed_float2", + "packed_float3", + "packed_float4", + "packed_half2", + "packed_half3", + "packed_half4", + "packed_int2", + "packed_int3", + "packed_int4", + "packed_short2", + "packed_short3", + "packed_short4", + "packed_uchar2", + "packed_uchar3", + "packed_uchar4", + "packed_uint2", + "packed_uint3", + "packed_uint4", + "packed_ushort2", + "packed_ushort3", + "packed_ushort4", + "patch_control_point", + "private", + "protected", + "ptrdiff_t", + "public", + "r16snorm", + "r16unorm", + "r8unorm", + "reference", + "register", + "reinterpret_cast", + "return", + "rg11b10f", + "rg16snorm", + "rg16unorm", + "rg8snorm", + "rg8unorm", + "rgb10a2", + "rgb9e5", + "rgba16snorm", + "rgba16unorm", + "rgba8snorm", + "rgba8unorm", + "sampler", + "short", + "short2", + "short3", + "short4", + "signed", + "size_t", + "sizeof", + "srgba8unorm", + "static", + "static_assert", + "static_cast", + "struct", + "switch", + "template", + "texture", + "texture1d", + "texture1d_array", + "texture2d", + "texture2d_array", + "texture2d_ms", + "texture2d_ms_array", + "texture3d", + "texture_buffer", + "texturecube", + "texturecube_array", + "this", + "thread", + "thread_local", + "threadgroup", + "threadgroup_imageblock", + "throw", + "true", + "try", + "typedef", + "typeid", + "typename", + "uchar", + "uchar2", + "uchar3", + "uchar4", + "uint", + "uint16_t", + "uint2", + "uint3", + "uint32_t", + "uint4", + "uint64_t", + "uint8_t", + "ulong2", + "ulong3", + "ulong4", + "uniform", + "union", + "unsigned", + "ushort", + "ushort2", + "ushort3", + "ushort4", + "using", + "vec", + "vertex", + "virtual", + "void", + "volatile", + "wchar_t", + "while", + "xor", + "xor_eq"}; +} // namespace Msl::Msl() = default; Msl::~Msl() = default; Transform::Output Msl::Run(const Program* in) { ProgramBuilder out; - CloneContext(&out, in).Clone(); + CloneContext ctx(&out, in); + RenameReservedKeywords(&ctx, kReservedKeywords); + ctx.Clone(); + return Output{Program(std::move(out))}; } diff --git a/src/transform/msl_test.cc b/src/transform/msl_test.cc new file mode 100644 index 0000000000..5390ab243c --- /dev/null +++ b/src/transform/msl_test.cc @@ -0,0 +1,334 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/transform/msl.h" + +#include +#include +#include +#include + +#include "src/transform/test_helper.h" + +namespace tint { +namespace transform { +namespace { + +using MslReservedKeywordTest = TransformTestWithParam; + +TEST_F(MslReservedKeywordTest, Basic) { + auto* src = R"( +struct class { + delete : i32; +}; + +[[stage(fragment)]] +fn main() -> void { + var foo : i32; + var half : f32; + var half1 : f32; + var half2 : f32; + var _tint_half2 : f32; +} +)"; + + auto* expect = R"( +struct _tint_class { + _tint_delete : i32; +}; + +[[stage(fragment)]] +fn _tint_main() -> void { + var foo : i32; + var _tint_half : f32; + var half1 : f32; + var _tint_half2_0 : f32; + var _tint_half2 : f32; +} +)"; + + auto got = Transform(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_P(MslReservedKeywordTest, Keywords) { + auto keyword = GetParam(); + + auto src = R"( +[[stage(fragment)]] +fn main() -> void { + var )" + keyword + + R"( : i32; +} +)"; + + auto expect = R"( +[[stage(fragment)]] +fn _tint_main() -> void { + var _tint_)" + keyword + + R"( : i32; +} +)"; + + auto got = Transform(src); + + EXPECT_EQ(expect, str(got)); +} +INSTANTIATE_TEST_SUITE_P(MslReservedKeywordTest, + MslReservedKeywordTest, + testing::Values( + // c++14 spec + "alignas", + "alignof", + "and", + "and_eq", + // "asm", // Also reserved in WGSL + "auto", + "bitand", + "bitor", + // "bool", // Also used in WGSL + // "break", // Also used in WGSL + // "case", // Also used in WGSL + "catch", + "char", + "char16_t", + "char32_t", + "class", + "compl", + // "const", // Also used in WGSL + "const_cast", + "constexpr", + // "continue", // Also used in WGSL + "decltype", + // "default", // Also used in WGSL + "delete", + // "do", // Also used in WGSL + "double", + "dynamic_cast", + // "else", // Also used in WGSL + // "enum", // Also used in WGSL + "explicit", + "extern", + // "false", // Also used in WGSL + "final", + "float", + // "for", // Also used in WGSL + "friend", + "goto", + // "if", // Also used in WGSL + "inline", + "int", + "long", + "mutable", + "namespace", + "new", + "noexcept", + "not", + "not_eq", + "nullptr", + "operator", + "or", + "or_eq", + "override", + // "private", // Also used in WGSL + "protected", + "public", + "register", + "reinterpret_cast", + // "return", // Also used in WGSL + "short", + "signed", + "sizeof", + "static", + "static_assert", + "static_cast", + // "struct", // Also used in WGSL + // "switch", // Also used in WGSL + "template", + "this", + "thread_local", + "throw", + // "true", // Also used in WGSL + "try", + // "typedef", // Also used in WGSL + "typeid", + "typename", + "union", + "unsigned", + "using", + "virtual", + // "void", // Also used in WGSL + "volatile", + "wchar_t", + "while", + "xor", + "xor_eq", + + // Metal Spec + "access", + // "array", // Also used in WGSL + "array_ref", + "as_type", + "atomic", + "atomic_bool", + "atomic_int", + "atomic_uint", + "bool2", + "bool3", + "bool4", + "buffer", + "char2", + "char3", + "char4", + "const_reference", + "constant", + "depth2d", + "depth2d_array", + "depth2d_ms", + "depth2d_ms_array", + "depthcube", + "depthcube_array", + "device", + "discard_fragment", + "float2", + "float2x2", + "float2x3", + "float2x4", + "float3", + "float3x2", + "float3x3", + "float3x4", + "float4", + "float4x2", + "float4x3", + "float4x4", + "fragment", + "half", + "half2", + "half2x2", + "half2x3", + "half2x4", + "half3", + "half3x2", + "half3x3", + "half3x4", + "half4", + "half4x2", + "half4x3", + "half4x4", + "imageblock", + "int16_t", + "int2", + "int3", + "int32_t", + "int4", + "int64_t", + "int8_t", + "kernel", + "long2", + "long3", + "long4", + "main", // No functions called main + "metal", // The namespace + "packed_bool2", + "packed_bool3", + "packed_bool4", + "packed_char2", + "packed_char3", + "packed_char4", + "packed_float2", + "packed_float3", + "packed_float4", + "packed_half2", + "packed_half3", + "packed_half4", + "packed_int2", + "packed_int3", + "packed_int4", + "packed_short2", + "packed_short3", + "packed_short4", + "packed_uchar2", + "packed_uchar3", + "packed_uchar4", + "packed_uint2", + "packed_uint3", + "packed_uint4", + "packed_ushort2", + "packed_ushort3", + "packed_ushort4", + "patch_control_point", + "ptrdiff_t", + "r16snorm", + "r16unorm", + // "r8unorm", // Also used in WGSL + "reference", + "rg11b10f", + "rg16snorm", + "rg16unorm", + // "rg8snorm", // Also used in WGSL + // "rg8unorm", // Also used in WGSL + "rgb10a2", + "rgb9e5", + "rgba16snorm", + "rgba16unorm", + // "rgba8snorm", // Also used in WGSL + // "rgba8unorm", // Also used in WGSL + // "sampler", // Also used in WGSL + "short2", + "short3", + "short4", + "size_t", + "srgba8unorm", + "texture", + "texture1d", + "texture1d_array", + "texture2d", + "texture2d_array", + "texture2d_ms", + "texture2d_ms_array", + "texture3d", + "texture_buffer", + "texturecube", + "texturecube_array", + "thread", + "threadgroup", + "threadgroup_imageblock", + "uchar", + "uchar2", + "uchar3", + "uchar4", + "uint", + "uint16_t", + "uint2", + "uint3", + "uint32_t", + "uint4", + "uint64_t", + "uint8_t", + "ulong2", + "ulong3", + "ulong4", + // "uniform", // Also used in WGSL + "ushort", + "ushort2", + "ushort3", + "ushort4", + "vec", + "vertex")); + +} // namespace +} // namespace transform +} // namespace tint diff --git a/src/transform/test_helper.h b/src/transform/test_helper.h index ae78d51ec8..191f1d89b1 100644 --- a/src/transform/test_helper.h +++ b/src/transform/test_helper.h @@ -32,7 +32,8 @@ namespace tint { namespace transform { /// Helper class for testing transforms -class TransformTest : public testing::Test { +template +class TransformTestBase : public BASE { public: /// Transforms and returns the WGSL source `in`, transformed using /// `transforms`. @@ -42,8 +43,11 @@ class TransformTest : public testing::Test { Transform::Output Transform( std::string in, std::vector> transforms) { - Source::File file("test", in); - auto program = reader::wgsl::Parse(&file); + auto file = std::make_unique("test", in); + auto program = reader::wgsl::Parse(file.get()); + + // Keep this pointer alive after Transform() returns + files_.emplace_back(std::move(file)); if (!program.IsValid()) { return Transform::Output(std::move(program)); @@ -108,8 +112,16 @@ class TransformTest : public testing::Test { } return "\n" + res + "\n"; } + + private: + std::vector> files_; }; +using TransformTest = TransformTestBase; + +template +using TransformTestWithParam = TransformTestBase>; + } // namespace transform } // namespace tint diff --git a/src/transform/transform.cc b/src/transform/transform.cc index 36dbc8f7de..e0ab4c9db7 100644 --- a/src/transform/transform.cc +++ b/src/transform/transform.cc @@ -14,6 +14,8 @@ #include "src/transform/transform.h" +#include + #include "src/ast/block_statement.h" #include "src/ast/function.h" #include "src/clone_context.h" @@ -63,5 +65,23 @@ ast::Function* Transform::CloneWithStatementsAtStart( body, decos); } +void Transform::RenameReservedKeywords(CloneContext* ctx, + const char* names[], + size_t count) { + ctx->ReplaceAll([=](Symbol in) { + auto name_in = ctx->src->Symbols().NameFor(in); + if (!std::binary_search(names, names + count, name_in)) { + return ctx->dst->Symbols().Register(name_in); + } + // Create a new unique name + auto base_name = "_tint_" + name_in; + auto name_out = base_name; + for (int i = 0; ctx->src->Symbols().Get(name_out).IsValid(); i++) { + name_out = base_name + "_" + std::to_string(i); + } + return ctx->dst->Symbols().Register(name_out); + }); +} + } // namespace transform } // namespace tint diff --git a/src/transform/transform.h b/src/transform/transform.h index 47a950acf5..c82bcf65a3 100644 --- a/src/transform/transform.h +++ b/src/transform/transform.h @@ -154,6 +154,25 @@ class Transform { CloneContext* ctx, ast::Function* in, ast::StatementList statements); + + /// Registers a symbol renamer on `ctx` for any symbol that is found in the + /// list of reserved identifiers. + /// @param ctx the clone context + /// @param names the lexicographically sorted list of reserved identifiers + /// @param count the number of identifiers in the array `names` + static void RenameReservedKeywords(CloneContext* ctx, + const char* names[], + size_t count); + + /// Registers a symbol renamer on `ctx` for any symbol that is found in the + /// list of reserved identifiers. + /// @param ctx the clone context + /// @param names the lexicographically sorted list of reserved identifiers + template + static void RenameReservedKeywords(CloneContext* ctx, + const char* (&names)[N]) { + RenameReservedKeywords(ctx, names, N); + } }; } // namespace transform