Use SubstituteOverride transform to implement overrides

Remove the old backend specific implementation for
overrides. Use tint SubstituteOverride transform to replace
overrides with const expressions and use the updated program
at pipeline creation time.

This CL also adds support for overrides used as workgroup size
and related tests. Workgroup size validation now happens
in backend code and at compute pipeline creation time.

Bug: dawn:1504
Change-Id: I7df1fe9c3e358caa23235eacd6d13ba0b2998aec
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99821
Commit-Queue: Shrek Shao <shrekshao@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
shrekshao 2022-09-07 20:09:54 +00:00 committed by Dawn LUCI CQ
parent 23cf74c30e
commit 145337f309
36 changed files with 1246 additions and 698 deletions

View File

@ -25,7 +25,15 @@
namespace {{native_namespace}} { namespace {{native_namespace}} {
// //
// Cache key writers for wgpu structures used in caching. // Streaming readers for wgpu structures.
//
{% macro render_reader(member) %}
{%- set name = member.name.camelCase() -%}
DAWN_TRY(StreamOut(source, &t->{{name}}));
{% endmacro %}
//
// Streaming writers for wgpu structures.
// //
{% macro render_writer(member) %} {% macro render_writer(member) %}
{%- set name = member.name.camelCase() -%} {%- set name = member.name.camelCase() -%}
@ -38,31 +46,50 @@ namespace {{native_namespace}} {
{% endif %} {% endif %}
{% endmacro %} {% endmacro %}
{# Helper macro to render writers. Should be used in a call block to provide additional custom {# Helper macro to render readers and writers. Should be used in a call block to provide additional custom
handling when necessary. The optional `omit` field can be used to omit fields that are either handling when necessary. The optional `omit` field can be used to omit fields that are either
handled in the custom code, or unnecessary in the serialized output. handled in the custom code, or unnecessary in the serialized output.
Example: Example:
{% call render_cache_key_writer("struct name", omits=["omit field"]) %} {% call render_streaming_impl("struct name", writer=true, reader=false, omits=["omit field"]) %}
// Custom C++ code to handle special types/members that are hard to generate code for // Custom C++ code to handle special types/members that are hard to generate code for
{% endcall %} {% endcall %}
One day we should probably make the generator smart enough to generate everything it can
instead of manually adding streaming implementations here.
#} #}
{% macro render_cache_key_writer(json_type, omits=[]) %} {% macro render_streaming_impl(json_type, writer, reader, omits=[]) %}
{%- set cpp_type = types[json_type].name.CamelCase() -%} {%- set cpp_type = types[json_type].name.CamelCase() -%}
template <> {% if reader %}
void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) { template <>
{{ caller() }} MaybeError stream::Stream<{{cpp_type}}>::Read(stream::Source* source, {{cpp_type}}* t) {
{% for member in types[json_type].members %} {{ caller() }}
{%- if not member.name.get() in omits %} {% for member in types[json_type].members %}
{{render_writer(member)}} {% if not member.name.get() in omits %}
{%- endif %} {{render_reader(member)}}
{% endfor %} {% endif %}
} {% endfor %}
return {};
}
{% endif %}
{% if writer %}
template <>
void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) {
{{ caller() }}
{% for member in types[json_type].members %}
{% if not member.name.get() in omits %}
{{render_writer(member)}}
{% endif %}
{% endfor %}
}
{% endif %}
{% endmacro %} {% endmacro %}
{% call render_cache_key_writer("adapter properties") %} {% call render_streaming_impl("adapter properties", true, false) %}
{% endcall %} {% endcall %}
{% call render_cache_key_writer("dawn cache device descriptor") %} {% call render_streaming_impl("dawn cache device descriptor", true, false) %}
{% endcall %}
{% call render_streaming_impl("extent 3D", true, true) %}
{% endcall %} {% endcall %}
} // namespace {{native_namespace}} } // namespace {{native_namespace}}

View File

@ -215,4 +215,22 @@ Limits ApplyLimitTiers(Limits limits) {
return limits; return limits;
} }
#define DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT(type, name) \
{ result.name = limits.name; }
#define DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(MEMBERS) \
MEMBERS(DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT)
LimitsForCompilationRequest LimitsForCompilationRequest::Create(const Limits& limits) {
LimitsForCompilationRequest result;
DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS)
return result;
}
#undef DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT
#undef DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT
template <>
void stream::Stream<LimitsForCompilationRequest>::Write(Sink* s,
const LimitsForCompilationRequest& t) {
t.VisitAll([&](const auto&... members) { StreamIn(s, members...); });
}
} // namespace dawn::native } // namespace dawn::native

View File

@ -16,6 +16,7 @@
#define SRC_DAWN_NATIVE_LIMITS_H_ #define SRC_DAWN_NATIVE_LIMITS_H_
#include "dawn/native/Error.h" #include "dawn/native/Error.h"
#include "dawn/native/VisitableMembers.h"
#include "dawn/native/dawn_platform.h" #include "dawn/native/dawn_platform.h"
namespace dawn::native { namespace dawn::native {
@ -38,6 +39,20 @@ MaybeError ValidateLimits(const Limits& supportedLimits, const Limits& requiredL
// Returns a copy of |limits| where limit tiers are applied. // Returns a copy of |limits| where limit tiers are applied.
Limits ApplyLimitTiers(Limits limits); Limits ApplyLimitTiers(Limits limits);
// If there are new limit member needed at shader compilation time
// Simply append a new X(type, name) here.
#define LIMITS_FOR_COMPILATION_REQUEST_MEMBERS(X) \
X(uint32_t, maxComputeWorkgroupSizeX) \
X(uint32_t, maxComputeWorkgroupSizeY) \
X(uint32_t, maxComputeWorkgroupSizeZ) \
X(uint32_t, maxComputeInvocationsPerWorkgroup) \
X(uint32_t, maxComputeWorkgroupStorageSize)
struct LimitsForCompilationRequest {
static LimitsForCompilationRequest Create(const Limits& limits);
DAWN_VISITABLE_MEMBERS(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS)
};
} // namespace dawn::native } // namespace dawn::native
#endif // SRC_DAWN_NATIVE_LIMITS_H_ #endif // SRC_DAWN_NATIVE_LIMITS_H_

View File

@ -15,7 +15,9 @@
#ifndef SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_ #ifndef SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_
#define SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_ #define SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_
#include <map>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "dawn/common/HashUtils.h" #include "dawn/common/HashUtils.h"
@ -60,6 +62,13 @@ class ObjectContentHasher {
} }
}; };
template <typename T, typename E>
struct RecordImpl<std::map<T, E>> {
static constexpr void Call(ObjectContentHasher* recorder, const std::map<T, E>& map) {
recorder->RecordIterable<std::map<T, E>>(map);
}
};
template <typename IteratorT> template <typename IteratorT>
constexpr void RecordIterable(const IteratorT& iterable) { constexpr void RecordIterable(const IteratorT& iterable) {
for (auto it = iterable.begin(); it != iterable.end(); ++it) { for (auto it = iterable.begin(); it != iterable.end(); ++it) {
@ -67,6 +76,14 @@ class ObjectContentHasher {
} }
} }
template <typename T, typename E>
struct RecordImpl<std::pair<T, E>> {
static constexpr void Call(ObjectContentHasher* recorder, const std::pair<T, E>& pair) {
recorder->Record(pair.first);
recorder->Record(pair.second);
}
};
size_t mContentHash = 0; size_t mContentHash = 0;
}; };

View File

@ -58,12 +58,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout)); DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout));
} }
if (constantCount > 0u && device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs)) {
return DAWN_VALIDATION_ERROR(
"Pipeline overridable constants are disallowed because they are partially "
"implemented.");
}
// Validate if overridable constants exist in shader module // Validate if overridable constants exist in shader module
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor // pipelineBase is not yet constructed at this moment so iterate constants from descriptor
size_t numUninitializedConstants = metadata.uninitializedOverrides.size(); size_t numUninitializedConstants = metadata.uninitializedOverrides.size();
@ -233,6 +227,7 @@ size_t PipelineBase::ComputeContentHash() {
for (SingleShaderStage stage : IterateStages(mStageMask)) { for (SingleShaderStage stage : IterateStages(mStageMask)) {
recorder.Record(mStages[stage].module->GetContentHash()); recorder.Record(mStages[stage].module->GetContentHash());
recorder.Record(mStages[stage].entryPoint); recorder.Record(mStages[stage].entryPoint);
recorder.Record(mStages[stage].constants);
} }
return recorder.GetContentHash(); return recorder.GetContentHash();
@ -248,7 +243,14 @@ bool PipelineBase::EqualForCache(const PipelineBase* a, const PipelineBase* b) {
for (SingleShaderStage stage : IterateStages(a->mStageMask)) { for (SingleShaderStage stage : IterateStages(a->mStageMask)) {
// The module is deduplicated so it can be compared by pointer. // The module is deduplicated so it can be compared by pointer.
if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() || if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() ||
a->mStages[stage].entryPoint != b->mStages[stage].entryPoint) { a->mStages[stage].entryPoint != b->mStages[stage].entryPoint ||
a->mStages[stage].constants.size() != b->mStages[stage].constants.size()) {
return false;
}
// If the constants.size are the same, we still need to compare the key and value.
if (!std::equal(a->mStages[stage].constants.begin(), a->mStages[stage].constants.end(),
b->mStages[stage].constants.begin())) {
return false; return false;
} }
} }

View File

