Use SubstituteOverride transform to implement overrides
Remove the old backend specific implementation for overrides. Use tint SubstituteOverride transform to replace overrides with const expressions and use the updated program at pipeline creation time. This CL also adds support for overrides used as workgroup size and related tests. Workgroup size validation now happens in backend code and at compute pipeline creation time. Bug: dawn:1504 Change-Id: I7df1fe9c3e358caa23235eacd6d13ba0b2998aec Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99821 Commit-Queue: Shrek Shao <shrekshao@google.com> Reviewed-by: Austin Eng <enga@chromium.org> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
23cf74c30e
commit
145337f309
|
@ -25,7 +25,15 @@
|
||||||
namespace {{native_namespace}} {
|
namespace {{native_namespace}} {
|
||||||
|
|
||||||
//
|
//
|
||||||
// Cache key writers for wgpu structures used in caching.
|
// Streaming readers for wgpu structures.
|
||||||
|
//
|
||||||
|
{% macro render_reader(member) %}
|
||||||
|
{%- set name = member.name.camelCase() -%}
|
||||||
|
DAWN_TRY(StreamOut(source, &t->{{name}}));
|
||||||
|
{% endmacro %}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Streaming writers for wgpu structures.
|
||||||
//
|
//
|
||||||
{% macro render_writer(member) %}
|
{% macro render_writer(member) %}
|
||||||
{%- set name = member.name.camelCase() -%}
|
{%- set name = member.name.camelCase() -%}
|
||||||
|
@ -38,31 +46,50 @@ namespace {{native_namespace}} {
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endmacro %}
|
{% endmacro %}
|
||||||
|
|
||||||
{# Helper macro to render writers. Should be used in a call block to provide additional custom
|
{# Helper macro to render readers and writers. Should be used in a call block to provide additional custom
|
||||||
handling when necessary. The optional `omit` field can be used to omit fields that are either
|
handling when necessary. The optional `omit` field can be used to omit fields that are either
|
||||||
handled in the custom code, or unnecessary in the serialized output.
|
handled in the custom code, or unnecessary in the serialized output.
|
||||||
Example:
|
Example:
|
||||||
{% call render_cache_key_writer("struct name", omits=["omit field"]) %}
|
{% call render_streaming_impl("struct name", writer=true, reader=false, omits=["omit field"]) %}
|
||||||
// Custom C++ code to handle special types/members that are hard to generate code for
|
// Custom C++ code to handle special types/members that are hard to generate code for
|
||||||
{% endcall %}
|
{% endcall %}
|
||||||
|
One day we should probably make the generator smart enough to generate everything it can
|
||||||
|
instead of manually adding streaming implementations here.
|
||||||
#}
|
#}
|
||||||
{% macro render_cache_key_writer(json_type, omits=[]) %}
|
{% macro render_streaming_impl(json_type, writer, reader, omits=[]) %}
|
||||||
{%- set cpp_type = types[json_type].name.CamelCase() -%}
|
{%- set cpp_type = types[json_type].name.CamelCase() -%}
|
||||||
template <>
|
{% if reader %}
|
||||||
void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) {
|
template <>
|
||||||
{{ caller() }}
|
MaybeError stream::Stream<{{cpp_type}}>::Read(stream::Source* source, {{cpp_type}}* t) {
|
||||||
{% for member in types[json_type].members %}
|
{{ caller() }}
|
||||||
{%- if not member.name.get() in omits %}
|
{% for member in types[json_type].members %}
|
||||||
{{render_writer(member)}}
|
{% if not member.name.get() in omits %}
|
||||||
{%- endif %}
|
{{render_reader(member)}}
|
||||||
{% endfor %}
|
{% endif %}
|
||||||
}
|
{% endfor %}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
{% endif %}
|
||||||
|
{% if writer %}
|
||||||
|
template <>
|
||||||
|
void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) {
|
||||||
|
{{ caller() }}
|
||||||
|
{% for member in types[json_type].members %}
|
||||||
|
{% if not member.name.get() in omits %}
|
||||||
|
{{render_writer(member)}}
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
}
|
||||||
|
{% endif %}
|
||||||
{% endmacro %}
|
{% endmacro %}
|
||||||
|
|
||||||
{% call render_cache_key_writer("adapter properties") %}
|
{% call render_streaming_impl("adapter properties", true, false) %}
|
||||||
{% endcall %}
|
{% endcall %}
|
||||||
|
|
||||||
{% call render_cache_key_writer("dawn cache device descriptor") %}
|
{% call render_streaming_impl("dawn cache device descriptor", true, false) %}
|
||||||
|
{% endcall %}
|
||||||
|
|
||||||
|
{% call render_streaming_impl("extent 3D", true, true) %}
|
||||||
{% endcall %}
|
{% endcall %}
|
||||||
|
|
||||||
} // namespace {{native_namespace}}
|
} // namespace {{native_namespace}}
|
||||||
|
|
|
@ -215,4 +215,22 @@ Limits ApplyLimitTiers(Limits limits) {
|
||||||
return limits;
|
return limits;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT(type, name) \
|
||||||
|
{ result.name = limits.name; }
|
||||||
|
#define DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(MEMBERS) \
|
||||||
|
MEMBERS(DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT)
|
||||||
|
LimitsForCompilationRequest LimitsForCompilationRequest::Create(const Limits& limits) {
|
||||||
|
LimitsForCompilationRequest result;
|
||||||
|
DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS)
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
#undef DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT
|
||||||
|
#undef DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void stream::Stream<LimitsForCompilationRequest>::Write(Sink* s,
|
||||||
|
const LimitsForCompilationRequest& t) {
|
||||||
|
t.VisitAll([&](const auto&... members) { StreamIn(s, members...); });
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dawn::native
|
} // namespace dawn::native
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
#define SRC_DAWN_NATIVE_LIMITS_H_
|
#define SRC_DAWN_NATIVE_LIMITS_H_
|
||||||
|
|
||||||
#include "dawn/native/Error.h"
|
#include "dawn/native/Error.h"
|
||||||
|
#include "dawn/native/VisitableMembers.h"
|
||||||
#include "dawn/native/dawn_platform.h"
|
#include "dawn/native/dawn_platform.h"
|
||||||
|
|
||||||
namespace dawn::native {
|
namespace dawn::native {
|
||||||
|
@ -38,6 +39,20 @@ MaybeError ValidateLimits(const Limits& supportedLimits, const Limits& requiredL
|
||||||
// Returns a copy of |limits| where limit tiers are applied.
|
// Returns a copy of |limits| where limit tiers are applied.
|
||||||
Limits ApplyLimitTiers(Limits limits);
|
Limits ApplyLimitTiers(Limits limits);
|
||||||
|
|
||||||
|
// If there are new limit member needed at shader compilation time
|
||||||
|
// Simply append a new X(type, name) here.
|
||||||
|
#define LIMITS_FOR_COMPILATION_REQUEST_MEMBERS(X) \
|
||||||
|
X(uint32_t, maxComputeWorkgroupSizeX) \
|
||||||
|
X(uint32_t, maxComputeWorkgroupSizeY) \
|
||||||
|
X(uint32_t, maxComputeWorkgroupSizeZ) \
|
||||||
|
X(uint32_t, maxComputeInvocationsPerWorkgroup) \
|
||||||
|
X(uint32_t, maxComputeWorkgroupStorageSize)
|
||||||
|
|
||||||
|
struct LimitsForCompilationRequest {
|
||||||
|
static LimitsForCompilationRequest Create(const Limits& limits);
|
||||||
|
DAWN_VISITABLE_MEMBERS(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS)
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace dawn::native
|
} // namespace dawn::native
|
||||||
|
|
||||||
#endif // SRC_DAWN_NATIVE_LIMITS_H_
|
#endif // SRC_DAWN_NATIVE_LIMITS_H_
|
||||||
|
|
|
@ -15,7 +15,9 @@
|
||||||
#ifndef SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_
|
#ifndef SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_
|
||||||
#define SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_
|
#define SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "dawn/common/HashUtils.h"
|
#include "dawn/common/HashUtils.h"
|
||||||
|
@ -60,6 +62,13 @@ class ObjectContentHasher {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T, typename E>
|
||||||
|
struct RecordImpl<std::map<T, E>> {
|
||||||
|
static constexpr void Call(ObjectContentHasher* recorder, const std::map<T, E>& map) {
|
||||||
|
recorder->RecordIterable<std::map<T, E>>(map);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename IteratorT>
|
template <typename IteratorT>
|
||||||
constexpr void RecordIterable(const IteratorT& iterable) {
|
constexpr void RecordIterable(const IteratorT& iterable) {
|
||||||
for (auto it = iterable.begin(); it != iterable.end(); ++it) {
|
for (auto it = iterable.begin(); it != iterable.end(); ++it) {
|
||||||
|
@ -67,6 +76,14 @@ class ObjectContentHasher {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename E>
|
||||||
|
struct RecordImpl<std::pair<T, E>> {
|
||||||
|
static constexpr void Call(ObjectContentHasher* recorder, const std::pair<T, E>& pair) {
|
||||||
|
recorder->Record(pair.first);
|
||||||
|
recorder->Record(pair.second);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
size_t mContentHash = 0;
|
size_t mContentHash = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -58,12 +58,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
|
||||||
DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout));
|
DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (constantCount > 0u && device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs)) {
|
|
||||||
return DAWN_VALIDATION_ERROR(
|
|
||||||
"Pipeline overridable constants are disallowed because they are partially "
|
|
||||||
"implemented.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate if overridable constants exist in shader module
|
// Validate if overridable constants exist in shader module
|
||||||
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor
|
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor
|
||||||
size_t numUninitializedConstants = metadata.uninitializedOverrides.size();
|
size_t numUninitializedConstants = metadata.uninitializedOverrides.size();
|
||||||
|
@ -233,6 +227,7 @@ size_t PipelineBase::ComputeContentHash() {
|
||||||
for (SingleShaderStage stage : IterateStages(mStageMask)) {
|
for (SingleShaderStage stage : IterateStages(mStageMask)) {
|
||||||
recorder.Record(mStages[stage].module->GetContentHash());
|
recorder.Record(mStages[stage].module->GetContentHash());
|
||||||
recorder.Record(mStages[stage].entryPoint);
|
recorder.Record(mStages[stage].entryPoint);
|
||||||
|
recorder.Record(mStages[stage].constants);
|
||||||
}
|
}
|
||||||
|
|
||||||
return recorder.GetContentHash();
|
return recorder.GetContentHash();
|
||||||
|
@ -248,7 +243,14 @@ bool PipelineBase::EqualForCache(const PipelineBase* a, const PipelineBase* b) {
|
||||||
for (SingleShaderStage stage : IterateStages(a->mStageMask)) {
|
for (SingleShaderStage stage : IterateStages(a->mStageMask)) {
|
||||||
// The module is deduplicated so it can be compared by pointer.
|
// The module is deduplicated so it can be compared by pointer.
|
||||||
if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() ||
|
if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() ||
|
||||||
a->mStages[stage].entryPoint != b->mStages[stage].entryPoint) {
|
a->mStages[stage].entryPoint != b->mStages[stage].entryPoint ||
|
||||||
|
a->mStages[stage].constants.size() != b->mStages[stage].constants.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the constants.size are the same, we still need to compare the key and value.
|
||||||
|
if (!std::equal(a->mStages[stage].constants.begin(), a->mStages[stage].constants.end(),
|
||||||
|
b->mStages[stage].constants.begin())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,9 +40,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
|
||||||
const PipelineLayoutBase* layout,
|
const PipelineLayoutBase* layout,
|
||||||
SingleShaderStage stage);
|
SingleShaderStage stage);
|
||||||
|
|
||||||
// Use map to make sure constant keys are sorted for creating shader cache keys
|
|
||||||
using PipelineConstantEntries = std::map<std::string, double>;
|
|
||||||
|
|
||||||
struct ProgrammableStage {
|
struct ProgrammableStage {
|
||||||
Ref<ShaderModuleBase> module;
|
Ref<ShaderModuleBase> module;
|
||||||
std::string entryPoint;
|
std::string entryPoint;
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "dawn/common/BitSetIterator.h"
|
#include "dawn/common/BitSetIterator.h"
|
||||||
#include "dawn/common/Constants.h"
|
#include "dawn/common/Constants.h"
|
||||||
#include "dawn/common/HashUtils.h"
|
|
||||||
#include "dawn/native/BindGroupLayout.h"
|
#include "dawn/native/BindGroupLayout.h"
|
||||||
#include "dawn/native/ChainUtils_autogen.h"
|
#include "dawn/native/ChainUtils_autogen.h"
|
||||||
#include "dawn/native/CompilationMessages.h"
|
#include "dawn/native/CompilationMessages.h"
|
||||||
|
@ -511,7 +510,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
||||||
const DeviceBase* device,
|
const DeviceBase* device,
|
||||||
tint::inspector::Inspector* inspector,
|
tint::inspector::Inspector* inspector,
|
||||||
const tint::inspector::EntryPoint& entryPoint) {
|
const tint::inspector::EntryPoint& entryPoint) {
|
||||||
const CombinedLimits& limits = device->GetLimits();
|
|
||||||
constexpr uint32_t kMaxInterStageShaderLocation = kMaxInterStageShaderVariables - 1;
|
constexpr uint32_t kMaxInterStageShaderLocation = kMaxInterStageShaderVariables - 1;
|
||||||
|
|
||||||
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
|
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
|
||||||
|
@ -528,10 +526,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
||||||
})()
|
})()
|
||||||
|
|
||||||
if (!entryPoint.overrides.empty()) {
|
if (!entryPoint.overrides.empty()) {
|
||||||
DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs),
|
|
||||||
"Pipeline overridable constants are disallowed because they "
|
|
||||||
"are partially implemented.");
|
|
||||||
|
|
||||||
const auto& name2Id = inspector->GetNamedOverrideIds();
|
const auto& name2Id = inspector->GetNamedOverrideIds();
|
||||||
const auto& id2Scalar = inspector->GetOverrideDefaultValues();
|
const auto& id2Scalar = inspector->GetOverrideDefaultValues();
|
||||||
|
|
||||||
|
@ -553,10 +547,10 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
||||||
UNREACHABLE();
|
UNREACHABLE();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EntryPointMetadata::Override override = {id.value, FromTintOverrideType(c.type),
|
EntryPointMetadata::Override override = {id, FromTintOverrideType(c.type),
|
||||||
c.is_initialized, defaultValue};
|
c.is_initialized, defaultValue};
|
||||||
|
|
||||||
std::string identifier = c.is_id_specified ? std::to_string(override.id) : c.name;
|
std::string identifier = c.is_id_specified ? std::to_string(override.id.value) : c.name;
|
||||||
metadata->overrides[identifier] = override;
|
metadata->overrides[identifier] = override;
|
||||||
|
|
||||||
if (!c.is_initialized) {
|
if (!c.is_initialized) {
|
||||||
|
@ -575,39 +569,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
||||||
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
|
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
|
||||||
|
|
||||||
if (metadata->stage == SingleShaderStage::Compute) {
|
if (metadata->stage == SingleShaderStage::Compute) {
|
||||||
auto workgroup_size = entryPoint.workgroup_size;
|
|
||||||
DAWN_INVALID_IF(
|
|
||||||
!workgroup_size.has_value(),
|
|
||||||
"TODO(crbug.com/dawn/1504): Dawn does not currently support @workgroup_size "
|
|
||||||
"attributes using override-expressions");
|
|
||||||
DelayedInvalidIf(workgroup_size->x > limits.v1.maxComputeWorkgroupSizeX ||
|
|
||||||
workgroup_size->y > limits.v1.maxComputeWorkgroupSizeY ||
|
|
||||||
workgroup_size->z > limits.v1.maxComputeWorkgroupSizeZ,
|
|
||||||
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
|
|
||||||
"maximum allowed (%u, %u, %u).",
|
|
||||||
workgroup_size->x, workgroup_size->y, workgroup_size->z,
|
|
||||||
limits.v1.maxComputeWorkgroupSizeX, limits.v1.maxComputeWorkgroupSizeY,
|
|
||||||
limits.v1.maxComputeWorkgroupSizeZ);
|
|
||||||
|
|
||||||
// Dimensions have already been validated against their individual limits above.
|
|
||||||
// Cast to uint64_t to avoid overflow in this multiplication.
|
|
||||||
uint64_t numInvocations =
|
|
||||||
static_cast<uint64_t>(workgroup_size->x) * workgroup_size->y * workgroup_size->z;
|
|
||||||
DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup,
|
|
||||||
"The total number of workgroup invocations (%u) exceeds the "
|
|
||||||
"maximum allowed (%u).",
|
|
||||||
numInvocations, limits.v1.maxComputeInvocationsPerWorkgroup);
|
|
||||||
|
|
||||||
const size_t workgroupStorageSize = inspector->GetWorkgroupStorageSize(entryPoint.name);
|
|
||||||
DelayedInvalidIf(workgroupStorageSize > limits.v1.maxComputeWorkgroupStorageSize,
|
|
||||||
"The total use of workgroup storage (%u bytes) is larger than "
|
|
||||||
"the maximum allowed (%u bytes).",
|
|
||||||
workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize);
|
|
||||||
|
|
||||||
metadata->localWorkgroupSize.x = workgroup_size->x;
|
|
||||||
metadata->localWorkgroupSize.y = workgroup_size->y;
|
|
||||||
metadata->localWorkgroupSize.z = workgroup_size->z;
|
|
||||||
|
|
||||||
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
|
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -883,6 +844,46 @@ MaybeError ReflectShaderUsingTint(const DeviceBase* device,
|
||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
|
ResultOrError<Extent3D> ValidateComputeStageWorkgroupSize(
|
||||||
|
const tint::Program& program,
|
||||||
|
const char* entryPointName,
|
||||||
|
const LimitsForCompilationRequest& limits) {
|
||||||
|
tint::inspector::Inspector inspector(&program);
|
||||||
|
// At this point the entry point must exist and must have workgroup size values.
|
||||||
|
tint::inspector::EntryPoint entryPoint = inspector.GetEntryPoint(entryPointName);
|
||||||
|
ASSERT(entryPoint.workgroup_size.has_value());
|
||||||
|
const tint::inspector::WorkgroupSize& workgroup_size = entryPoint.workgroup_size.value();
|
||||||
|
|
||||||
|
DAWN_INVALID_IF(workgroup_size.x < 1 || workgroup_size.y < 1 || workgroup_size.z < 1,
|
||||||
|
"Entry-point uses workgroup_size(%u, %u, %u) that are below the "
|
||||||
|
"minimum allowed (1, 1, 1).",
|
||||||
|
workgroup_size.x, workgroup_size.y, workgroup_size.z);
|
||||||
|
|
||||||
|
DAWN_INVALID_IF(workgroup_size.x > limits.maxComputeWorkgroupSizeX ||
|
||||||
|
workgroup_size.y > limits.maxComputeWorkgroupSizeY ||
|
||||||
|
workgroup_size.z > limits.maxComputeWorkgroupSizeZ,
|
||||||
|
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
|
||||||
|
"maximum allowed (%u, %u, %u).",
|
||||||
|
workgroup_size.x, workgroup_size.y, workgroup_size.z,
|
||||||
|
limits.maxComputeWorkgroupSizeX, limits.maxComputeWorkgroupSizeY,
|
||||||
|
limits.maxComputeWorkgroupSizeZ);
|
||||||
|
|
||||||
|
uint64_t numInvocations =
|
||||||
|
static_cast<uint64_t>(workgroup_size.x) * workgroup_size.y * workgroup_size.z;
|
||||||
|
DAWN_INVALID_IF(numInvocations > limits.maxComputeInvocationsPerWorkgroup,
|
||||||
|
"The total number of workgroup invocations (%u) exceeds the "
|
||||||
|
"maximum allowed (%u).",
|
||||||
|
numInvocations, limits.maxComputeInvocationsPerWorkgroup);
|
||||||
|
|
||||||
|
const size_t workgroupStorageSize = inspector.GetWorkgroupStorageSize(entryPointName);
|
||||||
|
DAWN_INVALID_IF(workgroupStorageSize > limits.maxComputeWorkgroupStorageSize,
|
||||||
|
"The total use of workgroup storage (%u bytes) is larger than "
|
||||||
|
"the maximum allowed (%u bytes).",
|
||||||
|
workgroupStorageSize, limits.maxComputeWorkgroupStorageSize);
|
||||||
|
|
||||||
|
return Extent3D{workgroup_size.x, workgroup_size.y, workgroup_size.z};
|
||||||
|
}
|
||||||
|
|
||||||
ShaderModuleParseResult::ShaderModuleParseResult() = default;
|
ShaderModuleParseResult::ShaderModuleParseResult() = default;
|
||||||
ShaderModuleParseResult::~ShaderModuleParseResult() = default;
|
ShaderModuleParseResult::~ShaderModuleParseResult() = default;
|
||||||
|
|
||||||
|
@ -1200,11 +1201,4 @@ MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t PipelineLayoutEntryPointPairHashFunc::operator()(
|
|
||||||
const PipelineLayoutEntryPointPair& pair) const {
|
|
||||||
size_t hash = 0;
|
|
||||||
HashCombine(&hash, pair.first, pair.second);
|
|
||||||
return hash;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace dawn::native
|
} // namespace dawn::native
|
||||||
|
|
|
@ -33,10 +33,12 @@
|
||||||
#include "dawn/native/Format.h"
|
#include "dawn/native/Format.h"
|
||||||
#include "dawn/native/Forward.h"
|
#include "dawn/native/Forward.h"
|
||||||
#include "dawn/native/IntegerTypes.h"
|
#include "dawn/native/IntegerTypes.h"
|
||||||
|
#include "dawn/native/Limits.h"
|
||||||
#include "dawn/native/ObjectBase.h"
|
#include "dawn/native/ObjectBase.h"
|
||||||
#include "dawn/native/PerStage.h"
|
#include "dawn/native/PerStage.h"
|
||||||
#include "dawn/native/VertexFormat.h"
|
#include "dawn/native/VertexFormat.h"
|
||||||
#include "dawn/native/dawn_platform.h"
|
#include "dawn/native/dawn_platform.h"
|
||||||
|
#include "tint/override_id.h"
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
|
|
||||||
|
@ -76,10 +78,8 @@ enum class InterpolationSampling {
|
||||||
Sample,
|
Sample,
|
||||||
};
|
};
|
||||||
|
|
||||||
using PipelineLayoutEntryPointPair = std::pair<const PipelineLayoutBase*, std::string>;
|
// Use map to make sure constant keys are sorted for creating shader cache keys
|
||||||
struct PipelineLayoutEntryPointPairHashFunc {
|
using PipelineConstantEntries = std::map<std::string, double>;
|
||||||
size_t operator()(const PipelineLayoutEntryPointPair& pair) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
// A map from name to EntryPointMetadata.
|
// A map from name to EntryPointMetadata.
|
||||||
using EntryPointMetadataTable =
|
using EntryPointMetadataTable =
|
||||||
|
@ -108,6 +108,13 @@ MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
|
||||||
const EntryPointMetadata& entryPoint,
|
const EntryPointMetadata& entryPoint,
|
||||||
const PipelineLayoutBase* layout);
|
const PipelineLayoutBase* layout);
|
||||||
|
|
||||||
|
// Return extent3D with workgroup size dimension info if it is valid
|
||||||
|
// width = x, height = y, depthOrArrayLength = z
|
||||||
|
ResultOrError<Extent3D> ValidateComputeStageWorkgroupSize(
|
||||||
|
const tint::Program& program,
|
||||||
|
const char* entryPointName,
|
||||||
|
const LimitsForCompilationRequest& limits);
|
||||||
|
|
||||||
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
|
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
|
||||||
const PipelineLayoutBase* layout);
|
const PipelineLayoutBase* layout);
|
||||||
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
|
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
|
||||||
|
@ -204,14 +211,12 @@ struct EntryPointMetadata {
|
||||||
std::bitset<kMaxInterStageShaderVariables> usedInterStageVariables;
|
std::bitset<kMaxInterStageShaderVariables> usedInterStageVariables;
|
||||||
std::array<InterStageVariableInfo, kMaxInterStageShaderVariables> interStageVariables;
|
std::array<InterStageVariableInfo, kMaxInterStageShaderVariables> interStageVariables;
|
||||||
|
|
||||||
// The local workgroup size declared for a compute entry point (or 0s otehrwise).
|
|
||||||
Origin3D localWorkgroupSize;
|
|
||||||
|
|
||||||
// The shader stage for this binding.
|
// The shader stage for this binding.
|
||||||
SingleShaderStage stage;
|
SingleShaderStage stage;
|
||||||
|
|
||||||
struct Override {
|
struct Override {
|
||||||
uint32_t id;
|
tint::OverrideId id;
|
||||||
|
|
||||||
// Match tint::inspector::Override::Type
|
// Match tint::inspector::Override::Type
|
||||||
// Bool is defined as a macro on linux X11 and cannot compile
|
// Bool is defined as a macro on linux X11 and cannot compile
|
||||||
enum class Type { Boolean, Float32, Uint32, Int32 } type;
|
enum class Type { Boolean, Float32, Uint32, Int32 } type;
|
||||||
|
@ -273,6 +278,7 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
|
||||||
bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const;
|
bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// This returns tint program before running transforms.
|
||||||
const tint::Program* GetTintProgram() const;
|
const tint::Program* GetTintProgram() const;
|
||||||
|
|
||||||
void APIGetCompilationInfo(wgpu::CompilationInfoCallback callback, void* userdata);
|
void APIGetCompilationInfo(wgpu::CompilationInfoCallback callback, void* userdata);
|
||||||
|
|
|
@ -65,6 +65,23 @@ void stream::Stream<tint::transform::VertexPulling::Config>::Write(
|
||||||
StreamInTintObject(cfg, sink);
|
StreamInTintObject(cfg, sink);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// static
|
||||||
|
template <>
|
||||||
|
void stream::Stream<tint::transform::SubstituteOverride::Config>::Write(
|
||||||
|
stream::Sink* sink,
|
||||||
|
const tint::transform::SubstituteOverride::Config& cfg) {
|
||||||
|
StreamInTintObject(cfg, sink);
|
||||||
|
}
|
||||||
|
|
||||||
|
// static
|
||||||
|
template <>
|
||||||
|
void stream::Stream<tint::OverrideId>::Write(stream::Sink* sink, const tint::OverrideId& id) {
|
||||||
|
// TODO(tint:1640): fix the include build issues and use StreamInTintObject instead.
|
||||||
|
static_assert(offsetof(tint::OverrideId, value) == 0,
|
||||||
|
"Please update serialization for tint::OverrideId");
|
||||||
|
StreamIn(sink, id.value);
|
||||||
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
template <>
|
template <>
|
||||||
void stream::Stream<tint::transform::VertexBufferLayoutDescriptor>::Write(
|
void stream::Stream<tint::transform::VertexBufferLayoutDescriptor>::Write(
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include "dawn/native/BindGroupLayout.h"
|
#include "dawn/native/BindGroupLayout.h"
|
||||||
#include "dawn/native/Device.h"
|
#include "dawn/native/Device.h"
|
||||||
|
#include "dawn/native/Pipeline.h"
|
||||||
#include "dawn/native/PipelineLayout.h"
|
#include "dawn/native/PipelineLayout.h"
|
||||||
#include "dawn/native/RenderPipeline.h"
|
#include "dawn/native/RenderPipeline.h"
|
||||||
|
|
||||||
|
@ -183,6 +184,21 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
|
||||||
return cfg;
|
return cfg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig(
|
||||||
|
const ProgrammableStage& stage) {
|
||||||
|
const EntryPointMetadata& metadata = *stage.metadata;
|
||||||
|
const auto& constants = stage.constants;
|
||||||
|
|
||||||
|
tint::transform::SubstituteOverride::Config cfg;
|
||||||
|
|
||||||
|
for (const auto& [key, value] : constants) {
|
||||||
|
const auto& o = metadata.overrides.at(key);
|
||||||
|
cfg.map.insert({o.id, value});
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dawn::native
|
} // namespace dawn::native
|
||||||
|
|
||||||
namespace tint::sem {
|
namespace tint::sem {
|
||||||
|
|
|
@ -26,6 +26,7 @@ namespace dawn::native {
|
||||||
|
|
||||||
class DeviceBase;
|
class DeviceBase;
|
||||||
class PipelineLayoutBase;
|
class PipelineLayoutBase;
|
||||||
|
struct ProgrammableStage;
|
||||||
class RenderPipelineBase;
|
class RenderPipelineBase;
|
||||||
|
|
||||||
// Indicates that for the lifetime of this object tint internal compiler errors should be
|
// Indicates that for the lifetime of this object tint internal compiler errors should be
|
||||||
|
@ -47,6 +48,9 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
|
||||||
const std::string_view& entryPoint,
|
const std::string_view& entryPoint,
|
||||||
BindGroupIndex pullingBufferBindingSet);
|
BindGroupIndex pullingBufferBindingSet);
|
||||||
|
|
||||||
|
tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig(
|
||||||
|
const ProgrammableStage& stage);
|
||||||
|
|
||||||
} // namespace dawn::native
|
} // namespace dawn::native
|
||||||
|
|
||||||
namespace tint::sem {
|
namespace tint::sem {
|
||||||
|
|
|
@ -351,8 +351,6 @@ MaybeError RenderPipeline::Initialize() {
|
||||||
|
|
||||||
D3D12_GRAPHICS_PIPELINE_STATE_DESC descriptorD3D12 = {};
|
D3D12_GRAPHICS_PIPELINE_STATE_DESC descriptorD3D12 = {};
|
||||||
|
|
||||||
PerStage<ProgrammableStage> pipelineStages = GetAllStages();
|
|
||||||
|
|
||||||
PerStage<D3D12_SHADER_BYTECODE*> shaders;
|
PerStage<D3D12_SHADER_BYTECODE*> shaders;
|
||||||
shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS;
|
shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS;
|
||||||
shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS;
|
shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS;
|
||||||
|
@ -360,8 +358,9 @@ MaybeError RenderPipeline::Initialize() {
|
||||||
PerStage<CompiledShader> compiledShader;
|
PerStage<CompiledShader> compiledShader;
|
||||||
|
|
||||||
for (auto stage : IterateStages(GetStageMask())) {
|
for (auto stage : IterateStages(GetStageMask())) {
|
||||||
DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(pipelineStages[stage].module)
|
const ProgrammableStage& programmableStage = GetStage(stage);
|
||||||
->Compile(pipelineStages[stage], stage,
|
DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(programmableStage.module)
|
||||||
|
->Compile(programmableStage, stage,
|
||||||
ToBackend(GetLayout()), compileFlags));
|
ToBackend(GetLayout()), compileFlags));
|
||||||
*shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
|
*shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,100 +65,35 @@ namespace dawn::native::d3d12 {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
|
|
||||||
std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
|
|
||||||
std::ostringstream out;
|
|
||||||
out.precision(n);
|
|
||||||
out << std::fixed << v;
|
|
||||||
return out.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string GetHLSLValueString(EntryPointMetadata::Override::Type dawnType,
|
|
||||||
const OverrideScalar* entry,
|
|
||||||
double value = 0) {
|
|
||||||
switch (dawnType) {
|
|
||||||
case EntryPointMetadata::Override::Type::Boolean:
|
|
||||||
return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
|
|
||||||
case EntryPointMetadata::Override::Type::Float32:
|
|
||||||
return FloatToStringWithPrecision(entry ? entry->f32 : static_cast<float>(value));
|
|
||||||
case EntryPointMetadata::Override::Type::Int32:
|
|
||||||
return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
|
|
||||||
case EntryPointMetadata::Override::Type::Uint32:
|
|
||||||
return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
|
|
||||||
default:
|
|
||||||
UNREACHABLE();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
|
|
||||||
|
|
||||||
using DefineStrings = std::vector<std::pair<std::string, std::string>>;
|
|
||||||
|
|
||||||
DefineStrings GetOverridableConstantsDefines(
|
|
||||||
const PipelineConstantEntries& pipelineConstantEntries,
|
|
||||||
const EntryPointMetadata::OverridesMap& shaderEntryPointConstants) {
|
|
||||||
DefineStrings defineStrings;
|
|
||||||
std::unordered_set<std::string> overriddenConstants;
|
|
||||||
|
|
||||||
// Set pipeline overridden values
|
|
||||||
for (const auto& [name, value] : pipelineConstantEntries) {
|
|
||||||
overriddenConstants.insert(name);
|
|
||||||
|
|
||||||
// This is already validated so `name` must exist
|
|
||||||
const auto& moduleConstant = shaderEntryPointConstants.at(name);
|
|
||||||
|
|
||||||
defineStrings.emplace_back(
|
|
||||||
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
|
|
||||||
GetHLSLValueString(moduleConstant.type, nullptr, value));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set shader initialized default values
|
|
||||||
for (const auto& iter : shaderEntryPointConstants) {
|
|
||||||
const std::string& name = iter.first;
|
|
||||||
if (overriddenConstants.count(name) != 0) {
|
|
||||||
// This constant already has overridden value
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto& moduleConstant = shaderEntryPointConstants.at(name);
|
|
||||||
|
|
||||||
// Uninitialized default values are okay since they ar only defined to pass
|
|
||||||
// compilation but not used
|
|
||||||
defineStrings.emplace_back(
|
|
||||||
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
|
|
||||||
GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
|
|
||||||
}
|
|
||||||
return defineStrings;
|
|
||||||
}
|
|
||||||
|
|
||||||
enum class Compiler { FXC, DXC };
|
enum class Compiler { FXC, DXC };
|
||||||
|
|
||||||
#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
|
#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
|
||||||
X(const tint::Program*, inputProgram) \
|
X(const tint::Program*, inputProgram) \
|
||||||
X(std::string_view, entryPointName) \
|
X(std::string_view, entryPointName) \
|
||||||
X(SingleShaderStage, stage) \
|
X(SingleShaderStage, stage) \
|
||||||
X(uint32_t, shaderModel) \
|
X(uint32_t, shaderModel) \
|
||||||
X(uint32_t, compileFlags) \
|
X(uint32_t, compileFlags) \
|
||||||
X(Compiler, compiler) \
|
X(Compiler, compiler) \
|
||||||
X(uint64_t, compilerVersion) \
|
X(uint64_t, compilerVersion) \
|
||||||
X(std::wstring_view, dxcShaderProfile) \
|
X(std::wstring_view, dxcShaderProfile) \
|
||||||
X(std::string_view, fxcShaderProfile) \
|
X(std::string_view, fxcShaderProfile) \
|
||||||
X(pD3DCompile, d3dCompile) \
|
X(pD3DCompile, d3dCompile) \
|
||||||
X(IDxcLibrary*, dxcLibrary) \
|
X(IDxcLibrary*, dxcLibrary) \
|
||||||
X(IDxcCompiler*, dxcCompiler) \
|
X(IDxcCompiler*, dxcCompiler) \
|
||||||
X(uint32_t, firstIndexOffsetShaderRegister) \
|
X(uint32_t, firstIndexOffsetShaderRegister) \
|
||||||
X(uint32_t, firstIndexOffsetRegisterSpace) \
|
X(uint32_t, firstIndexOffsetRegisterSpace) \
|
||||||
X(bool, usesNumWorkgroups) \
|
X(bool, usesNumWorkgroups) \
|
||||||
X(uint32_t, numWorkgroupsShaderRegister) \
|
X(uint32_t, numWorkgroupsShaderRegister) \
|
||||||
X(uint32_t, numWorkgroupsRegisterSpace) \
|
X(uint32_t, numWorkgroupsRegisterSpace) \
|
||||||
X(DefineStrings, defineStrings) \
|
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
|
||||||
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
|
X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \
|
||||||
X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \
|
X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \
|
||||||
X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \
|
X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \
|
||||||
X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \
|
X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
|
||||||
X(bool, disableSymbolRenaming) \
|
X(LimitsForCompilationRequest, limits) \
|
||||||
X(bool, isRobustnessEnabled) \
|
X(bool, disableSymbolRenaming) \
|
||||||
X(bool, disableWorkgroupInit) \
|
X(bool, isRobustnessEnabled) \
|
||||||
|
X(bool, disableWorkgroupInit) \
|
||||||
X(bool, dumpShaders)
|
X(bool, dumpShaders)
|
||||||
|
|
||||||
#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \
|
#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \
|
||||||
|
@ -170,8 +105,7 @@ enum class Compiler { FXC, DXC };
|
||||||
X(std::string_view, fxcShaderProfile) \
|
X(std::string_view, fxcShaderProfile) \
|
||||||
X(pD3DCompile, d3dCompile) \
|
X(pD3DCompile, d3dCompile) \
|
||||||
X(IDxcLibrary*, dxcLibrary) \
|
X(IDxcLibrary*, dxcLibrary) \
|
||||||
X(IDxcCompiler*, dxcCompiler) \
|
X(IDxcCompiler*, dxcCompiler)
|
||||||
X(DefineStrings, defineStrings)
|
|
||||||
|
|
||||||
DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){};
|
DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){};
|
||||||
#undef HLSL_COMPILATION_REQUEST_MEMBERS
|
#undef HLSL_COMPILATION_REQUEST_MEMBERS
|
||||||
|
@ -255,25 +189,11 @@ ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(const D3DBytecodeCompilationReq
|
||||||
std::vector<const wchar_t*> arguments =
|
std::vector<const wchar_t*> arguments =
|
||||||
GetDXCArguments(r.compileFlags, r.hasShaderFloat16Feature);
|
GetDXCArguments(r.compileFlags, r.hasShaderFloat16Feature);
|
||||||
|
|
||||||
// Build defines for overridable constants
|
|
||||||
std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
|
|
||||||
defineStrings.reserve(r.defineStrings.size());
|
|
||||||
for (const auto& [name, value] : r.defineStrings) {
|
|
||||||
defineStrings.emplace_back(UTF8ToWStr(name.c_str()), UTF8ToWStr(value.c_str()));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<DxcDefine> dxcDefines;
|
|
||||||
dxcDefines.reserve(defineStrings.size());
|
|
||||||
for (const auto& [name, value] : defineStrings) {
|
|
||||||
dxcDefines.push_back({name.c_str(), value.c_str()});
|
|
||||||
}
|
|
||||||
|
|
||||||
ComPtr<IDxcOperationResult> result;
|
ComPtr<IDxcOperationResult> result;
|
||||||
DAWN_TRY(CheckHRESULT(
|
DAWN_TRY(CheckHRESULT(r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
|
||||||
r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
|
r.dxcShaderProfile.data(), arguments.data(),
|
||||||
r.dxcShaderProfile.data(), arguments.data(), arguments.size(),
|
arguments.size(), nullptr, 0, nullptr, &result),
|
||||||
dxcDefines.data(), dxcDefines.size(), nullptr, &result),
|
"DXC compile"));
|
||||||
"DXC compile"));
|
|
||||||
|
|
||||||
HRESULT hr;
|
HRESULT hr;
|
||||||
DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
|
DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
|
||||||
|
@ -359,20 +279,7 @@ ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const D3DBytecodeCompilationReq
|
||||||
ComPtr<ID3DBlob> compiledShader;
|
ComPtr<ID3DBlob> compiledShader;
|
||||||
ComPtr<ID3DBlob> errors;
|
ComPtr<ID3DBlob> errors;
|
||||||
|
|
||||||
// Build defines for overridable constants
|
DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
|
||||||
const D3D_SHADER_MACRO* pDefines = nullptr;
|
|
||||||
std::vector<D3D_SHADER_MACRO> fxcDefines;
|
|
||||||
if (r.defineStrings.size() > 0) {
|
|
||||||
fxcDefines.reserve(r.defineStrings.size() + 1);
|
|
||||||
for (const auto& [name, value] : r.defineStrings) {
|
|
||||||
fxcDefines.push_back({name.c_str(), value.c_str()});
|
|
||||||
}
|
|
||||||
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
|
|
||||||
fxcDefines.push_back({nullptr, nullptr});
|
|
||||||
pDefines = fxcDefines.data();
|
|
||||||
}
|
|
||||||
|
|
||||||
DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines,
|
|
||||||
nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(),
|
nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(),
|
||||||
r.compileFlags, 0, &compiledShader, &errors)),
|
r.compileFlags, 0, &compiledShader, &errors)),
|
||||||
"D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer()));
|
"D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer()));
|
||||||
|
@ -420,6 +327,14 @@ ResultOrError<std::string> TranslateToHLSL(
|
||||||
tint::transform::Renamer::Target::kHlslKeywords);
|
tint::transform::Renamer::Target::kHlslKeywords);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (r.substituteOverrideConfig) {
|
||||||
|
// This needs to run after SingleEntryPoint transform to get rid of overrides not used for
|
||||||
|
// the current entry point.
|
||||||
|
transformManager.Add<tint::transform::SubstituteOverride>();
|
||||||
|
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
|
||||||
|
std::move(r.substituteOverrideConfig).value());
|
||||||
|
}
|
||||||
|
|
||||||
// D3D12 registers like `t3` and `c3` have the same bindingOffset number in
|
// D3D12 registers like `t3` and `c3` have the same bindingOffset number in
|
||||||
// the remapping but should not be considered a collision because they have
|
// the remapping but should not be considered a collision because they have
|
||||||
// different types.
|
// different types.
|
||||||
|
@ -450,6 +365,13 @@ ResultOrError<std::string> TranslateToHLSL(
|
||||||
return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
|
return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (r.stage == SingleShaderStage::Compute) {
|
||||||
|
// Validate workgroup size after program runs transforms.
|
||||||
|
Extent3D _;
|
||||||
|
DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
|
||||||
|
transformedProgram, remappedEntryPointName->data(), r.limits));
|
||||||
|
}
|
||||||
|
|
||||||
if (r.stage == SingleShaderStage::Vertex) {
|
if (r.stage == SingleShaderStage::Vertex) {
|
||||||
if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
|
if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
|
||||||
*usesVertexOrInstanceIndex = data->has_vertex_or_instance_index;
|
*usesVertexOrInstanceIndex = data->has_vertex_or_instance_index;
|
||||||
|
@ -555,8 +477,7 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
|
||||||
|
|
||||||
req.bytecode.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
|
req.bytecode.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
|
||||||
req.bytecode.compileFlags = compileFlags;
|
req.bytecode.compileFlags = compileFlags;
|
||||||
req.bytecode.defineStrings =
|
|
||||||
GetOverridableConstantsDefines(programmableStage.constants, entryPoint.overrides);
|
|
||||||
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
||||||
req.bytecode.compiler = Compiler::DXC;
|
req.bytecode.compiler = Compiler::DXC;
|
||||||
req.bytecode.dxcLibrary = device->GetDxcLibrary().Get();
|
req.bytecode.dxcLibrary = device->GetDxcLibrary().Get();
|
||||||
|
@ -645,6 +566,11 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
|
||||||
|
if (!programmableStage.metadata->overrides.empty()) {
|
||||||
|
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
|
||||||
|
}
|
||||||
|
|
||||||
req.hlsl.inputProgram = GetTintProgram();
|
req.hlsl.inputProgram = GetTintProgram();
|
||||||
req.hlsl.entryPointName = programmableStage.entryPoint.c_str();
|
req.hlsl.entryPointName = programmableStage.entryPoint.c_str();
|
||||||
req.hlsl.stage = stage;
|
req.hlsl.stage = stage;
|
||||||
|
@ -657,6 +583,10 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
|
||||||
req.hlsl.remappedAccessControls = std::move(remappedAccessControls);
|
req.hlsl.remappedAccessControls = std::move(remappedAccessControls);
|
||||||
req.hlsl.newBindingsMap = BuildExternalTextureTransformBindings(layout);
|
req.hlsl.newBindingsMap = BuildExternalTextureTransformBindings(layout);
|
||||||
req.hlsl.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
|
req.hlsl.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
|
||||||
|
req.hlsl.substituteOverrideConfig = std::move(substituteOverrideConfig);
|
||||||
|
|
||||||
|
const CombinedLimits& limits = device->GetLimits();
|
||||||
|
req.hlsl.limits = LimitsForCompilationRequest::Create(limits.v1);
|
||||||
|
|
||||||
CacheResult<CompiledShader> compiledShader;
|
CacheResult<CompiledShader> compiledShader;
|
||||||
DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob,
|
DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob,
|
||||||
|
|
|
@ -40,8 +40,9 @@ MaybeError ComputePipeline::Initialize() {
|
||||||
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
|
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
|
||||||
ShaderModule::MetalFunctionData computeData;
|
ShaderModule::MetalFunctionData computeData;
|
||||||
|
|
||||||
DAWN_TRY(CreateMTLFunction(computeStage, SingleShaderStage::Compute, ToBackend(GetLayout()),
|
DAWN_TRY(ToBackend(computeStage.module.Get())
|
||||||
&computeData));
|
->CreateFunction(SingleShaderStage::Compute, computeStage, ToBackend(GetLayout()),
|
||||||
|
&computeData));
|
||||||
|
|
||||||
NSError* error = nullptr;
|
NSError* error = nullptr;
|
||||||
mMtlComputePipelineState.Acquire(
|
mMtlComputePipelineState.Acquire(
|
||||||
|
@ -53,8 +54,7 @@ MaybeError ComputePipeline::Initialize() {
|
||||||
ASSERT(mMtlComputePipelineState != nil);
|
ASSERT(mMtlComputePipelineState != nil);
|
||||||
|
|
||||||
// Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
|
// Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
|
||||||
Origin3D localSize = GetStage(SingleShaderStage::Compute).metadata->localWorkgroupSize;
|
mLocalWorkgroupSize = computeData.localWorkgroupSize;
|
||||||
mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z);
|
|
||||||
|
|
||||||
mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
|
mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
|
||||||
mWorkgroupAllocations = std::move(computeData.workgroupAllocations);
|
mWorkgroupAllocations = std::move(computeData.workgroupAllocations);
|
||||||
|
|
|
@ -340,8 +340,9 @@ MaybeError RenderPipeline::Initialize() {
|
||||||
const PerStage<ProgrammableStage>& allStages = GetAllStages();
|
const PerStage<ProgrammableStage>& allStages = GetAllStages();
|
||||||
const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex];
|
const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex];
|
||||||
ShaderModule::MetalFunctionData vertexData;
|
ShaderModule::MetalFunctionData vertexData;
|
||||||
DAWN_TRY(CreateMTLFunction(vertexStage, SingleShaderStage::Vertex, ToBackend(GetLayout()),
|
DAWN_TRY(ToBackend(vertexStage.module.Get())
|
||||||
&vertexData, 0xFFFFFFFF, this));
|
->CreateFunction(SingleShaderStage::Vertex, vertexStage, ToBackend(GetLayout()),
|
||||||
|
&vertexData, 0xFFFFFFFF, this));
|
||||||
|
|
||||||
descriptorMTL.vertexFunction = vertexData.function.Get();
|
descriptorMTL.vertexFunction = vertexData.function.Get();
|
||||||
if (vertexData.needsStorageBufferLength) {
|
if (vertexData.needsStorageBufferLength) {
|
||||||
|
@ -351,8 +352,9 @@ MaybeError RenderPipeline::Initialize() {
|
||||||
if (GetStageMask() & wgpu::ShaderStage::Fragment) {
|
if (GetStageMask() & wgpu::ShaderStage::Fragment) {
|
||||||
const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment];
|
const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment];
|
||||||
ShaderModule::MetalFunctionData fragmentData;
|
ShaderModule::MetalFunctionData fragmentData;
|
||||||
DAWN_TRY(CreateMTLFunction(fragmentStage, SingleShaderStage::Fragment,
|
DAWN_TRY(ToBackend(fragmentStage.module.Get())
|
||||||
ToBackend(GetLayout()), &fragmentData, GetSampleMask()));
|
->CreateFunction(SingleShaderStage::Fragment, fragmentStage,
|
||||||
|
ToBackend(GetLayout()), &fragmentData, GetSampleMask()));
|
||||||
|
|
||||||
descriptorMTL.fragmentFunction = fragmentData.function.Get();
|
descriptorMTL.fragmentFunction = fragmentData.function.Get();
|
||||||
if (fragmentData.needsStorageBufferLength) {
|
if (fragmentData.needsStorageBufferLength) {
|
||||||
|
|
|
@ -25,6 +25,10 @@
|
||||||
|
|
||||||
#import <Metal/Metal.h>
|
#import <Metal/Metal.h>
|
||||||
|
|
||||||
|
namespace dawn::native {
|
||||||
|
struct ProgrammableStage;
|
||||||
|
}
|
||||||
|
|
||||||
namespace dawn::native::metal {
|
namespace dawn::native::metal {
|
||||||
|
|
||||||
class Device;
|
class Device;
|
||||||
|
@ -42,15 +46,13 @@ class ShaderModule final : public ShaderModuleBase {
|
||||||
NSPRef<id<MTLFunction>> function;
|
NSPRef<id<MTLFunction>> function;
|
||||||
bool needsStorageBufferLength;
|
bool needsStorageBufferLength;
|
||||||
std::vector<uint32_t> workgroupAllocations;
|
std::vector<uint32_t> workgroupAllocations;
|
||||||
|
MTLSize localWorkgroupSize;
|
||||||
};
|
};
|
||||||
|
|
||||||
// MTLFunctionConstantValues needs @available tag to compile
|
MaybeError CreateFunction(SingleShaderStage stage,
|
||||||
// Use id (like void*) in function signature as workaround and do static cast inside
|
const ProgrammableStage& programmableStage,
|
||||||
MaybeError CreateFunction(const char* entryPointName,
|
|
||||||
SingleShaderStage stage,
|
|
||||||
const PipelineLayout* layout,
|
const PipelineLayout* layout,
|
||||||
MetalFunctionData* out,
|
MetalFunctionData* out,
|
||||||
id constantValues = nil,
|
|
||||||
uint32_t sampleMask = 0xFFFFFFFF,
|
uint32_t sampleMask = 0xFFFFFFFF,
|
||||||
const RenderPipeline* renderPipeline = nullptr);
|
const RenderPipeline* renderPipeline = nullptr);
|
||||||
|
|
||||||
|
|
|
@ -35,17 +35,20 @@ namespace {
|
||||||
|
|
||||||
using OptionalVertexPullingTransformConfig = std::optional<tint::transform::VertexPulling::Config>;
|
using OptionalVertexPullingTransformConfig = std::optional<tint::transform::VertexPulling::Config>;
|
||||||
|
|
||||||
#define MSL_COMPILATION_REQUEST_MEMBERS(X) \
|
#define MSL_COMPILATION_REQUEST_MEMBERS(X) \
|
||||||
X(const tint::Program*, inputProgram) \
|
X(SingleShaderStage, stage) \
|
||||||
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
|
X(const tint::Program*, inputProgram) \
|
||||||
X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \
|
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
|
||||||
X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \
|
X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \
|
||||||
X(std::string, entryPointName) \
|
X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \
|
||||||
X(uint32_t, sampleMask) \
|
X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
|
||||||
X(bool, emitVertexPointSize) \
|
X(LimitsForCompilationRequest, limits) \
|
||||||
X(bool, isRobustnessEnabled) \
|
X(std::string, entryPointName) \
|
||||||
X(bool, disableSymbolRenaming) \
|
X(uint32_t, sampleMask) \
|
||||||
X(bool, disableWorkgroupInit) \
|
X(bool, emitVertexPointSize) \
|
||||||
|
X(bool, isRobustnessEnabled) \
|
||||||
|
X(bool, disableSymbolRenaming) \
|
||||||
|
X(bool, disableWorkgroupInit) \
|
||||||
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
|
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
|
||||||
|
|
||||||
DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS);
|
DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS);
|
||||||
|
@ -53,12 +56,13 @@ DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS);
|
||||||
|
|
||||||
using WorkgroupAllocations = std::vector<uint32_t>;
|
using WorkgroupAllocations = std::vector<uint32_t>;
|
||||||
|
|
||||||
#define MSL_COMPILATION_MEMBERS(X) \
|
#define MSL_COMPILATION_MEMBERS(X) \
|
||||||
X(std::string, msl) \
|
X(std::string, msl) \
|
||||||
X(std::string, remappedEntryPointName) \
|
X(std::string, remappedEntryPointName) \
|
||||||
X(bool, needsStorageBufferLength) \
|
X(bool, needsStorageBufferLength) \
|
||||||
X(bool, hasInvariantAttribute) \
|
X(bool, hasInvariantAttribute) \
|
||||||
X(WorkgroupAllocations, workgroupAllocations)
|
X(WorkgroupAllocations, workgroupAllocations) \
|
||||||
|
X(Extent3D, localWorkgroupSize)
|
||||||
|
|
||||||
DAWN_SERIALIZABLE(struct, MslCompilation, MSL_COMPILATION_MEMBERS){};
|
DAWN_SERIALIZABLE(struct, MslCompilation, MSL_COMPILATION_MEMBERS){};
|
||||||
#undef MSL_COMPILATION_MEMBERS
|
#undef MSL_COMPILATION_MEMBERS
|
||||||
|
@ -92,13 +96,14 @@ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(
|
||||||
const tint::Program* inputProgram,
|
DeviceBase* device,
|
||||||
const char* entryPointName,
|
const ProgrammableStage& programmableStage,
|
||||||
SingleShaderStage stage,
|
SingleShaderStage stage,
|
||||||
const PipelineLayout* layout,
|
const PipelineLayout* layout,
|
||||||
uint32_t sampleMask,
|
ShaderModule::MetalFunctionData* out,
|
||||||
const RenderPipeline* renderPipeline) {
|
uint32_t sampleMask,
|
||||||
|
const RenderPipeline* renderPipeline) {
|
||||||
ScopedTintICEHandler scopedICEHandler(device);
|
ScopedTintICEHandler scopedICEHandler(device);
|
||||||
|
|
||||||
std::ostringstream errorStream;
|
std::ostringstream errorStream;
|
||||||
|
@ -137,7 +142,7 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
||||||
if (stage == SingleShaderStage::Vertex &&
|
if (stage == SingleShaderStage::Vertex &&
|
||||||
device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
|
device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
|
||||||
vertexPullingTransformConfig = BuildVertexPullingTransformConfig(
|
vertexPullingTransformConfig = BuildVertexPullingTransformConfig(
|
||||||
*renderPipeline, entryPointName, kPullingBufferBindingSet);
|
*renderPipeline, programmableStage.entryPoint.c_str(), kPullingBufferBindingSet);
|
||||||
|
|
||||||
for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
|
for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
|
||||||
uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot);
|
uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot);
|
||||||
|
@ -152,12 +157,19 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
|
||||||
|
if (!programmableStage.metadata->overrides.empty()) {
|
||||||
|
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
|
||||||
|
}
|
||||||
|
|
||||||
MslCompilationRequest req = {};
|
MslCompilationRequest req = {};
|
||||||
req.inputProgram = inputProgram;
|
req.stage = stage;
|
||||||
|
req.inputProgram = programmableStage.module->GetTintProgram();
|
||||||
req.bindingPoints = std::move(bindingPoints);
|
req.bindingPoints = std::move(bindingPoints);
|
||||||
req.externalTextureBindings = std::move(externalTextureBindings);
|
req.externalTextureBindings = std::move(externalTextureBindings);
|
||||||
req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig);
|
req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig);
|
||||||
req.entryPointName = entryPointName;
|
req.substituteOverrideConfig = std::move(substituteOverrideConfig);
|
||||||
|
req.entryPointName = programmableStage.entryPoint.c_str();
|
||||||
req.sampleMask = sampleMask;
|
req.sampleMask = sampleMask;
|
||||||
req.emitVertexPointSize =
|
req.emitVertexPointSize =
|
||||||
stage == SingleShaderStage::Vertex &&
|
stage == SingleShaderStage::Vertex &&
|
||||||
|
@ -166,6 +178,9 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
||||||
req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
|
req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
|
||||||
req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform());
|
req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform());
|
||||||
|
|
||||||
|
const CombinedLimits& limits = device->GetLimits();
|
||||||
|
req.limits = LimitsForCompilationRequest::Create(limits.v1);
|
||||||
|
|
||||||
CacheResult<MslCompilation> mslCompilation;
|
CacheResult<MslCompilation> mslCompilation;
|
||||||
DAWN_TRY_LOAD_OR_RUN(
|
DAWN_TRY_LOAD_OR_RUN(
|
||||||
mslCompilation, device, std::move(req), MslCompilation::FromBlob,
|
mslCompilation, device, std::move(req), MslCompilation::FromBlob,
|
||||||
|
@ -190,6 +205,14 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
||||||
std::move(r.vertexPullingTransformConfig).value());
|
std::move(r.vertexPullingTransformConfig).value());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (r.substituteOverrideConfig) {
|
||||||
|
// This needs to run after SingleEntryPoint transform to get rid of overrides not
|
||||||
|
// used for the current entry point.
|
||||||
|
transformManager.Add<tint::transform::SubstituteOverride>();
|
||||||
|
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
|
||||||
|
std::move(r.substituteOverrideConfig).value());
|
||||||
|
}
|
||||||
|
|
||||||
if (r.isRobustnessEnabled) {
|
if (r.isRobustnessEnabled) {
|
||||||
transformManager.Add<tint::transform::Robustness>();
|
transformManager.Add<tint::transform::Robustness>();
|
||||||
}
|
}
|
||||||
|
@ -230,6 +253,13 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
||||||
return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
|
return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Extent3D localSize{0, 0, 0};
|
||||||
|
if (r.stage == SingleShaderStage::Compute) {
|
||||||
|
// Validate workgroup size after program runs transforms.
|
||||||
|
DAWN_TRY_ASSIGN(localSize, ValidateComputeStageWorkgroupSize(
|
||||||
|
program, remappedEntryPointName.data(), r.limits));
|
||||||
|
}
|
||||||
|
|
||||||
tint::writer::msl::Options options;
|
tint::writer::msl::Options options;
|
||||||
options.buffer_size_ubo_index = kBufferLengthBufferSlot;
|
options.buffer_size_ubo_index = kBufferLengthBufferSlot;
|
||||||
options.fixed_sample_mask = r.sampleMask;
|
options.fixed_sample_mask = r.sampleMask;
|
||||||
|
@ -258,6 +288,7 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
||||||
result.needs_storage_buffer_sizes,
|
result.needs_storage_buffer_sizes,
|
||||||
result.has_invariant_attribute,
|
result.has_invariant_attribute,
|
||||||
std::move(workgroupAllocations),
|
std::move(workgroupAllocations),
|
||||||
|
localSize,
|
||||||
}};
|
}};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -272,11 +303,10 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
MaybeError ShaderModule::CreateFunction(const char* entryPointName,
|
MaybeError ShaderModule::CreateFunction(SingleShaderStage stage,
|
||||||
SingleShaderStage stage,
|
const ProgrammableStage& programmableStage,
|
||||||
const PipelineLayout* layout,
|
const PipelineLayout* layout,
|
||||||
ShaderModule::MetalFunctionData* out,
|
ShaderModule::MetalFunctionData* out,
|
||||||
id constantValuesPointer,
|
|
||||||
uint32_t sampleMask,
|
uint32_t sampleMask,
|
||||||
const RenderPipeline* renderPipeline) {
|
const RenderPipeline* renderPipeline) {
|
||||||
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleMTL::CreateFunction");
|
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleMTL::CreateFunction");
|
||||||
|
@ -284,16 +314,21 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
|
||||||
ASSERT(!IsError());
|
ASSERT(!IsError());
|
||||||
ASSERT(out);
|
ASSERT(out);
|
||||||
|
|
||||||
|
const char* entryPointName = programmableStage.entryPoint.c_str();
|
||||||
|
|
||||||
// Vertex stages must specify a renderPipeline
|
// Vertex stages must specify a renderPipeline
|
||||||
if (stage == SingleShaderStage::Vertex) {
|
if (stage == SingleShaderStage::Vertex) {
|
||||||
ASSERT(renderPipeline != nullptr);
|
ASSERT(renderPipeline != nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
CacheResult<MslCompilation> mslCompilation;
|
CacheResult<MslCompilation> mslCompilation;
|
||||||
DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), GetTintProgram(), entryPointName,
|
DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), programmableStage, stage, layout,
|
||||||
stage, layout, sampleMask, renderPipeline));
|
out, sampleMask, renderPipeline));
|
||||||
out->needsStorageBufferLength = mslCompilation->needsStorageBufferLength;
|
out->needsStorageBufferLength = mslCompilation->needsStorageBufferLength;
|
||||||
out->workgroupAllocations = std::move(mslCompilation->workgroupAllocations);
|
out->workgroupAllocations = std::move(mslCompilation->workgroupAllocations);
|
||||||
|
out->localWorkgroupSize = MTLSizeMake(mslCompilation->localWorkgroupSize.width,
|
||||||
|
mslCompilation->localWorkgroupSize.height,
|
||||||
|
mslCompilation->localWorkgroupSize.depthOrArrayLayers);
|
||||||
|
|
||||||
NSRef<NSString> mslSource =
|
NSRef<NSString> mslSource =
|
||||||
AcquireNSRef([[NSString alloc] initWithUTF8String:mslCompilation->msl.c_str()]);
|
AcquireNSRef([[NSString alloc] initWithUTF8String:mslCompilation->msl.c_str()]);
|
||||||
|
@ -327,25 +362,7 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
|
||||||
|
|
||||||
{
|
{
|
||||||
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "MTLLibrary::newFunctionWithName");
|
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "MTLLibrary::newFunctionWithName");
|
||||||
if (constantValuesPointer != nil) {
|
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
|
||||||
if (@available(macOS 10.12, *)) {
|
|
||||||
MTLFunctionConstantValues* constantValues = constantValuesPointer;
|
|
||||||
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()
|
|
||||||
constantValues:constantValues
|
|
||||||
error:&error]);
|
|
||||||
if (error != nullptr) {
|
|
||||||
if (error.code != MTLLibraryErrorCompileWarning) {
|
|
||||||
return DAWN_VALIDATION_ERROR("Function compile error: %s",
|
|
||||||
[error.localizedDescription UTF8String]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ASSERT(out->function != nil);
|
|
||||||
} else {
|
|
||||||
UNREACHABLE();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (BlobCache* cache = GetDevice()->GetBlobCache()) {
|
if (BlobCache* cache = GetDevice()->GetBlobCache()) {
|
||||||
|
|
|
@ -68,15 +68,6 @@ void EnsureDestinationTextureInitialized(CommandRecordingContext* commandContext
|
||||||
|
|
||||||
MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect);
|
MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect);
|
||||||
|
|
||||||
// Helper function to create function with constant values wrapped in
|
|
||||||
// if available branch
|
|
||||||
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
|
|
||||||
SingleShaderStage singleShaderStage,
|
|
||||||
PipelineLayout* pipelineLayout,
|
|
||||||
ShaderModule::MetalFunctionData* functionData,
|
|
||||||
uint32_t sampleMask = 0xFFFFFFFF,
|
|
||||||
const RenderPipeline* renderPipeline = nullptr);
|
|
||||||
|
|
||||||
// Allow use MTLStoreActionStoreAndMultismapleResolve because the logic in the backend is
|
// Allow use MTLStoreActionStoreAndMultismapleResolve because the logic in the backend is
|
||||||
// first to compute what the "best" Metal render pass descriptor is, then fix it up if we
|
// first to compute what the "best" Metal render pass descriptor is, then fix it up if we
|
||||||
// are not on macOS 10.12 (i.e. the EmulateStoreAndMSAAResolve toggle is on).
|
// are not on macOS 10.12 (i.e. the EmulateStoreAndMSAAResolve toggle is on).
|
||||||
|
|
|
@ -323,104 +323,6 @@ MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect) {
|
||||||
return MTLBlitOptionNone;
|
return MTLBlitOptionNone;
|
||||||
}
|
}
|
||||||
|
|
||||||
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
|
|
||||||
SingleShaderStage singleShaderStage,
|
|
||||||
PipelineLayout* pipelineLayout,
|
|
||||||
ShaderModule::MetalFunctionData* functionData,
|
|
||||||
uint32_t sampleMask,
|
|
||||||
const RenderPipeline* renderPipeline) {
|
|
||||||
ShaderModule* shaderModule = ToBackend(programmableStage.module.Get());
|
|
||||||
const char* shaderEntryPoint = programmableStage.entryPoint.c_str();
|
|
||||||
const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint);
|
|
||||||
if (entryPointMetadata.overrides.size() == 0) {
|
|
||||||
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout,
|
|
||||||
functionData, nil, sampleMask, renderPipeline));
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (@available(macOS 10.12, *)) {
|
|
||||||
// MTLFunctionConstantValues can only be created within the if available branch
|
|
||||||
NSRef<MTLFunctionConstantValues> constantValues =
|
|
||||||
AcquireNSRef([MTLFunctionConstantValues new]);
|
|
||||||
|
|
||||||
std::unordered_set<std::string> overriddenConstants;
|
|
||||||
|
|
||||||
auto switchType = [&](EntryPointMetadata::Override::Type dawnType,
|
|
||||||
MTLDataType* type, OverrideScalar* entry,
|
|
||||||
double value = 0) {
|
|
||||||
switch (dawnType) {
|
|
||||||
case EntryPointMetadata::Override::Type::Boolean:
|
|
||||||
*type = MTLDataTypeBool;
|
|
||||||
if (entry) {
|
|
||||||
entry->b = static_cast<int32_t>(value);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case EntryPointMetadata::Override::Type::Float32:
|
|
||||||
*type = MTLDataTypeFloat;
|
|
||||||
if (entry) {
|
|
||||||
entry->f32 = static_cast<float>(value);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case EntryPointMetadata::Override::Type::Int32:
|
|
||||||
*type = MTLDataTypeInt;
|
|
||||||
if (entry) {
|
|
||||||
entry->i32 = static_cast<int32_t>(value);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case EntryPointMetadata::Override::Type::Uint32:
|
|
||||||
*type = MTLDataTypeUInt;
|
|
||||||
if (entry) {
|
|
||||||
entry->u32 = static_cast<uint32_t>(value);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
UNREACHABLE();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
for (const auto& [name, value] : programmableStage.constants) {
|
|
||||||
overriddenConstants.insert(name);
|
|
||||||
|
|
||||||
// This is already validated so `name` must exist
|
|
||||||
const auto& moduleConstant = entryPointMetadata.overrides.at(name);
|
|
||||||
|
|
||||||
MTLDataType type;
|
|
||||||
OverrideScalar entry{};
|
|
||||||
|
|
||||||
switchType(moduleConstant.type, &type, &entry, value);
|
|
||||||
|
|
||||||
[constantValues.Get() setConstantValue:&entry type:type atIndex:moduleConstant.id];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set shader initialized default values because MSL function_constant
|
|
||||||
// has no default value
|
|
||||||
for (const std::string& name : entryPointMetadata.initializedOverrides) {
|
|
||||||
if (overriddenConstants.count(name) != 0) {
|
|
||||||
// This constant already has overridden value
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Must exist because it is validated
|
|
||||||
const auto& moduleConstant = entryPointMetadata.overrides.at(name);
|
|
||||||
ASSERT(moduleConstant.isInitialized);
|
|
||||||
MTLDataType type;
|
|
||||||
|
|
||||||
switchType(moduleConstant.type, &type, nullptr);
|
|
||||||
|
|
||||||
[constantValues.Get() setConstantValue:&moduleConstant.defaultValue
|
|
||||||
type:type
|
|
||||||
atIndex:moduleConstant.id];
|
|
||||||
}
|
|
||||||
|
|
||||||
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout,
|
|
||||||
functionData, constantValues.Get(), sampleMask,
|
|
||||||
renderPipeline));
|
|
||||||
} else {
|
|
||||||
UNREACHABLE();
|
|
||||||
}
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
MaybeError EncodeMetalRenderPass(Device* device,
|
MaybeError EncodeMetalRenderPass(Device* device,
|
||||||
CommandRecordingContext* commandContext,
|
CommandRecordingContext* commandContext,
|
||||||
MTLRenderPassDescriptor* mtlRenderPass,
|
MTLRenderPassDescriptor* mtlRenderPass,
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "dawn/native/Limits.h"
|
||||||
|
|
||||||
namespace dawn::native::stream {
|
namespace dawn::native::stream {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
|
|
@ -61,16 +61,12 @@ MaybeError ComputePipeline::Initialize() {
|
||||||
|
|
||||||
ShaderModule::ModuleAndSpirv moduleAndSpirv;
|
ShaderModule::ModuleAndSpirv moduleAndSpirv;
|
||||||
DAWN_TRY_ASSIGN(moduleAndSpirv,
|
DAWN_TRY_ASSIGN(moduleAndSpirv,
|
||||||
module->GetHandleAndSpirv(computeStage.entryPoint.c_str(), layout));
|
module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout));
|
||||||
|
|
||||||
createInfo.stage.module = moduleAndSpirv.module;
|
createInfo.stage.module = moduleAndSpirv.module;
|
||||||
createInfo.stage.pName = computeStage.entryPoint.c_str();
|
createInfo.stage.pName = computeStage.entryPoint.c_str();
|
||||||
|
|
||||||
std::vector<OverrideScalar> specializationDataEntries;
|
createInfo.stage.pSpecializationInfo = nullptr;
|
||||||
std::vector<VkSpecializationMapEntry> specializationMapEntries;
|
|
||||||
VkSpecializationInfo specializationInfo{};
|
|
||||||
createInfo.stage.pSpecializationInfo = GetVkSpecializationInfo(
|
|
||||||
computeStage, &specializationInfo, &specializationDataEntries, &specializationMapEntries);
|
|
||||||
|
|
||||||
PNextChainBuilder stageExtChain(&createInfo.stage);
|
PNextChainBuilder stageExtChain(&createInfo.stage);
|
||||||
|
|
||||||
|
|
|
@ -341,9 +341,6 @@ MaybeError RenderPipeline::Initialize() {
|
||||||
|
|
||||||
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
|
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
|
||||||
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
|
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
|
||||||
std::array<std::vector<OverrideScalar>, 2> specializationDataEntriesPerStages;
|
|
||||||
std::array<std::vector<VkSpecializationMapEntry>, 2> specializationMapEntriesPerStages;
|
|
||||||
std::array<VkSpecializationInfo, 2> specializationInfoPerStages;
|
|
||||||
uint32_t stageCount = 0;
|
uint32_t stageCount = 0;
|
||||||
|
|
||||||
for (auto stage : IterateStages(this->GetStageMask())) {
|
for (auto stage : IterateStages(this->GetStageMask())) {
|
||||||
|
@ -354,7 +351,7 @@ MaybeError RenderPipeline::Initialize() {
|
||||||
|
|
||||||
ShaderModule::ModuleAndSpirv moduleAndSpirv;
|
ShaderModule::ModuleAndSpirv moduleAndSpirv;
|
||||||
DAWN_TRY_ASSIGN(moduleAndSpirv,
|
DAWN_TRY_ASSIGN(moduleAndSpirv,
|
||||||
module->GetHandleAndSpirv(programmableStage.entryPoint.c_str(), layout));
|
module->GetHandleAndSpirv(stage, programmableStage, layout));
|
||||||
|
|
||||||
shaderStage.module = moduleAndSpirv.module;
|
shaderStage.module = moduleAndSpirv.module;
|
||||||
shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||||
|
@ -379,11 +376,6 @@ MaybeError RenderPipeline::Initialize() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
shaderStage.pSpecializationInfo =
|
|
||||||
GetVkSpecializationInfo(programmableStage, &specializationInfoPerStages[stageCount],
|
|
||||||
&specializationDataEntriesPerStages[stageCount],
|
|
||||||
&specializationMapEntriesPerStages[stageCount]);
|
|
||||||
|
|
||||||
DAWN_ASSERT(stageCount < 2);
|
DAWN_ASSERT(stageCount < 2);
|
||||||
shaderStages[stageCount] = shaderStage;
|
shaderStages[stageCount] = shaderStage;
|
||||||
stageCount++;
|
stageCount++;
|
||||||
|
|
|
@ -60,13 +60,35 @@ class ShaderModule::Spirv : private Blob {
|
||||||
|
|
||||||
namespace dawn::native::vulkan {
|
namespace dawn::native::vulkan {
|
||||||
|
|
||||||
|
bool TransformedShaderModuleCacheKey::operator==(
|
||||||
|
const TransformedShaderModuleCacheKey& other) const {
|
||||||
|
if (layout != other.layout || entryPoint != other.entryPoint ||
|
||||||
|
constants.size() != other.constants.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!std::equal(constants.begin(), constants.end(), other.constants.begin())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t TransformedShaderModuleCacheKeyHashFunc::operator()(
|
||||||
|
const TransformedShaderModuleCacheKey& key) const {
|
||||||
|
size_t hash = 0;
|
||||||
|
HashCombine(&hash, key.layout, key.entryPoint);
|
||||||
|
for (const auto& entry : key.constants) {
|
||||||
|
HashCombine(&hash, entry.first, entry.second);
|
||||||
|
}
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
|
||||||
class ShaderModule::ConcurrentTransformedShaderModuleCache {
|
class ShaderModule::ConcurrentTransformedShaderModuleCache {
|
||||||
public:
|
public:
|
||||||
explicit ConcurrentTransformedShaderModuleCache(Device* device);
|
explicit ConcurrentTransformedShaderModuleCache(Device* device);
|
||||||
~ConcurrentTransformedShaderModuleCache();
|
~ConcurrentTransformedShaderModuleCache();
|
||||||
|
|
||||||
std::optional<ModuleAndSpirv> Find(const PipelineLayoutEntryPointPair& key);
|
std::optional<ModuleAndSpirv> Find(const TransformedShaderModuleCacheKey& key);
|
||||||
ModuleAndSpirv AddOrGet(const PipelineLayoutEntryPointPair& key,
|
ModuleAndSpirv AddOrGet(const TransformedShaderModuleCacheKey& key,
|
||||||
VkShaderModule module,
|
VkShaderModule module,
|
||||||
Spirv&& spirv);
|
Spirv&& spirv);
|
||||||
|
|
||||||
|
@ -75,7 +97,9 @@ class ShaderModule::ConcurrentTransformedShaderModuleCache {
|
||||||
|
|
||||||
Device* mDevice;
|
Device* mDevice;
|
||||||
std::mutex mMutex;
|
std::mutex mMutex;
|
||||||
std::unordered_map<PipelineLayoutEntryPointPair, Entry, PipelineLayoutEntryPointPairHashFunc>
|
std::unordered_map<TransformedShaderModuleCacheKey,
|
||||||
|
Entry,
|
||||||
|
TransformedShaderModuleCacheKeyHashFunc>
|
||||||
mTransformedShaderModuleCache;
|
mTransformedShaderModuleCache;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -92,7 +116,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::~ConcurrentTransformedShad
|
||||||
|
|
||||||
std::optional<ShaderModule::ModuleAndSpirv>
|
std::optional<ShaderModule::ModuleAndSpirv>
|
||||||
ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
|
ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
|
||||||
const PipelineLayoutEntryPointPair& key) {
|
const TransformedShaderModuleCacheKey& key) {
|
||||||
std::lock_guard<std::mutex> lock(mMutex);
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
auto iter = mTransformedShaderModuleCache.find(key);
|
auto iter = mTransformedShaderModuleCache.find(key);
|
||||||
if (iter != mTransformedShaderModuleCache.end()) {
|
if (iter != mTransformedShaderModuleCache.end()) {
|
||||||
|
@ -106,7 +130,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
|
||||||
}
|
}
|
||||||
|
|
||||||
ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet(
|
ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet(
|
||||||
const PipelineLayoutEntryPointPair& key,
|
const TransformedShaderModuleCacheKey& key,
|
||||||
VkShaderModule module,
|
VkShaderModule module,
|
||||||
Spirv&& spirv) {
|
Spirv&& spirv) {
|
||||||
ASSERT(module != VK_NULL_HANDLE);
|
ASSERT(module != VK_NULL_HANDLE);
|
||||||
|
@ -168,20 +192,24 @@ void ShaderModule::DestroyImpl() {
|
||||||
|
|
||||||
ShaderModule::~ShaderModule() = default;
|
ShaderModule::~ShaderModule() = default;
|
||||||
|
|
||||||
#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \
|
#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \
|
||||||
X(const tint::Program*, inputProgram) \
|
X(SingleShaderStage, stage) \
|
||||||
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
|
X(const tint::Program*, inputProgram) \
|
||||||
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
|
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
|
||||||
X(std::string_view, entryPointName) \
|
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
|
||||||
X(bool, disableWorkgroupInit) \
|
X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
|
||||||
X(bool, useZeroInitializeWorkgroupMemoryExtension) \
|
X(LimitsForCompilationRequest, limits) \
|
||||||
|
X(std::string_view, entryPointName) \
|
||||||
|
X(bool, disableWorkgroupInit) \
|
||||||
|
X(bool, useZeroInitializeWorkgroupMemoryExtension) \
|
||||||
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
|
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
|
||||||
|
|
||||||
DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS);
|
DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS);
|
||||||
#undef SPIRV_COMPILATION_REQUEST_MEMBERS
|
#undef SPIRV_COMPILATION_REQUEST_MEMBERS
|
||||||
|
|
||||||
ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
|
ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
|
||||||
const char* entryPointName,
|
SingleShaderStage stage,
|
||||||
|
const ProgrammableStage& programmableStage,
|
||||||
const PipelineLayout* layout) {
|
const PipelineLayout* layout) {
|
||||||
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv");
|
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv");
|
||||||
|
|
||||||
|
@ -191,7 +219,8 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
|
||||||
ScopedTintICEHandler scopedICEHandler(GetDevice());
|
ScopedTintICEHandler scopedICEHandler(GetDevice());
|
||||||
|
|
||||||
// Check to see if we have the handle and spirv cached already.
|
// Check to see if we have the handle and spirv cached already.
|
||||||
auto cacheKey = std::make_pair(layout, entryPointName);
|
auto cacheKey = TransformedShaderModuleCacheKey{layout, programmableStage.entryPoint.c_str(),
|
||||||
|
programmableStage.constants};
|
||||||
auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey);
|
auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey);
|
||||||
if (handleAndSpirv.has_value()) {
|
if (handleAndSpirv.has_value()) {
|
||||||
return std::move(*handleAndSpirv);
|
return std::move(*handleAndSpirv);
|
||||||
|
@ -204,7 +233,8 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
|
||||||
using BindingPoint = tint::transform::BindingPoint;
|
using BindingPoint = tint::transform::BindingPoint;
|
||||||
BindingRemapper::BindingPoints bindingPoints;
|
BindingRemapper::BindingPoints bindingPoints;
|
||||||
|
|
||||||
const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings;
|
const BindingInfoArray& moduleBindingInfo =
|
||||||
|
GetEntryPoint(programmableStage.entryPoint.c_str()).bindings;
|
||||||
|
|
||||||
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
|
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
|
||||||
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
|
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
|
||||||
|
@ -238,16 +268,26 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
|
||||||
|
if (!programmableStage.metadata->overrides.empty()) {
|
||||||
|
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
|
||||||
|
}
|
||||||
|
|
||||||
#if TINT_BUILD_SPV_WRITER
|
#if TINT_BUILD_SPV_WRITER
|
||||||
SpirvCompilationRequest req = {};
|
SpirvCompilationRequest req = {};
|
||||||
|
req.stage = stage;
|
||||||
req.inputProgram = GetTintProgram();
|
req.inputProgram = GetTintProgram();
|
||||||
req.bindingPoints = std::move(bindingPoints);
|
req.bindingPoints = std::move(bindingPoints);
|
||||||
req.newBindingsMap = std::move(newBindingsMap);
|
req.newBindingsMap = std::move(newBindingsMap);
|
||||||
req.entryPointName = entryPointName;
|
req.entryPointName = programmableStage.entryPoint;
|
||||||
req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
|
req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
|
||||||
req.useZeroInitializeWorkgroupMemoryExtension =
|
req.useZeroInitializeWorkgroupMemoryExtension =
|
||||||
GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension);
|
GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension);
|
||||||
req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform());
|
req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform());
|
||||||
|
req.substituteOverrideConfig = std::move(substituteOverrideConfig);
|
||||||
|
|
||||||
|
const CombinedLimits& limits = GetDevice()->GetLimits();
|
||||||
|
req.limits = LimitsForCompilationRequest::Create(limits.v1);
|
||||||
|
|
||||||
CacheResult<Spirv> spirv;
|
CacheResult<Spirv> spirv;
|
||||||
DAWN_TRY_LOAD_OR_RUN(
|
DAWN_TRY_LOAD_OR_RUN(
|
||||||
|
@ -270,12 +310,27 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
|
||||||
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
|
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
|
||||||
r.newBindingsMap);
|
r.newBindingsMap);
|
||||||
}
|
}
|
||||||
|
if (r.substituteOverrideConfig) {
|
||||||
|
// This needs to run after SingleEntryPoint transform to get rid of overrides not
|
||||||
|
// used for the current entry point.
|
||||||
|
transformManager.Add<tint::transform::SubstituteOverride>();
|
||||||
|
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
|
||||||
|
std::move(r.substituteOverrideConfig).value());
|
||||||
|
}
|
||||||
tint::Program program;
|
tint::Program program;
|
||||||
{
|
{
|
||||||
TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms");
|
TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms");
|
||||||
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram,
|
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram,
|
||||||
transformInputs, nullptr, nullptr));
|
transformInputs, nullptr, nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (r.stage == SingleShaderStage::Compute) {
|
||||||
|
// Validate workgroup size after program runs transforms.
|
||||||
|
Extent3D _;
|
||||||
|
DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
|
||||||
|
program, r.entryPointName.data(), r.limits));
|
||||||
|
}
|
||||||
|
|
||||||
tint::writer::spirv::Options options;
|
tint::writer::spirv::Options options;
|
||||||
options.emit_vertex_point_size = true;
|
options.emit_vertex_point_size = true;
|
||||||
options.disable_workgroup_init = r.disableWorkgroupInit;
|
options.disable_workgroup_init = r.disableWorkgroupInit;
|
||||||
|
|
|
@ -18,14 +18,32 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "dawn/common/HashUtils.h"
|
||||||
#include "dawn/common/vulkan_platform.h"
|
#include "dawn/common/vulkan_platform.h"
|
||||||
#include "dawn/native/Error.h"
|
#include "dawn/native/Error.h"
|
||||||
#include "dawn/native/ShaderModule.h"
|
#include "dawn/native/ShaderModule.h"
|
||||||
|
|
||||||
namespace dawn::native::vulkan {
|
namespace dawn::native {
|
||||||
|
|
||||||
|
struct ProgrammableStage;
|
||||||
|
|
||||||
|
namespace vulkan {
|
||||||
|
|
||||||
|
struct TransformedShaderModuleCacheKey {
|
||||||
|
const PipelineLayoutBase* layout;
|
||||||
|
std::string entryPoint;
|
||||||
|
PipelineConstantEntries constants;
|
||||||
|
|
||||||
|
bool operator==(const TransformedShaderModuleCacheKey& other) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TransformedShaderModuleCacheKeyHashFunc {
|
||||||
|
size_t operator()(const TransformedShaderModuleCacheKey& key) const;
|
||||||
|
};
|
||||||
|
|
||||||
class Device;
|
class Device;
|
||||||
class PipelineLayout;
|
class PipelineLayout;
|
||||||
|
@ -44,7 +62,8 @@ class ShaderModule final : public ShaderModuleBase {
|
||||||
ShaderModuleParseResult* parseResult,
|
ShaderModuleParseResult* parseResult,
|
||||||
OwnedCompilationMessages* compilationMessages);
|
OwnedCompilationMessages* compilationMessages);
|
||||||
|
|
||||||
ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const char* entryPointName,
|
ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(SingleShaderStage stage,
|
||||||
|
const ProgrammableStage& programmableStage,
|
||||||
const PipelineLayout* layout);
|
const PipelineLayout* layout);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -59,6 +78,8 @@ class ShaderModule final : public ShaderModuleBase {
|
||||||
std::unique_ptr<ConcurrentTransformedShaderModuleCache> mTransformedShaderModuleCache;
|
std::unique_ptr<ConcurrentTransformedShaderModuleCache> mTransformedShaderModuleCache;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace dawn::native::vulkan
|
} // namespace vulkan
|
||||||
|
|
||||||
|
} // namespace dawn::native
|
||||||
|
|
||||||
#endif // SRC_DAWN_NATIVE_VULKAN_SHADERMODULEVK_H_
|
#endif // SRC_DAWN_NATIVE_VULKAN_SHADERMODULEVK_H_
|
||||||
|
|
|
@ -258,60 +258,4 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName) {
|
||||||
return std::string(debugName, length);
|
return std::string(debugName, length);
|
||||||
}
|
}
|
||||||
|
|
||||||
VkSpecializationInfo* GetVkSpecializationInfo(
|
|
||||||
const ProgrammableStage& programmableStage,
|
|
||||||
VkSpecializationInfo* specializationInfo,
|
|
||||||
std::vector<OverrideScalar>* specializationDataEntries,
|
|
||||||
std::vector<VkSpecializationMapEntry>* specializationMapEntries) {
|
|
||||||
ASSERT(specializationInfo);
|
|
||||||
ASSERT(specializationDataEntries);
|
|
||||||
ASSERT(specializationMapEntries);
|
|
||||||
|
|
||||||
if (programmableStage.constants.size() == 0) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const EntryPointMetadata& entryPointMetaData =
|
|
||||||
programmableStage.module->GetEntryPoint(programmableStage.entryPoint);
|
|
||||||
|
|
||||||
for (const auto& pipelineConstant : programmableStage.constants) {
|
|
||||||
const std::string& identifier = pipelineConstant.first;
|
|
||||||
double value = pipelineConstant.second;
|
|
||||||
|
|
||||||
// This is already validated so `identifier` must exist
|
|
||||||
const auto& moduleConstant = entryPointMetaData.overrides.at(identifier);
|
|
||||||
|
|
||||||
specializationMapEntries->push_back(VkSpecializationMapEntry{
|
|
||||||
moduleConstant.id,
|
|
||||||
static_cast<uint32_t>(specializationDataEntries->size() * sizeof(OverrideScalar)),
|
|
||||||
sizeof(OverrideScalar)});
|
|
||||||
|
|
||||||
OverrideScalar entry{};
|
|
||||||
switch (moduleConstant.type) {
|
|
||||||
case EntryPointMetadata::Override::Type::Boolean:
|
|
||||||
entry.b = static_cast<int32_t>(value);
|
|
||||||
break;
|
|
||||||
case EntryPointMetadata::Override::Type::Float32:
|
|
||||||
entry.f32 = static_cast<float>(value);
|
|
||||||
break;
|
|
||||||
case EntryPointMetadata::Override::Type::Int32:
|
|
||||||
entry.i32 = static_cast<int32_t>(value);
|
|
||||||
break;
|
|
||||||
case EntryPointMetadata::Override::Type::Uint32:
|
|
||||||
entry.u32 = static_cast<uint32_t>(value);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
UNREACHABLE();
|
|
||||||
}
|
|
||||||
specializationDataEntries->push_back(entry);
|
|
||||||
}
|
|
||||||
|
|
||||||
specializationInfo->mapEntryCount = static_cast<uint32_t>(specializationMapEntries->size());
|
|
||||||
specializationInfo->pMapEntries = specializationMapEntries->data();
|
|
||||||
specializationInfo->dataSize = specializationDataEntries->size() * sizeof(OverrideScalar);
|
|
||||||
specializationInfo->pData = specializationDataEntries->data();
|
|
||||||
|
|
||||||
return specializationInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace dawn::native::vulkan
|
} // namespace dawn::native::vulkan
|
||||||
|
|
|
@ -144,15 +144,6 @@ void SetDebugName(Device* device,
|
||||||
std::string GetNextDeviceDebugPrefix();
|
std::string GetNextDeviceDebugPrefix();
|
||||||
std::string GetDeviceDebugPrefixFromDebugName(const char* debugName);
|
std::string GetDeviceDebugPrefixFromDebugName(const char* debugName);
|
||||||
|
|
||||||
// Returns nullptr or &specializationInfo
|
|
||||||
// specializationInfo, specializationDataEntries, specializationMapEntries needs to
|
|
||||||
// be alive at least until VkSpecializationInfo is passed into Vulkan Create*Pipelines
|
|
||||||
VkSpecializationInfo* GetVkSpecializationInfo(
|
|
||||||
const ProgrammableStage& programmableStage,
|
|
||||||
VkSpecializationInfo* specializationInfo,
|
|
||||||
std::vector<OverrideScalar>* specializationDataEntries,
|
|
||||||
std::vector<VkSpecializationMapEntry>* specializationMapEntries);
|
|
||||||
|
|
||||||
} // namespace dawn::native::vulkan
|
} // namespace dawn::native::vulkan
|
||||||
|
|
||||||
#endif // SRC_DAWN_NATIVE_VULKAN_UTILSVULKAN_H_
|
#endif // SRC_DAWN_NATIVE_VULKAN_UTILSVULKAN_H_
|
||||||
|
|
|
@ -265,6 +265,7 @@ dawn_test("dawn_unittests") {
|
||||||
"unittests/native/CreatePipelineAsyncTaskTests.cpp",
|
"unittests/native/CreatePipelineAsyncTaskTests.cpp",
|
||||||
"unittests/native/DestroyObjectTests.cpp",
|
"unittests/native/DestroyObjectTests.cpp",
|
||||||
"unittests/native/DeviceCreationTests.cpp",
|
"unittests/native/DeviceCreationTests.cpp",
|
||||||
|
"unittests/native/ObjectContentHasherTests.cpp",
|
||||||
"unittests/native/StreamTests.cpp",
|
"unittests/native/StreamTests.cpp",
|
||||||
"unittests/validation/BindGroupValidationTests.cpp",
|
"unittests/validation/BindGroupValidationTests.cpp",
|
||||||
"unittests/validation/BufferValidationTests.cpp",
|
"unittests/validation/BufferValidationTests.cpp",
|
||||||
|
@ -489,6 +490,7 @@ source_set("end2end_tests_sources") {
|
||||||
"end2end/ScissorTests.cpp",
|
"end2end/ScissorTests.cpp",
|
||||||
"end2end/ShaderFloat16Tests.cpp",
|
"end2end/ShaderFloat16Tests.cpp",
|
||||||
"end2end/ShaderTests.cpp",
|
"end2end/ShaderTests.cpp",
|
||||||
|
"end2end/ShaderValidationTests.cpp",
|
||||||
"end2end/StorageTextureTests.cpp",
|
"end2end/StorageTextureTests.cpp",
|
||||||
"end2end/SubresourceRenderAttachmentTests.cpp",
|
"end2end/SubresourceRenderAttachmentTests.cpp",
|
||||||
"end2end/Texture3DTests.cpp",
|
"end2end/Texture3DTests.cpp",
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "dawn/tests/DawnTest.h"
|
#include "dawn/tests/DawnTest.h"
|
||||||
|
|
||||||
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
|
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
|
||||||
|
@ -158,6 +160,46 @@ TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) {
|
||||||
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that ComputePipeline are correctly deduplicated wrt. their constants override values
|
||||||
|
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnOverrides) {
|
||||||
|
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
|
||||||
|
override x: u32 = 1u;
|
||||||
|
var<workgroup> i : u32;
|
||||||
|
@compute @workgroup_size(x) fn main() {
|
||||||
|
i = 0u;
|
||||||
|
})");
|
||||||
|
|
||||||
|
wgpu::PipelineLayout layout = utils::MakeBasicPipelineLayout(device, nullptr);
|
||||||
|
|
||||||
|
wgpu::ComputePipelineDescriptor desc;
|
||||||
|
desc.compute.entryPoint = "main";
|
||||||
|
desc.layout = layout;
|
||||||
|
desc.compute.module = module;
|
||||||
|
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 16}};
|
||||||
|
desc.compute.constantCount = constants.size();
|
||||||
|
desc.compute.constants = constants.data();
|
||||||
|
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
|
||||||
|
|
||||||
|
std::vector<wgpu::ConstantEntry> sameConstants{{nullptr, "x", 16}};
|
||||||
|
desc.compute.constantCount = sameConstants.size();
|
||||||
|
desc.compute.constants = sameConstants.data();
|
||||||
|
wgpu::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
|
||||||
|
|
||||||
|
desc.compute.constantCount = 0;
|
||||||
|
desc.compute.constants = nullptr;
|
||||||
|
wgpu::ComputePipeline otherPipeline1 = device.CreateComputePipeline(&desc);
|
||||||
|
|
||||||
|
std::vector<wgpu::ConstantEntry> otherConstants{{nullptr, "x", 4}};
|
||||||
|
desc.compute.constantCount = otherConstants.size();
|
||||||
|
desc.compute.constants = otherConstants.data();
|
||||||
|
wgpu::ComputePipeline otherPipeline2 = device.CreateComputePipeline(&desc);
|
||||||
|
|
||||||
|
EXPECT_NE(pipeline.Get(), otherPipeline1.Get());
|
||||||
|
EXPECT_NE(pipeline.Get(), otherPipeline2.Get());
|
||||||
|
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||||
|
}
|
||||||
|
|
||||||
// Test that ComputePipeline are correctly deduplicated wrt. their layout
|
// Test that ComputePipeline are correctly deduplicated wrt. their layout
|
||||||
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) {
|
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) {
|
||||||
wgpu::BindGroupLayout bgl = utils::MakeBindGroupLayout(
|
wgpu::BindGroupLayout bgl = utils::MakeBindGroupLayout(
|
||||||
|
@ -303,6 +345,48 @@ TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnFragmentModule) {
|
||||||
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that Renderpipelines are correctly deduplicated wrt. their constants override values
|
||||||
|
TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnOverrides) {
|
||||||
|
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
|
||||||
|
override a: f32 = 1.0;
|
||||||
|
@vertex fn vertexMain() -> @builtin(position) vec4<f32> {
|
||||||
|
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||||
|
}
|
||||||
|
@fragment fn fragmentMain() -> @location(0) vec4<f32> {
|
||||||
|
return vec4<f32>(0.0, 0.0, 0.0, a);
|
||||||
|
})");
|
||||||
|
|
||||||
|
utils::ComboRenderPipelineDescriptor desc;
|
||||||
|
desc.vertex.module = module;
|
||||||
|
desc.vertex.entryPoint = "vertexMain";
|
||||||
|
desc.cFragment.module = module;
|
||||||
|
desc.cFragment.entryPoint = "fragmentMain";
|
||||||
|
desc.cTargets[0].writeMask = wgpu::ColorWriteMask::None;
|
||||||
|
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{{nullptr, "a", 0.5}};
|
||||||
|
desc.cFragment.constantCount = constants.size();
|
||||||
|
desc.cFragment.constants = constants.data();
|
||||||
|
wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&desc);
|
||||||
|
|
||||||
|
std::vector<wgpu::ConstantEntry> sameConstants{{nullptr, "a", 0.5}};
|
||||||
|
desc.cFragment.constantCount = sameConstants.size();
|
||||||
|
desc.cFragment.constants = sameConstants.data();
|
||||||
|
wgpu::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc);
|
||||||
|
|
||||||
|
std::vector<wgpu::ConstantEntry> otherConstants{{nullptr, "a", 1.0}};
|
||||||
|
desc.cFragment.constantCount = otherConstants.size();
|
||||||
|
desc.cFragment.constants = otherConstants.data();
|
||||||
|
wgpu::RenderPipeline otherPipeline1 = device.CreateRenderPipeline(&desc);
|
||||||
|
|
||||||
|
desc.cFragment.constantCount = 0;
|
||||||
|
desc.cFragment.constants = nullptr;
|
||||||
|
wgpu::RenderPipeline otherPipeline2 = device.CreateRenderPipeline(&desc);
|
||||||
|
|
||||||
|
EXPECT_NE(pipeline.Get(), otherPipeline1.Get());
|
||||||
|
EXPECT_NE(pipeline.Get(), otherPipeline2.Get());
|
||||||
|
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||||
|
}
|
||||||
|
|
||||||
// Test that Samplers are correctly deduplicated.
|
// Test that Samplers are correctly deduplicated.
|
||||||
TEST_P(ObjectCachingTest, SamplerDeduplication) {
|
TEST_P(ObjectCachingTest, SamplerDeduplication) {
|
||||||
wgpu::SamplerDescriptor samplerDesc;
|
wgpu::SamplerDescriptor samplerDesc;
|
||||||
|
|
|
@ -471,6 +471,127 @@ struct Buf {
|
||||||
EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount);
|
EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test one shader shared by two pipelines with different constants overridden
|
||||||
|
TEST_P(ShaderTests, OverridableConstantsSharedShader) {
|
||||||
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||||
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||||
|
|
||||||
|
std::vector<uint32_t> expected1{1};
|
||||||
|
wgpu::Buffer buffer1 = CreateBuffer(expected1.size());
|
||||||
|
std::vector<uint32_t> expected2{2};
|
||||||
|
wgpu::Buffer buffer2 = CreateBuffer(expected2.size());
|
||||||
|
|
||||||
|
std::string shader = R"(
|
||||||
|
override a: u32;
|
||||||
|
|
||||||
|
struct Buf {
|
||||||
|
data : array<u32, 1>
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(0) var<storage, read_write> buf : Buf;
|
||||||
|
|
||||||
|
@compute @workgroup_size(1) fn main() {
|
||||||
|
buf.data[0] = a;
|
||||||
|
})";
|
||||||
|
|
||||||
|
std::vector<wgpu::ConstantEntry> constants1;
|
||||||
|
constants1.push_back({nullptr, "a", 1});
|
||||||
|
std::vector<wgpu::ConstantEntry> constants2;
|
||||||
|
constants2.push_back({nullptr, "a", 2});
|
||||||
|
|
||||||
|
wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1);
|
||||||
|
wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2);
|
||||||
|
|
||||||
|
wgpu::BindGroup bindGroup1 =
|
||||||
|
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
|
||||||
|
wgpu::BindGroup bindGroup2 =
|
||||||
|
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
|
||||||
|
|
||||||
|
wgpu::CommandBuffer commands;
|
||||||
|
{
|
||||||
|
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||||
|
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||||
|
pass.SetPipeline(pipeline1);
|
||||||
|
pass.SetBindGroup(0, bindGroup1);
|
||||||
|
pass.DispatchWorkgroups(1);
|
||||||
|
pass.SetPipeline(pipeline2);
|
||||||
|
pass.SetBindGroup(0, bindGroup2);
|
||||||
|
pass.DispatchWorkgroups(1);
|
||||||
|
pass.End();
|
||||||
|
|
||||||
|
commands = encoder.Finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
queue.Submit(1, &commands);
|
||||||
|
|
||||||
|
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size());
|
||||||
|
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test overridable constants work with workgroup size
|
||||||
|
TEST_P(ShaderTests, OverridableConstantsWorkgroupSize) {
|
||||||
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||||
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||||
|
|
||||||
|
std::string shader = R"(
|
||||||
|
override x: u32;
|
||||||
|
|
||||||
|
struct Buf {
|
||||||
|
data : array<u32, 1>
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(0) var<storage, read_write> buf : Buf;
|
||||||
|
|
||||||
|
@compute @workgroup_size(x) fn main(
|
||||||
|
@builtin(local_invocation_id) local_invocation_id : vec3<u32>
|
||||||
|
) {
|
||||||
|
if (local_invocation_id.x >= x - 1) {
|
||||||
|
buf.data[0] = local_invocation_id.x + 1;
|
||||||
|
}
|
||||||
|
})";
|
||||||
|
|
||||||
|
const uint32_t workgroup_size_x_1 = 16u;
|
||||||
|
const uint32_t workgroup_size_x_2 = 64u;
|
||||||
|
|
||||||
|
std::vector<uint32_t> expected1{workgroup_size_x_1};
|
||||||
|
wgpu::Buffer buffer1 = CreateBuffer(expected1.size());
|
||||||
|
std::vector<uint32_t> expected2{workgroup_size_x_2};
|
||||||
|
wgpu::Buffer buffer2 = CreateBuffer(expected2.size());
|
||||||
|
|
||||||
|
std::vector<wgpu::ConstantEntry> constants1;
|
||||||
|
constants1.push_back({nullptr, "x", static_cast<double>(workgroup_size_x_1)});
|
||||||
|
std::vector<wgpu::ConstantEntry> constants2;
|
||||||
|
constants2.push_back({nullptr, "x", static_cast<double>(workgroup_size_x_2)});
|
||||||
|
|
||||||
|
wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1);
|
||||||
|
wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2);
|
||||||
|
|
||||||
|
wgpu::BindGroup bindGroup1 =
|
||||||
|
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
|
||||||
|
wgpu::BindGroup bindGroup2 =
|
||||||
|
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
|
||||||
|
|
||||||
|
wgpu::CommandBuffer commands;
|
||||||
|
{
|
||||||
|
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||||
|
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||||
|
pass.SetPipeline(pipeline1);
|
||||||
|
pass.SetBindGroup(0, bindGroup1);
|
||||||
|
pass.DispatchWorkgroups(1);
|
||||||
|
pass.SetPipeline(pipeline2);
|
||||||
|
pass.SetBindGroup(0, bindGroup2);
|
||||||
|
pass.DispatchWorkgroups(1);
|
||||||
|
pass.End();
|
||||||
|
|
||||||
|
commands = encoder.Finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
queue.Submit(1, &commands);
|
||||||
|
|
||||||
|
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size());
|
||||||
|
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size());
|
||||||
|
}
|
||||||
|
|
||||||
// Test overridable constants with numeric identifiers
|
// Test overridable constants with numeric identifiers
|
||||||
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
|
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
|
||||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||||
|
@ -596,6 +717,7 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
||||||
std::string shader = R"(
|
std::string shader = R"(
|
||||||
@id(1001) override c1: u32;
|
@id(1001) override c1: u32;
|
||||||
@id(1002) override c2: u32;
|
@id(1002) override c2: u32;
|
||||||
|
@id(1003) override c3: u32;
|
||||||
|
|
||||||
struct Buf {
|
struct Buf {
|
||||||
data : array<u32, 1>
|
data : array<u32, 1>
|
||||||
|
@ -611,7 +733,7 @@ struct Buf {
|
||||||
buf.data[0] = c2;
|
buf.data[0] = c2;
|
||||||
}
|
}
|
||||||
|
|
||||||
@compute @workgroup_size(1) fn main3() {
|
@compute @workgroup_size(c3) fn main3() {
|
||||||
buf.data[0] = 3u;
|
buf.data[0] = 3u;
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
@ -620,6 +742,8 @@ struct Buf {
|
||||||
constants1.push_back({nullptr, "1001", 1});
|
constants1.push_back({nullptr, "1001", 1});
|
||||||
std::vector<wgpu::ConstantEntry> constants2;
|
std::vector<wgpu::ConstantEntry> constants2;
|
||||||
constants2.push_back({nullptr, "1002", 2});
|
constants2.push_back({nullptr, "1002", 2});
|
||||||
|
std::vector<wgpu::ConstantEntry> constants3;
|
||||||
|
constants3.push_back({nullptr, "1003", 1});
|
||||||
|
|
||||||
wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader.c_str());
|
wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader.c_str());
|
||||||
|
|
||||||
|
@ -640,6 +764,8 @@ struct Buf {
|
||||||
wgpu::ComputePipelineDescriptor csDesc3;
|
wgpu::ComputePipelineDescriptor csDesc3;
|
||||||
csDesc3.compute.module = shaderModule;
|
csDesc3.compute.module = shaderModule;
|
||||||
csDesc3.compute.entryPoint = "main3";
|
csDesc3.compute.entryPoint = "main3";
|
||||||
|
csDesc3.compute.constants = constants3.data();
|
||||||
|
csDesc3.compute.constantCount = constants3.size();
|
||||||
wgpu::ComputePipeline pipeline3 = device.CreateComputePipeline(&csDesc3);
|
wgpu::ComputePipeline pipeline3 = device.CreateComputePipeline(&csDesc3);
|
||||||
|
|
||||||
wgpu::BindGroup bindGroup1 =
|
wgpu::BindGroup bindGroup1 =
|
||||||
|
@ -765,8 +891,6 @@ TEST_P(ShaderTests, ConflictingBindingsDueToTransformOrder) {
|
||||||
device.CreateRenderPipeline(&desc);
|
device.CreateRenderPipeline(&desc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(tint:1155): Test overridable constants used for workgroup size
|
|
||||||
|
|
||||||
DAWN_INSTANTIATE_TEST(ShaderTests,
|
DAWN_INSTANTIATE_TEST(ShaderTests,
|
||||||
D3D12Backend(),
|
D3D12Backend(),
|
||||||
MetalBackend(),
|
MetalBackend(),
|
||||||
|
|
|
@ -0,0 +1,395 @@
|
||||||
|
// Copyright 2022 The Dawn Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "dawn/tests/DawnTest.h"
|
||||||
|
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
|
||||||
|
#include "dawn/utils/WGPUHelpers.h"
|
||||||
|
|
||||||
|
// The compute shader workgroup size is settled at compute pipeline creation time.
|
||||||
|
// The validation code in dawn is in each backend (not including Null backend) thus this test needs
|
||||||
|
// to be as part of a dawn_end2end_tests instead of the dawn_unittests
|
||||||
|
// TODO(dawn:1504): Add support for GL backend.
|
||||||
|
class WorkgroupSizeValidationTest : public DawnTest {
|
||||||
|
public:
|
||||||
|
wgpu::ShaderModule SetUpShaderWithValidDefaultValueConstants() {
|
||||||
|
return utils::CreateShaderModule(device, R"(
|
||||||
|
override x: u32 = 1u;
|
||||||
|
override y: u32 = 1u;
|
||||||
|
override z: u32 = 1u;
|
||||||
|
|
||||||
|
@compute @workgroup_size(x, y, z) fn main() {
|
||||||
|
_ = 0u;
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
|
||||||
|
wgpu::ShaderModule SetUpShaderWithZeroDefaultValueConstants() {
|
||||||
|
return utils::CreateShaderModule(device, R"(
|
||||||
|
override x: u32 = 0u;
|
||||||
|
override y: u32 = 0u;
|
||||||
|
override z: u32 = 0u;
|
||||||
|
|
||||||
|
@compute @workgroup_size(x, y, z) fn main() {
|
||||||
|
_ = 0u;
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
|
||||||
|
wgpu::ShaderModule SetUpShaderWithOutOfLimitsDefaultValueConstants() {
|
||||||
|
return utils::CreateShaderModule(device, R"(
|
||||||
|
override x: u32 = 1u;
|
||||||
|
override y: u32 = 1u;
|
||||||
|
override z: u32 = 9999u;
|
||||||
|
|
||||||
|
@compute @workgroup_size(x, y, z) fn main() {
|
||||||
|
_ = 0u;
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
|
||||||
|
wgpu::ShaderModule SetUpShaderWithUninitializedConstants() {
|
||||||
|
return utils::CreateShaderModule(device, R"(
|
||||||
|
override x: u32;
|
||||||
|
override y: u32;
|
||||||
|
override z: u32;
|
||||||
|
|
||||||
|
@compute @workgroup_size(x, y, z) fn main() {
|
||||||
|
_ = 0u;
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
|
||||||
|
wgpu::ShaderModule SetUpShaderWithPartialConstants() {
|
||||||
|
return utils::CreateShaderModule(device, R"(
|
||||||
|
override x: u32;
|
||||||
|
|
||||||
|
@compute @workgroup_size(x, 1, 1) fn main() {
|
||||||
|
_ = 0u;
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestCreatePipeline(const wgpu::ShaderModule& module) {
|
||||||
|
wgpu::ComputePipelineDescriptor csDesc;
|
||||||
|
csDesc.compute.module = module;
|
||||||
|
csDesc.compute.entryPoint = "main";
|
||||||
|
csDesc.compute.constants = nullptr;
|
||||||
|
csDesc.compute.constantCount = 0;
|
||||||
|
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestCreatePipeline(const wgpu::ShaderModule& module,
|
||||||
|
const std::vector<wgpu::ConstantEntry>& constants) {
|
||||||
|
wgpu::ComputePipelineDescriptor csDesc;
|
||||||
|
csDesc.compute.module = module;
|
||||||
|
csDesc.compute.entryPoint = "main";
|
||||||
|
csDesc.compute.constants = constants.data();
|
||||||
|
csDesc.compute.constantCount = constants.size();
|
||||||
|
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestInitializedWithZero(const wgpu::ShaderModule& module) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "x", 0}, {nullptr, "y", 0}, {nullptr, "z", 0}};
|
||||||
|
TestCreatePipeline(module, constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestInitializedWithOutOfLimitValue(const wgpu::ShaderModule& module) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "x",
|
||||||
|
static_cast<double>(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)},
|
||||||
|
{nullptr, "y", 1},
|
||||||
|
{nullptr, "z", 1}};
|
||||||
|
TestCreatePipeline(module, constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestInitializedWithValidValue(const wgpu::ShaderModule& module) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "x", 4}, {nullptr, "y", 4}, {nullptr, "z", 4}};
|
||||||
|
TestCreatePipeline(module, constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestInitializedPartially(const wgpu::ShaderModule& module) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{{nullptr, "y", 4}};
|
||||||
|
TestCreatePipeline(module, constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
wgpu::Buffer buffer;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Test workgroup size validation with fixed values.
|
||||||
|
TEST_P(WorkgroupSizeValidationTest, WithFixedValues) {
|
||||||
|
auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}";
|
||||||
|
|
||||||
|
wgpu::ComputePipelineDescriptor desc;
|
||||||
|
desc.compute.entryPoint = "main";
|
||||||
|
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
|
||||||
|
|
||||||
|
if (success) {
|
||||||
|
device.CreateComputePipeline(&desc);
|
||||||
|
} else {
|
||||||
|
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
||||||
|
|
||||||
|
CheckShaderWithWorkgroupSize(true, 1, 1, 1);
|
||||||
|
CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1);
|
||||||
|
CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1);
|
||||||
|
CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ);
|
||||||
|
|
||||||
|
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1);
|
||||||
|
CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1);
|
||||||
|
CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1);
|
||||||
|
|
||||||
|
// No individual dimension exceeds its limit, but the combined size should definitely exceed the
|
||||||
|
// total invocation limit.
|
||||||
|
DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX *
|
||||||
|
supportedLimits.maxComputeWorkgroupSizeY *
|
||||||
|
supportedLimits.maxComputeWorkgroupSizeZ >
|
||||||
|
supportedLimits.maxComputeInvocationsPerWorkgroup);
|
||||||
|
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX,
|
||||||
|
supportedLimits.maxComputeWorkgroupSizeY,
|
||||||
|
supportedLimits.maxComputeWorkgroupSizeZ);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test workgroup size validation with fixed values (storage size limits validation).
|
||||||
|
TEST_P(WorkgroupSizeValidationTest, WithFixedValuesStorageSizeLimits) {
|
||||||
|
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
||||||
|
|
||||||
|
constexpr uint32_t kVec4Size = 16;
|
||||||
|
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
|
||||||
|
constexpr uint32_t kMat4Size = 64;
|
||||||
|
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
|
||||||
|
|
||||||
|
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
|
||||||
|
uint32_t mat4_count) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
std::ostringstream body;
|
||||||
|
if (vec4_count > 0) {
|
||||||
|
ss << "var<workgroup> vec4_data: array<vec4<f32>, " << vec4_count << ">;";
|
||||||
|
body << "_ = vec4_data;";
|
||||||
|
}
|
||||||
|
if (mat4_count > 0) {
|
||||||
|
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, " << mat4_count << ">;";
|
||||||
|
body << "_ = mat4_data;";
|
||||||
|
}
|
||||||
|
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
|
||||||
|
|
||||||
|
wgpu::ComputePipelineDescriptor desc;
|
||||||
|
desc.compute.entryPoint = "main";
|
||||||
|
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
|
||||||
|
|
||||||
|
if (success) {
|
||||||
|
device.CreateComputePipeline(&desc);
|
||||||
|
} else {
|
||||||
|
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
CheckPipelineWithWorkgroupStorage(true, 1, 1);
|
||||||
|
CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0);
|
||||||
|
CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count);
|
||||||
|
CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1);
|
||||||
|
CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1);
|
||||||
|
|
||||||
|
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
|
||||||
|
CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1);
|
||||||
|
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
|
||||||
|
CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test workgroup size validation with valid overrides default values.
|
||||||
|
TEST_P(WorkgroupSizeValidationTest, OverridesWithValidDefault) {
|
||||||
|
wgpu::ShaderModule module = SetUpShaderWithValidDefaultValueConstants();
|
||||||
|
{
|
||||||
|
// Valid default
|
||||||
|
TestCreatePipeline(module);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: invalid value (zero)
|
||||||
|
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: invalid value (out of device limits)
|
||||||
|
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Valid: initialized partially
|
||||||
|
TestInitializedPartially(module);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Valid
|
||||||
|
TestInitializedWithValidValue(module);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test workgroup size validation with zero as the overrides default values.
|
||||||
|
TEST_P(WorkgroupSizeValidationTest, OverridesWithZeroDefault) {
|
||||||
|
// Error: zero is detected as invalid at shader creation time
|
||||||
|
ASSERT_DEVICE_ERROR(SetUpShaderWithZeroDefaultValueConstants());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test workgroup size validation with out-of-limits overrides default values.
|
||||||
|
TEST_P(WorkgroupSizeValidationTest, OverridesWithOutOfLimitsDefault) {
|
||||||
|
wgpu::ShaderModule module = SetUpShaderWithOutOfLimitsDefaultValueConstants();
|
||||||
|
{
|
||||||
|
// Error: invalid default
|
||||||
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: invalid value (zero)
|
||||||
|
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: invalid value (out of device limits)
|
||||||
|
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: initialized partially
|
||||||
|
ASSERT_DEVICE_ERROR(TestInitializedPartially(module));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Valid
|
||||||
|
TestInitializedWithValidValue(module);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test workgroup size validation without overrides default values specified.
|
||||||
|
TEST_P(WorkgroupSizeValidationTest, OverridesWithUninitialized) {
|
||||||
|
wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants();
|
||||||
|
{
|
||||||
|
// Error: uninitialized
|
||||||
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: invalid value (zero)
|
||||||
|
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: invalid value (out of device limits)
|
||||||
|
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: initialized partially
|
||||||
|
ASSERT_DEVICE_ERROR(TestInitializedPartially(module));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Valid
|
||||||
|
TestInitializedWithValidValue(module);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test workgroup size validation with only partial dimensions are overrides.
|
||||||
|
TEST_P(WorkgroupSizeValidationTest, PartialOverrides) {
|
||||||
|
wgpu::ShaderModule module = SetUpShaderWithPartialConstants();
|
||||||
|
{
|
||||||
|
// Error: uninitialized
|
||||||
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: invalid value (zero)
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 0}};
|
||||||
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: invalid value (out of device limits)
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "x",
|
||||||
|
static_cast<double>(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)}};
|
||||||
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Valid
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 16}};
|
||||||
|
TestCreatePipeline(module, constants);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test workgroup size validation after being overrided with invalid values.
|
||||||
|
TEST_P(WorkgroupSizeValidationTest, ValidationAfterOverride) {
|
||||||
|
wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants();
|
||||||
|
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
||||||
|
{
|
||||||
|
// Error: exceed maxComputeWorkgroupSizeZ
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "x", 1},
|
||||||
|
{nullptr, "y", 1},
|
||||||
|
{nullptr, "z", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeZ + 1)}};
|
||||||
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: exceed maxComputeInvocationsPerWorkgroup
|
||||||
|
DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX *
|
||||||
|
supportedLimits.maxComputeWorkgroupSizeY *
|
||||||
|
supportedLimits.maxComputeWorkgroupSizeZ >
|
||||||
|
supportedLimits.maxComputeInvocationsPerWorkgroup);
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "x", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeX)},
|
||||||
|
{nullptr, "y", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeY)},
|
||||||
|
{nullptr, "z", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeZ)}};
|
||||||
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test workgroup size validation after being overrided with invalid values (storage size limits
|
||||||
|
// validation).
|
||||||
|
// TODO(tint:1660): re-enable after override can be used as array size.
|
||||||
|
TEST_P(WorkgroupSizeValidationTest, DISABLED_ValidationAfterOverrideStorageSize) {
|
||||||
|
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
||||||
|
|
||||||
|
constexpr uint32_t kVec4Size = 16;
|
||||||
|
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
|
||||||
|
constexpr uint32_t kMat4Size = 64;
|
||||||
|
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
|
||||||
|
|
||||||
|
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
|
||||||
|
uint32_t mat4_count) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
std::ostringstream body;
|
||||||
|
ss << "override a: u32;";
|
||||||
|
ss << "override b: u32;";
|
||||||
|
if (vec4_count > 0) {
|
||||||
|
ss << "var<workgroup> vec4_data: array<vec4<f32>, a>;";
|
||||||
|
body << "_ = vec4_data;";
|
||||||
|
}
|
||||||
|
if (mat4_count > 0) {
|
||||||
|
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, b>;";
|
||||||
|
body << "_ = mat4_data;";
|
||||||
|
}
|
||||||
|
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
|
||||||
|
|
||||||
|
wgpu::ComputePipelineDescriptor desc;
|
||||||
|
desc.compute.entryPoint = "main";
|
||||||
|
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
|
||||||
|
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{{nullptr, "a", static_cast<double>(vec4_count)},
|
||||||
|
{nullptr, "b", static_cast<double>(mat4_count)}};
|
||||||
|
desc.compute.constants = constants.data();
|
||||||
|
desc.compute.constantCount = constants.size();
|
||||||
|
|
||||||
|
if (success) {
|
||||||
|
device.CreateComputePipeline(&desc);
|
||||||
|
} else {
|
||||||
|
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
|
||||||
|
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
DAWN_INSTANTIATE_TEST(WorkgroupSizeValidationTest, D3D12Backend(), MetalBackend(), VulkanBackend());
|
|
@ -0,0 +1,81 @@
|
||||||
|
// Copyright 2022 The Dawn Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "dawn/native/ObjectContentHasher.h"
|
||||||
|
#include "dawn/tests/DawnNativeTest.h"
|
||||||
|
|
||||||
|
namespace dawn::native {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class ObjectContentHasherTests : public DawnNativeTest {};
|
||||||
|
|
||||||
|
#define EXPECT_IF_HASH_EQ(eq, a, b) \
|
||||||
|
do { \
|
||||||
|
ObjectContentHasher ra, rb; \
|
||||||
|
ra.Record(a); \
|
||||||
|
rb.Record(b); \
|
||||||
|
EXPECT_EQ(eq, ra.GetContentHash() == rb.GetContentHash()); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
TEST(ObjectContentHasherTests, Pair) {
|
||||||
|
EXPECT_IF_HASH_EQ(true, (std::pair<std::string, uint8_t>{"a", 1}),
|
||||||
|
(std::pair<std::string, uint8_t>{"a", 1}));
|
||||||
|
EXPECT_IF_HASH_EQ(false, (std::pair<uint8_t, std::string>{1, "a"}),
|
||||||
|
(std::pair<std::string, uint8_t>{"a", 1}));
|
||||||
|
EXPECT_IF_HASH_EQ(false, (std::pair<std::string, uint8_t>{"a", 1}),
|
||||||
|
(std::pair<std::string, uint8_t>{"a", 2}));
|
||||||
|
EXPECT_IF_HASH_EQ(false, (std::pair<std::string, uint8_t>{"a", 1}),
|
||||||
|
(std::pair<std::string, uint8_t>{"b", 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ObjectContentHasherTests, Vector) {
|
||||||
|
EXPECT_IF_HASH_EQ(true, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{0, 1}));
|
||||||
|
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{0, 1, 2}));
|
||||||
|
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{1, 0}));
|
||||||
|
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{}));
|
||||||
|
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<float>{0, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ObjectContentHasherTests, Map) {
|
||||||
|
EXPECT_IF_HASH_EQ(true, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
|
||||||
|
(std::map<std::string, uint8_t>{{"b", 2}, {"a", 1}}));
|
||||||
|
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
|
||||||
|
(std::map<std::string, uint8_t>{{"a", 2}, {"b", 1}}));
|
||||||
|
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
|
||||||
|
(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}, {"c", 1}}));
|
||||||
|
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
|
||||||
|
(std::map<std::string, uint8_t>{}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ObjectContentHasherTests, HashCombine) {
|
||||||
|
ObjectContentHasher ra, rb;
|
||||||
|
|
||||||
|
ra.Record(std::vector<uint8_t>{0, 1});
|
||||||
|
ra.Record(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}});
|
||||||
|
|
||||||
|
rb.Record(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}});
|
||||||
|
rb.Record(std::vector<uint8_t>{0, 1});
|
||||||
|
|
||||||
|
EXPECT_NE(ra.GetContentHash(), rb.GetContentHash());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace dawn::native
|
|
@ -450,87 +450,6 @@ TEST_F(ShaderModuleValidationTest, MaximumInterStageShaderComponents) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that we validate workgroup size limits.
|
|
||||||
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupSizeLimits) {
|
|
||||||
auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) {
|
|
||||||
std::ostringstream ss;
|
|
||||||
ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}";
|
|
||||||
|
|
||||||
wgpu::ComputePipelineDescriptor desc;
|
|
||||||
desc.compute.entryPoint = "main";
|
|
||||||
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
|
|
||||||
|
|
||||||
if (success) {
|
|
||||||
device.CreateComputePipeline(&desc);
|
|
||||||
} else {
|
|
||||||
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
|
||||||
|
|
||||||
CheckShaderWithWorkgroupSize(true, 1, 1, 1);
|
|
||||||
CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1);
|
|
||||||
CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1);
|
|
||||||
CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ);
|
|
||||||
|
|
||||||
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1);
|
|
||||||
CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1);
|
|
||||||
CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1);
|
|
||||||
|
|
||||||
// No individual dimension exceeds its limit, but the combined size should definitely exceed the
|
|
||||||
// total invocation limit.
|
|
||||||
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX,
|
|
||||||
supportedLimits.maxComputeWorkgroupSizeY,
|
|
||||||
supportedLimits.maxComputeWorkgroupSizeZ);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tests that we validate workgroup storage size limits.
|
|
||||||
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupStorageSizeLimits) {
|
|
||||||
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
|
||||||
|
|
||||||
constexpr uint32_t kVec4Size = 16;
|
|
||||||
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
|
|
||||||
constexpr uint32_t kMat4Size = 64;
|
|
||||||
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
|
|
||||||
|
|
||||||
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
|
|
||||||
uint32_t mat4_count) {
|
|
||||||
std::ostringstream ss;
|
|
||||||
std::ostringstream body;
|
|
||||||
if (vec4_count > 0) {
|
|
||||||
ss << "var<workgroup> vec4_data: array<vec4<f32>, " << vec4_count << ">;";
|
|
||||||
body << "_ = vec4_data;";
|
|
||||||
}
|
|
||||||
if (mat4_count > 0) {
|
|
||||||
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, " << mat4_count << ">;";
|
|
||||||
body << "_ = mat4_data;";
|
|
||||||
}
|
|
||||||
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
|
|
||||||
|
|
||||||
wgpu::ComputePipelineDescriptor desc;
|
|
||||||
desc.compute.entryPoint = "main";
|
|
||||||
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
|
|
||||||
|
|
||||||
if (success) {
|
|
||||||
device.CreateComputePipeline(&desc);
|
|
||||||
} else {
|
|
||||||
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
CheckPipelineWithWorkgroupStorage(true, 1, 1);
|
|
||||||
CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0);
|
|
||||||
CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count);
|
|
||||||
CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1);
|
|
||||||
CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1);
|
|
||||||
|
|
||||||
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
|
|
||||||
CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1);
|
|
||||||
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
|
|
||||||
CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test that numeric ID must be unique
|
// Test that numeric ID must be unique
|
||||||
TEST_F(ShaderModuleValidationTest, OverridableConstantsNumericIDConflicts) {
|
TEST_F(ShaderModuleValidationTest, OverridableConstantsNumericIDConflicts) {
|
||||||
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(
|
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(
|
||||||
|
|
|
@ -37,46 +37,6 @@ class UnsafeAPIValidationTest : public ValidationTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check that pipeline overridable constants are disallowed as part of unsafe APIs.
|
|
||||||
// TODO(dawn:1041) Remove when implementation for all backend is added
|
|
||||||
TEST_F(UnsafeAPIValidationTest, PipelineOverridableConstants) {
|
|
||||||
// Create the placeholder compute pipeline.
|
|
||||||
wgpu::ComputePipelineDescriptor pipelineDescBase;
|
|
||||||
pipelineDescBase.compute.entryPoint = "main";
|
|
||||||
|
|
||||||
// Control case: shader without overridable constant is allowed.
|
|
||||||
{
|
|
||||||
wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase;
|
|
||||||
pipelineDesc.compute.module =
|
|
||||||
utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}");
|
|
||||||
|
|
||||||
device.CreateComputePipeline(&pipelineDesc);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error case: shader with overridable constant with default value
|
|
||||||
{
|
|
||||||
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(
|
|
||||||
@id(1000) override c0: u32 = 1u;
|
|
||||||
@id(1000) override c1: u32;
|
|
||||||
|
|
||||||
@compute @workgroup_size(1) fn main() {
|
|
||||||
_ = c0;
|
|
||||||
_ = c1;
|
|
||||||
})"));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error case: pipeline stage with constant entry is disallowed
|
|
||||||
{
|
|
||||||
wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase;
|
|
||||||
pipelineDesc.compute.module =
|
|
||||||
utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}");
|
|
||||||
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c", 1u}};
|
|
||||||
pipelineDesc.compute.constants = constants.data();
|
|
||||||
pipelineDesc.compute.constantCount = constants.size();
|
|
||||||
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&pipelineDesc));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class UnsafeQueryAPIValidationTest : public ValidationTest {
|
class UnsafeQueryAPIValidationTest : public ValidationTest {
|
||||||
protected:
|
protected:
|
||||||
WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override {
|
WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override {
|
||||||
|
|
|
@ -133,6 +133,110 @@ Inspector::Inspector(const Program* program) : program_(program) {}
|
||||||
|
|
||||||
Inspector::~Inspector() = default;
|
Inspector::~Inspector() = default;
|
||||||
|
|
||||||
|
EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) {
|
||||||
|
EntryPoint entry_point;
|
||||||
|
TINT_ASSERT(Inspector, func != nullptr);
|
||||||
|
TINT_ASSERT(Inspector, func->IsEntryPoint());
|
||||||
|
|
||||||
|
auto* sem = program_->Sem().Get(func);
|
||||||
|
|
||||||
|
entry_point.name = program_->Symbols().NameFor(func->symbol);
|
||||||
|
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol);
|
||||||
|
|
||||||
|
switch (func->PipelineStage()) {
|
||||||
|
case ast::PipelineStage::kCompute: {
|
||||||
|
entry_point.stage = PipelineStage::kCompute;
|
||||||
|
|
||||||
|
auto wgsize = sem->WorkgroupSize();
|
||||||
|
if (!wgsize[0].overridable_const && !wgsize[1].overridable_const &&
|
||||||
|
!wgsize[2].overridable_const) {
|
||||||
|
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value, wgsize[2].value};
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ast::PipelineStage::kFragment: {
|
||||||
|
entry_point.stage = PipelineStage::kFragment;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ast::PipelineStage::kVertex: {
|
||||||
|
entry_point.stage = PipelineStage::kVertex;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
TINT_UNREACHABLE(Inspector, diagnostics_)
|
||||||
|
<< "invalid pipeline stage for entry point '" << entry_point.name << "'";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto* param : sem->Parameters()) {
|
||||||
|
AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol),
|
||||||
|
param->Type(), param->Declaration()->attributes,
|
||||||
|
entry_point.input_variables);
|
||||||
|
|
||||||
|
entry_point.input_position_used |= ContainsBuiltin(
|
||||||
|
ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes);
|
||||||
|
entry_point.front_facing_used |= ContainsBuiltin(
|
||||||
|
ast::BuiltinValue::kFrontFacing, param->Type(), param->Declaration()->attributes);
|
||||||
|
entry_point.sample_index_used |= ContainsBuiltin(
|
||||||
|
ast::BuiltinValue::kSampleIndex, param->Type(), param->Declaration()->attributes);
|
||||||
|
entry_point.input_sample_mask_used |= ContainsBuiltin(
|
||||||
|
ast::BuiltinValue::kSampleMask, param->Type(), param->Declaration()->attributes);
|
||||||
|
entry_point.num_workgroups_used |= ContainsBuiltin(
|
||||||
|
ast::BuiltinValue::kNumWorkgroups, param->Type(), param->Declaration()->attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sem->ReturnType()->Is<sem::Void>()) {
|
||||||
|
AddEntryPointInOutVariables("<retval>", sem->ReturnType(), func->return_type_attributes,
|
||||||
|
entry_point.output_variables);
|
||||||
|
|
||||||
|
entry_point.output_sample_mask_used = ContainsBuiltin(
|
||||||
|
ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto* var : sem->TransitivelyReferencedGlobals()) {
|
||||||
|
auto* decl = var->Declaration();
|
||||||
|
|
||||||
|
auto name = program_->Symbols().NameFor(decl->symbol);
|
||||||
|
|
||||||
|
auto* global = var->As<sem::GlobalVariable>();
|
||||||
|
if (global && global->Declaration()->Is<ast::Override>()) {
|
||||||
|
Override override;
|
||||||
|
override.name = name;
|
||||||
|
override.id = global->OverrideId();
|
||||||
|
auto* type = var->Type();
|
||||||
|
TINT_ASSERT(Inspector, type->is_scalar());
|
||||||
|
if (type->is_bool_scalar_or_vector()) {
|
||||||
|
override.type = Override::Type::kBool;
|
||||||
|
} else if (type->is_float_scalar()) {
|
||||||
|
override.type = Override::Type::kFloat32;
|
||||||
|
} else if (type->is_signed_integer_scalar()) {
|
||||||
|
override.type = Override::Type::kInt32;
|
||||||
|
} else if (type->is_unsigned_integer_scalar()) {
|
||||||
|
override.type = Override::Type::kUint32;
|
||||||
|
} else {
|
||||||
|
TINT_UNREACHABLE(Inspector, diagnostics_);
|
||||||
|
}
|
||||||
|
|
||||||
|
override.is_initialized = global->Declaration()->constructor;
|
||||||
|
override.is_id_specified =
|
||||||
|
ast::HasAttribute<ast::IdAttribute>(global->Declaration()->attributes);
|
||||||
|
|
||||||
|
entry_point.overrides.push_back(override);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry_point;
|
||||||
|
}
|
||||||
|
|
||||||
|
EntryPoint Inspector::GetEntryPoint(const std::string& entry_point_name) {
|
||||||
|
auto* func = FindEntryPointByName(entry_point_name);
|
||||||
|
if (!func) {
|
||||||
|
return EntryPoint();
|
||||||
|
}
|
||||||
|
return GetEntryPoint(func);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
||||||
std::vector<EntryPoint> result;
|
std::vector<EntryPoint> result;
|
||||||
|
|
||||||
|
@ -141,97 +245,7 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* sem = program_->Sem().Get(func);
|
result.push_back(GetEntryPoint(func));
|
||||||
|
|
||||||
EntryPoint entry_point;
|
|
||||||
entry_point.name = program_->Symbols().NameFor(func->symbol);
|
|
||||||
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol);
|
|
||||||
|
|
||||||
switch (func->PipelineStage()) {
|
|
||||||
case ast::PipelineStage::kCompute: {
|
|
||||||
entry_point.stage = PipelineStage::kCompute;
|
|
||||||
|
|
||||||
auto wgsize = sem->WorkgroupSize();
|
|
||||||
if (!wgsize[0].overridable_const && !wgsize[1].overridable_const &&
|
|
||||||
!wgsize[2].overridable_const) {
|
|
||||||
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value,
|
|
||||||
wgsize[2].value};
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ast::PipelineStage::kFragment: {
|
|
||||||
entry_point.stage = PipelineStage::kFragment;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ast::PipelineStage::kVertex: {
|
|
||||||
entry_point.stage = PipelineStage::kVertex;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default: {
|
|
||||||
TINT_UNREACHABLE(Inspector, diagnostics_)
|
|
||||||
<< "invalid pipeline stage for entry point '" << entry_point.name << "'";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto* param : sem->Parameters()) {
|
|
||||||
AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol),
|
|
||||||
param->Type(), param->Declaration()->attributes,
|
|
||||||
entry_point.input_variables);
|
|
||||||
|
|
||||||
entry_point.input_position_used |= ContainsBuiltin(
|
|
||||||
ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes);
|
|
||||||
entry_point.front_facing_used |= ContainsBuiltin(
|
|
||||||
ast::BuiltinValue::kFrontFacing, param->Type(), param->Declaration()->attributes);
|
|
||||||
entry_point.sample_index_used |= ContainsBuiltin(
|
|
||||||
ast::BuiltinValue::kSampleIndex, param->Type(), param->Declaration()->attributes);
|
|
||||||
entry_point.input_sample_mask_used |= ContainsBuiltin(
|
|
||||||
ast::BuiltinValue::kSampleMask, param->Type(), param->Declaration()->attributes);
|
|
||||||
entry_point.num_workgroups_used |= ContainsBuiltin(
|
|
||||||
ast::BuiltinValue::kNumWorkgroups, param->Type(), param->Declaration()->attributes);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!sem->ReturnType()->Is<sem::Void>()) {
|
|
||||||
AddEntryPointInOutVariables("<retval>", sem->ReturnType(), func->return_type_attributes,
|
|
||||||
entry_point.output_variables);
|
|
||||||
|
|
||||||
entry_point.output_sample_mask_used = ContainsBuiltin(
|
|
||||||
ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto* var : sem->TransitivelyReferencedGlobals()) {
|
|
||||||
auto* decl = var->Declaration();
|
|
||||||
|
|
||||||
auto name = program_->Symbols().NameFor(decl->symbol);
|
|
||||||
|
|
||||||
auto* global = var->As<sem::GlobalVariable>();
|
|
||||||
if (global && global->Declaration()->Is<ast::Override>()) {
|
|
||||||
Override override;
|
|
||||||
override.name = name;
|
|
||||||
override.id = global->OverrideId();
|
|
||||||
auto* type = var->Type();
|
|
||||||
TINT_ASSERT(Inspector, type->is_scalar());
|
|
||||||
if (type->is_bool_scalar_or_vector()) {
|
|
||||||
override.type = Override::Type::kBool;
|
|
||||||
} else if (type->is_float_scalar()) {
|
|
||||||
override.type = Override::Type::kFloat32;
|
|
||||||
} else if (type->is_signed_integer_scalar()) {
|
|
||||||
override.type = Override::Type::kInt32;
|
|
||||||
} else if (type->is_unsigned_integer_scalar()) {
|
|
||||||
override.type = Override::Type::kUint32;
|
|
||||||
} else {
|
|
||||||
TINT_UNREACHABLE(Inspector, diagnostics_);
|
|
||||||
}
|
|
||||||
|
|
||||||
override.is_initialized = global->Declaration()->constructor;
|
|
||||||
override.is_id_specified =
|
|
||||||
ast::HasAttribute<ast::IdAttribute>(global->Declaration()->attributes);
|
|
||||||
|
|
||||||
entry_point.overrides.push_back(override);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result.push_back(std::move(entry_point));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
|
@ -55,6 +55,10 @@ class Inspector {
|
||||||
/// @returns vector of entry point information
|
/// @returns vector of entry point information
|
||||||
std::vector<EntryPoint> GetEntryPoints();
|
std::vector<EntryPoint> GetEntryPoints();
|
||||||
|
|
||||||
|
/// @param entry_point name of the entry point to get information about
|
||||||
|
/// @returns the entry point information
|
||||||
|
EntryPoint GetEntryPoint(const std::string& entry_point);
|
||||||
|
|
||||||
/// @returns map of override identifier to initial value
|
/// @returns map of override identifier to initial value
|
||||||
std::map<OverrideId, Scalar> GetOverrideDefaultValues();
|
std::map<OverrideId, Scalar> GetOverrideDefaultValues();
|
||||||
|
|
||||||
|
@ -230,6 +234,10 @@ class Inspector {
|
||||||
/// whenever a set of expressions are resolved to globals.
|
/// whenever a set of expressions are resolved to globals.
|
||||||
template <size_t N, typename F>
|
template <size_t N, typename F>
|
||||||
void GetOriginatingResources(std::array<const ast::Expression*, N> exprs, F&& cb);
|
void GetOriginatingResources(std::array<const ast::Expression*, N> exprs, F&& cb);
|
||||||
|
|
||||||
|
/// @param func the function of the entry point. Must be non-nullptr and true for IsEntryPoint()
|
||||||
|
/// @returns the entry point information
|
||||||
|
EntryPoint GetEntryPoint(const tint::ast::Function* func);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tint::inspector
|
} // namespace tint::inspector
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include "tint/override_id.h"
|
#include "tint/override_id.h"
|
||||||
|
|
||||||
|
#include "src/tint/reflection.h"
|
||||||
#include "src/tint/transform/transform.h"
|
#include "src/tint/transform/transform.h"
|
||||||
|
|
||||||
namespace tint::transform {
|
namespace tint::transform {
|
||||||
|
@ -63,6 +64,9 @@ class SubstituteOverride final : public Castable<SubstituteOverride, Transform>
|
||||||
/// The value is always a double coming into the transform and will be
|
/// The value is always a double coming into the transform and will be
|
||||||
/// converted to the correct type through and initializer.
|
/// converted to the correct type through and initializer.
|
||||||
std::unordered_map<OverrideId, double> map;
|
std::unordered_map<OverrideId, double> map;
|
||||||
|
|
||||||
|
/// Reflect the fields of this class so that it can be used by tint::ForeachField()
|
||||||
|
TINT_REFLECT(map);
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Constructor
|
/// Constructor
|
||||||
|
|
Loading…
Reference in New Issue