diff --git a/generator/templates/dawn/native/api_StreamImpl.cpp b/generator/templates/dawn/native/api_StreamImpl.cpp index 82b64043c2..e6617aacb9 100644 --- a/generator/templates/dawn/native/api_StreamImpl.cpp +++ b/generator/templates/dawn/native/api_StreamImpl.cpp @@ -25,7 +25,15 @@ namespace {{native_namespace}} { // -// Cache key writers for wgpu structures used in caching. +// Streaming readers for wgpu structures. +// +{% macro render_reader(member) %} + {%- set name = member.name.camelCase() -%} + DAWN_TRY(StreamOut(source, &t->{{name}})); +{% endmacro %} + +// +// Streaming writers for wgpu structures. // {% macro render_writer(member) %} {%- set name = member.name.camelCase() -%} @@ -38,31 +46,50 @@ namespace {{native_namespace}} { {% endif %} {% endmacro %} -{# Helper macro to render writers. Should be used in a call block to provide additional custom +{# Helper macro to render readers and writers. Should be used in a call block to provide additional custom handling when necessary. The optional `omit` field can be used to omit fields that are either handled in the custom code, or unnecessary in the serialized output. Example: - {% call render_cache_key_writer("struct name", omits=["omit field"]) %} + {% call render_streaming_impl("struct name", writer=true, reader=false, omits=["omit field"]) %} // Custom C++ code to handle special types/members that are hard to generate code for {% endcall %} + One day we should probably make the generator smart enough to generate everything it can + instead of manually adding streaming implementations here. #} -{% macro render_cache_key_writer(json_type, omits=[]) %} +{% macro render_streaming_impl(json_type, writer, reader, omits=[]) %} {%- set cpp_type = types[json_type].name.CamelCase() -%} - template <> - void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) { - {{ caller() }} - {% for member in types[json_type].members %} - {%- if not member.name.get() in omits %} - {{render_writer(member)}} - {%- endif %} - {% endfor %} - } + {% if reader %} + template <> + MaybeError stream::Stream<{{cpp_type}}>::Read(stream::Source* source, {{cpp_type}}* t) { + {{ caller() }} + {% for member in types[json_type].members %} + {% if not member.name.get() in omits %} + {{render_reader(member)}} + {% endif %} + {% endfor %} + return {}; + } + {% endif %} + {% if writer %} + template <> + void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) { + {{ caller() }} + {% for member in types[json_type].members %} + {% if not member.name.get() in omits %} + {{render_writer(member)}} + {% endif %} + {% endfor %} + } + {% endif %} {% endmacro %} -{% call render_cache_key_writer("adapter properties") %} +{% call render_streaming_impl("adapter properties", true, false) %} {% endcall %} -{% call render_cache_key_writer("dawn cache device descriptor") %} +{% call render_streaming_impl("dawn cache device descriptor", true, false) %} +{% endcall %} + +{% call render_streaming_impl("extent 3D", true, true) %} {% endcall %} } // namespace {{native_namespace}} diff --git a/src/dawn/native/Limits.cpp b/src/dawn/native/Limits.cpp index ef285b2a4d..215609c126 100644 --- a/src/dawn/native/Limits.cpp +++ b/src/dawn/native/Limits.cpp @@ -215,4 +215,22 @@ Limits ApplyLimitTiers(Limits limits) { return limits; } +#define DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT(type, name) \ + { result.name = limits.name; } +#define DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(MEMBERS) \ + MEMBERS(DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT) +LimitsForCompilationRequest LimitsForCompilationRequest::Create(const Limits& limits) { + LimitsForCompilationRequest result; + DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS) + return result; +} +#undef DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT +#undef DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT + +template <> +void stream::Stream::Write(Sink* s, + const LimitsForCompilationRequest& t) { + t.VisitAll([&](const auto&... members) { StreamIn(s, members...); }); +} + } // namespace dawn::native diff --git a/src/dawn/native/Limits.h b/src/dawn/native/Limits.h index cc81742053..8091a31cd2 100644 --- a/src/dawn/native/Limits.h +++ b/src/dawn/native/Limits.h @@ -16,6 +16,7 @@ #define SRC_DAWN_NATIVE_LIMITS_H_ #include "dawn/native/Error.h" +#include "dawn/native/VisitableMembers.h" #include "dawn/native/dawn_platform.h" namespace dawn::native { @@ -38,6 +39,20 @@ MaybeError ValidateLimits(const Limits& supportedLimits, const Limits& requiredL // Returns a copy of |limits| where limit tiers are applied. Limits ApplyLimitTiers(Limits limits); +// If there are new limit member needed at shader compilation time +// Simply append a new X(type, name) here. +#define LIMITS_FOR_COMPILATION_REQUEST_MEMBERS(X) \ + X(uint32_t, maxComputeWorkgroupSizeX) \ + X(uint32_t, maxComputeWorkgroupSizeY) \ + X(uint32_t, maxComputeWorkgroupSizeZ) \ + X(uint32_t, maxComputeInvocationsPerWorkgroup) \ + X(uint32_t, maxComputeWorkgroupStorageSize) + +struct LimitsForCompilationRequest { + static LimitsForCompilationRequest Create(const Limits& limits); + DAWN_VISITABLE_MEMBERS(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS) +}; + } // namespace dawn::native #endif // SRC_DAWN_NATIVE_LIMITS_H_ diff --git a/src/dawn/native/ObjectContentHasher.h b/src/dawn/native/ObjectContentHasher.h index 4211fb3514..40ba602649 100644 --- a/src/dawn/native/ObjectContentHasher.h +++ b/src/dawn/native/ObjectContentHasher.h @@ -15,7 +15,9 @@ #ifndef SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_ #define SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_ +#include #include +#include #include #include "dawn/common/HashUtils.h" @@ -60,6 +62,13 @@ class ObjectContentHasher { } }; + template + struct RecordImpl> { + static constexpr void Call(ObjectContentHasher* recorder, const std::map& map) { + recorder->RecordIterable>(map); + } + }; + template constexpr void RecordIterable(const IteratorT& iterable) { for (auto it = iterable.begin(); it != iterable.end(); ++it) { @@ -67,6 +76,14 @@ class ObjectContentHasher { } } + template + struct RecordImpl> { + static constexpr void Call(ObjectContentHasher* recorder, const std::pair& pair) { + recorder->Record(pair.first); + recorder->Record(pair.second); + } + }; + size_t mContentHash = 0; }; diff --git a/src/dawn/native/Pipeline.cpp b/src/dawn/native/Pipeline.cpp index 50a4b9c105..dc55f637d1 100644 --- a/src/dawn/native/Pipeline.cpp +++ b/src/dawn/native/Pipeline.cpp @@ -58,12 +58,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device, DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout)); } - if (constantCount > 0u && device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs)) { - return DAWN_VALIDATION_ERROR( - "Pipeline overridable constants are disallowed because they are partially " - "implemented."); - } - // Validate if overridable constants exist in shader module // pipelineBase is not yet constructed at this moment so iterate constants from descriptor size_t numUninitializedConstants = metadata.uninitializedOverrides.size(); @@ -233,6 +227,7 @@ size_t PipelineBase::ComputeContentHash() { for (SingleShaderStage stage : IterateStages(mStageMask)) { recorder.Record(mStages[stage].module->GetContentHash()); recorder.Record(mStages[stage].entryPoint); + recorder.Record(mStages[stage].constants); } return recorder.GetContentHash(); @@ -248,7 +243,14 @@ bool PipelineBase::EqualForCache(const PipelineBase* a, const PipelineBase* b) { for (SingleShaderStage stage : IterateStages(a->mStageMask)) { // The module is deduplicated so it can be compared by pointer. if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() || - a->mStages[stage].entryPoint != b->mStages[stage].entryPoint) { + a->mStages[stage].entryPoint != b->mStages[stage].entryPoint || + a->mStages[stage].constants.size() != b->mStages[stage].constants.size()) { + return false; + } + + // If the constants.size are the same, we still need to compare the key and value. + if (!std::equal(a->mStages[stage].constants.begin(), a->mStages[stage].constants.end(), + b->mStages[stage].constants.begin())) { return false; } } diff --git a/src/dawn/native/Pipeline.h b/src/dawn/native/Pipeline.h index 2d5b6dfc65..adf8f8d107 100644 --- a/src/dawn/native/Pipeline.h +++ b/src/dawn/native/Pipeline.h @@ -40,9 +40,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device, const PipelineLayoutBase* layout, SingleShaderStage stage); -// Use map to make sure constant keys are sorted for creating shader cache keys -using PipelineConstantEntries = std::map; - struct ProgrammableStage { Ref module; std::string entryPoint; diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp index 6769f7e4ff..e527634b2e 100644 --- a/src/dawn/native/ShaderModule.cpp +++ b/src/dawn/native/ShaderModule.cpp @@ -20,7 +20,6 @@ #include "absl/strings/str_format.h" #include "dawn/common/BitSetIterator.h" #include "dawn/common/Constants.h" -#include "dawn/common/HashUtils.h" #include "dawn/native/BindGroupLayout.h" #include "dawn/native/ChainUtils_autogen.h" #include "dawn/native/CompilationMessages.h" @@ -511,7 +510,6 @@ ResultOrError> ReflectEntryPointUsingTint( const DeviceBase* device, tint::inspector::Inspector* inspector, const tint::inspector::EntryPoint& entryPoint) { - const CombinedLimits& limits = device->GetLimits(); constexpr uint32_t kMaxInterStageShaderLocation = kMaxInterStageShaderVariables - 1; std::unique_ptr metadata = std::make_unique(); @@ -528,10 +526,6 @@ ResultOrError> ReflectEntryPointUsingTint( })() if (!entryPoint.overrides.empty()) { - DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs), - "Pipeline overridable constants are disallowed because they " - "are partially implemented."); - const auto& name2Id = inspector->GetNamedOverrideIds(); const auto& id2Scalar = inspector->GetOverrideDefaultValues(); @@ -553,10 +547,10 @@ ResultOrError> ReflectEntryPointUsingTint( UNREACHABLE(); } } - EntryPointMetadata::Override override = {id.value, FromTintOverrideType(c.type), + EntryPointMetadata::Override override = {id, FromTintOverrideType(c.type), c.is_initialized, defaultValue}; - std::string identifier = c.is_id_specified ? std::to_string(override.id) : c.name; + std::string identifier = c.is_id_specified ? std::to_string(override.id.value) : c.name; metadata->overrides[identifier] = override; if (!c.is_initialized) { @@ -575,39 +569,6 @@ ResultOrError> ReflectEntryPointUsingTint( DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage)); if (metadata->stage == SingleShaderStage::Compute) { - auto workgroup_size = entryPoint.workgroup_size; - DAWN_INVALID_IF( - !workgroup_size.has_value(), - "TODO(crbug.com/dawn/1504): Dawn does not currently support @workgroup_size " - "attributes using override-expressions"); - DelayedInvalidIf(workgroup_size->x > limits.v1.maxComputeWorkgroupSizeX || - workgroup_size->y > limits.v1.maxComputeWorkgroupSizeY || - workgroup_size->z > limits.v1.maxComputeWorkgroupSizeZ, - "Entry-point uses workgroup_size(%u, %u, %u) that exceeds the " - "maximum allowed (%u, %u, %u).", - workgroup_size->x, workgroup_size->y, workgroup_size->z, - limits.v1.maxComputeWorkgroupSizeX, limits.v1.maxComputeWorkgroupSizeY, - limits.v1.maxComputeWorkgroupSizeZ); - - // Dimensions have already been validated against their individual limits above. - // Cast to uint64_t to avoid overflow in this multiplication. - uint64_t numInvocations = - static_cast(workgroup_size->x) * workgroup_size->y * workgroup_size->z; - DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup, - "The total number of workgroup invocations (%u) exceeds the " - "maximum allowed (%u).", - numInvocations, limits.v1.maxComputeInvocationsPerWorkgroup); - - const size_t workgroupStorageSize = inspector->GetWorkgroupStorageSize(entryPoint.name); - DelayedInvalidIf(workgroupStorageSize > limits.v1.maxComputeWorkgroupStorageSize, - "The total use of workgroup storage (%u bytes) is larger than " - "the maximum allowed (%u bytes).", - workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize); - - metadata->localWorkgroupSize.x = workgroup_size->x; - metadata->localWorkgroupSize.y = workgroup_size->y; - metadata->localWorkgroupSize.z = workgroup_size->z; - metadata->usesNumWorkgroups = entryPoint.num_workgroups_used; } @@ -883,6 +844,46 @@ MaybeError ReflectShaderUsingTint(const DeviceBase* device, } } // anonymous namespace +ResultOrError ValidateComputeStageWorkgroupSize( + const tint::Program& program, + const char* entryPointName, + const LimitsForCompilationRequest& limits) { + tint::inspector::Inspector inspector(&program); + // At this point the entry point must exist and must have workgroup size values. + tint::inspector::EntryPoint entryPoint = inspector.GetEntryPoint(entryPointName); + ASSERT(entryPoint.workgroup_size.has_value()); + const tint::inspector::WorkgroupSize& workgroup_size = entryPoint.workgroup_size.value(); + + DAWN_INVALID_IF(workgroup_size.x < 1 || workgroup_size.y < 1 || workgroup_size.z < 1, + "Entry-point uses workgroup_size(%u, %u, %u) that are below the " + "minimum allowed (1, 1, 1).", + workgroup_size.x, workgroup_size.y, workgroup_size.z); + + DAWN_INVALID_IF(workgroup_size.x > limits.maxComputeWorkgroupSizeX || + workgroup_size.y > limits.maxComputeWorkgroupSizeY || + workgroup_size.z > limits.maxComputeWorkgroupSizeZ, + "Entry-point uses workgroup_size(%u, %u, %u) that exceeds the " + "maximum allowed (%u, %u, %u).", + workgroup_size.x, workgroup_size.y, workgroup_size.z, + limits.maxComputeWorkgroupSizeX, limits.maxComputeWorkgroupSizeY, + limits.maxComputeWorkgroupSizeZ); + + uint64_t numInvocations = + static_cast(workgroup_size.x) * workgroup_size.y * workgroup_size.z; + DAWN_INVALID_IF(numInvocations > limits.maxComputeInvocationsPerWorkgroup, + "The total number of workgroup invocations (%u) exceeds the " + "maximum allowed (%u).", + numInvocations, limits.maxComputeInvocationsPerWorkgroup); + + const size_t workgroupStorageSize = inspector.GetWorkgroupStorageSize(entryPointName); + DAWN_INVALID_IF(workgroupStorageSize > limits.maxComputeWorkgroupStorageSize, + "The total use of workgroup storage (%u bytes) is larger than " + "the maximum allowed (%u bytes).", + workgroupStorageSize, limits.maxComputeWorkgroupStorageSize); + + return Extent3D{workgroup_size.x, workgroup_size.y, workgroup_size.z}; +} + ShaderModuleParseResult::ShaderModuleParseResult() = default; ShaderModuleParseResult::~ShaderModuleParseResult() = default; @@ -1200,11 +1201,4 @@ MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult return {}; } -size_t PipelineLayoutEntryPointPairHashFunc::operator()( - const PipelineLayoutEntryPointPair& pair) const { - size_t hash = 0; - HashCombine(&hash, pair.first, pair.second); - return hash; -} - } // namespace dawn::native diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h index 170b388000..b04e9ed7e1 100644 --- a/src/dawn/native/ShaderModule.h +++ b/src/dawn/native/ShaderModule.h @@ -33,10 +33,12 @@ #include "dawn/native/Format.h" #include "dawn/native/Forward.h" #include "dawn/native/IntegerTypes.h" +#include "dawn/native/Limits.h" #include "dawn/native/ObjectBase.h" #include "dawn/native/PerStage.h" #include "dawn/native/VertexFormat.h" #include "dawn/native/dawn_platform.h" +#include "tint/override_id.h" namespace tint { @@ -76,10 +78,8 @@ enum class InterpolationSampling { Sample, }; -using PipelineLayoutEntryPointPair = std::pair; -struct PipelineLayoutEntryPointPairHashFunc { - size_t operator()(const PipelineLayoutEntryPointPair& pair) const; -}; +// Use map to make sure constant keys are sorted for creating shader cache keys +using PipelineConstantEntries = std::map; // A map from name to EntryPointMetadata. using EntryPointMetadataTable = @@ -108,6 +108,13 @@ MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device, const EntryPointMetadata& entryPoint, const PipelineLayoutBase* layout); +// Return extent3D with workgroup size dimension info if it is valid +// width = x, height = y, depthOrArrayLength = z +ResultOrError ValidateComputeStageWorkgroupSize( + const tint::Program& program, + const char* entryPointName, + const LimitsForCompilationRequest& limits); + RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint, const PipelineLayoutBase* layout); ResultOrError RunTransforms(tint::transform::Transform* transform, @@ -204,14 +211,12 @@ struct EntryPointMetadata { std::bitset usedInterStageVariables; std::array interStageVariables; - // The local workgroup size declared for a compute entry point (or 0s otehrwise). - Origin3D localWorkgroupSize; - // The shader stage for this binding. SingleShaderStage stage; struct Override { - uint32_t id; + tint::OverrideId id; + // Match tint::inspector::Override::Type // Bool is defined as a macro on linux X11 and cannot compile enum class Type { Boolean, Float32, Uint32, Int32 } type; @@ -273,6 +278,7 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject { bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const; }; + // This returns tint program before running transforms. const tint::Program* GetTintProgram() const; void APIGetCompilationInfo(wgpu::CompilationInfoCallback callback, void* userdata); diff --git a/src/dawn/native/StreamImplTint.cpp b/src/dawn/native/StreamImplTint.cpp index deeaa0f11a..5f077ec129 100644 --- a/src/dawn/native/StreamImplTint.cpp +++ b/src/dawn/native/StreamImplTint.cpp @@ -65,6 +65,23 @@ void stream::Stream::Write( StreamInTintObject(cfg, sink); } +// static +template <> +void stream::Stream::Write( + stream::Sink* sink, + const tint::transform::SubstituteOverride::Config& cfg) { + StreamInTintObject(cfg, sink); +} + +// static +template <> +void stream::Stream::Write(stream::Sink* sink, const tint::OverrideId& id) { + // TODO(tint:1640): fix the include build issues and use StreamInTintObject instead. + static_assert(offsetof(tint::OverrideId, value) == 0, + "Please update serialization for tint::OverrideId"); + StreamIn(sink, id.value); +} + // static template <> void stream::Stream::Write( diff --git a/src/dawn/native/TintUtils.cpp b/src/dawn/native/TintUtils.cpp index 24d4cc6afb..a2cf5dc4e7 100644 --- a/src/dawn/native/TintUtils.cpp +++ b/src/dawn/native/TintUtils.cpp @@ -16,6 +16,7 @@ #include "dawn/native/BindGroupLayout.h" #include "dawn/native/Device.h" +#include "dawn/native/Pipeline.h" #include "dawn/native/PipelineLayout.h" #include "dawn/native/RenderPipeline.h" @@ -183,6 +184,21 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( return cfg; } +tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig( + const ProgrammableStage& stage) { + const EntryPointMetadata& metadata = *stage.metadata; + const auto& constants = stage.constants; + + tint::transform::SubstituteOverride::Config cfg; + + for (const auto& [key, value] : constants) { + const auto& o = metadata.overrides.at(key); + cfg.map.insert({o.id, value}); + } + + return cfg; +} + } // namespace dawn::native namespace tint::sem { diff --git a/src/dawn/native/TintUtils.h b/src/dawn/native/TintUtils.h index 7c03881502..d6bea1d88f 100644 --- a/src/dawn/native/TintUtils.h +++ b/src/dawn/native/TintUtils.h @@ -26,6 +26,7 @@ namespace dawn::native { class DeviceBase; class PipelineLayoutBase; +struct ProgrammableStage; class RenderPipelineBase; // Indicates that for the lifetime of this object tint internal compiler errors should be @@ -47,6 +48,9 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( const std::string_view& entryPoint, BindGroupIndex pullingBufferBindingSet); +tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig( + const ProgrammableStage& stage); + } // namespace dawn::native namespace tint::sem { diff --git a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp index b1da726691..80bc6fce9c 100644 --- a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp +++ b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp @@ -351,8 +351,6 @@ MaybeError RenderPipeline::Initialize() { D3D12_GRAPHICS_PIPELINE_STATE_DESC descriptorD3D12 = {}; - PerStage pipelineStages = GetAllStages(); - PerStage shaders; shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS; shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS; @@ -360,8 +358,9 @@ MaybeError RenderPipeline::Initialize() { PerStage compiledShader; for (auto stage : IterateStages(GetStageMask())) { - DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(pipelineStages[stage].module) - ->Compile(pipelineStages[stage], stage, + const ProgrammableStage& programmableStage = GetStage(stage); + DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(programmableStage.module) + ->Compile(programmableStage, stage, ToBackend(GetLayout()), compileFlags)); *shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode(); } diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp index 9e1b9e9a11..bb0077aa29 100644 --- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp @@ -65,100 +65,35 @@ namespace dawn::native::d3d12 { namespace { -// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough -std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) { - std::ostringstream out; - out.precision(n); - out << std::fixed << v; - return out.str(); -} - -std::string GetHLSLValueString(EntryPointMetadata::Override::Type dawnType, - const OverrideScalar* entry, - double value = 0) { - switch (dawnType) { - case EntryPointMetadata::Override::Type::Boolean: - return std::to_string(entry ? entry->b : static_cast(value)); - case EntryPointMetadata::Override::Type::Float32: - return FloatToStringWithPrecision(entry ? entry->f32 : static_cast(value)); - case EntryPointMetadata::Override::Type::Int32: - return std::to_string(entry ? entry->i32 : static_cast(value)); - case EntryPointMetadata::Override::Type::Uint32: - return std::to_string(entry ? entry->u32 : static_cast(value)); - default: - UNREACHABLE(); - } -} - -constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_"; - -using DefineStrings = std::vector>; - -DefineStrings GetOverridableConstantsDefines( - const PipelineConstantEntries& pipelineConstantEntries, - const EntryPointMetadata::OverridesMap& shaderEntryPointConstants) { - DefineStrings defineStrings; - std::unordered_set overriddenConstants; - - // Set pipeline overridden values - for (const auto& [name, value] : pipelineConstantEntries) { - overriddenConstants.insert(name); - - // This is already validated so `name` must exist - const auto& moduleConstant = shaderEntryPointConstants.at(name); - - defineStrings.emplace_back( - kSpecConstantPrefix + std::to_string(static_cast(moduleConstant.id)), - GetHLSLValueString(moduleConstant.type, nullptr, value)); - } - - // Set shader initialized default values - for (const auto& iter : shaderEntryPointConstants) { - const std::string& name = iter.first; - if (overriddenConstants.count(name) != 0) { - // This constant already has overridden value - continue; - } - - const auto& moduleConstant = shaderEntryPointConstants.at(name); - - // Uninitialized default values are okay since they ar only defined to pass - // compilation but not used - defineStrings.emplace_back( - kSpecConstantPrefix + std::to_string(static_cast(moduleConstant.id)), - GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue)); - } - return defineStrings; -} - enum class Compiler { FXC, DXC }; -#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \ - X(const tint::Program*, inputProgram) \ - X(std::string_view, entryPointName) \ - X(SingleShaderStage, stage) \ - X(uint32_t, shaderModel) \ - X(uint32_t, compileFlags) \ - X(Compiler, compiler) \ - X(uint64_t, compilerVersion) \ - X(std::wstring_view, dxcShaderProfile) \ - X(std::string_view, fxcShaderProfile) \ - X(pD3DCompile, d3dCompile) \ - X(IDxcLibrary*, dxcLibrary) \ - X(IDxcCompiler*, dxcCompiler) \ - X(uint32_t, firstIndexOffsetShaderRegister) \ - X(uint32_t, firstIndexOffsetRegisterSpace) \ - X(bool, usesNumWorkgroups) \ - X(uint32_t, numWorkgroupsShaderRegister) \ - X(uint32_t, numWorkgroupsRegisterSpace) \ - X(DefineStrings, defineStrings) \ - X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \ - X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \ - X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \ - X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \ - X(bool, disableSymbolRenaming) \ - X(bool, isRobustnessEnabled) \ - X(bool, disableWorkgroupInit) \ +#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \ + X(const tint::Program*, inputProgram) \ + X(std::string_view, entryPointName) \ + X(SingleShaderStage, stage) \ + X(uint32_t, shaderModel) \ + X(uint32_t, compileFlags) \ + X(Compiler, compiler) \ + X(uint64_t, compilerVersion) \ + X(std::wstring_view, dxcShaderProfile) \ + X(std::string_view, fxcShaderProfile) \ + X(pD3DCompile, d3dCompile) \ + X(IDxcLibrary*, dxcLibrary) \ + X(IDxcCompiler*, dxcCompiler) \ + X(uint32_t, firstIndexOffsetShaderRegister) \ + X(uint32_t, firstIndexOffsetRegisterSpace) \ + X(bool, usesNumWorkgroups) \ + X(uint32_t, numWorkgroupsShaderRegister) \ + X(uint32_t, numWorkgroupsRegisterSpace) \ + X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \ + X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \ + X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \ + X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \ + X(std::optional, substituteOverrideConfig) \ + X(LimitsForCompilationRequest, limits) \ + X(bool, disableSymbolRenaming) \ + X(bool, isRobustnessEnabled) \ + X(bool, disableWorkgroupInit) \ X(bool, dumpShaders) #define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \ @@ -170,8 +105,7 @@ enum class Compiler { FXC, DXC }; X(std::string_view, fxcShaderProfile) \ X(pD3DCompile, d3dCompile) \ X(IDxcLibrary*, dxcLibrary) \ - X(IDxcCompiler*, dxcCompiler) \ - X(DefineStrings, defineStrings) + X(IDxcCompiler*, dxcCompiler) DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){}; #undef HLSL_COMPILATION_REQUEST_MEMBERS @@ -255,25 +189,11 @@ ResultOrError> CompileShaderDXC(const D3DBytecodeCompilationReq std::vector arguments = GetDXCArguments(r.compileFlags, r.hasShaderFloat16Feature); - // Build defines for overridable constants - std::vector> defineStrings; - defineStrings.reserve(r.defineStrings.size()); - for (const auto& [name, value] : r.defineStrings) { - defineStrings.emplace_back(UTF8ToWStr(name.c_str()), UTF8ToWStr(value.c_str())); - } - - std::vector dxcDefines; - dxcDefines.reserve(defineStrings.size()); - for (const auto& [name, value] : defineStrings) { - dxcDefines.push_back({name.c_str(), value.c_str()}); - } - ComPtr result; - DAWN_TRY(CheckHRESULT( - r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), - r.dxcShaderProfile.data(), arguments.data(), arguments.size(), - dxcDefines.data(), dxcDefines.size(), nullptr, &result), - "DXC compile")); + DAWN_TRY(CheckHRESULT(r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), + r.dxcShaderProfile.data(), arguments.data(), + arguments.size(), nullptr, 0, nullptr, &result), + "DXC compile")); HRESULT hr; DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status")); @@ -359,20 +279,7 @@ ResultOrError> CompileShaderFXC(const D3DBytecodeCompilationReq ComPtr compiledShader; ComPtr errors; - // Build defines for overridable constants - const D3D_SHADER_MACRO* pDefines = nullptr; - std::vector fxcDefines; - if (r.defineStrings.size() > 0) { - fxcDefines.reserve(r.defineStrings.size() + 1); - for (const auto& [name, value] : r.defineStrings) { - fxcDefines.push_back({name.c_str(), value.c_str()}); - } - // d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array - fxcDefines.push_back({nullptr, nullptr}); - pDefines = fxcDefines.data(); - } - - DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines, + DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr, nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(), r.compileFlags, 0, &compiledShader, &errors)), "D3D compile failed with: %s", static_cast(errors->GetBufferPointer())); @@ -420,6 +327,14 @@ ResultOrError TranslateToHLSL( tint::transform::Renamer::Target::kHlslKeywords); } + if (r.substituteOverrideConfig) { + // This needs to run after SingleEntryPoint transform to get rid of overrides not used for + // the current entry point. + transformManager.Add(); + transformInputs.Add( + std::move(r.substituteOverrideConfig).value()); + } + // D3D12 registers like `t3` and `c3` have the same bindingOffset number in // the remapping but should not be considered a collision because they have // different types. @@ -450,6 +365,13 @@ ResultOrError TranslateToHLSL( return DAWN_VALIDATION_ERROR("Transform output missing renamer data."); } + if (r.stage == SingleShaderStage::Compute) { + // Validate workgroup size after program runs transforms. + Extent3D _; + DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize( + transformedProgram, remappedEntryPointName->data(), r.limits)); + } + if (r.stage == SingleShaderStage::Vertex) { if (auto* data = transformOutputs.Get()) { *usesVertexOrInstanceIndex = data->has_vertex_or_instance_index; @@ -555,8 +477,7 @@ ResultOrError ShaderModule::Compile(const ProgrammableStage& pro req.bytecode.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16); req.bytecode.compileFlags = compileFlags; - req.bytecode.defineStrings = - GetOverridableConstantsDefines(programmableStage.constants, entryPoint.overrides); + if (device->IsToggleEnabled(Toggle::UseDXC)) { req.bytecode.compiler = Compiler::DXC; req.bytecode.dxcLibrary = device->GetDxcLibrary().Get(); @@ -645,6 +566,11 @@ ResultOrError ShaderModule::Compile(const ProgrammableStage& pro } } + std::optional substituteOverrideConfig; + if (!programmableStage.metadata->overrides.empty()) { + substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage); + } + req.hlsl.inputProgram = GetTintProgram(); req.hlsl.entryPointName = programmableStage.entryPoint.c_str(); req.hlsl.stage = stage; @@ -657,6 +583,10 @@ ResultOrError ShaderModule::Compile(const ProgrammableStage& pro req.hlsl.remappedAccessControls = std::move(remappedAccessControls); req.hlsl.newBindingsMap = BuildExternalTextureTransformBindings(layout); req.hlsl.arrayLengthFromUniform = std::move(arrayLengthFromUniform); + req.hlsl.substituteOverrideConfig = std::move(substituteOverrideConfig); + + const CombinedLimits& limits = device->GetLimits(); + req.hlsl.limits = LimitsForCompilationRequest::Create(limits.v1); CacheResult compiledShader; DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob, diff --git a/src/dawn/native/metal/ComputePipelineMTL.mm b/src/dawn/native/metal/ComputePipelineMTL.mm index 855cd7b4bf..3ed84dcc13 100644 --- a/src/dawn/native/metal/ComputePipelineMTL.mm +++ b/src/dawn/native/metal/ComputePipelineMTL.mm @@ -40,8 +40,9 @@ MaybeError ComputePipeline::Initialize() { const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute); ShaderModule::MetalFunctionData computeData; - DAWN_TRY(CreateMTLFunction(computeStage, SingleShaderStage::Compute, ToBackend(GetLayout()), - &computeData)); + DAWN_TRY(ToBackend(computeStage.module.Get()) + ->CreateFunction(SingleShaderStage::Compute, computeStage, ToBackend(GetLayout()), + &computeData)); NSError* error = nullptr; mMtlComputePipelineState.Acquire( @@ -53,8 +54,7 @@ MaybeError ComputePipeline::Initialize() { ASSERT(mMtlComputePipelineState != nil); // Copy over the local workgroup size as it is passed to dispatch explicitly in Metal - Origin3D localSize = GetStage(SingleShaderStage::Compute).metadata->localWorkgroupSize; - mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z); + mLocalWorkgroupSize = computeData.localWorkgroupSize; mRequiresStorageBufferLength = computeData.needsStorageBufferLength; mWorkgroupAllocations = std::move(computeData.workgroupAllocations); diff --git a/src/dawn/native/metal/RenderPipelineMTL.mm b/src/dawn/native/metal/RenderPipelineMTL.mm index 6e5afd54fe..9b9ccb4295 100644 --- a/src/dawn/native/metal/RenderPipelineMTL.mm +++ b/src/dawn/native/metal/RenderPipelineMTL.mm @@ -340,8 +340,9 @@ MaybeError RenderPipeline::Initialize() { const PerStage& allStages = GetAllStages(); const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex]; ShaderModule::MetalFunctionData vertexData; - DAWN_TRY(CreateMTLFunction(vertexStage, SingleShaderStage::Vertex, ToBackend(GetLayout()), - &vertexData, 0xFFFFFFFF, this)); + DAWN_TRY(ToBackend(vertexStage.module.Get()) + ->CreateFunction(SingleShaderStage::Vertex, vertexStage, ToBackend(GetLayout()), + &vertexData, 0xFFFFFFFF, this)); descriptorMTL.vertexFunction = vertexData.function.Get(); if (vertexData.needsStorageBufferLength) { @@ -351,8 +352,9 @@ MaybeError RenderPipeline::Initialize() { if (GetStageMask() & wgpu::ShaderStage::Fragment) { const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment]; ShaderModule::MetalFunctionData fragmentData; - DAWN_TRY(CreateMTLFunction(fragmentStage, SingleShaderStage::Fragment, - ToBackend(GetLayout()), &fragmentData, GetSampleMask())); + DAWN_TRY(ToBackend(fragmentStage.module.Get()) + ->CreateFunction(SingleShaderStage::Fragment, fragmentStage, + ToBackend(GetLayout()), &fragmentData, GetSampleMask())); descriptorMTL.fragmentFunction = fragmentData.function.Get(); if (fragmentData.needsStorageBufferLength) { diff --git a/src/dawn/native/metal/ShaderModuleMTL.h b/src/dawn/native/metal/ShaderModuleMTL.h index 27f1def213..d79dbb3a72 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.h +++ b/src/dawn/native/metal/ShaderModuleMTL.h @@ -25,6 +25,10 @@ #import +namespace dawn::native { +struct ProgrammableStage; +} + namespace dawn::native::metal { class Device; @@ -42,15 +46,13 @@ class ShaderModule final : public ShaderModuleBase { NSPRef> function; bool needsStorageBufferLength; std::vector workgroupAllocations; + MTLSize localWorkgroupSize; }; - // MTLFunctionConstantValues needs @available tag to compile - // Use id (like void*) in function signature as workaround and do static cast inside - MaybeError CreateFunction(const char* entryPointName, - SingleShaderStage stage, + MaybeError CreateFunction(SingleShaderStage stage, + const ProgrammableStage& programmableStage, const PipelineLayout* layout, MetalFunctionData* out, - id constantValues = nil, uint32_t sampleMask = 0xFFFFFFFF, const RenderPipeline* renderPipeline = nullptr); diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm index 5d4196b8af..b0f59a3b7e 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.mm +++ b/src/dawn/native/metal/ShaderModuleMTL.mm @@ -35,17 +35,20 @@ namespace { using OptionalVertexPullingTransformConfig = std::optional; -#define MSL_COMPILATION_REQUEST_MEMBERS(X) \ - X(const tint::Program*, inputProgram) \ - X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \ - X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \ - X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \ - X(std::string, entryPointName) \ - X(uint32_t, sampleMask) \ - X(bool, emitVertexPointSize) \ - X(bool, isRobustnessEnabled) \ - X(bool, disableSymbolRenaming) \ - X(bool, disableWorkgroupInit) \ +#define MSL_COMPILATION_REQUEST_MEMBERS(X) \ + X(SingleShaderStage, stage) \ + X(const tint::Program*, inputProgram) \ + X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \ + X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \ + X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \ + X(std::optional, substituteOverrideConfig) \ + X(LimitsForCompilationRequest, limits) \ + X(std::string, entryPointName) \ + X(uint32_t, sampleMask) \ + X(bool, emitVertexPointSize) \ + X(bool, isRobustnessEnabled) \ + X(bool, disableSymbolRenaming) \ + X(bool, disableWorkgroupInit) \ X(CacheKey::UnsafeUnkeyedValue, tracePlatform) DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS); @@ -53,12 +56,13 @@ DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS); using WorkgroupAllocations = std::vector; -#define MSL_COMPILATION_MEMBERS(X) \ - X(std::string, msl) \ - X(std::string, remappedEntryPointName) \ - X(bool, needsStorageBufferLength) \ - X(bool, hasInvariantAttribute) \ - X(WorkgroupAllocations, workgroupAllocations) +#define MSL_COMPILATION_MEMBERS(X) \ + X(std::string, msl) \ + X(std::string, remappedEntryPointName) \ + X(bool, needsStorageBufferLength) \ + X(bool, hasInvariantAttribute) \ + X(WorkgroupAllocations, workgroupAllocations) \ + X(Extent3D, localWorkgroupSize) DAWN_SERIALIZABLE(struct, MslCompilation, MSL_COMPILATION_MEMBERS){}; #undef MSL_COMPILATION_MEMBERS @@ -92,13 +96,14 @@ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult, namespace { -ResultOrError> TranslateToMSL(DeviceBase* device, - const tint::Program* inputProgram, - const char* entryPointName, - SingleShaderStage stage, - const PipelineLayout* layout, - uint32_t sampleMask, - const RenderPipeline* renderPipeline) { +ResultOrError> TranslateToMSL( + DeviceBase* device, + const ProgrammableStage& programmableStage, + SingleShaderStage stage, + const PipelineLayout* layout, + ShaderModule::MetalFunctionData* out, + uint32_t sampleMask, + const RenderPipeline* renderPipeline) { ScopedTintICEHandler scopedICEHandler(device); std::ostringstream errorStream; @@ -137,7 +142,7 @@ ResultOrError> TranslateToMSL(DeviceBase* device, if (stage == SingleShaderStage::Vertex && device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) { vertexPullingTransformConfig = BuildVertexPullingTransformConfig( - *renderPipeline, entryPointName, kPullingBufferBindingSet); + *renderPipeline, programmableStage.entryPoint.c_str(), kPullingBufferBindingSet); for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot); @@ -152,12 +157,19 @@ ResultOrError> TranslateToMSL(DeviceBase* device, } } + std::optional substituteOverrideConfig; + if (!programmableStage.metadata->overrides.empty()) { + substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage); + } + MslCompilationRequest req = {}; - req.inputProgram = inputProgram; + req.stage = stage; + req.inputProgram = programmableStage.module->GetTintProgram(); req.bindingPoints = std::move(bindingPoints); req.externalTextureBindings = std::move(externalTextureBindings); req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig); - req.entryPointName = entryPointName; + req.substituteOverrideConfig = std::move(substituteOverrideConfig); + req.entryPointName = programmableStage.entryPoint.c_str(); req.sampleMask = sampleMask; req.emitVertexPointSize = stage == SingleShaderStage::Vertex && @@ -166,6 +178,9 @@ ResultOrError> TranslateToMSL(DeviceBase* device, req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming); req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform()); + const CombinedLimits& limits = device->GetLimits(); + req.limits = LimitsForCompilationRequest::Create(limits.v1); + CacheResult mslCompilation; DAWN_TRY_LOAD_OR_RUN( mslCompilation, device, std::move(req), MslCompilation::FromBlob, @@ -190,6 +205,14 @@ ResultOrError> TranslateToMSL(DeviceBase* device, std::move(r.vertexPullingTransformConfig).value()); } + if (r.substituteOverrideConfig) { + // This needs to run after SingleEntryPoint transform to get rid of overrides not + // used for the current entry point. + transformManager.Add(); + transformInputs.Add( + std::move(r.substituteOverrideConfig).value()); + } + if (r.isRobustnessEnabled) { transformManager.Add(); } @@ -230,6 +253,13 @@ ResultOrError> TranslateToMSL(DeviceBase* device, return DAWN_VALIDATION_ERROR("Transform output missing renamer data."); } + Extent3D localSize{0, 0, 0}; + if (r.stage == SingleShaderStage::Compute) { + // Validate workgroup size after program runs transforms. + DAWN_TRY_ASSIGN(localSize, ValidateComputeStageWorkgroupSize( + program, remappedEntryPointName.data(), r.limits)); + } + tint::writer::msl::Options options; options.buffer_size_ubo_index = kBufferLengthBufferSlot; options.fixed_sample_mask = r.sampleMask; @@ -258,6 +288,7 @@ ResultOrError> TranslateToMSL(DeviceBase* device, result.needs_storage_buffer_sizes, result.has_invariant_attribute, std::move(workgroupAllocations), + localSize, }}; }); @@ -272,11 +303,10 @@ ResultOrError> TranslateToMSL(DeviceBase* device, } // namespace -MaybeError ShaderModule::CreateFunction(const char* entryPointName, - SingleShaderStage stage, +MaybeError ShaderModule::CreateFunction(SingleShaderStage stage, + const ProgrammableStage& programmableStage, const PipelineLayout* layout, ShaderModule::MetalFunctionData* out, - id constantValuesPointer, uint32_t sampleMask, const RenderPipeline* renderPipeline) { TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleMTL::CreateFunction"); @@ -284,16 +314,21 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName, ASSERT(!IsError()); ASSERT(out); + const char* entryPointName = programmableStage.entryPoint.c_str(); + // Vertex stages must specify a renderPipeline if (stage == SingleShaderStage::Vertex) { ASSERT(renderPipeline != nullptr); } CacheResult mslCompilation; - DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), GetTintProgram(), entryPointName, - stage, layout, sampleMask, renderPipeline)); + DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), programmableStage, stage, layout, + out, sampleMask, renderPipeline)); out->needsStorageBufferLength = mslCompilation->needsStorageBufferLength; out->workgroupAllocations = std::move(mslCompilation->workgroupAllocations); + out->localWorkgroupSize = MTLSizeMake(mslCompilation->localWorkgroupSize.width, + mslCompilation->localWorkgroupSize.height, + mslCompilation->localWorkgroupSize.depthOrArrayLayers); NSRef mslSource = AcquireNSRef([[NSString alloc] initWithUTF8String:mslCompilation->msl.c_str()]); @@ -327,25 +362,7 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName, { TRACE_EVENT0(GetDevice()->GetPlatform(), General, "MTLLibrary::newFunctionWithName"); - if (constantValuesPointer != nil) { - if (@available(macOS 10.12, *)) { - MTLFunctionConstantValues* constantValues = constantValuesPointer; - out->function = AcquireNSPRef([*library newFunctionWithName:name.Get() - constantValues:constantValues - error:&error]); - if (error != nullptr) { - if (error.code != MTLLibraryErrorCompileWarning) { - return DAWN_VALIDATION_ERROR("Function compile error: %s", - [error.localizedDescription UTF8String]); - } - } - ASSERT(out->function != nil); - } else { - UNREACHABLE(); - } - } else { - out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]); - } + out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]); } if (BlobCache* cache = GetDevice()->GetBlobCache()) { diff --git a/src/dawn/native/metal/UtilsMetal.h b/src/dawn/native/metal/UtilsMetal.h index 9ee31bd618..418c4a6864 100644 --- a/src/dawn/native/metal/UtilsMetal.h +++ b/src/dawn/native/metal/UtilsMetal.h @@ -68,15 +68,6 @@ void EnsureDestinationTextureInitialized(CommandRecordingContext* commandContext MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect); -// Helper function to create function with constant values wrapped in -// if available branch -MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage, - SingleShaderStage singleShaderStage, - PipelineLayout* pipelineLayout, - ShaderModule::MetalFunctionData* functionData, - uint32_t sampleMask = 0xFFFFFFFF, - const RenderPipeline* renderPipeline = nullptr); - // Allow use MTLStoreActionStoreAndMultismapleResolve because the logic in the backend is // first to compute what the "best" Metal render pass descriptor is, then fix it up if we // are not on macOS 10.12 (i.e. the EmulateStoreAndMSAAResolve toggle is on). diff --git a/src/dawn/native/metal/UtilsMetal.mm b/src/dawn/native/metal/UtilsMetal.mm index df122f3e3e..4eb5fb6e81 100644 --- a/src/dawn/native/metal/UtilsMetal.mm +++ b/src/dawn/native/metal/UtilsMetal.mm @@ -323,104 +323,6 @@ MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect) { return MTLBlitOptionNone; } -MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage, - SingleShaderStage singleShaderStage, - PipelineLayout* pipelineLayout, - ShaderModule::MetalFunctionData* functionData, - uint32_t sampleMask, - const RenderPipeline* renderPipeline) { - ShaderModule* shaderModule = ToBackend(programmableStage.module.Get()); - const char* shaderEntryPoint = programmableStage.entryPoint.c_str(); - const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint); - if (entryPointMetadata.overrides.size() == 0) { - DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout, - functionData, nil, sampleMask, renderPipeline)); - return {}; - } - - if (@available(macOS 10.12, *)) { - // MTLFunctionConstantValues can only be created within the if available branch - NSRef constantValues = - AcquireNSRef([MTLFunctionConstantValues new]); - - std::unordered_set overriddenConstants; - - auto switchType = [&](EntryPointMetadata::Override::Type dawnType, - MTLDataType* type, OverrideScalar* entry, - double value = 0) { - switch (dawnType) { - case EntryPointMetadata::Override::Type::Boolean: - *type = MTLDataTypeBool; - if (entry) { - entry->b = static_cast(value); - } - break; - case EntryPointMetadata::Override::Type::Float32: - *type = MTLDataTypeFloat; - if (entry) { - entry->f32 = static_cast(value); - } - break; - case EntryPointMetadata::Override::Type::Int32: - *type = MTLDataTypeInt; - if (entry) { - entry->i32 = static_cast(value); - } - break; - case EntryPointMetadata::Override::Type::Uint32: - *type = MTLDataTypeUInt; - if (entry) { - entry->u32 = static_cast(value); - } - break; - default: - UNREACHABLE(); - } - }; - - for (const auto& [name, value] : programmableStage.constants) { - overriddenConstants.insert(name); - - // This is already validated so `name` must exist - const auto& moduleConstant = entryPointMetadata.overrides.at(name); - - MTLDataType type; - OverrideScalar entry{}; - - switchType(moduleConstant.type, &type, &entry, value); - - [constantValues.Get() setConstantValue:&entry type:type atIndex:moduleConstant.id]; - } - - // Set shader initialized default values because MSL function_constant - // has no default value - for (const std::string& name : entryPointMetadata.initializedOverrides) { - if (overriddenConstants.count(name) != 0) { - // This constant already has overridden value - continue; - } - - // Must exist because it is validated - const auto& moduleConstant = entryPointMetadata.overrides.at(name); - ASSERT(moduleConstant.isInitialized); - MTLDataType type; - - switchType(moduleConstant.type, &type, nullptr); - - [constantValues.Get() setConstantValue:&moduleConstant.defaultValue - type:type - atIndex:moduleConstant.id]; - } - - DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout, - functionData, constantValues.Get(), sampleMask, - renderPipeline)); - } else { - UNREACHABLE(); - } - return {}; -} - MaybeError EncodeMetalRenderPass(Device* device, CommandRecordingContext* commandContext, MTLRenderPassDescriptor* mtlRenderPass, diff --git a/src/dawn/native/stream/Stream.cpp b/src/dawn/native/stream/Stream.cpp index beb3823da0..eba4f5386c 100644 --- a/src/dawn/native/stream/Stream.cpp +++ b/src/dawn/native/stream/Stream.cpp @@ -16,6 +16,8 @@ #include +#include "dawn/native/Limits.h" + namespace dawn::native::stream { template <> diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp index ad6fbf8410..ea8c59a6bf 100644 --- a/src/dawn/native/vulkan/ComputePipelineVk.cpp +++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp @@ -61,16 +61,12 @@ MaybeError ComputePipeline::Initialize() { ShaderModule::ModuleAndSpirv moduleAndSpirv; DAWN_TRY_ASSIGN(moduleAndSpirv, - module->GetHandleAndSpirv(computeStage.entryPoint.c_str(), layout)); + module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout)); createInfo.stage.module = moduleAndSpirv.module; createInfo.stage.pName = computeStage.entryPoint.c_str(); - std::vector specializationDataEntries; - std::vector specializationMapEntries; - VkSpecializationInfo specializationInfo{}; - createInfo.stage.pSpecializationInfo = GetVkSpecializationInfo( - computeStage, &specializationInfo, &specializationDataEntries, &specializationMapEntries); + createInfo.stage.pSpecializationInfo = nullptr; PNextChainBuilder stageExtChain(&createInfo.stage); diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp index 372abb515b..830e012437 100644 --- a/src/dawn/native/vulkan/RenderPipelineVk.cpp +++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp @@ -341,9 +341,6 @@ MaybeError RenderPipeline::Initialize() { // There are at most 2 shader stages in render pipeline, i.e. vertex and fragment std::array shaderStages; - std::array, 2> specializationDataEntriesPerStages; - std::array, 2> specializationMapEntriesPerStages; - std::array specializationInfoPerStages; uint32_t stageCount = 0; for (auto stage : IterateStages(this->GetStageMask())) { @@ -354,7 +351,7 @@ MaybeError RenderPipeline::Initialize() { ShaderModule::ModuleAndSpirv moduleAndSpirv; DAWN_TRY_ASSIGN(moduleAndSpirv, - module->GetHandleAndSpirv(programmableStage.entryPoint.c_str(), layout)); + module->GetHandleAndSpirv(stage, programmableStage, layout)); shaderStage.module = moduleAndSpirv.module; shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; @@ -379,11 +376,6 @@ MaybeError RenderPipeline::Initialize() { } } - shaderStage.pSpecializationInfo = - GetVkSpecializationInfo(programmableStage, &specializationInfoPerStages[stageCount], - &specializationDataEntriesPerStages[stageCount], - &specializationMapEntriesPerStages[stageCount]); - DAWN_ASSERT(stageCount < 2); shaderStages[stageCount] = shaderStage; stageCount++; diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp index 5477d84784..6588627d96 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.cpp +++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp @@ -60,13 +60,35 @@ class ShaderModule::Spirv : private Blob { namespace dawn::native::vulkan { +bool TransformedShaderModuleCacheKey::operator==( + const TransformedShaderModuleCacheKey& other) const { + if (layout != other.layout || entryPoint != other.entryPoint || + constants.size() != other.constants.size()) { + return false; + } + if (!std::equal(constants.begin(), constants.end(), other.constants.begin())) { + return false; + } + return true; +} + +size_t TransformedShaderModuleCacheKeyHashFunc::operator()( + const TransformedShaderModuleCacheKey& key) const { + size_t hash = 0; + HashCombine(&hash, key.layout, key.entryPoint); + for (const auto& entry : key.constants) { + HashCombine(&hash, entry.first, entry.second); + } + return hash; +} + class ShaderModule::ConcurrentTransformedShaderModuleCache { public: explicit ConcurrentTransformedShaderModuleCache(Device* device); ~ConcurrentTransformedShaderModuleCache(); - std::optional Find(const PipelineLayoutEntryPointPair& key); - ModuleAndSpirv AddOrGet(const PipelineLayoutEntryPointPair& key, + std::optional Find(const TransformedShaderModuleCacheKey& key); + ModuleAndSpirv AddOrGet(const TransformedShaderModuleCacheKey& key, VkShaderModule module, Spirv&& spirv); @@ -75,7 +97,9 @@ class ShaderModule::ConcurrentTransformedShaderModuleCache { Device* mDevice; std::mutex mMutex; - std::unordered_map + std::unordered_map mTransformedShaderModuleCache; }; @@ -92,7 +116,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::~ConcurrentTransformedShad std::optional ShaderModule::ConcurrentTransformedShaderModuleCache::Find( - const PipelineLayoutEntryPointPair& key) { + const TransformedShaderModuleCacheKey& key) { std::lock_guard lock(mMutex); auto iter = mTransformedShaderModuleCache.find(key); if (iter != mTransformedShaderModuleCache.end()) { @@ -106,7 +130,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::Find( } ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet( - const PipelineLayoutEntryPointPair& key, + const TransformedShaderModuleCacheKey& key, VkShaderModule module, Spirv&& spirv) { ASSERT(module != VK_NULL_HANDLE); @@ -168,20 +192,24 @@ void ShaderModule::DestroyImpl() { ShaderModule::~ShaderModule() = default; -#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \ - X(const tint::Program*, inputProgram) \ - X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \ - X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \ - X(std::string_view, entryPointName) \ - X(bool, disableWorkgroupInit) \ - X(bool, useZeroInitializeWorkgroupMemoryExtension) \ +#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \ + X(SingleShaderStage, stage) \ + X(const tint::Program*, inputProgram) \ + X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \ + X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \ + X(std::optional, substituteOverrideConfig) \ + X(LimitsForCompilationRequest, limits) \ + X(std::string_view, entryPointName) \ + X(bool, disableWorkgroupInit) \ + X(bool, useZeroInitializeWorkgroupMemoryExtension) \ X(CacheKey::UnsafeUnkeyedValue, tracePlatform) DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS); #undef SPIRV_COMPILATION_REQUEST_MEMBERS ResultOrError ShaderModule::GetHandleAndSpirv( - const char* entryPointName, + SingleShaderStage stage, + const ProgrammableStage& programmableStage, const PipelineLayout* layout) { TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv"); @@ -191,7 +219,8 @@ ResultOrError ShaderModule::GetHandleAndSpirv( ScopedTintICEHandler scopedICEHandler(GetDevice()); // Check to see if we have the handle and spirv cached already. - auto cacheKey = std::make_pair(layout, entryPointName); + auto cacheKey = TransformedShaderModuleCacheKey{layout, programmableStage.entryPoint.c_str(), + programmableStage.constants}; auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey); if (handleAndSpirv.has_value()) { return std::move(*handleAndSpirv); @@ -204,7 +233,8 @@ ResultOrError ShaderModule::GetHandleAndSpirv( using BindingPoint = tint::transform::BindingPoint; BindingRemapper::BindingPoints bindingPoints; - const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings; + const BindingInfoArray& moduleBindingInfo = + GetEntryPoint(programmableStage.entryPoint.c_str()).bindings; for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); @@ -238,16 +268,26 @@ ResultOrError ShaderModule::GetHandleAndSpirv( } } + std::optional substituteOverrideConfig; + if (!programmableStage.metadata->overrides.empty()) { + substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage); + } + #if TINT_BUILD_SPV_WRITER SpirvCompilationRequest req = {}; + req.stage = stage; req.inputProgram = GetTintProgram(); req.bindingPoints = std::move(bindingPoints); req.newBindingsMap = std::move(newBindingsMap); - req.entryPointName = entryPointName; + req.entryPointName = programmableStage.entryPoint; req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit); req.useZeroInitializeWorkgroupMemoryExtension = GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension); req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform()); + req.substituteOverrideConfig = std::move(substituteOverrideConfig); + + const CombinedLimits& limits = GetDevice()->GetLimits(); + req.limits = LimitsForCompilationRequest::Create(limits.v1); CacheResult spirv; DAWN_TRY_LOAD_OR_RUN( @@ -270,12 +310,27 @@ ResultOrError ShaderModule::GetHandleAndSpirv( transformInputs.Add( r.newBindingsMap); } + if (r.substituteOverrideConfig) { + // This needs to run after SingleEntryPoint transform to get rid of overrides not + // used for the current entry point. + transformManager.Add(); + transformInputs.Add( + std::move(r.substituteOverrideConfig).value()); + } tint::Program program; { TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms"); DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram, transformInputs, nullptr, nullptr)); } + + if (r.stage == SingleShaderStage::Compute) { + // Validate workgroup size after program runs transforms. + Extent3D _; + DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize( + program, r.entryPointName.data(), r.limits)); + } + tint::writer::spirv::Options options; options.emit_vertex_point_size = true; options.disable_workgroup_init = r.disableWorkgroupInit; diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h index c0741eff6e..7c6a5dff46 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.h +++ b/src/dawn/native/vulkan/ShaderModuleVk.h @@ -18,14 +18,32 @@ #include #include #include +#include #include #include +#include "dawn/common/HashUtils.h" #include "dawn/common/vulkan_platform.h" #include "dawn/native/Error.h" #include "dawn/native/ShaderModule.h" -namespace dawn::native::vulkan { +namespace dawn::native { + +struct ProgrammableStage; + +namespace vulkan { + +struct TransformedShaderModuleCacheKey { + const PipelineLayoutBase* layout; + std::string entryPoint; + PipelineConstantEntries constants; + + bool operator==(const TransformedShaderModuleCacheKey& other) const; +}; + +struct TransformedShaderModuleCacheKeyHashFunc { + size_t operator()(const TransformedShaderModuleCacheKey& key) const; +}; class Device; class PipelineLayout; @@ -44,7 +62,8 @@ class ShaderModule final : public ShaderModuleBase { ShaderModuleParseResult* parseResult, OwnedCompilationMessages* compilationMessages); - ResultOrError GetHandleAndSpirv(const char* entryPointName, + ResultOrError GetHandleAndSpirv(SingleShaderStage stage, + const ProgrammableStage& programmableStage, const PipelineLayout* layout); private: @@ -59,6 +78,8 @@ class ShaderModule final : public ShaderModuleBase { std::unique_ptr mTransformedShaderModuleCache; }; -} // namespace dawn::native::vulkan +} // namespace vulkan + +} // namespace dawn::native #endif // SRC_DAWN_NATIVE_VULKAN_SHADERMODULEVK_H_ diff --git a/src/dawn/native/vulkan/UtilsVulkan.cpp b/src/dawn/native/vulkan/UtilsVulkan.cpp index ca8b4052ca..aa02ec35e5 100644 --- a/src/dawn/native/vulkan/UtilsVulkan.cpp +++ b/src/dawn/native/vulkan/UtilsVulkan.cpp @@ -258,60 +258,4 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName) { return std::string(debugName, length); } -VkSpecializationInfo* GetVkSpecializationInfo( - const ProgrammableStage& programmableStage, - VkSpecializationInfo* specializationInfo, - std::vector* specializationDataEntries, - std::vector* specializationMapEntries) { - ASSERT(specializationInfo); - ASSERT(specializationDataEntries); - ASSERT(specializationMapEntries); - - if (programmableStage.constants.size() == 0) { - return nullptr; - } - - const EntryPointMetadata& entryPointMetaData = - programmableStage.module->GetEntryPoint(programmableStage.entryPoint); - - for (const auto& pipelineConstant : programmableStage.constants) { - const std::string& identifier = pipelineConstant.first; - double value = pipelineConstant.second; - - // This is already validated so `identifier` must exist - const auto& moduleConstant = entryPointMetaData.overrides.at(identifier); - - specializationMapEntries->push_back(VkSpecializationMapEntry{ - moduleConstant.id, - static_cast(specializationDataEntries->size() * sizeof(OverrideScalar)), - sizeof(OverrideScalar)}); - - OverrideScalar entry{}; - switch (moduleConstant.type) { - case EntryPointMetadata::Override::Type::Boolean: - entry.b = static_cast(value); - break; - case EntryPointMetadata::Override::Type::Float32: - entry.f32 = static_cast(value); - break; - case EntryPointMetadata::Override::Type::Int32: - entry.i32 = static_cast(value); - break; - case EntryPointMetadata::Override::Type::Uint32: - entry.u32 = static_cast(value); - break; - default: - UNREACHABLE(); - } - specializationDataEntries->push_back(entry); - } - - specializationInfo->mapEntryCount = static_cast(specializationMapEntries->size()); - specializationInfo->pMapEntries = specializationMapEntries->data(); - specializationInfo->dataSize = specializationDataEntries->size() * sizeof(OverrideScalar); - specializationInfo->pData = specializationDataEntries->data(); - - return specializationInfo; -} - } // namespace dawn::native::vulkan diff --git a/src/dawn/native/vulkan/UtilsVulkan.h b/src/dawn/native/vulkan/UtilsVulkan.h index 3e2748e173..5ccdf0adf6 100644 --- a/src/dawn/native/vulkan/UtilsVulkan.h +++ b/src/dawn/native/vulkan/UtilsVulkan.h @@ -144,15 +144,6 @@ void SetDebugName(Device* device, std::string GetNextDeviceDebugPrefix(); std::string GetDeviceDebugPrefixFromDebugName(const char* debugName); -// Returns nullptr or &specializationInfo -// specializationInfo, specializationDataEntries, specializationMapEntries needs to -// be alive at least until VkSpecializationInfo is passed into Vulkan Create*Pipelines -VkSpecializationInfo* GetVkSpecializationInfo( - const ProgrammableStage& programmableStage, - VkSpecializationInfo* specializationInfo, - std::vector* specializationDataEntries, - std::vector* specializationMapEntries); - } // namespace dawn::native::vulkan #endif // SRC_DAWN_NATIVE_VULKAN_UTILSVULKAN_H_ diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn index eb1ed6c186..2955a8b94a 100644 --- a/src/dawn/tests/BUILD.gn +++ b/src/dawn/tests/BUILD.gn @@ -265,6 +265,7 @@ dawn_test("dawn_unittests") { "unittests/native/CreatePipelineAsyncTaskTests.cpp", "unittests/native/DestroyObjectTests.cpp", "unittests/native/DeviceCreationTests.cpp", + "unittests/native/ObjectContentHasherTests.cpp", "unittests/native/StreamTests.cpp", "unittests/validation/BindGroupValidationTests.cpp", "unittests/validation/BufferValidationTests.cpp", @@ -489,6 +490,7 @@ source_set("end2end_tests_sources") { "end2end/ScissorTests.cpp", "end2end/ShaderFloat16Tests.cpp", "end2end/ShaderTests.cpp", + "end2end/ShaderValidationTests.cpp", "end2end/StorageTextureTests.cpp", "end2end/SubresourceRenderAttachmentTests.cpp", "end2end/Texture3DTests.cpp", diff --git a/src/dawn/tests/end2end/ObjectCachingTests.cpp b/src/dawn/tests/end2end/ObjectCachingTests.cpp index cda1fca1ea..aeebd30c3e 100644 --- a/src/dawn/tests/end2end/ObjectCachingTests.cpp +++ b/src/dawn/tests/end2end/ObjectCachingTests.cpp @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "dawn/tests/DawnTest.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h" @@ -158,6 +160,46 @@ TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) { EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); } +// Test that ComputePipeline are correctly deduplicated wrt. their constants override values +TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnOverrides) { + wgpu::ShaderModule module = utils::CreateShaderModule(device, R"( + override x: u32 = 1u; + var i : u32; + @compute @workgroup_size(x) fn main() { + i = 0u; + })"); + + wgpu::PipelineLayout layout = utils::MakeBasicPipelineLayout(device, nullptr); + + wgpu::ComputePipelineDescriptor desc; + desc.compute.entryPoint = "main"; + desc.layout = layout; + desc.compute.module = module; + + std::vector constants{{nullptr, "x", 16}}; + desc.compute.constantCount = constants.size(); + desc.compute.constants = constants.data(); + wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&desc); + + std::vector sameConstants{{nullptr, "x", 16}}; + desc.compute.constantCount = sameConstants.size(); + desc.compute.constants = sameConstants.data(); + wgpu::ComputePipeline samePipeline = device.CreateComputePipeline(&desc); + + desc.compute.constantCount = 0; + desc.compute.constants = nullptr; + wgpu::ComputePipeline otherPipeline1 = device.CreateComputePipeline(&desc); + + std::vector otherConstants{{nullptr, "x", 4}}; + desc.compute.constantCount = otherConstants.size(); + desc.compute.constants = otherConstants.data(); + wgpu::ComputePipeline otherPipeline2 = device.CreateComputePipeline(&desc); + + EXPECT_NE(pipeline.Get(), otherPipeline1.Get()); + EXPECT_NE(pipeline.Get(), otherPipeline2.Get()); + EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); +} + // Test that ComputePipeline are correctly deduplicated wrt. their layout TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) { wgpu::BindGroupLayout bgl = utils::MakeBindGroupLayout( @@ -303,6 +345,48 @@ TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnFragmentModule) { EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); } +// Test that Renderpipelines are correctly deduplicated wrt. their constants override values +TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnOverrides) { + wgpu::ShaderModule module = utils::CreateShaderModule(device, R"( + override a: f32 = 1.0; + @vertex fn vertexMain() -> @builtin(position) vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); + } + @fragment fn fragmentMain() -> @location(0) vec4 { + return vec4(0.0, 0.0, 0.0, a); + })"); + + utils::ComboRenderPipelineDescriptor desc; + desc.vertex.module = module; + desc.vertex.entryPoint = "vertexMain"; + desc.cFragment.module = module; + desc.cFragment.entryPoint = "fragmentMain"; + desc.cTargets[0].writeMask = wgpu::ColorWriteMask::None; + + std::vector constants{{nullptr, "a", 0.5}}; + desc.cFragment.constantCount = constants.size(); + desc.cFragment.constants = constants.data(); + wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&desc); + + std::vector sameConstants{{nullptr, "a", 0.5}}; + desc.cFragment.constantCount = sameConstants.size(); + desc.cFragment.constants = sameConstants.data(); + wgpu::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc); + + std::vector otherConstants{{nullptr, "a", 1.0}}; + desc.cFragment.constantCount = otherConstants.size(); + desc.cFragment.constants = otherConstants.data(); + wgpu::RenderPipeline otherPipeline1 = device.CreateRenderPipeline(&desc); + + desc.cFragment.constantCount = 0; + desc.cFragment.constants = nullptr; + wgpu::RenderPipeline otherPipeline2 = device.CreateRenderPipeline(&desc); + + EXPECT_NE(pipeline.Get(), otherPipeline1.Get()); + EXPECT_NE(pipeline.Get(), otherPipeline2.Get()); + EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); +} + // Test that Samplers are correctly deduplicated. TEST_P(ObjectCachingTest, SamplerDeduplication) { wgpu::SamplerDescriptor samplerDesc; diff --git a/src/dawn/tests/end2end/ShaderTests.cpp b/src/dawn/tests/end2end/ShaderTests.cpp index 82159953ae..45ba077516 100644 --- a/src/dawn/tests/end2end/ShaderTests.cpp +++ b/src/dawn/tests/end2end/ShaderTests.cpp @@ -471,6 +471,127 @@ struct Buf { EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount); } +// Test one shader shared by two pipelines with different constants overridden +TEST_P(ShaderTests, OverridableConstantsSharedShader) { + DAWN_TEST_UNSUPPORTED_IF(IsOpenGL()); + DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES()); + + std::vector expected1{1}; + wgpu::Buffer buffer1 = CreateBuffer(expected1.size()); + std::vector expected2{2}; + wgpu::Buffer buffer2 = CreateBuffer(expected2.size()); + + std::string shader = R"( +override a: u32; + +struct Buf { + data : array +} + +@group(0) @binding(0) var buf : Buf; + +@compute @workgroup_size(1) fn main() { + buf.data[0] = a; +})"; + + std::vector constants1; + constants1.push_back({nullptr, "a", 1}); + std::vector constants2; + constants2.push_back({nullptr, "a", 2}); + + wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1); + wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2); + + wgpu::BindGroup bindGroup1 = + utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}}); + wgpu::BindGroup bindGroup2 = + utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}}); + + wgpu::CommandBuffer commands; + { + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(pipeline1); + pass.SetBindGroup(0, bindGroup1); + pass.DispatchWorkgroups(1); + pass.SetPipeline(pipeline2); + pass.SetBindGroup(0, bindGroup2); + pass.DispatchWorkgroups(1); + pass.End(); + + commands = encoder.Finish(); + } + + queue.Submit(1, &commands); + + EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size()); + EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size()); +} + +// Test overridable constants work with workgroup size +TEST_P(ShaderTests, OverridableConstantsWorkgroupSize) { + DAWN_TEST_UNSUPPORTED_IF(IsOpenGL()); + DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES()); + + std::string shader = R"( +override x: u32; + +struct Buf { + data : array +} + +@group(0) @binding(0) var buf : Buf; + +@compute @workgroup_size(x) fn main( + @builtin(local_invocation_id) local_invocation_id : vec3 +) { + if (local_invocation_id.x >= x - 1) { + buf.data[0] = local_invocation_id.x + 1; + } +})"; + + const uint32_t workgroup_size_x_1 = 16u; + const uint32_t workgroup_size_x_2 = 64u; + + std::vector expected1{workgroup_size_x_1}; + wgpu::Buffer buffer1 = CreateBuffer(expected1.size()); + std::vector expected2{workgroup_size_x_2}; + wgpu::Buffer buffer2 = CreateBuffer(expected2.size()); + + std::vector constants1; + constants1.push_back({nullptr, "x", static_cast(workgroup_size_x_1)}); + std::vector constants2; + constants2.push_back({nullptr, "x", static_cast(workgroup_size_x_2)}); + + wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1); + wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2); + + wgpu::BindGroup bindGroup1 = + utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}}); + wgpu::BindGroup bindGroup2 = + utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}}); + + wgpu::CommandBuffer commands; + { + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(pipeline1); + pass.SetBindGroup(0, bindGroup1); + pass.DispatchWorkgroups(1); + pass.SetPipeline(pipeline2); + pass.SetBindGroup(0, bindGroup2); + pass.DispatchWorkgroups(1); + pass.End(); + + commands = encoder.Finish(); + } + + queue.Submit(1, &commands); + + EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size()); + EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size()); +} + // Test overridable constants with numeric identifiers TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) { DAWN_TEST_UNSUPPORTED_IF(IsOpenGL()); @@ -596,6 +717,7 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) { std::string shader = R"( @id(1001) override c1: u32; @id(1002) override c2: u32; +@id(1003) override c3: u32; struct Buf { data : array @@ -611,7 +733,7 @@ struct Buf { buf.data[0] = c2; } -@compute @workgroup_size(1) fn main3() { +@compute @workgroup_size(c3) fn main3() { buf.data[0] = 3u; } )"; @@ -620,6 +742,8 @@ struct Buf { constants1.push_back({nullptr, "1001", 1}); std::vector constants2; constants2.push_back({nullptr, "1002", 2}); + std::vector constants3; + constants3.push_back({nullptr, "1003", 1}); wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader.c_str()); @@ -640,6 +764,8 @@ struct Buf { wgpu::ComputePipelineDescriptor csDesc3; csDesc3.compute.module = shaderModule; csDesc3.compute.entryPoint = "main3"; + csDesc3.compute.constants = constants3.data(); + csDesc3.compute.constantCount = constants3.size(); wgpu::ComputePipeline pipeline3 = device.CreateComputePipeline(&csDesc3); wgpu::BindGroup bindGroup1 = @@ -765,8 +891,6 @@ TEST_P(ShaderTests, ConflictingBindingsDueToTransformOrder) { device.CreateRenderPipeline(&desc); } -// TODO(tint:1155): Test overridable constants used for workgroup size - DAWN_INSTANTIATE_TEST(ShaderTests, D3D12Backend(), MetalBackend(), diff --git a/src/dawn/tests/end2end/ShaderValidationTests.cpp b/src/dawn/tests/end2end/ShaderValidationTests.cpp new file mode 100644 index 0000000000..c8b03f31db --- /dev/null +++ b/src/dawn/tests/end2end/ShaderValidationTests.cpp @@ -0,0 +1,395 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "dawn/tests/DawnTest.h" +#include "dawn/utils/ComboRenderPipelineDescriptor.h" +#include "dawn/utils/WGPUHelpers.h" + +// The compute shader workgroup size is settled at compute pipeline creation time. +// The validation code in dawn is in each backend (not including Null backend) thus this test needs +// to be as part of a dawn_end2end_tests instead of the dawn_unittests +// TODO(dawn:1504): Add support for GL backend. +class WorkgroupSizeValidationTest : public DawnTest { + public: + wgpu::ShaderModule SetUpShaderWithValidDefaultValueConstants() { + return utils::CreateShaderModule(device, R"( +override x: u32 = 1u; +override y: u32 = 1u; +override z: u32 = 1u; + +@compute @workgroup_size(x, y, z) fn main() { + _ = 0u; +})"); + } + + wgpu::ShaderModule SetUpShaderWithZeroDefaultValueConstants() { + return utils::CreateShaderModule(device, R"( +override x: u32 = 0u; +override y: u32 = 0u; +override z: u32 = 0u; + +@compute @workgroup_size(x, y, z) fn main() { + _ = 0u; +})"); + } + + wgpu::ShaderModule SetUpShaderWithOutOfLimitsDefaultValueConstants() { + return utils::CreateShaderModule(device, R"( +override x: u32 = 1u; +override y: u32 = 1u; +override z: u32 = 9999u; + +@compute @workgroup_size(x, y, z) fn main() { + _ = 0u; +})"); + } + + wgpu::ShaderModule SetUpShaderWithUninitializedConstants() { + return utils::CreateShaderModule(device, R"( +override x: u32; +override y: u32; +override z: u32; + +@compute @workgroup_size(x, y, z) fn main() { + _ = 0u; +})"); + } + + wgpu::ShaderModule SetUpShaderWithPartialConstants() { + return utils::CreateShaderModule(device, R"( +override x: u32; + +@compute @workgroup_size(x, 1, 1) fn main() { + _ = 0u; +})"); + } + + void TestCreatePipeline(const wgpu::ShaderModule& module) { + wgpu::ComputePipelineDescriptor csDesc; + csDesc.compute.module = module; + csDesc.compute.entryPoint = "main"; + csDesc.compute.constants = nullptr; + csDesc.compute.constantCount = 0; + wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc); + } + + void TestCreatePipeline(const wgpu::ShaderModule& module, + const std::vector& constants) { + wgpu::ComputePipelineDescriptor csDesc; + csDesc.compute.module = module; + csDesc.compute.entryPoint = "main"; + csDesc.compute.constants = constants.data(); + csDesc.compute.constantCount = constants.size(); + wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc); + } + + void TestInitializedWithZero(const wgpu::ShaderModule& module) { + std::vector constants{ + {nullptr, "x", 0}, {nullptr, "y", 0}, {nullptr, "z", 0}}; + TestCreatePipeline(module, constants); + } + + void TestInitializedWithOutOfLimitValue(const wgpu::ShaderModule& module) { + std::vector constants{ + {nullptr, "x", + static_cast(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)}, + {nullptr, "y", 1}, + {nullptr, "z", 1}}; + TestCreatePipeline(module, constants); + } + + void TestInitializedWithValidValue(const wgpu::ShaderModule& module) { + std::vector constants{ + {nullptr, "x", 4}, {nullptr, "y", 4}, {nullptr, "z", 4}}; + TestCreatePipeline(module, constants); + } + + void TestInitializedPartially(const wgpu::ShaderModule& module) { + std::vector constants{{nullptr, "y", 4}}; + TestCreatePipeline(module, constants); + } + + wgpu::Buffer buffer; +}; + +// Test workgroup size validation with fixed values. +TEST_P(WorkgroupSizeValidationTest, WithFixedValues) { + auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) { + std::ostringstream ss; + ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}"; + + wgpu::ComputePipelineDescriptor desc; + desc.compute.entryPoint = "main"; + desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); + + if (success) { + device.CreateComputePipeline(&desc); + } else { + ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); + } + }; + + wgpu::Limits supportedLimits = GetSupportedLimits().limits; + + CheckShaderWithWorkgroupSize(true, 1, 1, 1); + CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1); + CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1); + CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ); + + CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1); + CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1); + CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1); + + // No individual dimension exceeds its limit, but the combined size should definitely exceed the + // total invocation limit. + DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX * + supportedLimits.maxComputeWorkgroupSizeY * + supportedLimits.maxComputeWorkgroupSizeZ > + supportedLimits.maxComputeInvocationsPerWorkgroup); + CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX, + supportedLimits.maxComputeWorkgroupSizeY, + supportedLimits.maxComputeWorkgroupSizeZ); +} + +// Test workgroup size validation with fixed values (storage size limits validation). +TEST_P(WorkgroupSizeValidationTest, WithFixedValuesStorageSizeLimits) { + wgpu::Limits supportedLimits = GetSupportedLimits().limits; + + constexpr uint32_t kVec4Size = 16; + const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size; + constexpr uint32_t kMat4Size = 64; + const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size; + + auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count, + uint32_t mat4_count) { + std::ostringstream ss; + std::ostringstream body; + if (vec4_count > 0) { + ss << "var vec4_data: array, " << vec4_count << ">;"; + body << "_ = vec4_data;"; + } + if (mat4_count > 0) { + ss << "var mat4_data: array, " << mat4_count << ">;"; + body << "_ = mat4_data;"; + } + ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }"; + + wgpu::ComputePipelineDescriptor desc; + desc.compute.entryPoint = "main"; + desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); + + if (success) { + device.CreateComputePipeline(&desc); + } else { + ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); + } + }; + + CheckPipelineWithWorkgroupStorage(true, 1, 1); + CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0); + CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count); + CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1); + CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1); + + CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0); + CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1); + CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1); + CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count); +} + +// Test workgroup size validation with valid overrides default values. +TEST_P(WorkgroupSizeValidationTest, OverridesWithValidDefault) { + wgpu::ShaderModule module = SetUpShaderWithValidDefaultValueConstants(); + { + // Valid default + TestCreatePipeline(module); + } + { + // Error: invalid value (zero) + ASSERT_DEVICE_ERROR(TestInitializedWithZero(module)); + } + { + // Error: invalid value (out of device limits) + ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module)); + } + { + // Valid: initialized partially + TestInitializedPartially(module); + } + { + // Valid + TestInitializedWithValidValue(module); + } +} + +// Test workgroup size validation with zero as the overrides default values. +TEST_P(WorkgroupSizeValidationTest, OverridesWithZeroDefault) { + // Error: zero is detected as invalid at shader creation time + ASSERT_DEVICE_ERROR(SetUpShaderWithZeroDefaultValueConstants()); +} + +// Test workgroup size validation with out-of-limits overrides default values. +TEST_P(WorkgroupSizeValidationTest, OverridesWithOutOfLimitsDefault) { + wgpu::ShaderModule module = SetUpShaderWithOutOfLimitsDefaultValueConstants(); + { + // Error: invalid default + ASSERT_DEVICE_ERROR(TestCreatePipeline(module)); + } + { + // Error: invalid value (zero) + ASSERT_DEVICE_ERROR(TestInitializedWithZero(module)); + } + { + // Error: invalid value (out of device limits) + ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module)); + } + { + // Error: initialized partially + ASSERT_DEVICE_ERROR(TestInitializedPartially(module)); + } + { + // Valid + TestInitializedWithValidValue(module); + } +} + +// Test workgroup size validation without overrides default values specified. +TEST_P(WorkgroupSizeValidationTest, OverridesWithUninitialized) { + wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants(); + { + // Error: uninitialized + ASSERT_DEVICE_ERROR(TestCreatePipeline(module)); + } + { + // Error: invalid value (zero) + ASSERT_DEVICE_ERROR(TestInitializedWithZero(module)); + } + { + // Error: invalid value (out of device limits) + ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module)); + } + { + // Error: initialized partially + ASSERT_DEVICE_ERROR(TestInitializedPartially(module)); + } + { + // Valid + TestInitializedWithValidValue(module); + } +} + +// Test workgroup size validation with only partial dimensions are overrides. +TEST_P(WorkgroupSizeValidationTest, PartialOverrides) { + wgpu::ShaderModule module = SetUpShaderWithPartialConstants(); + { + // Error: uninitialized + ASSERT_DEVICE_ERROR(TestCreatePipeline(module)); + } + { + // Error: invalid value (zero) + std::vector constants{{nullptr, "x", 0}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants)); + } + { + // Error: invalid value (out of device limits) + std::vector constants{ + {nullptr, "x", + static_cast(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants)); + } + { + // Valid + std::vector constants{{nullptr, "x", 16}}; + TestCreatePipeline(module, constants); + } +} + +// Test workgroup size validation after being overrided with invalid values. +TEST_P(WorkgroupSizeValidationTest, ValidationAfterOverride) { + wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants(); + wgpu::Limits supportedLimits = GetSupportedLimits().limits; + { + // Error: exceed maxComputeWorkgroupSizeZ + std::vector constants{ + {nullptr, "x", 1}, + {nullptr, "y", 1}, + {nullptr, "z", static_cast(supportedLimits.maxComputeWorkgroupSizeZ + 1)}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants)); + } + { + // Error: exceed maxComputeInvocationsPerWorkgroup + DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX * + supportedLimits.maxComputeWorkgroupSizeY * + supportedLimits.maxComputeWorkgroupSizeZ > + supportedLimits.maxComputeInvocationsPerWorkgroup); + std::vector constants{ + {nullptr, "x", static_cast(supportedLimits.maxComputeWorkgroupSizeX)}, + {nullptr, "y", static_cast(supportedLimits.maxComputeWorkgroupSizeY)}, + {nullptr, "z", static_cast(supportedLimits.maxComputeWorkgroupSizeZ)}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants)); + } +} + +// Test workgroup size validation after being overrided with invalid values (storage size limits +// validation). +// TODO(tint:1660): re-enable after override can be used as array size. +TEST_P(WorkgroupSizeValidationTest, DISABLED_ValidationAfterOverrideStorageSize) { + wgpu::Limits supportedLimits = GetSupportedLimits().limits; + + constexpr uint32_t kVec4Size = 16; + const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size; + constexpr uint32_t kMat4Size = 64; + const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size; + + auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count, + uint32_t mat4_count) { + std::ostringstream ss; + std::ostringstream body; + ss << "override a: u32;"; + ss << "override b: u32;"; + if (vec4_count > 0) { + ss << "var vec4_data: array, a>;"; + body << "_ = vec4_data;"; + } + if (mat4_count > 0) { + ss << "var mat4_data: array, b>;"; + body << "_ = mat4_data;"; + } + ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }"; + + wgpu::ComputePipelineDescriptor desc; + desc.compute.entryPoint = "main"; + desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); + + std::vector constants{{nullptr, "a", static_cast(vec4_count)}, + {nullptr, "b", static_cast(mat4_count)}}; + desc.compute.constants = constants.data(); + desc.compute.constantCount = constants.size(); + + if (success) { + device.CreateComputePipeline(&desc); + } else { + ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); + } + }; + + CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0); + CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1); +} + +DAWN_INSTANTIATE_TEST(WorkgroupSizeValidationTest, D3D12Backend(), MetalBackend(), VulkanBackend()); diff --git a/src/dawn/tests/unittests/native/ObjectContentHasherTests.cpp b/src/dawn/tests/unittests/native/ObjectContentHasherTests.cpp new file mode 100644 index 0000000000..bf926fe304 --- /dev/null +++ b/src/dawn/tests/unittests/native/ObjectContentHasherTests.cpp @@ -0,0 +1,81 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "dawn/native/ObjectContentHasher.h" +#include "dawn/tests/DawnNativeTest.h" + +namespace dawn::native { + +namespace { + +class ObjectContentHasherTests : public DawnNativeTest {}; + +#define EXPECT_IF_HASH_EQ(eq, a, b) \ + do { \ + ObjectContentHasher ra, rb; \ + ra.Record(a); \ + rb.Record(b); \ + EXPECT_EQ(eq, ra.GetContentHash() == rb.GetContentHash()); \ + } while (0) + +TEST(ObjectContentHasherTests, Pair) { + EXPECT_IF_HASH_EQ(true, (std::pair{"a", 1}), + (std::pair{"a", 1})); + EXPECT_IF_HASH_EQ(false, (std::pair{1, "a"}), + (std::pair{"a", 1})); + EXPECT_IF_HASH_EQ(false, (std::pair{"a", 1}), + (std::pair{"a", 2})); + EXPECT_IF_HASH_EQ(false, (std::pair{"a", 1}), + (std::pair{"b", 1})); +} + +TEST(ObjectContentHasherTests, Vector) { + EXPECT_IF_HASH_EQ(true, (std::vector{0, 1}), (std::vector{0, 1})); + EXPECT_IF_HASH_EQ(false, (std::vector{0, 1}), (std::vector{0, 1, 2})); + EXPECT_IF_HASH_EQ(false, (std::vector{0, 1}), (std::vector{1, 0})); + EXPECT_IF_HASH_EQ(false, (std::vector{0, 1}), (std::vector{})); + EXPECT_IF_HASH_EQ(false, (std::vector{0, 1}), (std::vector{0, 1})); +} + +TEST(ObjectContentHasherTests, Map) { + EXPECT_IF_HASH_EQ(true, (std::map{{"a", 1}, {"b", 2}}), + (std::map{{"b", 2}, {"a", 1}})); + EXPECT_IF_HASH_EQ(false, (std::map{{"a", 1}, {"b", 2}}), + (std::map{{"a", 2}, {"b", 1}})); + EXPECT_IF_HASH_EQ(false, (std::map{{"a", 1}, {"b", 2}}), + (std::map{{"a", 1}, {"b", 2}, {"c", 1}})); + EXPECT_IF_HASH_EQ(false, (std::map{{"a", 1}, {"b", 2}}), + (std::map{})); +} + +TEST(ObjectContentHasherTests, HashCombine) { + ObjectContentHasher ra, rb; + + ra.Record(std::vector{0, 1}); + ra.Record(std::map{{"a", 1}, {"b", 2}}); + + rb.Record(std::map{{"a", 1}, {"b", 2}}); + rb.Record(std::vector{0, 1}); + + EXPECT_NE(ra.GetContentHash(), rb.GetContentHash()); +} + +} // namespace + +} // namespace dawn::native diff --git a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp index 17e25f96ee..1520bb8275 100644 --- a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp +++ b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp @@ -450,87 +450,6 @@ TEST_F(ShaderModuleValidationTest, MaximumInterStageShaderComponents) { } } -// Tests that we validate workgroup size limits. -TEST_F(ShaderModuleValidationTest, ComputeWorkgroupSizeLimits) { - auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) { - std::ostringstream ss; - ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}"; - - wgpu::ComputePipelineDescriptor desc; - desc.compute.entryPoint = "main"; - desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); - - if (success) { - device.CreateComputePipeline(&desc); - } else { - ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); - } - }; - - wgpu::Limits supportedLimits = GetSupportedLimits().limits; - - CheckShaderWithWorkgroupSize(true, 1, 1, 1); - CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1); - CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1); - CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ); - - CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1); - CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1); - CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1); - - // No individual dimension exceeds its limit, but the combined size should definitely exceed the - // total invocation limit. - CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX, - supportedLimits.maxComputeWorkgroupSizeY, - supportedLimits.maxComputeWorkgroupSizeZ); -} - -// Tests that we validate workgroup storage size limits. -TEST_F(ShaderModuleValidationTest, ComputeWorkgroupStorageSizeLimits) { - wgpu::Limits supportedLimits = GetSupportedLimits().limits; - - constexpr uint32_t kVec4Size = 16; - const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size; - constexpr uint32_t kMat4Size = 64; - const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size; - - auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count, - uint32_t mat4_count) { - std::ostringstream ss; - std::ostringstream body; - if (vec4_count > 0) { - ss << "var vec4_data: array, " << vec4_count << ">;"; - body << "_ = vec4_data;"; - } - if (mat4_count > 0) { - ss << "var mat4_data: array, " << mat4_count << ">;"; - body << "_ = mat4_data;"; - } - ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }"; - - wgpu::ComputePipelineDescriptor desc; - desc.compute.entryPoint = "main"; - desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); - - if (success) { - device.CreateComputePipeline(&desc); - } else { - ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); - } - }; - - CheckPipelineWithWorkgroupStorage(true, 1, 1); - CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0); - CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count); - CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1); - CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1); - - CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0); - CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1); - CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1); - CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count); -} - // Test that numeric ID must be unique TEST_F(ShaderModuleValidationTest, OverridableConstantsNumericIDConflicts) { ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"( diff --git a/src/dawn/tests/unittests/validation/UnsafeAPIValidationTests.cpp b/src/dawn/tests/unittests/validation/UnsafeAPIValidationTests.cpp index 10d7a0049e..171e037bee 100644 --- a/src/dawn/tests/unittests/validation/UnsafeAPIValidationTests.cpp +++ b/src/dawn/tests/unittests/validation/UnsafeAPIValidationTests.cpp @@ -37,46 +37,6 @@ class UnsafeAPIValidationTest : public ValidationTest { } }; -// Check that pipeline overridable constants are disallowed as part of unsafe APIs. -// TODO(dawn:1041) Remove when implementation for all backend is added -TEST_F(UnsafeAPIValidationTest, PipelineOverridableConstants) { - // Create the placeholder compute pipeline. - wgpu::ComputePipelineDescriptor pipelineDescBase; - pipelineDescBase.compute.entryPoint = "main"; - - // Control case: shader without overridable constant is allowed. - { - wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase; - pipelineDesc.compute.module = - utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}"); - - device.CreateComputePipeline(&pipelineDesc); - } - - // Error case: shader with overridable constant with default value - { - ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"( -@id(1000) override c0: u32 = 1u; -@id(1000) override c1: u32; - -@compute @workgroup_size(1) fn main() { - _ = c0; - _ = c1; -})")); - } - - // Error case: pipeline stage with constant entry is disallowed - { - wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase; - pipelineDesc.compute.module = - utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}"); - std::vector constants{{nullptr, "c", 1u}}; - pipelineDesc.compute.constants = constants.data(); - pipelineDesc.compute.constantCount = constants.size(); - ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&pipelineDesc)); - } -} - class UnsafeQueryAPIValidationTest : public ValidationTest { protected: WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override { diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc index 2f8d09ac02..4dd3c75d50 100644 --- a/src/tint/inspector/inspector.cc +++ b/src/tint/inspector/inspector.cc @@ -133,6 +133,110 @@ Inspector::Inspector(const Program* program) : program_(program) {} Inspector::~Inspector() = default; +EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) { + EntryPoint entry_point; + TINT_ASSERT(Inspector, func != nullptr); + TINT_ASSERT(Inspector, func->IsEntryPoint()); + + auto* sem = program_->Sem().Get(func); + + entry_point.name = program_->Symbols().NameFor(func->symbol); + entry_point.remapped_name = program_->Symbols().NameFor(func->symbol); + + switch (func->PipelineStage()) { + case ast::PipelineStage::kCompute: { + entry_point.stage = PipelineStage::kCompute; + + auto wgsize = sem->WorkgroupSize(); + if (!wgsize[0].overridable_const && !wgsize[1].overridable_const && + !wgsize[2].overridable_const) { + entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value, wgsize[2].value}; + } + break; + } + case ast::PipelineStage::kFragment: { + entry_point.stage = PipelineStage::kFragment; + break; + } + case ast::PipelineStage::kVertex: { + entry_point.stage = PipelineStage::kVertex; + break; + } + default: { + TINT_UNREACHABLE(Inspector, diagnostics_) + << "invalid pipeline stage for entry point '" << entry_point.name << "'"; + break; + } + } + + for (auto* param : sem->Parameters()) { + AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol), + param->Type(), param->Declaration()->attributes, + entry_point.input_variables); + + entry_point.input_position_used |= ContainsBuiltin( + ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes); + entry_point.front_facing_used |= ContainsBuiltin( + ast::BuiltinValue::kFrontFacing, param->Type(), param->Declaration()->attributes); + entry_point.sample_index_used |= ContainsBuiltin( + ast::BuiltinValue::kSampleIndex, param->Type(), param->Declaration()->attributes); + entry_point.input_sample_mask_used |= ContainsBuiltin( + ast::BuiltinValue::kSampleMask, param->Type(), param->Declaration()->attributes); + entry_point.num_workgroups_used |= ContainsBuiltin( + ast::BuiltinValue::kNumWorkgroups, param->Type(), param->Declaration()->attributes); + } + + if (!sem->ReturnType()->Is()) { + AddEntryPointInOutVariables("", sem->ReturnType(), func->return_type_attributes, + entry_point.output_variables); + + entry_point.output_sample_mask_used = ContainsBuiltin( + ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes); + } + + for (auto* var : sem->TransitivelyReferencedGlobals()) { + auto* decl = var->Declaration(); + + auto name = program_->Symbols().NameFor(decl->symbol); + + auto* global = var->As(); + if (global && global->Declaration()->Is()) { + Override override; + override.name = name; + override.id = global->OverrideId(); + auto* type = var->Type(); + TINT_ASSERT(Inspector, type->is_scalar()); + if (type->is_bool_scalar_or_vector()) { + override.type = Override::Type::kBool; + } else if (type->is_float_scalar()) { + override.type = Override::Type::kFloat32; + } else if (type->is_signed_integer_scalar()) { + override.type = Override::Type::kInt32; + } else if (type->is_unsigned_integer_scalar()) { + override.type = Override::Type::kUint32; + } else { + TINT_UNREACHABLE(Inspector, diagnostics_); + } + + override.is_initialized = global->Declaration()->constructor; + override.is_id_specified = + ast::HasAttribute(global->Declaration()->attributes); + + entry_point.overrides.push_back(override); + } + } + + return entry_point; +} + +EntryPoint Inspector::GetEntryPoint(const std::string& entry_point_name) { + auto* func = FindEntryPointByName(entry_point_name); + if (!func) { + return EntryPoint(); + } + return GetEntryPoint(func); +} + std::vector Inspector::GetEntryPoints() { std::vector result; @@ -141,97 +245,7 @@ std::vector Inspector::GetEntryPoints() { continue; } - auto* sem = program_->Sem().Get(func); - - EntryPoint entry_point; - entry_point.name = program_->Symbols().NameFor(func->symbol); - entry_point.remapped_name = program_->Symbols().NameFor(func->symbol); - - switch (func->PipelineStage()) { - case ast::PipelineStage::kCompute: { - entry_point.stage = PipelineStage::kCompute; - - auto wgsize = sem->WorkgroupSize(); - if (!wgsize[0].overridable_const && !wgsize[1].overridable_const && - !wgsize[2].overridable_const) { - entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value, - wgsize[2].value}; - } - break; - } - case ast::PipelineStage::kFragment: { - entry_point.stage = PipelineStage::kFragment; - break; - } - case ast::PipelineStage::kVertex: { - entry_point.stage = PipelineStage::kVertex; - break; - } - default: { - TINT_UNREACHABLE(Inspector, diagnostics_) - << "invalid pipeline stage for entry point '" << entry_point.name << "'"; - break; - } - } - - for (auto* param : sem->Parameters()) { - AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol), - param->Type(), param->Declaration()->attributes, - entry_point.input_variables); - - entry_point.input_position_used |= ContainsBuiltin( - ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes); - entry_point.front_facing_used |= ContainsBuiltin( - ast::BuiltinValue::kFrontFacing, param->Type(), param->Declaration()->attributes); - entry_point.sample_index_used |= ContainsBuiltin( - ast::BuiltinValue::kSampleIndex, param->Type(), param->Declaration()->attributes); - entry_point.input_sample_mask_used |= ContainsBuiltin( - ast::BuiltinValue::kSampleMask, param->Type(), param->Declaration()->attributes); - entry_point.num_workgroups_used |= ContainsBuiltin( - ast::BuiltinValue::kNumWorkgroups, param->Type(), param->Declaration()->attributes); - } - - if (!sem->ReturnType()->Is()) { - AddEntryPointInOutVariables("", sem->ReturnType(), func->return_type_attributes, - entry_point.output_variables); - - entry_point.output_sample_mask_used = ContainsBuiltin( - ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes); - } - - for (auto* var : sem->TransitivelyReferencedGlobals()) { - auto* decl = var->Declaration(); - - auto name = program_->Symbols().NameFor(decl->symbol); - - auto* global = var->As(); - if (global && global->Declaration()->Is()) { - Override override; - override.name = name; - override.id = global->OverrideId(); - auto* type = var->Type(); - TINT_ASSERT(Inspector, type->is_scalar()); - if (type->is_bool_scalar_or_vector()) { - override.type = Override::Type::kBool; - } else if (type->is_float_scalar()) { - override.type = Override::Type::kFloat32; - } else if (type->is_signed_integer_scalar()) { - override.type = Override::Type::kInt32; - } else if (type->is_unsigned_integer_scalar()) { - override.type = Override::Type::kUint32; - } else { - TINT_UNREACHABLE(Inspector, diagnostics_); - } - - override.is_initialized = global->Declaration()->constructor; - override.is_id_specified = - ast::HasAttribute(global->Declaration()->attributes); - - entry_point.overrides.push_back(override); - } - } - - result.push_back(std::move(entry_point)); + result.push_back(GetEntryPoint(func)); } return result; diff --git a/src/tint/inspector/inspector.h b/src/tint/inspector/inspector.h index f3fe27008a..684852ce28 100644 --- a/src/tint/inspector/inspector.h +++ b/src/tint/inspector/inspector.h @@ -55,6 +55,10 @@ class Inspector { /// @returns vector of entry point information std::vector GetEntryPoints(); + /// @param entry_point name of the entry point to get information about + /// @returns the entry point information + EntryPoint GetEntryPoint(const std::string& entry_point); + /// @returns map of override identifier to initial value std::map GetOverrideDefaultValues(); @@ -230,6 +234,10 @@ class Inspector { /// whenever a set of expressions are resolved to globals. template void GetOriginatingResources(std::array exprs, F&& cb); + + /// @param func the function of the entry point. Must be non-nullptr and true for IsEntryPoint() + /// @returns the entry point information + EntryPoint GetEntryPoint(const tint::ast::Function* func); }; } // namespace tint::inspector diff --git a/src/tint/transform/substitute_override.h b/src/tint/transform/substitute_override.h index 9ea315d92d..940e11d2fa 100644 --- a/src/tint/transform/substitute_override.h +++ b/src/tint/transform/substitute_override.h @@ -20,6 +20,7 @@ #include "tint/override_id.h" +#include "src/tint/reflection.h" #include "src/tint/transform/transform.h" namespace tint::transform { @@ -63,6 +64,9 @@ class SubstituteOverride final : public Castable /// The value is always a double coming into the transform and will be /// converted to the correct type through and initializer. std::unordered_map map; + + /// Reflect the fields of this class so that it can be used by tint::ForeachField() + TINT_REFLECT(map); }; /// Constructor