@ -40,9 +40,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
const PipelineLayoutBase* layout, const PipelineLayoutBase* layout,
SingleShaderStage stage); SingleShaderStage stage);
// Use map to make sure constant keys are sorted for creating shader cache keys
using PipelineConstantEntries = std::map<std::string, double>;
struct ProgrammableStage { struct ProgrammableStage {
Ref<ShaderModuleBase> module; Ref<ShaderModuleBase> module;
std::string entryPoint; std::string entryPoint;

View File

@ -20,7 +20,6 @@
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "dawn/common/BitSetIterator.h" #include "dawn/common/BitSetIterator.h"
#include "dawn/common/Constants.h" #include "dawn/common/Constants.h"
#include "dawn/common/HashUtils.h"
#include "dawn/native/BindGroupLayout.h" #include "dawn/native/BindGroupLayout.h"
#include "dawn/native/ChainUtils_autogen.h" #include "dawn/native/ChainUtils_autogen.h"
#include "dawn/native/CompilationMessages.h" #include "dawn/native/CompilationMessages.h"
@ -511,7 +510,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
const DeviceBase* device, const DeviceBase* device,
tint::inspector::Inspector* inspector, tint::inspector::Inspector* inspector,
const tint::inspector::EntryPoint& entryPoint) { const tint::inspector::EntryPoint& entryPoint) {
const CombinedLimits& limits = device->GetLimits();
constexpr uint32_t kMaxInterStageShaderLocation = kMaxInterStageShaderVariables - 1; constexpr uint32_t kMaxInterStageShaderLocation = kMaxInterStageShaderVariables - 1;
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>(); std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
@ -528,10 +526,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
})() })()
if (!entryPoint.overrides.empty()) { if (!entryPoint.overrides.empty()) {
DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs),
"Pipeline overridable constants are disallowed because they "
"are partially implemented.");
const auto& name2Id = inspector->GetNamedOverrideIds(); const auto& name2Id = inspector->GetNamedOverrideIds();
const auto& id2Scalar = inspector->GetOverrideDefaultValues(); const auto& id2Scalar = inspector->GetOverrideDefaultValues();
@ -553,10 +547,10 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
UNREACHABLE(); UNREACHABLE();
} }
} }
EntryPointMetadata::Override override = {id.value, FromTintOverrideType(c.type), EntryPointMetadata::Override override = {id, FromTintOverrideType(c.type),
c.is_initialized, defaultValue}; c.is_initialized, defaultValue};
std::string identifier = c.is_id_specified ? std::to_string(override.id) : c.name; std::string identifier = c.is_id_specified ? std::to_string(override.id.value) : c.name;
metadata->overrides[identifier] = override; metadata->overrides[identifier] = override;
if (!c.is_initialized) { if (!c.is_initialized) {
@ -575,39 +569,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage)); DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
if (metadata->stage == SingleShaderStage::Compute) { if (metadata->stage == SingleShaderStage::Compute) {
auto workgroup_size = entryPoint.workgroup_size;
DAWN_INVALID_IF(
!workgroup_size.has_value(),
"TODO(crbug.com/dawn/1504): Dawn does not currently support @workgroup_size "
"attributes using override-expressions");
DelayedInvalidIf(workgroup_size->x > limits.v1.maxComputeWorkgroupSizeX ||
workgroup_size->y > limits.v1.maxComputeWorkgroupSizeY ||
workgroup_size->z > limits.v1.maxComputeWorkgroupSizeZ,
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
"maximum allowed (%u, %u, %u).",
workgroup_size->x, workgroup_size->y, workgroup_size->z,
limits.v1.maxComputeWorkgroupSizeX, limits.v1.maxComputeWorkgroupSizeY,
limits.v1.maxComputeWorkgroupSizeZ);
// Dimensions have already been validated against their individual limits above.
// Cast to uint64_t to avoid overflow in this multiplication.
uint64_t numInvocations =
static_cast<uint64_t>(workgroup_size->x) * workgroup_size->y * workgroup_size->z;
DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup,
"The total number of workgroup invocations (%u) exceeds the "
"maximum allowed (%u).",
numInvocations, limits.v1.maxComputeInvocationsPerWorkgroup);
const size_t workgroupStorageSize = inspector->GetWorkgroupStorageSize(entryPoint.name);
DelayedInvalidIf(workgroupStorageSize > limits.v1.maxComputeWorkgroupStorageSize,
"The total use of workgroup storage (%u bytes) is larger than "
"the maximum allowed (%u bytes).",
workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize);
metadata->localWorkgroupSize.x = workgroup_size->x;
metadata->localWorkgroupSize.y = workgroup_size->y;
metadata->localWorkgroupSize.z = workgroup_size->z;
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used; metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
} }
@ -883,6 +844,46 @@ MaybeError ReflectShaderUsingTint(const DeviceBase* device,
} }
} // anonymous namespace } // anonymous namespace
ResultOrError<Extent3D> ValidateComputeStageWorkgroupSize(
const tint::Program& program,
const char* entryPointName,
const LimitsForCompilationRequest& limits) {
tint::inspector::Inspector inspector(&program);
// At this point the entry point must exist and must have workgroup size values.
tint::inspector::EntryPoint entryPoint = inspector.GetEntryPoint(entryPointName);
ASSERT(entryPoint.workgroup_size.has_value());
const tint::inspector::WorkgroupSize& workgroup_size = entryPoint.workgroup_size.value();
DAWN_INVALID_IF(workgroup_size.x < 1 || workgroup_size.y < 1 || workgroup_size.z < 1,
"Entry-point uses workgroup_size(%u, %u, %u) that are below the "
"minimum allowed (1, 1, 1).",
workgroup_size.x, workgroup_size.y, workgroup_size.z);
DAWN_INVALID_IF(workgroup_size.x > limits.maxComputeWorkgroupSizeX ||
workgroup_size.y > limits.maxComputeWorkgroupSizeY ||
workgroup_size.z > limits.maxComputeWorkgroupSizeZ,
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
"maximum allowed (%u, %u, %u).",
workgroup_size.x, workgroup_size.y, workgroup_size.z,
limits.maxComputeWorkgroupSizeX, limits.maxComputeWorkgroupSizeY,
limits.maxComputeWorkgroupSizeZ);
uint64_t numInvocations =
static_cast<uint64_t>(workgroup_size.x) * workgroup_size.y * workgroup_size.z;
DAWN_INVALID_IF(numInvocations > limits.maxComputeInvocationsPerWorkgroup,
"The total number of workgroup invocations (%u) exceeds the "
"maximum allowed (%u).",
numInvocations, limits.maxComputeInvocationsPerWorkgroup);
const size_t workgroupStorageSize = inspector.GetWorkgroupStorageSize(entryPointName);
DAWN_INVALID_IF(workgroupStorageSize > limits.maxComputeWorkgroupStorageSize,
"The total use of workgroup storage (%u bytes) is larger than "
"the maximum allowed (%u bytes).",
workgroupStorageSize, limits.maxComputeWorkgroupStorageSize);
return Extent3D{workgroup_size.x, workgroup_size.y, workgroup_size.z};
}
ShaderModuleParseResult::ShaderModuleParseResult() = default; ShaderModuleParseResult::ShaderModuleParseResult() = default;
ShaderModuleParseResult::~ShaderModuleParseResult() = default; ShaderModuleParseResult::~ShaderModuleParseResult() = default;
@ -1200,11 +1201,4 @@ MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult
return {}; return {};
} }
size_t PipelineLayoutEntryPointPairHashFunc::operator()(
const PipelineLayoutEntryPointPair& pair) const {
size_t hash = 0;
HashCombine(&hash, pair.first, pair.second);
return hash;
}
} // namespace dawn::native } // namespace dawn::native

View File

@ -33,10 +33,12 @@
#include "dawn/native/Format.h" #include "dawn/native/Format.h"
#include "dawn/native/Forward.h" #include "dawn/native/Forward.h"
#include "dawn/native/IntegerTypes.h" #include "dawn/native/IntegerTypes.h"
#include "dawn/native/Limits.h"
#include "dawn/native/ObjectBase.h" #include "dawn/native/ObjectBase.h"
#include "dawn/native/PerStage.h" #include "dawn/native/PerStage.h"
#include "dawn/native/VertexFormat.h" #include "dawn/native/VertexFormat.h"
#include "dawn/native/dawn_platform.h" #include "dawn/native/dawn_platform.h"
#include "tint/override_id.h"
namespace tint { namespace tint {
@ -76,10 +78,8 @@ enum class InterpolationSampling {
Sample, Sample,
}; };
using PipelineLayoutEntryPointPair = std::pair<const PipelineLayoutBase*, std::string>; // Use map to make sure constant keys are sorted for creating shader cache keys
struct PipelineLayoutEntryPointPairHashFunc { using PipelineConstantEntries = std::map<std::string, double>;
size_t operator()(const PipelineLayoutEntryPointPair& pair) const;
};
// A map from name to EntryPointMetadata. // A map from name to EntryPointMetadata.
using EntryPointMetadataTable = using EntryPointMetadataTable =
@ -108,6 +108,13 @@ MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
const EntryPointMetadata& entryPoint, const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout); const PipelineLayoutBase* layout);
// Return extent3D with workgroup size dimension info if it is valid
// width = x, height = y, depthOrArrayLength = z
ResultOrError<Extent3D> ValidateComputeStageWorkgroupSize(
const tint::Program& program,
const char* entryPointName,
const LimitsForCompilationRequest& limits);
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint, RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout); const PipelineLayoutBase* layout);
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform, ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
@ -204,14 +211,12 @@ struct EntryPointMetadata {
std::bitset<kMaxInterStageShaderVariables> usedInterStageVariables; std::bitset<kMaxInterStageShaderVariables> usedInterStageVariables;
std::array<InterStageVariableInfo, kMaxInterStageShaderVariables> interStageVariables; std::array<InterStageVariableInfo, kMaxInterStageShaderVariables> interStageVariables;
// The local workgroup size declared for a compute entry point (or 0s otehrwise).
Origin3D localWorkgroupSize;
// The shader stage for this binding. // The shader stage for this binding.
SingleShaderStage stage; SingleShaderStage stage;
struct Override { struct Override {
uint32_t id; tint::OverrideId id;
// Match tint::inspector::Override::Type // Match tint::inspector::Override::Type
// Bool is defined as a macro on linux X11 and cannot compile // Bool is defined as a macro on linux X11 and cannot compile
enum class Type { Boolean, Float32, Uint32, Int32 } type; enum class Type { Boolean, Float32, Uint32, Int32 } type;
@ -273,6 +278,7 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const; bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const;
}; };
// This returns tint program before running transforms.
const tint::Program* GetTintProgram() const; const tint::Program* GetTintProgram() const;
void APIGetCompilationInfo(wgpu::CompilationInfoCallback callback, void* userdata); void APIGetCompilationInfo(wgpu::CompilationInfoCallback callback, void* userdata);

View File

@ -65,6 +65,23 @@ void stream::Stream<tint::transform::VertexPulling::Config>::Write(
StreamInTintObject(cfg, sink); StreamInTintObject(cfg, sink);
} }
// static
template <>
void stream::Stream<tint::transform::SubstituteOverride::Config>::Write(
stream::Sink* sink,
const tint::transform::SubstituteOverride::Config& cfg) {
StreamInTintObject(cfg, sink);
}
// static
template <>
void stream::Stream<tint::OverrideId>::Write(stream::Sink* sink, const tint::OverrideId& id) {
// TODO(tint:1640): fix the include build issues and use StreamInTintObject instead.
static_assert(offsetof(tint::OverrideId, value) == 0,
"Please update serialization for tint::OverrideId");
StreamIn(sink, id.value);
}
// static // static
template <> template <>
void stream::Stream<tint::transform::VertexBufferLayoutDescriptor>::Write( void stream::Stream<tint::transform::VertexBufferLayoutDescriptor>::Write(

View File

@ -16,6 +16,7 @@
#include "dawn/native/BindGroupLayout.h" #include "dawn/native/BindGroupLayout.h"
#include "dawn/native/Device.h" #include "dawn/native/Device.h"
#include "dawn/native/Pipeline.h"
#include "dawn/native/PipelineLayout.h" #include "dawn/native/PipelineLayout.h"
#include "dawn/native/RenderPipeline.h" #include "dawn/native/RenderPipeline.h"
@ -183,6 +184,21 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
return cfg; return cfg;
} }
tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig(
const ProgrammableStage& stage) {
const EntryPointMetadata& metadata = *stage.metadata;
const auto& constants = stage.constants;
tint::transform::SubstituteOverride::Config cfg;
for (const auto& [key, value] : constants) {
const auto& o = metadata.overrides.at(key);
cfg.map.insert({o.id, value});
}
return cfg;
}
} // namespace dawn::native } // namespace dawn::native
namespace tint::sem { namespace tint::sem {

View File

@ -26,6 +26,7 @@ namespace dawn::native {
class DeviceBase; class DeviceBase;
class PipelineLayoutBase; class PipelineLayoutBase;
struct ProgrammableStage;
class RenderPipelineBase; class RenderPipelineBase;
// Indicates that for the lifetime of this object tint internal compiler errors should be // Indicates that for the lifetime of this object tint internal compiler errors should be
@ -47,6 +48,9 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
const std::string_view& entryPoint, const std::string_view& entryPoint,
BindGroupIndex pullingBufferBindingSet); BindGroupIndex pullingBufferBindingSet);
tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig(
const ProgrammableStage& stage);
} // namespace dawn::native } // namespace dawn::native
namespace tint::sem { namespace tint::sem {

View File

@ -351,8 +351,6 @@ MaybeError RenderPipeline::Initialize() {
D3D12_GRAPHICS_PIPELINE_STATE_DESC descriptorD3D12 = {}; D3D12_GRAPHICS_PIPELINE_STATE_DESC descriptorD3D12 = {};
PerStage<ProgrammableStage> pipelineStages = GetAllStages();
PerStage<D3D12_SHADER_BYTECODE*> shaders; PerStage<D3D12_SHADER_BYTECODE*> shaders;
shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS; shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS;
shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS; shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS;
@ -360,8 +358,9 @@ MaybeError RenderPipeline::Initialize() {
PerStage<CompiledShader> compiledShader; PerStage<CompiledShader> compiledShader;
for (auto stage : IterateStages(GetStageMask())) { for (auto stage : IterateStages(GetStageMask())) {
DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(pipelineStages[stage].module) const ProgrammableStage& programmableStage = GetStage(stage);
->Compile(pipelineStages[stage], stage, DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(programmableStage.module)
->Compile(programmableStage, stage,
ToBackend(GetLayout()), compileFlags)); ToBackend(GetLayout()), compileFlags));
*shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode(); *shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
} }

View File

@ -65,100 +65,35 @@ namespace dawn::native::d3d12 {
namespace { namespace {
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
std::ostringstream out;
out.precision(n);
out << std::fixed << v;
return out.str();
}
std::string GetHLSLValueString(EntryPointMetadata::Override::Type dawnType,
const OverrideScalar* entry,
double value = 0) {
switch (dawnType) {
case EntryPointMetadata::Override::Type::Boolean:
return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
case EntryPointMetadata::Override::Type::Float32:
return FloatToStringWithPrecision(entry ? entry->f32 : static_cast<float>(value));
case EntryPointMetadata::Override::Type::Int32:
return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
case EntryPointMetadata::Override::Type::Uint32:
return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
default:
UNREACHABLE();
}
}
constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
using DefineStrings = std::vector<std::pair<std::string, std::string>>;
DefineStrings GetOverridableConstantsDefines(
const PipelineConstantEntries& pipelineConstantEntries,
const EntryPointMetadata::OverridesMap& shaderEntryPointConstants) {
DefineStrings defineStrings;
std::unordered_set<std::string> overriddenConstants;
// Set pipeline overridden values
for (const auto& [name, value] : pipelineConstantEntries) {
overriddenConstants.insert(name);
// This is already validated so `name` must exist
const auto& moduleConstant = shaderEntryPointConstants.at(name);
defineStrings.emplace_back(
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString(moduleConstant.type, nullptr, value));
}
// Set shader initialized default values
for (const auto& iter : shaderEntryPointConstants) {
const std::string& name = iter.first;
if (overriddenConstants.count(name) != 0) {
// This constant already has overridden value
continue;
}
const auto& moduleConstant = shaderEntryPointConstants.at(name);
// Uninitialized default values are okay since they ar only defined to pass
// compilation but not used
defineStrings.emplace_back(
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
}
return defineStrings;
}
enum class Compiler { FXC, DXC }; enum class Compiler { FXC, DXC };
#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \ #define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
X(const tint::Program*, inputProgram) \ X(const tint::Program*, inputProgram) \
X(std::string_view, entryPointName) \ X(std::string_view, entryPointName) \
X(SingleShaderStage, stage) \ X(SingleShaderStage, stage) \
X(uint32_t, shaderModel) \ X(uint32_t, shaderModel) \
X(uint32_t, compileFlags) \ X(uint32_t, compileFlags) \
X(Compiler, compiler) \ X(Compiler, compiler) \
X(uint64_t, compilerVersion) \ X(uint64_t, compilerVersion) \
X(std::wstring_view, dxcShaderProfile) \ X(std::wstring_view, dxcShaderProfile) \
X(std::string_view, fxcShaderProfile) \ X(std::string_view, fxcShaderProfile) \
X(pD3DCompile, d3dCompile) \ X(pD3DCompile, d3dCompile) \
X(IDxcLibrary*, dxcLibrary) \ X(IDxcLibrary*, dxcLibrary) \
X(IDxcCompiler*, dxcCompiler) \ X(IDxcCompiler*, dxcCompiler) \
X(uint32_t, firstIndexOffsetShaderRegister) \ X(uint32_t, firstIndexOffsetShaderRegister) \
X(uint32_t, firstIndexOffsetRegisterSpace) \ X(uint32_t, firstIndexOffsetRegisterSpace) \
X(bool, usesNumWorkgroups) \ X(bool, usesNumWorkgroups) \
X(uint32_t, numWorkgroupsShaderRegister) \ X(uint32_t, numWorkgroupsShaderRegister) \
X(uint32_t, numWorkgroupsRegisterSpace) \ X(uint32_t, numWorkgroupsRegisterSpace) \
X(DefineStrings, defineStrings) \ X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \ X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \
X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \ X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \
X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \ X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \
X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \ X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
X(bool, disableSymbolRenaming) \ X(LimitsForCompilationRequest, limits) \
X(bool, isRobustnessEnabled) \ X(bool, disableSymbolRenaming) \
X(bool, disableWorkgroupInit) \ X(bool, isRobustnessEnabled) \
X(bool, disableWorkgroupInit) \
X(bool, dumpShaders) X(bool, dumpShaders)
#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \ #define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \
@ -170,8 +105,7 @@ enum class Compiler { FXC, DXC };
X(std::string_view, fxcShaderProfile) \ X(std::string_view, fxcShaderProfile) \
X(pD3DCompile, d3dCompile) \ X(pD3DCompile, d3dCompile) \
X(IDxcLibrary*, dxcLibrary) \ X(IDxcLibrary*, dxcLibrary) \
X(IDxcCompiler*, dxcCompiler) \ X(IDxcCompiler*, dxcCompiler)
X(DefineStrings, defineStrings)
DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){}; DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){};
#undef HLSL_COMPILATION_REQUEST_MEMBERS #undef HLSL_COMPILATION_REQUEST_MEMBERS
@ -255,25 +189,11 @@ ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(const D3DBytecodeCompilationReq
std::vector<const wchar_t*> arguments = std::vector<const wchar_t*> arguments =
GetDXCArguments(r.compileFlags, r.hasShaderFloat16Feature); GetDXCArguments(r.compileFlags, r.hasShaderFloat16Feature);
// Build defines for overridable constants
std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
defineStrings.reserve(r.defineStrings.size());
for (const auto& [name, value] : r.defineStrings) {
defineStrings.emplace_back(UTF8ToWStr(name.c_str()), UTF8ToWStr(value.c_str()));
}
std::vector<DxcDefine> dxcDefines;
dxcDefines.reserve(defineStrings.size());
for (const auto& [name, value] : defineStrings) {
dxcDefines.push_back({name.c_str(), value.c_str()});
}
ComPtr<IDxcOperationResult> result; ComPtr<IDxcOperationResult> result;
DAWN_TRY(CheckHRESULT( DAWN_TRY(CheckHRESULT(r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), r.dxcShaderProfile.data(), arguments.data(),
r.dxcShaderProfile.data(), arguments.data(), arguments.size(), arguments.size(), nullptr, 0, nullptr, &result),
dxcDefines.data(), dxcDefines.size(), nullptr, &result), "DXC compile"));
"DXC compile"));
HRESULT hr; HRESULT hr;
DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status")); DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
@ -359,20 +279,7 @@ ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const D3DBytecodeCompilationReq
ComPtr<ID3DBlob> compiledShader; ComPtr<ID3DBlob> compiledShader;
ComPtr<ID3DBlob> errors; ComPtr<ID3DBlob> errors;
// Build defines for overridable constants DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
const D3D_SHADER_MACRO* pDefines = nullptr;
std::vector<D3D_SHADER_MACRO> fxcDefines;
if (r.defineStrings.size() > 0) {
fxcDefines.reserve(r.defineStrings.size() + 1);
for (const auto& [name, value] : r.defineStrings) {
fxcDefines.push_back({name.c_str(), value.c_str()});
}
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
fxcDefines.push_back({nullptr, nullptr});
pDefines = fxcDefines.data();
}
DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines,
nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(), nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(),
r.compileFlags, 0, &compiledShader, &errors)), r.compileFlags, 0, &compiledShader, &errors)),
"D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer())); "D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer()));
@ -420,6 +327,14 @@ ResultOrError<std::string> TranslateToHLSL(
tint::transform::Renamer::Target::kHlslKeywords); tint::transform::Renamer::Target::kHlslKeywords);
} }
if (r.substituteOverrideConfig) {
// This needs to run after SingleEntryPoint transform to get rid of overrides not used for
// the current entry point.
transformManager.Add<tint::transform::SubstituteOverride>();
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
std::move(r.substituteOverrideConfig).value());
}
// D3D12 registers like `t3` and `c3` have the same bindingOffset number in // D3D12 registers like `t3` and `c3` have the same bindingOffset number in
// the remapping but should not be considered a collision because they have // the remapping but should not be considered a collision because they have
// different types. // different types.
@ -450,6 +365,13 @@ ResultOrError<std::string> TranslateToHLSL(
return DAWN_VALIDATION_ERROR("Transform output missing renamer data."); return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
} }
if (r.stage == SingleShaderStage::Compute) {
// Validate workgroup size after program runs transforms.
Extent3D _;
DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
transformedProgram, remappedEntryPointName->data(), r.limits));
}
if (r.stage == SingleShaderStage::Vertex) { if (r.stage == SingleShaderStage::Vertex) {
if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) { if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
*usesVertexOrInstanceIndex = data->has_vertex_or_instance_index; *usesVertexOrInstanceIndex = data->has_vertex_or_instance_index;
@ -555,8 +477,7 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
req.bytecode.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16); req.bytecode.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
req.bytecode.compileFlags = compileFlags; req.bytecode.compileFlags = compileFlags;
req.bytecode.defineStrings =
GetOverridableConstantsDefines(programmableStage.constants, entryPoint.overrides);
if (device->IsToggleEnabled(Toggle::UseDXC)) { if (device->IsToggleEnabled(Toggle::UseDXC)) {
req.bytecode.compiler = Compiler::DXC; req.bytecode.compiler = Compiler::DXC;
req.bytecode.dxcLibrary = device->GetDxcLibrary().Get(); req.bytecode.dxcLibrary = device->GetDxcLibrary().Get();
@ -645,6 +566,11 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
} }
} }
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
if (!programmableStage.metadata->overrides.empty()) {
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
}
req.hlsl.inputProgram = GetTintProgram(); req.hlsl.inputProgram = GetTintProgram();
req.hlsl.entryPointName = programmableStage.entryPoint.c_str(); req.hlsl.entryPointName = programmableStage.entryPoint.c_str();
req.hlsl.stage = stage; req.hlsl.stage = stage;
@ -657,6 +583,10 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
req.hlsl.remappedAccessControls = std::move(remappedAccessControls); req.hlsl.remappedAccessControls = std::move(remappedAccessControls);
req.hlsl.newBindingsMap = BuildExternalTextureTransformBindings(layout); req.hlsl.newBindingsMap = BuildExternalTextureTransformBindings(layout);
req.hlsl.arrayLengthFromUniform = std::move(arrayLengthFromUniform); req.hlsl.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
req.hlsl.substituteOverrideConfig = std::move(substituteOverrideConfig);
const CombinedLimits& limits = device->GetLimits();
req.hlsl.limits = LimitsForCompilationRequest::Create(limits.v1);
CacheResult<CompiledShader> compiledShader; CacheResult<CompiledShader> compiledShader;
DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob, DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob,

View File

@ -40,8 +40,9 @@ MaybeError ComputePipeline::Initialize() {
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute); const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
ShaderModule::MetalFunctionData computeData; ShaderModule::MetalFunctionData computeData;
DAWN_TRY(CreateMTLFunction(computeStage, SingleShaderStage::Compute, ToBackend(GetLayout()), DAWN_TRY(ToBackend(computeStage.module.Get())
&computeData)); ->CreateFunction(SingleShaderStage::Compute, computeStage, ToBackend(GetLayout()),
&computeData));
NSError* error = nullptr; NSError* error = nullptr;
mMtlComputePipelineState.Acquire( mMtlComputePipelineState.Acquire(
@ -53,8 +54,7 @@ MaybeError ComputePipeline::Initialize() {
ASSERT(mMtlComputePipelineState != nil); ASSERT(mMtlComputePipelineState != nil);
// Copy over the local workgroup size as it is passed to dispatch explicitly in Metal // Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
Origin3D localSize = GetStage(SingleShaderStage::Compute).metadata->localWorkgroupSize; mLocalWorkgroupSize = computeData.localWorkgroupSize;
mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z);
mRequiresStorageBufferLength = computeData.needsStorageBufferLength; mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
mWorkgroupAllocations = std::move(computeData.workgroupAllocations); mWorkgroupAllocations = std::move(computeData.workgroupAllocations);

View File

@ -340,8 +340,9 @@ MaybeError RenderPipeline::Initialize() {
const PerStage<ProgrammableStage>& allStages = GetAllStages(); const PerStage<ProgrammableStage>& allStages = GetAllStages();
const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex]; const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex];
ShaderModule::MetalFunctionData vertexData; ShaderModule::MetalFunctionData vertexData;
DAWN_TRY(CreateMTLFunction(vertexStage, SingleShaderStage::Vertex, ToBackend(GetLayout()), DAWN_TRY(ToBackend(vertexStage.module.Get())
&vertexData, 0xFFFFFFFF, this)); ->CreateFunction(SingleShaderStage::Vertex, vertexStage, ToBackend(GetLayout()),
&vertexData, 0xFFFFFFFF, this));
descriptorMTL.vertexFunction = vertexData.function.Get(); descriptorMTL.vertexFunction = vertexData.function.Get();
if (vertexData.needsStorageBufferLength) { if (vertexData.needsStorageBufferLength) {
@ -351,8 +352,9 @@ MaybeError RenderPipeline::Initialize() {
if (GetStageMask() & wgpu::ShaderStage::Fragment) { if (GetStageMask() & wgpu::ShaderStage::Fragment) {
const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment]; const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment];
ShaderModule::MetalFunctionData fragmentData; ShaderModule::MetalFunctionData fragmentData;
DAWN_TRY(CreateMTLFunction(fragmentStage, SingleShaderStage::Fragment, DAWN_TRY(ToBackend(fragmentStage.module.Get())
ToBackend(GetLayout()), &fragmentData, GetSampleMask())); ->CreateFunction(SingleShaderStage::Fragment, fragmentStage,
ToBackend(GetLayout()), &fragmentData, GetSampleMask()));
descriptorMTL.fragmentFunction = fragmentData.function.Get(); descriptorMTL.fragmentFunction = fragmentData.function.Get();
if (fragmentData.needsStorageBufferLength) { if (fragmentData.needsStorageBufferLength) {

View File

@ -25,6 +25,10 @@
#import <Metal/Metal.h> #import <Metal/Metal.h>
namespace dawn::native {
struct ProgrammableStage;
}
namespace dawn::native::metal { namespace dawn::native::metal {
class Device; class Device;
@ -42,15 +46,13 @@ class ShaderModule final : public ShaderModuleBase {
NSPRef<id<MTLFunction>> function; NSPRef<id<MTLFunction>> function;
bool needsStorageBufferLength; bool needsStorageBufferLength;
std::vector<uint32_t> workgroupAllocations; std::vector<uint32_t> workgroupAllocations;
MTLSize localWorkgroupSize;
}; };
// MTLFunctionConstantValues needs @available tag to compile MaybeError CreateFunction(SingleShaderStage stage,
// Use id (like void*) in function signature as workaround and do static cast inside const ProgrammableStage& programmableStage,
MaybeError CreateFunction(const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout, const PipelineLayout* layout,
MetalFunctionData* out, MetalFunctionData* out,
id constantValues = nil,
uint32_t sampleMask = 0xFFFFFFFF, uint32_t sampleMask = 0xFFFFFFFF,
const RenderPipeline* renderPipeline = nullptr); const RenderPipeline* renderPipeline = nullptr);

View File

@ -35,17 +35,20 @@ namespace {
using OptionalVertexPullingTransformConfig = std::optional<tint::transform::VertexPulling::Config>; using OptionalVertexPullingTransformConfig = std::optional<tint::transform::VertexPulling::Config>;
#define MSL_COMPILATION_REQUEST_MEMBERS(X) \ #define MSL_COMPILATION_REQUEST_MEMBERS(X) \
X(const tint::Program*, inputProgram) \ X(SingleShaderStage, stage) \
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \ X(const tint::Program*, inputProgram) \
X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \ X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \ X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \
X(std::string, entryPointName) \ X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \
X(uint32_t, sampleMask) \ X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
X(bool, emitVertexPointSize) \ X(LimitsForCompilationRequest, limits) \
X(bool, isRobustnessEnabled) \ X(std::string, entryPointName) \
X(bool, disableSymbolRenaming) \ X(uint32_t, sampleMask) \
X(bool, disableWorkgroupInit) \ X(bool, emitVertexPointSize) \
X(bool, isRobustnessEnabled) \
X(bool, disableSymbolRenaming) \
X(bool, disableWorkgroupInit) \
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform) X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS); DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS);
@ -53,12 +56,13 @@ DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS);
using WorkgroupAllocations = std::vector<uint32_t>; using WorkgroupAllocations = std::vector<uint32_t>;
#define MSL_COMPILATION_MEMBERS(X) \ #define MSL_COMPILATION_MEMBERS(X) \
X(std::string, msl) \ X(std::string, msl) \
X(std::string, remappedEntryPointName) \ X(std::string, remappedEntryPointName) \
X(bool, needsStorageBufferLength) \ X(bool, needsStorageBufferLength) \
X(bool, hasInvariantAttribute) \ X(bool, hasInvariantAttribute) \
X(WorkgroupAllocations, workgroupAllocations) X(WorkgroupAllocations, workgroupAllocations) \
X(Extent3D, localWorkgroupSize)
DAWN_SERIALIZABLE(struct, MslCompilation, MSL_COMPILATION_MEMBERS){}; DAWN_SERIALIZABLE(struct, MslCompilation, MSL_COMPILATION_MEMBERS){};
#undef MSL_COMPILATION_MEMBERS #undef MSL_COMPILATION_MEMBERS
@ -92,13 +96,14 @@ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
namespace { namespace {
ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device, ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(
const tint::Program* inputProgram, DeviceBase* device,
const char* entryPointName, const ProgrammableStage& programmableStage,
SingleShaderStage stage, SingleShaderStage stage,
const PipelineLayout* layout, const PipelineLayout* layout,
uint32_t sampleMask, ShaderModule::MetalFunctionData* out,
const RenderPipeline* renderPipeline) { uint32_t sampleMask,
const RenderPipeline* renderPipeline) {
ScopedTintICEHandler scopedICEHandler(device); ScopedTintICEHandler scopedICEHandler(device);
std::ostringstream errorStream; std::ostringstream errorStream;
@ -137,7 +142,7 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
if (stage == SingleShaderStage::Vertex && if (stage == SingleShaderStage::Vertex &&
device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) { device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
vertexPullingTransformConfig = BuildVertexPullingTransformConfig( vertexPullingTransformConfig = BuildVertexPullingTransformConfig(
*renderPipeline, entryPointName, kPullingBufferBindingSet); *renderPipeline, programmableStage.entryPoint.c_str(), kPullingBufferBindingSet);
for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot); uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot);
@ -152,12 +157,19 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
} }
} }
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
if (!programmableStage.metadata->overrides.empty()) {
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
}
MslCompilationRequest req = {}; MslCompilationRequest req = {};
req.inputProgram = inputProgram; req.stage = stage;
req.inputProgram = programmableStage.module->GetTintProgram();
req.bindingPoints = std::move(bindingPoints); req.bindingPoints = std::move(bindingPoints);
req.externalTextureBindings = std::move(externalTextureBindings); req.externalTextureBindings = std::move(externalTextureBindings);
req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig); req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig);
req.entryPointName = entryPointName; req.substituteOverrideConfig = std::move(substituteOverrideConfig);
req.entryPointName = programmableStage.entryPoint.c_str();
req.sampleMask = sampleMask; req.sampleMask = sampleMask;
req.emitVertexPointSize = req.emitVertexPointSize =
stage == SingleShaderStage::Vertex && stage == SingleShaderStage::Vertex &&
@ -166,6 +178,9 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming); req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform()); req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform());
const CombinedLimits& limits = device->GetLimits();
req.limits = LimitsForCompilationRequest::Create(limits.v1);
CacheResult<MslCompilation> mslCompilation; CacheResult<MslCompilation> mslCompilation;
DAWN_TRY_LOAD_OR_RUN( DAWN_TRY_LOAD_OR_RUN(
mslCompilation, device, std::move(req), MslCompilation::FromBlob, mslCompilation, device, std::move(req), MslCompilation::FromBlob,
@ -190,6 +205,14 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
std::move(r.vertexPullingTransformConfig).value()); std::move(r.vertexPullingTransformConfig).value());
} }
if (r.substituteOverrideConfig) {
// This needs to run after SingleEntryPoint transform to get rid of overrides not
// used for the current entry point.
transformManager.Add<tint::transform::SubstituteOverride>();
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
std::move(r.substituteOverrideConfig).value());
}
if (r.isRobustnessEnabled) { if (r.isRobustnessEnabled) {
transformManager.Add<tint::transform::Robustness>(); transformManager.Add<tint::transform::Robustness>();
} }
@ -230,6 +253,13 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
return DAWN_VALIDATION_ERROR("Transform output missing renamer data."); return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
} }
Extent3D localSize{0, 0, 0};
if (r.stage == SingleShaderStage::Compute) {
// Validate workgroup size after program runs transforms.
DAWN_TRY_ASSIGN(localSize, ValidateComputeStageWorkgroupSize(
program, remappedEntryPointName.data(), r.limits));
}
tint::writer::msl::Options options; tint::writer::msl::Options options;
options.buffer_size_ubo_index = kBufferLengthBufferSlot; options.buffer_size_ubo_index = kBufferLengthBufferSlot;
options.fixed_sample_mask = r.sampleMask; options.fixed_sample_mask = r.sampleMask;
@ -258,6 +288,7 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
result.needs_storage_buffer_sizes, result.needs_storage_buffer_sizes,
result.has_invariant_attribute, result.has_invariant_attribute,
std::move(workgroupAllocations), std::move(workgroupAllocations),
localSize,
}}; }};
}); });
@ -272,11 +303,10 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
} // namespace } // namespace
MaybeError ShaderModule::CreateFunction(const char* entryPointName, MaybeError ShaderModule::CreateFunction(SingleShaderStage stage,
SingleShaderStage stage, const ProgrammableStage& programmableStage,
const PipelineLayout* layout, const PipelineLayout* layout,
ShaderModule::MetalFunctionData* out, ShaderModule::MetalFunctionData* out,
id constantValuesPointer,
uint32_t sampleMask, uint32_t sampleMask,
const RenderPipeline* renderPipeline) { const RenderPipeline* renderPipeline) {
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleMTL::CreateFunction"); TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleMTL::CreateFunction");
@ -284,16 +314,21 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
ASSERT(!IsError()); ASSERT(!IsError());
ASSERT(out); ASSERT(out);
const char* entryPointName = programmableStage.entryPoint.c_str();
// Vertex stages must specify a renderPipeline // Vertex stages must specify a renderPipeline
if (stage == SingleShaderStage::Vertex) { if (stage == SingleShaderStage::Vertex) {
ASSERT(renderPipeline != nullptr); ASSERT(renderPipeline != nullptr);
} }
CacheResult<MslCompilation> mslCompilation; CacheResult<MslCompilation> mslCompilation;
DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), GetTintProgram(), entryPointName, DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), programmableStage, stage, layout,
stage, layout, sampleMask, renderPipeline)); out, sampleMask, renderPipeline));
out->needsStorageBufferLength = mslCompilation->needsStorageBufferLength; out->needsStorageBufferLength = mslCompilation->needsStorageBufferLength;
out->workgroupAllocations = std::move(mslCompilation->workgroupAllocations); out->workgroupAllocations = std::move(mslCompilation->workgroupAllocations);
out->localWorkgroupSize = MTLSizeMake(mslCompilation->localWorkgroupSize.width,
mslCompilation->localWorkgroupSize.height,
mslCompilation->localWorkgroupSize.depthOrArrayLayers);
NSRef<NSString> mslSource = NSRef<NSString> mslSource =
AcquireNSRef([[NSString alloc] initWithUTF8String:mslCompilation->msl.c_str()]); AcquireNSRef([[NSString alloc] initWithUTF8String:mslCompilation->msl.c_str()]);
@ -327,25 +362,7 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
{ {
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "MTLLibrary::newFunctionWithName"); TRACE_EVENT0(GetDevice()->GetPlatform(), General, "MTLLibrary::newFunctionWithName");
if (constantValuesPointer != nil) { out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
if (@available(macOS 10.12, *)) {
MTLFunctionConstantValues* constantValues = constantValuesPointer;
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()
constantValues:constantValues
error:&error]);
if (error != nullptr) {
if (error.code != MTLLibraryErrorCompileWarning) {
return DAWN_VALIDATION_ERROR("Function compile error: %s",
[error.localizedDescription UTF8String]);
}
}
ASSERT(out->function != nil);
} else {
UNREACHABLE();
}
} else {
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
}
} }
if (BlobCache* cache = GetDevice()->GetBlobCache()) { if (BlobCache* cache = GetDevice()->GetBlobCache()) {

View File

@ -68,15 +68,6 @@ void EnsureDestinationTextureInitialized(CommandRecordingContext* commandContext
MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect); MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect);
// Helper function to create function with constant values wrapped in
// if available branch
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
SingleShaderStage singleShaderStage,
PipelineLayout* pipelineLayout,
ShaderModule::MetalFunctionData* functionData,
uint32_t sampleMask = 0xFFFFFFFF,
const RenderPipeline* renderPipeline = nullptr);
// Allow use MTLStoreActionStoreAndMultismapleResolve because the logic in the backend is // Allow use MTLStoreActionStoreAndMultismapleResolve because the logic in the backend is
// first to compute what the "best" Metal render pass descriptor is, then fix it up if we // first to compute what the "best" Metal render pass descriptor is, then fix it up if we
// are not on macOS 10.12 (i.e. the EmulateStoreAndMSAAResolve toggle is on). // are not on macOS 10.12 (i.e. the EmulateStoreAndMSAAResolve toggle is on).

View File

@ -323,104 +323,6 @@ MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect) {
return MTLBlitOptionNone; return MTLBlitOptionNone;
} }
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
SingleShaderStage singleShaderStage,
PipelineLayout* pipelineLayout,
ShaderModule::MetalFunctionData* functionData,
uint32_t sampleMask,
const RenderPipeline* renderPipeline) {
ShaderModule* shaderModule = ToBackend(programmableStage.module.Get());
const char* shaderEntryPoint = programmableStage.entryPoint.c_str();
const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint);
if (entryPointMetadata.overrides.size() == 0) {
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout,
functionData, nil, sampleMask, renderPipeline));
return {};
}
if (@available(macOS 10.12, *)) {
// MTLFunctionConstantValues can only be created within the if available branch
NSRef<MTLFunctionConstantValues> constantValues =
AcquireNSRef([MTLFunctionConstantValues new]);
std::unordered_set<std::string> overriddenConstants;
auto switchType = [&](EntryPointMetadata::Override::Type dawnType,
MTLDataType* type, OverrideScalar* entry,
double value = 0) {
switch (dawnType) {
case EntryPointMetadata::Override::Type::Boolean:
*type = MTLDataTypeBool;
if (entry) {
entry->b = static_cast<int32_t>(value);
}
break;
case EntryPointMetadata::Override::Type::Float32:
*type = MTLDataTypeFloat;
if (entry) {
entry->f32 = static_cast<float>(value);
}
break;
case EntryPointMetadata::Override::Type::Int32:
*type = MTLDataTypeInt;
if (entry) {
entry->i32 = static_cast<int32_t>(value);
}
break;
case EntryPointMetadata::Override::Type::Uint32:
*type = MTLDataTypeUInt;
if (entry) {
entry->u32 = static_cast<uint32_t>(value);
}
break;
default:
UNREACHABLE();
}
};
for (const auto& [name, value] : programmableStage.constants) {
overriddenConstants.insert(name);
// This is already validated so `name` must exist
const auto& moduleConstant = entryPointMetadata.overrides.at(name);
MTLDataType type;
OverrideScalar entry{};
switchType(moduleConstant.type, &type, &entry, value);
[constantValues.Get() setConstantValue:&entry type:type atIndex:moduleConstant.id];
}
// Set shader initialized default values because MSL function_constant
// has no default value
for (const std::string& name : entryPointMetadata.initializedOverrides) {
if (overriddenConstants.count(name) != 0) {
// This constant already has overridden value
continue;
}
// Must exist because it is validated
const auto& moduleConstant = entryPointMetadata.overrides.at(name);
ASSERT(moduleConstant.isInitialized);
MTLDataType type;
switchType(moduleConstant.type, &type, nullptr);
[constantValues.Get() setConstantValue:&moduleConstant.defaultValue
type:type
atIndex:moduleConstant.id];
}
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout,
functionData, constantValues.Get(), sampleMask,
renderPipeline));
} else {
UNREACHABLE();
}
return {};
}
MaybeError EncodeMetalRenderPass(Device* device, MaybeError EncodeMetalRenderPass(Device* device,
CommandRecordingContext* commandContext, CommandRecordingContext* commandContext,
MTLRenderPassDescriptor* mtlRenderPass, MTLRenderPassDescriptor* mtlRenderPass,

View File

@ -16,6 +16,8 @@
#include <string> #include <string>
#include "dawn/native/Limits.h"
namespace dawn::native::stream { namespace dawn::native::stream {
template <> template <>

View File

@ -61,16 +61,12 @@ MaybeError ComputePipeline::Initialize() {
ShaderModule::ModuleAndSpirv moduleAndSpirv; ShaderModule::ModuleAndSpirv moduleAndSpirv;
DAWN_TRY_ASSIGN(moduleAndSpirv, DAWN_TRY_ASSIGN(moduleAndSpirv,
module->GetHandleAndSpirv(computeStage.entryPoint.c_str(), layout)); module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout));
createInfo.stage.module = moduleAndSpirv.module; createInfo.stage.module = moduleAndSpirv.module;
createInfo.stage.pName = computeStage.entryPoint.c_str(); createInfo.stage.pName = computeStage.entryPoint.c_str();
std::vector<OverrideScalar> specializationDataEntries; createInfo.stage.pSpecializationInfo = nullptr;
std::vector<VkSpecializationMapEntry> specializationMapEntries;
VkSpecializationInfo specializationInfo{};
createInfo.stage.pSpecializationInfo = GetVkSpecializationInfo(
computeStage, &specializationInfo, &specializationDataEntries, &specializationMapEntries);
PNextChainBuilder stageExtChain(&createInfo.stage); PNextChainBuilder stageExtChain(&createInfo.stage);

View File

@ -341,9 +341,6 @@ MaybeError RenderPipeline::Initialize() {
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment // There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages; std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
std::array<std::vector<OverrideScalar>, 2> specializationDataEntriesPerStages;
std::array<std::vector<VkSpecializationMapEntry>, 2> specializationMapEntriesPerStages;
std::array<VkSpecializationInfo, 2> specializationInfoPerStages;
uint32_t stageCount = 0; uint32_t stageCount = 0;
for (auto stage : IterateStages(this->GetStageMask())) { for (auto stage : IterateStages(this->GetStageMask())) {
@ -354,7 +351,7 @@ MaybeError RenderPipeline::Initialize() {
ShaderModule::ModuleAndSpirv moduleAndSpirv; ShaderModule::ModuleAndSpirv moduleAndSpirv;
DAWN_TRY_ASSIGN(moduleAndSpirv, DAWN_TRY_ASSIGN(moduleAndSpirv,
module->GetHandleAndSpirv(programmableStage.entryPoint.c_str(), layout)); module->GetHandleAndSpirv(stage, programmableStage, layout));
shaderStage.module = moduleAndSpirv.module; shaderStage.module = moduleAndSpirv.module;
shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
@ -379,11 +376,6 @@ MaybeError RenderPipeline::Initialize() {
} }
} }
shaderStage.pSpecializationInfo =
GetVkSpecializationInfo(programmableStage, &specializationInfoPerStages[stageCount],
&specializationDataEntriesPerStages[stageCount],
&specializationMapEntriesPerStages[stageCount]);
DAWN_ASSERT(stageCount < 2); DAWN_ASSERT(stageCount < 2);
shaderStages[stageCount] = shaderStage; shaderStages[stageCount] = shaderStage;
stageCount++; stageCount++;

View File

@ -60,13 +60,35 @@ class ShaderModule::Spirv : private Blob {
namespace dawn::native::vulkan { namespace dawn::native::vulkan {
bool TransformedShaderModuleCacheKey::operator==(
const TransformedShaderModuleCacheKey& other) const {
if (layout != other.layout || entryPoint != other.entryPoint ||
constants.size() != other.constants.size()) {
return false;
}
if (!std::equal(constants.begin(), constants.end(), other.constants.begin())) {
return false;
}
return true;
}
size_t TransformedShaderModuleCacheKeyHashFunc::operator()(
const TransformedShaderModuleCacheKey& key) const {
size_t hash = 0;
HashCombine(&hash, key.layout, key.entryPoint);
for (const auto& entry : key.constants) {
HashCombine(&hash, entry.first, entry.second);
}
return hash;
}
class ShaderModule::ConcurrentTransformedShaderModuleCache { class ShaderModule::ConcurrentTransformedShaderModuleCache {
public: public:
explicit ConcurrentTransformedShaderModuleCache(Device* device); explicit ConcurrentTransformedShaderModuleCache(Device* device);
~ConcurrentTransformedShaderModuleCache(); ~ConcurrentTransformedShaderModuleCache();
std::optional<ModuleAndSpirv> Find(const PipelineLayoutEntryPointPair& key); std::optional<ModuleAndSpirv> Find(const TransformedShaderModuleCacheKey& key);
ModuleAndSpirv AddOrGet(const PipelineLayoutEntryPointPair& key, ModuleAndSpirv AddOrGet(const TransformedShaderModuleCacheKey& key,
VkShaderModule module, VkShaderModule module,
Spirv&& spirv); Spirv&& spirv);
@ -75,7 +97,9 @@ class ShaderModule::ConcurrentTransformedShaderModuleCache {
Device* mDevice; Device* mDevice;
std::mutex mMutex; std::mutex mMutex;
std::unordered_map<PipelineLayoutEntryPointPair, Entry, PipelineLayoutEntryPointPairHashFunc> std::unordered_map<TransformedShaderModuleCacheKey,
Entry,
TransformedShaderModuleCacheKeyHashFunc>
mTransformedShaderModuleCache; mTransformedShaderModuleCache;
}; };
@ -92,7 +116,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::~ConcurrentTransformedShad
std::optional<ShaderModule::ModuleAndSpirv> std::optional<ShaderModule::ModuleAndSpirv>
ShaderModule::ConcurrentTransformedShaderModuleCache::Find( ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
const PipelineLayoutEntryPointPair& key) { const TransformedShaderModuleCacheKey& key) {
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key); auto iter = mTransformedShaderModuleCache.find(key);
if (iter != mTransformedShaderModuleCache.end()) { if (iter != mTransformedShaderModuleCache.end()) {
@ -106,7 +130,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
} }
ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet( ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet(
const PipelineLayoutEntryPointPair& key, const TransformedShaderModuleCacheKey& key,
VkShaderModule module, VkShaderModule module,
Spirv&& spirv) { Spirv&& spirv) {
ASSERT(module != VK_NULL_HANDLE); ASSERT(module != VK_NULL_HANDLE);
@ -168,20 +192,24 @@ void ShaderModule::DestroyImpl() {
ShaderModule::~ShaderModule() = default; ShaderModule::~ShaderModule() = default;
#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \ #define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \
X(const tint::Program*, inputProgram) \ X(SingleShaderStage, stage) \
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \ X(const tint::Program*, inputProgram) \
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \ X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
X(std::string_view, entryPointName) \ X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
X(bool, disableWorkgroupInit) \ X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
X(bool, useZeroInitializeWorkgroupMemoryExtension) \ X(LimitsForCompilationRequest, limits) \
X(std::string_view, entryPointName) \
X(bool, disableWorkgroupInit) \
X(bool, useZeroInitializeWorkgroupMemoryExtension) \
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform) X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS); DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS);
#undef SPIRV_COMPILATION_REQUEST_MEMBERS #undef SPIRV_COMPILATION_REQUEST_MEMBERS
ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv( ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
const char* entryPointName, SingleShaderStage stage,
const ProgrammableStage& programmableStage,
const PipelineLayout* layout) { const PipelineLayout* layout) {
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv"); TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv");
@ -191,7 +219,8 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
ScopedTintICEHandler scopedICEHandler(GetDevice()); ScopedTintICEHandler scopedICEHandler(GetDevice());
// Check to see if we have the handle and spirv cached already. // Check to see if we have the handle and spirv cached already.
auto cacheKey = std::make_pair(layout, entryPointName); auto cacheKey = TransformedShaderModuleCacheKey{layout, programmableStage.entryPoint.c_str(),
programmableStage.constants};
auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey); auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey);
if (handleAndSpirv.has_value()) { if (handleAndSpirv.has_value()) {
return std::move(*handleAndSpirv); return std::move(*handleAndSpirv);
@ -204,7 +233,8 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
using BindingPoint = tint::transform::BindingPoint; using BindingPoint = tint::transform::BindingPoint;
BindingRemapper::BindingPoints bindingPoints; BindingRemapper::BindingPoints bindingPoints;
const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings; const BindingInfoArray& moduleBindingInfo =
GetEntryPoint(programmableStage.entryPoint.c_str()).bindings;
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
@ -238,16 +268,26 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
} }
} }
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
if (!programmableStage.metadata->overrides.empty()) {
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
}
#if TINT_BUILD_SPV_WRITER #if TINT_BUILD_SPV_WRITER
SpirvCompilationRequest req = {}; SpirvCompilationRequest req = {};
req.stage = stage;
req.inputProgram = GetTintProgram(); req.inputProgram = GetTintProgram();
req.bindingPoints = std::move(bindingPoints); req.bindingPoints = std::move(bindingPoints);
req.newBindingsMap = std::move(newBindingsMap); req.newBindingsMap = std::move(newBindingsMap);
req.entryPointName = entryPointName; req.entryPointName = programmableStage.entryPoint;
req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit); req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
req.useZeroInitializeWorkgroupMemoryExtension = req.useZeroInitializeWorkgroupMemoryExtension =
GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension); GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension);
req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform()); req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform());
req.substituteOverrideConfig = std::move(substituteOverrideConfig);
const CombinedLimits& limits = GetDevice()->GetLimits();
req.limits = LimitsForCompilationRequest::Create(limits.v1);
CacheResult<Spirv> spirv; CacheResult<Spirv> spirv;
DAWN_TRY_LOAD_OR_RUN( DAWN_TRY_LOAD_OR_RUN(
@ -270,12 +310,27 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>( transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
r.newBindingsMap); r.newBindingsMap);
} }
if (r.substituteOverrideConfig) {
// This needs to run after SingleEntryPoint transform to get rid of overrides not
// used for the current entry point.
transformManager.Add<tint::transform::SubstituteOverride>();
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
std::move(r.substituteOverrideConfig).value());
}
tint::Program program; tint::Program program;
{ {
TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms"); TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms");
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram, DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram,
transformInputs, nullptr, nullptr)); transformInputs, nullptr, nullptr));
} }
if (r.stage == SingleShaderStage::Compute) {
// Validate workgroup size after program runs transforms.
Extent3D _;
DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
program, r.entryPointName.data(), r.limits));
}
tint::writer::spirv::Options options; tint::writer::spirv::Options options;
options.emit_vertex_point_size = true; options.emit_vertex_point_size = true;
options.disable_workgroup_init = r.disableWorkgroupInit; options.disable_workgroup_init = r.disableWorkgroupInit;

View File

@ -18,14 +18,32 @@
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <optional> #include <optional>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "dawn/common/HashUtils.h"
#include "dawn/common/vulkan_platform.h" #include "dawn/common/vulkan_platform.h"
#include "dawn/native/Error.h" #include "dawn/native/Error.h"
#include "dawn/native/ShaderModule.h" #include "dawn/native/ShaderModule.h"
namespace dawn::native::vulkan { namespace dawn::native {
struct ProgrammableStage;
namespace vulkan {
struct TransformedShaderModuleCacheKey {
const PipelineLayoutBase* layout;
std::string entryPoint;
PipelineConstantEntries constants;
bool operator==(const TransformedShaderModuleCacheKey& other) const;
};
struct TransformedShaderModuleCacheKeyHashFunc {
size_t operator()(const TransformedShaderModuleCacheKey& key) const;
};
class Device; class Device;
class PipelineLayout; class PipelineLayout;
@ -44,7 +62,8 @@ class ShaderModule final : public ShaderModuleBase {
ShaderModuleParseResult* parseResult, ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages); OwnedCompilationMessages* compilationMessages);
ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const char* entryPointName, ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(SingleShaderStage stage,
const ProgrammableStage& programmableStage,
const PipelineLayout* layout); const PipelineLayout* layout);
private: private:
@ -59,6 +78,8 @@ class ShaderModule final : public ShaderModuleBase {
std::unique_ptr<ConcurrentTransformedShaderModuleCache> mTransformedShaderModuleCache; std::unique_ptr<ConcurrentTransformedShaderModuleCache> mTransformedShaderModuleCache;
}; };
} // namespace dawn::native::vulkan } // namespace vulkan
} // namespace dawn::native
#endif // SRC_DAWN_NATIVE_VULKAN_SHADERMODULEVK_H_ #endif // SRC_DAWN_NATIVE_VULKAN_SHADERMODULEVK_H_

View File

@ -258,60 +258,4 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName) {
return std::string(debugName, length); return std::string(debugName, length);
} }
VkSpecializationInfo* GetVkSpecializationInfo(
const ProgrammableStage& programmableStage,
VkSpecializationInfo* specializationInfo,
std::vector<OverrideScalar>* specializationDataEntries,
std::vector<VkSpecializationMapEntry>* specializationMapEntries) {
ASSERT(specializationInfo);
ASSERT(specializationDataEntries);
ASSERT(specializationMapEntries);
if (programmableStage.constants.size() == 0) {
return nullptr;
}
const EntryPointMetadata& entryPointMetaData =
programmableStage.module->GetEntryPoint(programmableStage.entryPoint);
for (const auto& pipelineConstant : programmableStage.constants) {
const std::string& identifier = pipelineConstant.first;
double value = pipelineConstant.second;
// This is already validated so `identifier` must exist
const auto& moduleConstant = entryPointMetaData.overrides.at(identifier);
specializationMapEntries->push_back(VkSpecializationMapEntry{
moduleConstant.id,
static_cast<uint32_t>(specializationDataEntries->size() * sizeof(OverrideScalar)),
sizeof(OverrideScalar)});
OverrideScalar entry{};
switch (moduleConstant.type) {
case EntryPointMetadata::Override::Type::Boolean:
entry.b = static_cast<int32_t>(value);
break;
case EntryPointMetadata::Override::Type::Float32:
entry.f32 = static_cast<float>(value);
break;
case EntryPointMetadata::Override::Type::Int32:
entry.i32 = static_cast<int32_t>(value);
break;
case EntryPointMetadata::Override::Type::Uint32:
entry.u32 = static_cast<uint32_t>(value);
break;
default:
UNREACHABLE();
}
specializationDataEntries->push_back(entry);
}
specializationInfo->mapEntryCount = static_cast<uint32_t>(specializationMapEntries->size());
specializationInfo->pMapEntries = specializationMapEntries->data();
specializationInfo->dataSize = specializationDataEntries->size() * sizeof(OverrideScalar);
specializationInfo->pData = specializationDataEntries->data();
return specializationInfo;
}
} // namespace dawn::native::vulkan } // namespace dawn::native::vulkan

View File

@ -144,15 +144,6 @@ void SetDebugName(Device* device,
std::string GetNextDeviceDebugPrefix(); std::string GetNextDeviceDebugPrefix();
std::string GetDeviceDebugPrefixFromDebugName(const char* debugName); std::string GetDeviceDebugPrefixFromDebugName(const char* debugName);
// Returns nullptr or &specializationInfo
// specializationInfo, specializationDataEntries, specializationMapEntries needs to
// be alive at least until VkSpecializationInfo is passed into Vulkan Create*Pipelines
VkSpecializationInfo* GetVkSpecializationInfo(
const ProgrammableStage& programmableStage,
VkSpecializationInfo* specializationInfo,
std::vector<OverrideScalar>* specializationDataEntries,
std::vector<VkSpecializationMapEntry>* specializationMapEntries);
} // namespace dawn::native::vulkan } // namespace dawn::native::vulkan
#endif // SRC_DAWN_NATIVE_VULKAN_UTILSVULKAN_H_ #endif // SRC_DAWN_NATIVE_VULKAN_UTILSVULKAN_H_

View File

@ -265,6 +265,7 @@ dawn_test("dawn_unittests") {
"unittests/native/CreatePipelineAsyncTaskTests.cpp", "unittests/native/CreatePipelineAsyncTaskTests.cpp",
"unittests/native/DestroyObjectTests.cpp", "unittests/native/DestroyObjectTests.cpp",
"unittests/native/DeviceCreationTests.cpp", "unittests/native/DeviceCreationTests.cpp",
"unittests/native/ObjectContentHasherTests.cpp",
"unittests/native/StreamTests.cpp", "unittests/native/StreamTests.cpp",
"unittests/validation/BindGroupValidationTests.cpp", "unittests/validation/BindGroupValidationTests.cpp",
"unittests/validation/BufferValidationTests.cpp", "unittests/validation/BufferValidationTests.cpp",
@ -489,6 +490,7 @@ source_set("end2end_tests_sources") {
"end2end/ScissorTests.cpp", "end2end/ScissorTests.cpp",
"end2end/ShaderFloat16Tests.cpp", "end2end/ShaderFloat16Tests.cpp",
"end2end/ShaderTests.cpp", "end2end/ShaderTests.cpp",
"end2end/ShaderValidationTests.cpp",
"end2end/StorageTextureTests.cpp", "end2end/StorageTextureTests.cpp",
"end2end/SubresourceRenderAttachmentTests.cpp", "end2end/SubresourceRenderAttachmentTests.cpp",
"end2end/Texture3DTests.cpp", "end2end/Texture3DTests.cpp",

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <vector>
#include "dawn/tests/DawnTest.h" #include "dawn/tests/DawnTest.h"
#include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h"
@ -158,6 +160,46 @@ TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) {
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
} }
// Test that ComputePipeline are correctly deduplicated wrt. their constants override values
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnOverrides) {
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
override x: u32 = 1u;
var<workgroup> i : u32;
@compute @workgroup_size(x) fn main() {
i = 0u;
})");
wgpu::PipelineLayout layout = utils::MakeBasicPipelineLayout(device, nullptr);
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.layout = layout;
desc.compute.module = module;
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 16}};
desc.compute.constantCount = constants.size();
desc.compute.constants = constants.data();
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
std::vector<wgpu::ConstantEntry> sameConstants{{nullptr, "x", 16}};
desc.compute.constantCount = sameConstants.size();
desc.compute.constants = sameConstants.data();
wgpu::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
desc.compute.constantCount = 0;
desc.compute.constants = nullptr;
wgpu::ComputePipeline otherPipeline1 = device.CreateComputePipeline(&desc);
std::vector<wgpu::ConstantEntry> otherConstants{{nullptr, "x", 4}};
desc.compute.constantCount = otherConstants.size();
desc.compute.constants = otherConstants.data();
wgpu::ComputePipeline otherPipeline2 = device.CreateComputePipeline(&desc);
EXPECT_NE(pipeline.Get(), otherPipeline1.Get());
EXPECT_NE(pipeline.Get(), otherPipeline2.Get());
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
}
// Test that ComputePipeline are correctly deduplicated wrt. their layout // Test that ComputePipeline are correctly deduplicated wrt. their layout
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) { TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) {
wgpu::BindGroupLayout bgl = utils::MakeBindGroupLayout( wgpu::BindGroupLayout bgl = utils::MakeBindGroupLayout(
@ -303,6 +345,48 @@ TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnFragmentModule) {
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
} }
// Test that Renderpipelines are correctly deduplicated wrt. their constants override values
TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnOverrides) {
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
override a: f32 = 1.0;
@vertex fn vertexMain() -> @builtin(position) vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
@fragment fn fragmentMain() -> @location(0) vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, a);
})");
utils::ComboRenderPipelineDescriptor desc;
desc.vertex.module = module;
desc.vertex.entryPoint = "vertexMain";
desc.cFragment.module = module;
desc.cFragment.entryPoint = "fragmentMain";
desc.cTargets[0].writeMask = wgpu::ColorWriteMask::None;
std::vector<wgpu::ConstantEntry> constants{{nullptr, "a", 0.5}};
desc.cFragment.constantCount = constants.size();
desc.cFragment.constants = constants.data();
wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&desc);
std::vector<wgpu::ConstantEntry> sameConstants{{nullptr, "a", 0.5}};
desc.cFragment.constantCount = sameConstants.size();
desc.cFragment.constants = sameConstants.data();
wgpu::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc);
std::vector<wgpu::ConstantEntry> otherConstants{{nullptr, "a", 1.0}};
desc.cFragment.constantCount = otherConstants.size();
desc.cFragment.constants = otherConstants.data();
wgpu::RenderPipeline otherPipeline1 = device.CreateRenderPipeline(&desc);
desc.cFragment.constantCount = 0;
desc.cFragment.constants = nullptr;
wgpu::RenderPipeline otherPipeline2 = device.CreateRenderPipeline(&desc);
EXPECT_NE(pipeline.Get(), otherPipeline1.Get());
EXPECT_NE(pipeline.Get(), otherPipeline2.Get());
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
}
// Test that Samplers are correctly deduplicated. // Test that Samplers are correctly deduplicated.
TEST_P(ObjectCachingTest, SamplerDeduplication) { TEST_P(ObjectCachingTest, SamplerDeduplication) {
wgpu::SamplerDescriptor samplerDesc; wgpu::SamplerDescriptor samplerDesc;

View File

@ -471,6 +471,127 @@ struct Buf {
EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount); EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount);
} }
// Test one shader shared by two pipelines with different constants overridden
TEST_P(ShaderTests, OverridableConstantsSharedShader) {
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
std::vector<uint32_t> expected1{1};
wgpu::Buffer buffer1 = CreateBuffer(expected1.size());
std::vector<uint32_t> expected2{2};
wgpu::Buffer buffer2 = CreateBuffer(expected2.size());
std::string shader = R"(
override a: u32;
struct Buf {
data : array<u32, 1>
}
@group(0) @binding(0) var<storage, read_write> buf : Buf;
@compute @workgroup_size(1) fn main() {
buf.data[0] = a;
})";
std::vector<wgpu::ConstantEntry> constants1;
constants1.push_back({nullptr, "a", 1});
std::vector<wgpu::ConstantEntry> constants2;
constants2.push_back({nullptr, "a", 2});
wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1);
wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2);
wgpu::BindGroup bindGroup1 =
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
wgpu::BindGroup bindGroup2 =
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
wgpu::CommandBuffer commands;
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipeline1);
pass.SetBindGroup(0, bindGroup1);
pass.DispatchWorkgroups(1);
pass.SetPipeline(pipeline2);
pass.SetBindGroup(0, bindGroup2);
pass.DispatchWorkgroups(1);
pass.End();
commands = encoder.Finish();
}
queue.Submit(1, &commands);
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size());
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size());
}
// Test overridable constants work with workgroup size
TEST_P(ShaderTests, OverridableConstantsWorkgroupSize) {
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
std::string shader = R"(
override x: u32;
struct Buf {
data : array<u32, 1>
}
@group(0) @binding(0) var<storage, read_write> buf : Buf;
@compute @workgroup_size(x) fn main(
@builtin(local_invocation_id) local_invocation_id : vec3<u32>
) {
if (local_invocation_id.x >= x - 1) {
buf.data[0] = local_invocation_id.x + 1;
}
})";
const uint32_t workgroup_size_x_1 = 16u;
const uint32_t workgroup_size_x_2 = 64u;
std::vector<uint32_t> expected1{workgroup_size_x_1};
wgpu::Buffer buffer1 = CreateBuffer(expected1.size());
std::vector<uint32_t> expected2{workgroup_size_x_2};
wgpu::Buffer buffer2 = CreateBuffer(expected2.size());
std::vector<wgpu::ConstantEntry> constants1;
constants1.push_back({nullptr, "x", static_cast<double>(workgroup_size_x_1)});
std::vector<wgpu::ConstantEntry> constants2;
constants2.push_back({nullptr, "x", static_cast<double>(workgroup_size_x_2)});
wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1);
wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2);
wgpu::BindGroup bindGroup1 =
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
wgpu::BindGroup bindGroup2 =
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
wgpu::CommandBuffer commands;
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipeline1);
pass.SetBindGroup(0, bindGroup1);
pass.DispatchWorkgroups(1);
pass.SetPipeline(pipeline2);
pass.SetBindGroup(0, bindGroup2);
pass.DispatchWorkgroups(1);
pass.End();
commands = encoder.Finish();
}
queue.Submit(1, &commands);
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size());
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size());
}
// Test overridable constants with numeric identifiers // Test overridable constants with numeric identifiers
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) { TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL()); DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
@ -596,6 +717,7 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
std::string shader = R"( std::string shader = R"(
@id(1001) override c1: u32; @id(1001) override c1: u32;
@id(1002) override c2: u32; @id(1002) override c2: u32;
@id(1003) override c3: u32;
struct Buf { struct Buf {
data : array<u32, 1> data : array<u32, 1>
@ -611,7 +733,7 @@ struct Buf {
buf.data[0] = c2; buf.data[0] = c2;
} }
@compute @workgroup_size(1) fn main3() { @compute @workgroup_size(c3) fn main3() {
buf.data[0] = 3u; buf.data[0] = 3u;
} }
)"; )";
@ -620,6 +742,8 @@ struct Buf {
constants1.push_back({nullptr, "1001", 1}); constants1.push_back({nullptr, "1001", 1});
std::vector<wgpu::ConstantEntry> constants2; std::vector<wgpu::ConstantEntry> constants2;
constants2.push_back({nullptr, "1002", 2}); constants2.push_back({nullptr, "1002", 2});
std::vector<wgpu::ConstantEntry> constants3;
constants3.push_back({nullptr, "1003", 1});
wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader.c_str()); wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader.c_str());
@ -640,6 +764,8 @@ struct Buf {
wgpu::ComputePipelineDescriptor csDesc3; wgpu::ComputePipelineDescriptor csDesc3;
csDesc3.compute.module = shaderModule; csDesc3.compute.module = shaderModule;
csDesc3.compute.entryPoint = "main3"; csDesc3.compute.entryPoint = "main3";
csDesc3.compute.constants = constants3.data();
csDesc3.compute.constantCount = constants3.size();
wgpu::ComputePipeline pipeline3 = device.CreateComputePipeline(&csDesc3); wgpu::ComputePipeline pipeline3 = device.CreateComputePipeline(&csDesc3);
wgpu::BindGroup bindGroup1 = wgpu::BindGroup bindGroup1 =
@ -765,8 +891,6 @@ TEST_P(ShaderTests, ConflictingBindingsDueToTransformOrder) {
device.CreateRenderPipeline(&desc); device.CreateRenderPipeline(&desc);
} }
// TODO(tint:1155): Test overridable constants used for workgroup size
DAWN_INSTANTIATE_TEST(ShaderTests, DAWN_INSTANTIATE_TEST(ShaderTests,
D3D12Backend(), D3D12Backend(),
MetalBackend(), MetalBackend(),

View File

@ -0,0 +1,395 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <numeric>
#include <string>
#include <vector>
#include "dawn/tests/DawnTest.h"
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
#include "dawn/utils/WGPUHelpers.h"
// The compute shader workgroup size is settled at compute pipeline creation time.
// The validation code in dawn is in each backend (not including Null backend) thus this test needs
// to be as part of a dawn_end2end_tests instead of the dawn_unittests
// TODO(dawn:1504): Add support for GL backend.
class WorkgroupSizeValidationTest : public DawnTest {
public:
wgpu::ShaderModule SetUpShaderWithValidDefaultValueConstants() {
return utils::CreateShaderModule(device, R"(
override x: u32 = 1u;
override y: u32 = 1u;
override z: u32 = 1u;
@compute @workgroup_size(x, y, z) fn main() {
_ = 0u;
})");
}
wgpu::ShaderModule SetUpShaderWithZeroDefaultValueConstants() {
return utils::CreateShaderModule(device, R"(
override x: u32 = 0u;
override y: u32 = 0u;
override z: u32 = 0u;
@compute @workgroup_size(x, y, z) fn main() {
_ = 0u;
})");
}
wgpu::ShaderModule SetUpShaderWithOutOfLimitsDefaultValueConstants() {
return utils::CreateShaderModule(device, R"(
override x: u32 = 1u;
override y: u32 = 1u;
override z: u32 = 9999u;
@compute @workgroup_size(x, y, z) fn main() {
_ = 0u;
})");
}
wgpu::ShaderModule SetUpShaderWithUninitializedConstants() {
return utils::CreateShaderModule(device, R"(
override x: u32;
override y: u32;
override z: u32;
@compute @workgroup_size(x, y, z) fn main() {
_ = 0u;
})");
}
wgpu::ShaderModule SetUpShaderWithPartialConstants() {
return utils::CreateShaderModule(device, R"(
override x: u32;
@compute @workgroup_size(x, 1, 1) fn main() {
_ = 0u;
})");
}
void TestCreatePipeline(const wgpu::ShaderModule& module) {
wgpu::ComputePipelineDescriptor csDesc;
csDesc.compute.module = module;
csDesc.compute.entryPoint = "main";
csDesc.compute.constants = nullptr;
csDesc.compute.constantCount = 0;
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
}
void TestCreatePipeline(const wgpu::ShaderModule& module,
const std::vector<wgpu::ConstantEntry>& constants) {
wgpu::ComputePipelineDescriptor csDesc;
csDesc.compute.module = module;
csDesc.compute.entryPoint = "main";
csDesc.compute.constants = constants.data();
csDesc.compute.constantCount = constants.size();
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
}
void TestInitializedWithZero(const wgpu::ShaderModule& module) {
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x", 0}, {nullptr, "y", 0}, {nullptr, "z", 0}};
TestCreatePipeline(module, constants);
}
void TestInitializedWithOutOfLimitValue(const wgpu::ShaderModule& module) {
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x",
static_cast<double>(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)},
{nullptr, "y", 1},
{nullptr, "z", 1}};
TestCreatePipeline(module, constants);
}
void TestInitializedWithValidValue(const wgpu::ShaderModule& module) {
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x", 4}, {nullptr, "y", 4}, {nullptr, "z", 4}};
TestCreatePipeline(module, constants);
}
void TestInitializedPartially(const wgpu::ShaderModule& module) {
std::vector<wgpu::ConstantEntry> constants{{nullptr, "y", 4}};
TestCreatePipeline(module, constants);
}
wgpu::Buffer buffer;
};
// Test workgroup size validation with fixed values.
TEST_P(WorkgroupSizeValidationTest, WithFixedValues) {
auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) {
std::ostringstream ss;
ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}";
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
if (success) {
device.CreateComputePipeline(&desc);
} else {
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
}
};
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
CheckShaderWithWorkgroupSize(true, 1, 1, 1);
CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1);
CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1);
CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ);
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1);
CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1);
CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1);
// No individual dimension exceeds its limit, but the combined size should definitely exceed the
// total invocation limit.
DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX *
supportedLimits.maxComputeWorkgroupSizeY *
supportedLimits.maxComputeWorkgroupSizeZ >
supportedLimits.maxComputeInvocationsPerWorkgroup);
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX,
supportedLimits.maxComputeWorkgroupSizeY,
supportedLimits.maxComputeWorkgroupSizeZ);
}
// Test workgroup size validation with fixed values (storage size limits validation).
TEST_P(WorkgroupSizeValidationTest, WithFixedValuesStorageSizeLimits) {
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
constexpr uint32_t kVec4Size = 16;
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
constexpr uint32_t kMat4Size = 64;
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
uint32_t mat4_count) {
std::ostringstream ss;
std::ostringstream body;
if (vec4_count > 0) {
ss << "var<workgroup> vec4_data: array<vec4<f32>, " << vec4_count << ">;";
body << "_ = vec4_data;";
}
if (mat4_count > 0) {
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, " << mat4_count << ">;";
body << "_ = mat4_data;";
}
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
if (success) {
device.CreateComputePipeline(&desc);
} else {
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
}
};
CheckPipelineWithWorkgroupStorage(true, 1, 1);
CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0);
CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count);
CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1);
CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1);
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1);
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count);
}
// Test workgroup size validation with valid overrides default values.
TEST_P(WorkgroupSizeValidationTest, OverridesWithValidDefault) {
wgpu::ShaderModule module = SetUpShaderWithValidDefaultValueConstants();
{
// Valid default
TestCreatePipeline(module);
}
{
// Error: invalid value (zero)
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
}
{
// Error: invalid value (out of device limits)
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
}
{
// Valid: initialized partially
TestInitializedPartially(module);
}
{
// Valid
TestInitializedWithValidValue(module);
}
}
// Test workgroup size validation with zero as the overrides default values.
TEST_P(WorkgroupSizeValidationTest, OverridesWithZeroDefault) {
// Error: zero is detected as invalid at shader creation time
ASSERT_DEVICE_ERROR(SetUpShaderWithZeroDefaultValueConstants());
}
// Test workgroup size validation with out-of-limits overrides default values.
TEST_P(WorkgroupSizeValidationTest, OverridesWithOutOfLimitsDefault) {
wgpu::ShaderModule module = SetUpShaderWithOutOfLimitsDefaultValueConstants();
{
// Error: invalid default
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
}
{
// Error: invalid value (zero)
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
}
{
// Error: invalid value (out of device limits)
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
}
{
// Error: initialized partially
ASSERT_DEVICE_ERROR(TestInitializedPartially(module));
}
{
// Valid
TestInitializedWithValidValue(module);
}
}
// Test workgroup size validation without overrides default values specified.
TEST_P(WorkgroupSizeValidationTest, OverridesWithUninitialized) {
wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants();
{
// Error: uninitialized
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
}
{
// Error: invalid value (zero)
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
}
{
// Error: invalid value (out of device limits)
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
}
{
// Error: initialized partially
ASSERT_DEVICE_ERROR(TestInitializedPartially(module));
}
{
// Valid
TestInitializedWithValidValue(module);
}
}
// Test workgroup size validation with only partial dimensions are overrides.
TEST_P(WorkgroupSizeValidationTest, PartialOverrides) {
wgpu::ShaderModule module = SetUpShaderWithPartialConstants();
{
// Error: uninitialized
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
}
{
// Error: invalid value (zero)
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 0}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
}
{
// Error: invalid value (out of device limits)
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x",
static_cast<double>(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
}
{
// Valid
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 16}};
TestCreatePipeline(module, constants);
}
}
// Test workgroup size validation after being overrided with invalid values.
TEST_P(WorkgroupSizeValidationTest, ValidationAfterOverride) {
wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants();
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
{
// Error: exceed maxComputeWorkgroupSizeZ
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x", 1},
{nullptr, "y", 1},
{nullptr, "z", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeZ + 1)}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
}
{
// Error: exceed maxComputeInvocationsPerWorkgroup
DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX *
supportedLimits.maxComputeWorkgroupSizeY *
supportedLimits.maxComputeWorkgroupSizeZ >
supportedLimits.maxComputeInvocationsPerWorkgroup);
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeX)},
{nullptr, "y", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeY)},
{nullptr, "z", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeZ)}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
}
}
// Test workgroup size validation after being overrided with invalid values (storage size limits
// validation).
// TODO(tint:1660): re-enable after override can be used as array size.
TEST_P(WorkgroupSizeValidationTest, DISABLED_ValidationAfterOverrideStorageSize) {
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
constexpr uint32_t kVec4Size = 16;
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
constexpr uint32_t kMat4Size = 64;
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
uint32_t mat4_count) {
std::ostringstream ss;
std::ostringstream body;
ss << "override a: u32;";
ss << "override b: u32;";
if (vec4_count > 0) {
ss << "var<workgroup> vec4_data: array<vec4<f32>, a>;";
body << "_ = vec4_data;";
}
if (mat4_count > 0) {
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, b>;";
body << "_ = mat4_data;";
}
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
std::vector<wgpu::ConstantEntry> constants{{nullptr, "a", static_cast<double>(vec4_count)},
{nullptr, "b", static_cast<double>(mat4_count)}};
desc.compute.constants = constants.data();
desc.compute.constantCount = constants.size();
if (success) {
device.CreateComputePipeline(&desc);
} else {
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
}
};
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
}
DAWN_INSTANTIATE_TEST(WorkgroupSizeValidationTest, D3D12Backend(), MetalBackend(), VulkanBackend());

View File

@ -0,0 +1,81 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "dawn/native/ObjectContentHasher.h"
#include "dawn/tests/DawnNativeTest.h"
namespace dawn::native {
namespace {
class ObjectContentHasherTests : public DawnNativeTest {};
#define EXPECT_IF_HASH_EQ(eq, a, b) \
do { \
ObjectContentHasher ra, rb; \
ra.Record(a); \
rb.Record(b); \
EXPECT_EQ(eq, ra.GetContentHash() == rb.GetContentHash()); \
} while (0)
TEST(ObjectContentHasherTests, Pair) {
EXPECT_IF_HASH_EQ(true, (std::pair<std::string, uint8_t>{"a", 1}),
(std::pair<std::string, uint8_t>{"a", 1}));
EXPECT_IF_HASH_EQ(false, (std::pair<uint8_t, std::string>{1, "a"}),
(std::pair<std::string, uint8_t>{"a", 1}));
EXPECT_IF_HASH_EQ(false, (std::pair<std::string, uint8_t>{"a", 1}),
(std::pair<std::string, uint8_t>{"a", 2}));
EXPECT_IF_HASH_EQ(false, (std::pair<std::string, uint8_t>{"a", 1}),
(std::pair<std::string, uint8_t>{"b", 1}));
}
TEST(ObjectContentHasherTests, Vector) {
EXPECT_IF_HASH_EQ(true, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{0, 1}));
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{0, 1, 2}));
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{1, 0}));
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{}));
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<float>{0, 1}));
}
TEST(ObjectContentHasherTests, Map) {
EXPECT_IF_HASH_EQ(true, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
(std::map<std::string, uint8_t>{{"b", 2}, {"a", 1}}));
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
(std::map<std::string, uint8_t>{{"a", 2}, {"b", 1}}));
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}, {"c", 1}}));
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
(std::map<std::string, uint8_t>{}));
}
TEST(ObjectContentHasherTests, HashCombine) {
ObjectContentHasher ra, rb;
ra.Record(std::vector<uint8_t>{0, 1});
ra.Record(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}});
rb.Record(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}});
rb.Record(std::vector<uint8_t>{0, 1});
EXPECT_NE(ra.GetContentHash(), rb.GetContentHash());
}
} // namespace
} // namespace dawn::native

View File

@ -450,87 +450,6 @@ TEST_F(ShaderModuleValidationTest, MaximumInterStageShaderComponents) {
} }
} }
// Tests that we validate workgroup size limits.
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupSizeLimits) {
auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) {
std::ostringstream ss;
ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}";
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
if (success) {
device.CreateComputePipeline(&desc);
} else {
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
}
};
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
CheckShaderWithWorkgroupSize(true, 1, 1, 1);
CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1);
CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1);
CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ);
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1);
CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1);
CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1);
// No individual dimension exceeds its limit, but the combined size should definitely exceed the
// total invocation limit.
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX,
supportedLimits.maxComputeWorkgroupSizeY,
supportedLimits.maxComputeWorkgroupSizeZ);
}
// Tests that we validate workgroup storage size limits.
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupStorageSizeLimits) {
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
constexpr uint32_t kVec4Size = 16;
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
constexpr uint32_t kMat4Size = 64;
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
uint32_t mat4_count) {
std::ostringstream ss;
std::ostringstream body;
if (vec4_count > 0) {
ss << "var<workgroup> vec4_data: array<vec4<f32>, " << vec4_count << ">;";
body << "_ = vec4_data;";
}
if (mat4_count > 0) {
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, " << mat4_count << ">;";
body << "_ = mat4_data;";
}
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
if (success) {
device.CreateComputePipeline(&desc);
} else {
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
}
};
CheckPipelineWithWorkgroupStorage(true, 1, 1);
CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0);
CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count);
CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1);
CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1);
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1);
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count);
}
// Test that numeric ID must be unique // Test that numeric ID must be unique
TEST_F(ShaderModuleValidationTest, OverridableConstantsNumericIDConflicts) { TEST_F(ShaderModuleValidationTest, OverridableConstantsNumericIDConflicts) {
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"( ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(

View File

@ -37,46 +37,6 @@ class UnsafeAPIValidationTest : public ValidationTest {
} }
}; };
// Check that pipeline overridable constants are disallowed as part of unsafe APIs.
// TODO(dawn:1041) Remove when implementation for all backend is added
TEST_F(UnsafeAPIValidationTest, PipelineOverridableConstants) {
// Create the placeholder compute pipeline.
wgpu::ComputePipelineDescriptor pipelineDescBase;
pipelineDescBase.compute.entryPoint = "main";
// Control case: shader without overridable constant is allowed.
{
wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase;
pipelineDesc.compute.module =
utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}");
device.CreateComputePipeline(&pipelineDesc);
}
// Error case: shader with overridable constant with default value
{
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(
@id(1000) override c0: u32 = 1u;
@id(1000) override c1: u32;
@compute @workgroup_size(1) fn main() {
_ = c0;
_ = c1;
})"));
}
// Error case: pipeline stage with constant entry is disallowed
{
wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase;
pipelineDesc.compute.module =
utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}");
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c", 1u}};
pipelineDesc.compute.constants = constants.data();
pipelineDesc.compute.constantCount = constants.size();
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&pipelineDesc));
}
}
class UnsafeQueryAPIValidationTest : public ValidationTest { class UnsafeQueryAPIValidationTest : public ValidationTest {
protected: protected:
WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override { WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override {

View File

@ -133,6 +133,110 @@ Inspector::Inspector(const Program* program) : program_(program) {}
Inspector::~Inspector() = default; Inspector::~Inspector() = default;
EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) {
EntryPoint entry_point;
TINT_ASSERT(Inspector, func != nullptr);
TINT_ASSERT(Inspector, func->IsEntryPoint());
auto* sem = program_->Sem().Get(func);
entry_point.name = program_->Symbols().NameFor(func->symbol);
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol);
switch (func->PipelineStage()) {
case ast::PipelineStage::kCompute: {
entry_point.stage = PipelineStage::kCompute;
auto wgsize = sem->WorkgroupSize();
if (!wgsize[0].overridable_const && !wgsize[1].overridable_const &&
!wgsize[2].overridable_const) {
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value, wgsize[2].value};
}
break;
}
case ast::PipelineStage::kFragment: {
entry_point.stage = PipelineStage::kFragment;
break;
}
case ast::PipelineStage::kVertex: {
entry_point.stage = PipelineStage::kVertex;
break;
}
default: {
TINT_UNREACHABLE(Inspector, diagnostics_)
<< "invalid pipeline stage for entry point '" << entry_point.name << "'";
break;
}
}
for (auto* param : sem->Parameters()) {
AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol),
param->Type(), param->Declaration()->attributes,
entry_point.input_variables);
entry_point.input_position_used |= ContainsBuiltin(
ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes);
entry_point.front_facing_used |= ContainsBuiltin(
ast::BuiltinValue::kFrontFacing, param->Type(), param->Declaration()->attributes);
entry_point.sample_index_used |= ContainsBuiltin(
ast::BuiltinValue::kSampleIndex, param->Type(), param->Declaration()->attributes);
entry_point.input_sample_mask_used |= ContainsBuiltin(
ast::BuiltinValue::kSampleMask, param->Type(), param->Declaration()->attributes);
entry_point.num_workgroups_used |= ContainsBuiltin(
ast::BuiltinValue::kNumWorkgroups, param->Type(), param->Declaration()->attributes);
}
if (!sem->ReturnType()->Is<sem::Void>()) {
AddEntryPointInOutVariables("<retval>", sem->ReturnType(), func->return_type_attributes,
entry_point.output_variables);
entry_point.output_sample_mask_used = ContainsBuiltin(
ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes);
}
for (auto* var : sem->TransitivelyReferencedGlobals()) {
auto* decl = var->Declaration();
auto name = program_->Symbols().NameFor(decl->symbol);
auto* global = var->As<sem::GlobalVariable>();
if (global && global->Declaration()->Is<ast::Override>()) {
Override override;
override.name = name;
override.id = global->OverrideId();
auto* type = var->Type();
TINT_ASSERT(Inspector, type->is_scalar());
if (type->is_bool_scalar_or_vector()) {
override.type = Override::Type::kBool;
} else if (type->is_float_scalar()) {
override.type = Override::Type::kFloat32;
} else if (type->is_signed_integer_scalar()) {
override.type = Override::Type::kInt32;
} else if (type->is_unsigned_integer_scalar()) {
override.type = Override::Type::kUint32;
} else {
TINT_UNREACHABLE(Inspector, diagnostics_);
}
override.is_initialized = global->Declaration()->constructor;
override.is_id_specified =
ast::HasAttribute<ast::IdAttribute>(global->Declaration()->attributes);
entry_point.overrides.push_back(override);
}
}
return entry_point;
}
EntryPoint Inspector::GetEntryPoint(const std::string& entry_point_name) {
auto* func = FindEntryPointByName(entry_point_name);
if (!func) {
return EntryPoint();
}
return GetEntryPoint(func);
}
std::vector<EntryPoint> Inspector::GetEntryPoints() { std::vector<EntryPoint> Inspector::GetEntryPoints() {
std::vector<EntryPoint> result; std::vector<EntryPoint> result;
@ -141,97 +245,7 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
continue; continue;
} }
auto* sem = program_->Sem().Get(func); result.push_back(GetEntryPoint(func));
EntryPoint entry_point;
entry_point.name = program_->Symbols().NameFor(func->symbol);
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol);
switch (func->PipelineStage()) {
case ast::PipelineStage::kCompute: {
entry_point.stage = PipelineStage::kCompute;
auto wgsize = sem->WorkgroupSize();
if (!wgsize[0].overridable_const && !wgsize[1].overridable_const &&
!wgsize[2].overridable_const) {
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value,
wgsize[2].value};
}
break;
}
case ast::PipelineStage::kFragment: {
entry_point.stage = PipelineStage::kFragment;
break;
}
case ast::PipelineStage::kVertex: {
entry_point.stage = PipelineStage::kVertex;
break;
}
default: {
TINT_UNREACHABLE(Inspector, diagnostics_)
<< "invalid pipeline stage for entry point '" << entry_point.name << "'";
break;
}
}
for (auto* param : sem->Parameters()) {
AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol),
param->Type(), param->Declaration()->attributes,
entry_point.input_variables);
entry_point.input_position_used |= ContainsBuiltin(
ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes);
entry_point.front_facing_used |= ContainsBuiltin(
ast::BuiltinValue::kFrontFacing, param->Type(), param->Declaration()->attributes);
entry_point.sample_index_used |= ContainsBuiltin(
ast::BuiltinValue::kSampleIndex, param->Type(), param->Declaration()->attributes);
entry_point.input_sample_mask_used |= ContainsBuiltin(
ast::BuiltinValue::kSampleMask, param->Type(), param->Declaration()->attributes);
entry_point.num_workgroups_used |= ContainsBuiltin(
ast::BuiltinValue::kNumWorkgroups, param->Type(), param->Declaration()->attributes);
}
if (!sem->ReturnType()->Is<sem::Void>()) {
AddEntryPointInOutVariables("<retval>", sem->ReturnType(), func->return_type_attributes,
entry_point.output_variables);
entry_point.output_sample_mask_used = ContainsBuiltin(
ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes);
}
for (auto* var : sem->TransitivelyReferencedGlobals()) {
auto* decl = var->Declaration();
auto name = program_->Symbols().NameFor(decl->symbol);
auto* global = var->As<sem::GlobalVariable>();
if (global && global->Declaration()->Is<ast::Override>()) {
Override override;
override.name = name;
override.id = global->OverrideId();
auto* type = var->Type();
TINT_ASSERT(Inspector, type->is_scalar());
if (type->is_bool_scalar_or_vector()) {
override.type = Override::Type::kBool;
} else if (type->is_float_scalar()) {
override.type = Override::Type::kFloat32;
} else if (type->is_signed_integer_scalar()) {
override.type = Override::Type::kInt32;
} else if (type->is_unsigned_integer_scalar()) {
override.type = Override::Type::kUint32;
} else {
TINT_UNREACHABLE(Inspector, diagnostics_);
}
override.is_initialized = global->Declaration()->constructor;
override.is_id_specified =
ast::HasAttribute<ast::IdAttribute>(global->Declaration()->attributes);
entry_point.overrides.push_back(override);
}
}
result.push_back(std::move(entry_point));
} }
return result; return result;

View File

@ -55,6 +55,10 @@ class Inspector {
/// @returns vector of entry point information /// @returns vector of entry point information
std::vector<EntryPoint> GetEntryPoints(); std::vector<EntryPoint> GetEntryPoints();
/// @param entry_point name of the entry point to get information about
/// @returns the entry point information
EntryPoint GetEntryPoint(const std::string& entry_point);
/// @returns map of override identifier to initial value /// @returns map of override identifier to initial value
std::map<OverrideId, Scalar> GetOverrideDefaultValues(); std::map<OverrideId, Scalar> GetOverrideDefaultValues();
@ -230,6 +234,10 @@ class Inspector {
/// whenever a set of expressions are resolved to globals. /// whenever a set of expressions are resolved to globals.
template <size_t N, typename F> template <size_t N, typename F>
void GetOriginatingResources(std::array<const ast::Expression*, N> exprs, F&& cb); void GetOriginatingResources(std::array<const ast::Expression*, N> exprs, F&& cb);
/// @param func the function of the entry point. Must be non-nullptr and true for IsEntryPoint()
/// @returns the entry point information
EntryPoint GetEntryPoint(const tint::ast::Function* func);
}; };
} // namespace tint::inspector } // namespace tint::inspector

View File

@ -20,6 +20,7 @@
#include "tint/override_id.h" #include "tint/override_id.h"
#include "src/tint/reflection.h"
#include "src/tint/transform/transform.h" #include "src/tint/transform/transform.h"
namespace tint::transform { namespace tint::transform {
@ -63,6 +64,9 @@ class SubstituteOverride final : public Castable<SubstituteOverride, Transform>
/// The value is always a double coming into the transform and will be /// The value is always a double coming into the transform and will be
/// converted to the correct type through and initializer. /// converted to the correct type through and initializer.
std::unordered_map<OverrideId, double> map; std::unordered_map<OverrideId, double> map;
/// Reflect the fields of this class so that it can be used by tint::ForeachField()
TINT_REFLECT(map);
}; };
/// Constructor /// Constructor