Use SubstituteOverride transform to implement overrides
Remove the old backend specific implementation for overrides. Use tint SubstituteOverride transform to replace overrides with const expressions and use the updated program at pipeline creation time. This CL also adds support for overrides used as workgroup size and related tests. Workgroup size validation now happens in backend code and at compute pipeline creation time. Bug: dawn:1504 Change-Id: I7df1fe9c3e358caa23235eacd6d13ba0b2998aec Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99821 Commit-Queue: Shrek Shao <shrekshao@google.com> Reviewed-by: Austin Eng <enga@chromium.org> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
23cf74c30e
commit
145337f309
|
@ -25,7 +25,15 @@
|
|||
namespace {{native_namespace}} {
|
||||
|
||||
//
|
||||
// Cache key writers for wgpu structures used in caching.
|
||||
// Streaming readers for wgpu structures.
|
||||
//
|
||||
{% macro render_reader(member) %}
|
||||
{%- set name = member.name.camelCase() -%}
|
||||
DAWN_TRY(StreamOut(source, &t->{{name}}));
|
||||
{% endmacro %}
|
||||
|
||||
//
|
||||
// Streaming writers for wgpu structures.
|
||||
//
|
||||
{% macro render_writer(member) %}
|
||||
{%- set name = member.name.camelCase() -%}
|
||||
|
@ -38,31 +46,50 @@ namespace {{native_namespace}} {
|
|||
{% endif %}
|
||||
{% endmacro %}
|
||||
|
||||
{# Helper macro to render writers. Should be used in a call block to provide additional custom
|
||||
{# Helper macro to render readers and writers. Should be used in a call block to provide additional custom
|
||||
handling when necessary. The optional `omit` field can be used to omit fields that are either
|
||||
handled in the custom code, or unnecessary in the serialized output.
|
||||
Example:
|
||||
{% call render_cache_key_writer("struct name", omits=["omit field"]) %}
|
||||
{% call render_streaming_impl("struct name", writer=true, reader=false, omits=["omit field"]) %}
|
||||
// Custom C++ code to handle special types/members that are hard to generate code for
|
||||
{% endcall %}
|
||||
One day we should probably make the generator smart enough to generate everything it can
|
||||
instead of manually adding streaming implementations here.
|
||||
#}
|
||||
{% macro render_cache_key_writer(json_type, omits=[]) %}
|
||||
{% macro render_streaming_impl(json_type, writer, reader, omits=[]) %}
|
||||
{%- set cpp_type = types[json_type].name.CamelCase() -%}
|
||||
template <>
|
||||
void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) {
|
||||
{{ caller() }}
|
||||
{% for member in types[json_type].members %}
|
||||
{%- if not member.name.get() in omits %}
|
||||
{{render_writer(member)}}
|
||||
{%- endif %}
|
||||
{% endfor %}
|
||||
}
|
||||
{% if reader %}
|
||||
template <>
|
||||
MaybeError stream::Stream<{{cpp_type}}>::Read(stream::Source* source, {{cpp_type}}* t) {
|
||||
{{ caller() }}
|
||||
{% for member in types[json_type].members %}
|
||||
{% if not member.name.get() in omits %}
|
||||
{{render_reader(member)}}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
return {};
|
||||
}
|
||||
{% endif %}
|
||||
{% if writer %}
|
||||
template <>
|
||||
void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) {
|
||||
{{ caller() }}
|
||||
{% for member in types[json_type].members %}
|
||||
{% if not member.name.get() in omits %}
|
||||
{{render_writer(member)}}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
}
|
||||
{% endif %}
|
||||
{% endmacro %}
|
||||
|
||||
{% call render_cache_key_writer("adapter properties") %}
|
||||
{% call render_streaming_impl("adapter properties", true, false) %}
|
||||
{% endcall %}
|
||||
|
||||
{% call render_cache_key_writer("dawn cache device descriptor") %}
|
||||
{% call render_streaming_impl("dawn cache device descriptor", true, false) %}
|
||||
{% endcall %}
|
||||
|
||||
{% call render_streaming_impl("extent 3D", true, true) %}
|
||||
{% endcall %}
|
||||
|
||||
} // namespace {{native_namespace}}
|
||||
|
|
|
@ -215,4 +215,22 @@ Limits ApplyLimitTiers(Limits limits) {
|
|||
return limits;
|
||||
}
|
||||
|
||||
#define DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT(type, name) \
|
||||
{ result.name = limits.name; }
|
||||
#define DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(MEMBERS) \
|
||||
MEMBERS(DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT)
|
||||
LimitsForCompilationRequest LimitsForCompilationRequest::Create(const Limits& limits) {
|
||||
LimitsForCompilationRequest result;
|
||||
DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS)
|
||||
return result;
|
||||
}
|
||||
#undef DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT
|
||||
#undef DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT
|
||||
|
||||
template <>
|
||||
void stream::Stream<LimitsForCompilationRequest>::Write(Sink* s,
|
||||
const LimitsForCompilationRequest& t) {
|
||||
t.VisitAll([&](const auto&... members) { StreamIn(s, members...); });
|
||||
}
|
||||
|
||||
} // namespace dawn::native
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#define SRC_DAWN_NATIVE_LIMITS_H_
|
||||
|
||||
#include "dawn/native/Error.h"
|
||||
#include "dawn/native/VisitableMembers.h"
|
||||
#include "dawn/native/dawn_platform.h"
|
||||
|
||||
namespace dawn::native {
|
||||
|
@ -38,6 +39,20 @@ MaybeError ValidateLimits(const Limits& supportedLimits, const Limits& requiredL
|
|||
// Returns a copy of |limits| where limit tiers are applied.
|
||||
Limits ApplyLimitTiers(Limits limits);
|
||||
|
||||
// If there are new limit member needed at shader compilation time
|
||||
// Simply append a new X(type, name) here.
|
||||
#define LIMITS_FOR_COMPILATION_REQUEST_MEMBERS(X) \
|
||||
X(uint32_t, maxComputeWorkgroupSizeX) \
|
||||
X(uint32_t, maxComputeWorkgroupSizeY) \
|
||||
X(uint32_t, maxComputeWorkgroupSizeZ) \
|
||||
X(uint32_t, maxComputeInvocationsPerWorkgroup) \
|
||||
X(uint32_t, maxComputeWorkgroupStorageSize)
|
||||
|
||||
struct LimitsForCompilationRequest {
|
||||
static LimitsForCompilationRequest Create(const Limits& limits);
|
||||
DAWN_VISITABLE_MEMBERS(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS)
|
||||
};
|
||||
|
||||
} // namespace dawn::native
|
||||
|
||||
#endif // SRC_DAWN_NATIVE_LIMITS_H_
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
#ifndef SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_
|
||||
#define SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "dawn/common/HashUtils.h"
|
||||
|
@ -60,6 +62,13 @@ class ObjectContentHasher {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T, typename E>
|
||||
struct RecordImpl<std::map<T, E>> {
|
||||
static constexpr void Call(ObjectContentHasher* recorder, const std::map<T, E>& map) {
|
||||
recorder->RecordIterable<std::map<T, E>>(map);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename IteratorT>
|
||||
constexpr void RecordIterable(const IteratorT& iterable) {
|
||||
for (auto it = iterable.begin(); it != iterable.end(); ++it) {
|
||||
|
@ -67,6 +76,14 @@ class ObjectContentHasher {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename E>
|
||||
struct RecordImpl<std::pair<T, E>> {
|
||||
static constexpr void Call(ObjectContentHasher* recorder, const std::pair<T, E>& pair) {
|
||||
recorder->Record(pair.first);
|
||||
recorder->Record(pair.second);
|
||||
}
|
||||
};
|
||||
|
||||
size_t mContentHash = 0;
|
||||
};
|
||||
|
||||
|
|
|
@ -58,12 +58,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
|
|||
DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout));
|
||||
}
|
||||
|
||||
if (constantCount > 0u && device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs)) {
|
||||
return DAWN_VALIDATION_ERROR(
|
||||
"Pipeline overridable constants are disallowed because they are partially "
|
||||
"implemented.");
|
||||
}
|
||||
|
||||
// Validate if overridable constants exist in shader module
|
||||
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor
|
||||
size_t numUninitializedConstants = metadata.uninitializedOverrides.size();
|
||||
|
@ -233,6 +227,7 @@ size_t PipelineBase::ComputeContentHash() {
|
|||
for (SingleShaderStage stage : IterateStages(mStageMask)) {
|
||||
recorder.Record(mStages[stage].module->GetContentHash());
|
||||
recorder.Record(mStages[stage].entryPoint);
|
||||
recorder.Record(mStages[stage].constants);
|
||||
}
|
||||
|
||||
return recorder.GetContentHash();
|
||||
|
@ -248,7 +243,14 @@ bool PipelineBase::EqualForCache(const PipelineBase* a, const PipelineBase* b) {
|
|||
for (SingleShaderStage stage : IterateStages(a->mStageMask)) {
|
||||
// The module is deduplicated so it can be compared by pointer.
|
||||
if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() ||
|
||||
a->mStages[stage].entryPoint != b->mStages[stage].entryPoint) {
|
||||
a->mStages[stage].entryPoint != b->mStages[stage].entryPoint ||
|
||||
a->mStages[stage].constants.size() != b->mStages[stage].constants.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If the constants.size are the same, we still need to compare the key and value.
|
||||
if (!std::equal(a->mStages[stage].constants.begin(), a->mStages[stage].constants.end(),
|
||||
b->mStages[stage].constants.begin())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,9 +40,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
|
|||
const PipelineLayoutBase* layout,
|
||||
SingleShaderStage stage);
|
||||
|
||||
// Use map to make sure constant keys are sorted for creating shader cache keys
|
||||
using PipelineConstantEntries = std::map<std::string, double>;
|
||||
|
||||
struct ProgrammableStage {
|
||||
Ref<ShaderModuleBase> module;
|
||||
std::string entryPoint;
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "absl/strings/str_format.h"
|
||||
#include "dawn/common/BitSetIterator.h"
|
||||
#include "dawn/common/Constants.h"
|
||||
#include "dawn/common/HashUtils.h"
|
||||
#include "dawn/native/BindGroupLayout.h"
|
||||
#include "dawn/native/ChainUtils_autogen.h"
|
||||
#include "dawn/native/CompilationMessages.h"
|
||||
|
@ -511,7 +510,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
|||
const DeviceBase* device,
|
||||
tint::inspector::Inspector* inspector,
|
||||
const tint::inspector::EntryPoint& entryPoint) {
|
||||
const CombinedLimits& limits = device->GetLimits();
|
||||
constexpr uint32_t kMaxInterStageShaderLocation = kMaxInterStageShaderVariables - 1;
|
||||
|
||||
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
|
||||
|
@ -528,10 +526,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
|||
})()
|
||||
|
||||
if (!entryPoint.overrides.empty()) {
|
||||
DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs),
|
||||
"Pipeline overridable constants are disallowed because they "
|
||||
"are partially implemented.");
|
||||
|
||||
const auto& name2Id = inspector->GetNamedOverrideIds();
|
||||
const auto& id2Scalar = inspector->GetOverrideDefaultValues();
|
||||
|
||||
|
@ -553,10 +547,10 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
|||
UNREACHABLE();
|
||||
}
|
||||
}
|
||||
EntryPointMetadata::Override override = {id.value, FromTintOverrideType(c.type),
|
||||
EntryPointMetadata::Override override = {id, FromTintOverrideType(c.type),
|
||||
c.is_initialized, defaultValue};
|
||||
|
||||
std::string identifier = c.is_id_specified ? std::to_string(override.id) : c.name;
|
||||
std::string identifier = c.is_id_specified ? std::to_string(override.id.value) : c.name;
|
||||
metadata->overrides[identifier] = override;
|
||||
|
||||
if (!c.is_initialized) {
|
||||
|
@ -575,39 +569,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
|||
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
|
||||
|
||||
if (metadata->stage == SingleShaderStage::Compute) {
|
||||
auto workgroup_size = entryPoint.workgroup_size;
|
||||
DAWN_INVALID_IF(
|
||||
!workgroup_size.has_value(),
|
||||
"TODO(crbug.com/dawn/1504): Dawn does not currently support @workgroup_size "
|
||||
"attributes using override-expressions");
|
||||
DelayedInvalidIf(workgroup_size->x > limits.v1.maxComputeWorkgroupSizeX ||
|
||||
workgroup_size->y > limits.v1.maxComputeWorkgroupSizeY ||
|
||||
workgroup_size->z > limits.v1.maxComputeWorkgroupSizeZ,
|
||||
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
|
||||
"maximum allowed (%u, %u, %u).",
|
||||
workgroup_size->x, workgroup_size->y, workgroup_size->z,
|
||||
limits.v1.maxComputeWorkgroupSizeX, limits.v1.maxComputeWorkgroupSizeY,
|
||||
limits.v1.maxComputeWorkgroupSizeZ);
|
||||
|
||||
// Dimensions have already been validated against their individual limits above.
|
||||
// Cast to uint64_t to avoid overflow in this multiplication.
|
||||
uint64_t numInvocations =
|
||||
static_cast<uint64_t>(workgroup_size->x) * workgroup_size->y * workgroup_size->z;
|
||||
DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup,
|
||||
"The total number of workgroup invocations (%u) exceeds the "
|
||||
"maximum allowed (%u).",
|
||||
numInvocations, limits.v1.maxComputeInvocationsPerWorkgroup);
|
||||
|
||||
const size_t workgroupStorageSize = inspector->GetWorkgroupStorageSize(entryPoint.name);
|
||||
DelayedInvalidIf(workgroupStorageSize > limits.v1.maxComputeWorkgroupStorageSize,
|
||||
"The total use of workgroup storage (%u bytes) is larger than "
|
||||
"the maximum allowed (%u bytes).",
|
||||
workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize);
|
||||
|
||||
metadata->localWorkgroupSize.x = workgroup_size->x;
|
||||
metadata->localWorkgroupSize.y = workgroup_size->y;
|
||||
metadata->localWorkgroupSize.z = workgroup_size->z;
|
||||
|
||||
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
|
||||
}
|
||||
|
||||
|
@ -883,6 +844,46 @@ MaybeError ReflectShaderUsingTint(const DeviceBase* device,
|
|||
}
|
||||
} // anonymous namespace
|
||||
|
||||
ResultOrError<Extent3D> ValidateComputeStageWorkgroupSize(
|
||||
const tint::Program& program,
|
||||
const char* entryPointName,
|
||||
const LimitsForCompilationRequest& limits) {
|
||||
tint::inspector::Inspector inspector(&program);
|
||||
// At this point the entry point must exist and must have workgroup size values.
|
||||
tint::inspector::EntryPoint entryPoint = inspector.GetEntryPoint(entryPointName);
|
||||
ASSERT(entryPoint.workgroup_size.has_value());
|
||||
const tint::inspector::WorkgroupSize& workgroup_size = entryPoint.workgroup_size.value();
|
||||
|
||||
DAWN_INVALID_IF(workgroup_size.x < 1 || workgroup_size.y < 1 || workgroup_size.z < 1,
|
||||
"Entry-point uses workgroup_size(%u, %u, %u) that are below the "
|
||||
"minimum allowed (1, 1, 1).",
|
||||
workgroup_size.x, workgroup_size.y, workgroup_size.z);
|
||||
|
||||
DAWN_INVALID_IF(workgroup_size.x > limits.maxComputeWorkgroupSizeX ||
|
||||
workgroup_size.y > limits.maxComputeWorkgroupSizeY ||
|
||||
workgroup_size.z > limits.maxComputeWorkgroupSizeZ,
|
||||
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
|
||||
"maximum allowed (%u, %u, %u).",
|
||||
workgroup_size.x, workgroup_size.y, workgroup_size.z,
|
||||
limits.maxComputeWorkgroupSizeX, limits.maxComputeWorkgroupSizeY,
|
||||
limits.maxComputeWorkgroupSizeZ);
|
||||
|
||||
uint64_t numInvocations =
|
||||
static_cast<uint64_t>(workgroup_size.x) * workgroup_size.y * workgroup_size.z;
|
||||
DAWN_INVALID_IF(numInvocations > limits.maxComputeInvocationsPerWorkgroup,
|
||||
"The total number of workgroup invocations (%u) exceeds the "
|
||||
"maximum allowed (%u).",
|
||||
numInvocations, limits.maxComputeInvocationsPerWorkgroup);
|
||||
|
||||
const size_t workgroupStorageSize = inspector.GetWorkgroupStorageSize(entryPointName);
|
||||
DAWN_INVALID_IF(workgroupStorageSize > limits.maxComputeWorkgroupStorageSize,
|
||||
"The total use of workgroup storage (%u bytes) is larger than "
|
||||
"the maximum allowed (%u bytes).",
|
||||
workgroupStorageSize, limits.maxComputeWorkgroupStorageSize);
|
||||
|
||||
return Extent3D{workgroup_size.x, workgroup_size.y, workgroup_size.z};
|
||||
}
|
||||
|
||||
ShaderModuleParseResult::ShaderModuleParseResult() = default;
|
||||
ShaderModuleParseResult::~ShaderModuleParseResult() = default;
|
||||
|
||||
|
@ -1200,11 +1201,4 @@ MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult
|
|||
return {};
|
||||
}
|
||||
|
||||
size_t PipelineLayoutEntryPointPairHashFunc::operator()(
|
||||
const PipelineLayoutEntryPointPair& pair) const {
|
||||
size_t hash = 0;
|
||||
HashCombine(&hash, pair.first, pair.second);
|
||||
return hash;
|
||||
}
|
||||
|
||||
} // namespace dawn::native
|
||||
|
|
|
@ -33,10 +33,12 @@
|
|||
#include "dawn/native/Format.h"
|
||||
#include "dawn/native/Forward.h"
|
||||
#include "dawn/native/IntegerTypes.h"
|
||||
#include "dawn/native/Limits.h"
|
||||
#include "dawn/native/ObjectBase.h"
|
||||
#include "dawn/native/PerStage.h"
|
||||
#include "dawn/native/VertexFormat.h"
|
||||
#include "dawn/native/dawn_platform.h"
|
||||
#include "tint/override_id.h"
|
||||
|
||||
namespace tint {
|
||||
|
||||
|
@ -76,10 +78,8 @@ enum class InterpolationSampling {
|
|||
Sample,
|
||||
};
|
||||
|
||||
using PipelineLayoutEntryPointPair = std::pair<const PipelineLayoutBase*, std::string>;
|
||||
struct PipelineLayoutEntryPointPairHashFunc {
|
||||
size_t operator()(const PipelineLayoutEntryPointPair& pair) const;
|
||||
};
|
||||
// Use map to make sure constant keys are sorted for creating shader cache keys
|
||||
using PipelineConstantEntries = std::map<std::string, double>;
|
||||
|
||||
// A map from name to EntryPointMetadata.
|
||||
using EntryPointMetadataTable =
|
||||
|
@ -108,6 +108,13 @@ MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
|
|||
const EntryPointMetadata& entryPoint,
|
||||
const PipelineLayoutBase* layout);
|
||||
|
||||
// Return extent3D with workgroup size dimension info if it is valid
|
||||
// width = x, height = y, depthOrArrayLength = z
|
||||
ResultOrError<Extent3D> ValidateComputeStageWorkgroupSize(
|
||||
const tint::Program& program,
|
||||
const char* entryPointName,
|
||||
const LimitsForCompilationRequest& limits);
|
||||
|
||||
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
|
||||
const PipelineLayoutBase* layout);
|
||||
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
|
||||
|
@ -204,14 +211,12 @@ struct EntryPointMetadata {
|
|||
std::bitset<kMaxInterStageShaderVariables> usedInterStageVariables;
|
||||
std::array<InterStageVariableInfo, kMaxInterStageShaderVariables> interStageVariables;
|
||||
|
||||
// The local workgroup size declared for a compute entry point (or 0s otehrwise).
|
||||
Origin3D localWorkgroupSize;
|
||||
|
||||
// The shader stage for this binding.
|
||||
SingleShaderStage stage;
|
||||
|
||||
struct Override {
|
||||
uint32_t id;
|
||||
tint::OverrideId id;
|
||||
|
||||
// Match tint::inspector::Override::Type
|
||||
// Bool is defined as a macro on linux X11 and cannot compile
|
||||
enum class Type { Boolean, Float32, Uint32, Int32 } type;
|
||||
|
@ -273,6 +278,7 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
|
|||
bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const;
|
||||
};
|
||||
|
||||
// This returns tint program before running transforms.
|
||||
const tint::Program* GetTintProgram() const;
|
||||
|
||||
void APIGetCompilationInfo(wgpu::CompilationInfoCallback callback, void* userdata);
|
||||
|
|
|
@ -65,6 +65,23 @@ void stream::Stream<tint::transform::VertexPulling::Config>::Write(
|
|||
StreamInTintObject(cfg, sink);
|
||||
}
|
||||
|
||||
// static
|
||||
template <>
|
||||
void stream::Stream<tint::transform::SubstituteOverride::Config>::Write(
|
||||
stream::Sink* sink,
|
||||
const tint::transform::SubstituteOverride::Config& cfg) {
|
||||
StreamInTintObject(cfg, sink);
|
||||
}
|
||||
|
||||
// static
|
||||
template <>
|
||||
void stream::Stream<tint::OverrideId>::Write(stream::Sink* sink, const tint::OverrideId& id) {
|
||||
// TODO(tint:1640): fix the include build issues and use StreamInTintObject instead.
|
||||
static_assert(offsetof(tint::OverrideId, value) == 0,
|
||||
"Please update serialization for tint::OverrideId");
|
||||
StreamIn(sink, id.value);
|
||||
}
|
||||
|
||||
// static
|
||||
template <>
|
||||
void stream::Stream<tint::transform::VertexBufferLayoutDescriptor>::Write(
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "dawn/native/BindGroupLayout.h"
|
||||
#include "dawn/native/Device.h"
|
||||
#include "dawn/native/Pipeline.h"
|
||||
#include "dawn/native/PipelineLayout.h"
|
||||
#include "dawn/native/RenderPipeline.h"
|
||||
|
||||
|
@ -183,6 +184,21 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
|
|||
return cfg;
|
||||
}
|
||||
|
||||
tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig(
|
||||
const ProgrammableStage& stage) {
|
||||
const EntryPointMetadata& metadata = *stage.metadata;
|
||||
const auto& constants = stage.constants;
|
||||
|
||||
tint::transform::SubstituteOverride::Config cfg;
|
||||
|
||||
for (const auto& [key, value] : constants) {
|
||||
const auto& o = metadata.overrides.at(key);
|
||||
cfg.map.insert({o.id, value});
|
||||
}
|
||||
|
||||
return cfg;
|
||||
}
|
||||
|
||||
} // namespace dawn::native
|
||||
|
||||
namespace tint::sem {
|
||||
|
|
|
@ -26,6 +26,7 @@ namespace dawn::native {
|
|||
|
||||
class DeviceBase;
|
||||
class PipelineLayoutBase;
|
||||
struct ProgrammableStage;
|
||||
class RenderPipelineBase;
|
||||
|
||||
// Indicates that for the lifetime of this object tint internal compiler errors should be
|
||||
|
@ -47,6 +48,9 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
|
|||
const std::string_view& entryPoint,
|
||||
BindGroupIndex pullingBufferBindingSet);
|
||||
|
||||
tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig(
|
||||
const ProgrammableStage& stage);
|
||||
|
||||
} // namespace dawn::native
|
||||
|
||||
namespace tint::sem {
|
||||
|
|
|
@ -351,8 +351,6 @@ MaybeError RenderPipeline::Initialize() {
|
|||
|
||||
D3D12_GRAPHICS_PIPELINE_STATE_DESC descriptorD3D12 = {};
|
||||
|
||||
PerStage<ProgrammableStage> pipelineStages = GetAllStages();
|
||||
|
||||
PerStage<D3D12_SHADER_BYTECODE*> shaders;
|
||||
shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS;
|
||||
shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS;
|
||||
|
@ -360,8 +358,9 @@ MaybeError RenderPipeline::Initialize() {
|
|||
PerStage<CompiledShader> compiledShader;
|
||||
|
||||
for (auto stage : IterateStages(GetStageMask())) {
|
||||
DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(pipelineStages[stage].module)
|
||||
->Compile(pipelineStages[stage], stage,
|
||||
const ProgrammableStage& programmableStage = GetStage(stage);
|
||||
DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(programmableStage.module)
|
||||
->Compile(programmableStage, stage,
|
||||
ToBackend(GetLayout()), compileFlags));
|
||||
*shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
|
||||
}
|
||||
|
|
|
@ -65,100 +65,35 @@ namespace dawn::native::d3d12 {
|
|||
|
||||
namespace {
|
||||
|
||||
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
|
||||
std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
|
||||
std::ostringstream out;
|
||||
out.precision(n);
|
||||
out << std::fixed << v;
|
||||
return out.str();
|
||||
}
|
||||
|
||||
std::string GetHLSLValueString(EntryPointMetadata::Override::Type dawnType,
|
||||
const OverrideScalar* entry,
|
||||
double value = 0) {
|
||||
switch (dawnType) {
|
||||
case EntryPointMetadata::Override::Type::Boolean:
|
||||
return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
|
||||
case EntryPointMetadata::Override::Type::Float32:
|
||||
return FloatToStringWithPrecision(entry ? entry->f32 : static_cast<float>(value));
|
||||
case EntryPointMetadata::Override::Type::Int32:
|
||||
return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
|
||||
case EntryPointMetadata::Override::Type::Uint32:
|
||||
return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
|
||||
default:
|
||||
UNREACHABLE();
|
||||
}
|
||||
}
|
||||
|
||||
constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
|
||||
|
||||
using DefineStrings = std::vector<std::pair<std::string, std::string>>;
|
||||
|
||||
DefineStrings GetOverridableConstantsDefines(
|
||||
const PipelineConstantEntries& pipelineConstantEntries,
|
||||
const EntryPointMetadata::OverridesMap& shaderEntryPointConstants) {
|
||||
DefineStrings defineStrings;
|
||||
std::unordered_set<std::string> overriddenConstants;
|
||||
|
||||
// Set pipeline overridden values
|
||||
for (const auto& [name, value] : pipelineConstantEntries) {
|
||||
overriddenConstants.insert(name);
|
||||
|
||||
// This is already validated so `name` must exist
|
||||
const auto& moduleConstant = shaderEntryPointConstants.at(name);
|
||||
|
||||
defineStrings.emplace_back(
|
||||
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
|
||||
GetHLSLValueString(moduleConstant.type, nullptr, value));
|
||||
}
|
||||
|
||||
// Set shader initialized default values
|
||||
for (const auto& iter : shaderEntryPointConstants) {
|
||||
const std::string& name = iter.first;
|
||||
if (overriddenConstants.count(name) != 0) {
|
||||
// This constant already has overridden value
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& moduleConstant = shaderEntryPointConstants.at(name);
|
||||
|
||||
// Uninitialized default values are okay since they ar only defined to pass
|
||||
// compilation but not used
|
||||
defineStrings.emplace_back(
|
||||
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
|
||||
GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
|
||||
}
|
||||
return defineStrings;
|
||||
}
|
||||
|
||||
enum class Compiler { FXC, DXC };
|
||||
|
||||
#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
|
||||
X(const tint::Program*, inputProgram) \
|
||||
X(std::string_view, entryPointName) \
|
||||
X(SingleShaderStage, stage) \
|
||||
X(uint32_t, shaderModel) \
|
||||
X(uint32_t, compileFlags) \
|
||||
X(Compiler, compiler) \
|
||||
X(uint64_t, compilerVersion) \
|
||||
X(std::wstring_view, dxcShaderProfile) \
|
||||
X(std::string_view, fxcShaderProfile) \
|
||||
X(pD3DCompile, d3dCompile) \
|
||||
X(IDxcLibrary*, dxcLibrary) \
|
||||
X(IDxcCompiler*, dxcCompiler) \
|
||||
X(uint32_t, firstIndexOffsetShaderRegister) \
|
||||
X(uint32_t, firstIndexOffsetRegisterSpace) \
|
||||
X(bool, usesNumWorkgroups) \
|
||||
X(uint32_t, numWorkgroupsShaderRegister) \
|
||||
X(uint32_t, numWorkgroupsRegisterSpace) \
|
||||
X(DefineStrings, defineStrings) \
|
||||
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
|
||||
X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \
|
||||
X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \
|
||||
X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \
|
||||
X(bool, disableSymbolRenaming) \
|
||||
X(bool, isRobustnessEnabled) \
|
||||
X(bool, disableWorkgroupInit) \
|
||||
#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
|
||||
X(const tint::Program*, inputProgram) \
|
||||
X(std::string_view, entryPointName) \
|
||||
X(SingleShaderStage, stage) \
|
||||
X(uint32_t, shaderModel) \
|
||||
X(uint32_t, compileFlags) \
|
||||
X(Compiler, compiler) \
|
||||
X(uint64_t, compilerVersion) \
|
||||
X(std::wstring_view, dxcShaderProfile) \
|
||||
X(std::string_view, fxcShaderProfile) \
|
||||
X(pD3DCompile, d3dCompile) \
|
||||
X(IDxcLibrary*, dxcLibrary) \
|
||||
X(IDxcCompiler*, dxcCompiler) \
|
||||
X(uint32_t, firstIndexOffsetShaderRegister) \
|
||||
X(uint32_t, firstIndexOffsetRegisterSpace) \
|
||||
X(bool, usesNumWorkgroups) \
|
||||
X(uint32_t, numWorkgroupsShaderRegister) \
|
||||
X(uint32_t, numWorkgroupsRegisterSpace) \
|
||||
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
|
||||
X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \
|
||||
X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \
|
||||
X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \
|
||||
X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
|
||||
X(LimitsForCompilationRequest, limits) \
|
||||
X(bool, disableSymbolRenaming) \
|
||||
X(bool, isRobustnessEnabled) \
|
||||
X(bool, disableWorkgroupInit) \
|
||||
X(bool, dumpShaders)
|
||||
|
||||
#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \
|
||||
|
@ -170,8 +105,7 @@ enum class Compiler { FXC, DXC };
|
|||
X(std::string_view, fxcShaderProfile) \
|
||||
X(pD3DCompile, d3dCompile) \
|
||||
X(IDxcLibrary*, dxcLibrary) \
|
||||
X(IDxcCompiler*, dxcCompiler) \
|
||||
X(DefineStrings, defineStrings)
|
||||
X(IDxcCompiler*, dxcCompiler)
|
||||
|
||||
DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){};
|
||||
#undef HLSL_COMPILATION_REQUEST_MEMBERS
|
||||
|
@ -255,25 +189,11 @@ ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(const D3DBytecodeCompilationReq
|
|||
std::vector<const wchar_t*> arguments =
|
||||
GetDXCArguments(r.compileFlags, r.hasShaderFloat16Feature);
|
||||
|
||||
// Build defines for overridable constants
|
||||
std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
|
||||
defineStrings.reserve(r.defineStrings.size());
|
||||
for (const auto& [name, value] : r.defineStrings) {
|
||||
defineStrings.emplace_back(UTF8ToWStr(name.c_str()), UTF8ToWStr(value.c_str()));
|
||||
}
|
||||
|
||||
std::vector<DxcDefine> dxcDefines;
|
||||
dxcDefines.reserve(defineStrings.size());
|
||||
for (const auto& [name, value] : defineStrings) {
|
||||
dxcDefines.push_back({name.c_str(), value.c_str()});
|
||||
}
|
||||
|
||||
ComPtr<IDxcOperationResult> result;
|
||||
DAWN_TRY(CheckHRESULT(
|
||||
r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
|
||||
r.dxcShaderProfile.data(), arguments.data(), arguments.size(),
|
||||
dxcDefines.data(), dxcDefines.size(), nullptr, &result),
|
||||
"DXC compile"));
|
||||
DAWN_TRY(CheckHRESULT(r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
|
||||
r.dxcShaderProfile.data(), arguments.data(),
|
||||
arguments.size(), nullptr, 0, nullptr, &result),
|
||||
"DXC compile"));
|
||||
|
||||
HRESULT hr;
|
||||
DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
|
||||
|
@ -359,20 +279,7 @@ ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const D3DBytecodeCompilationReq
|
|||
ComPtr<ID3DBlob> compiledShader;
|
||||
ComPtr<ID3DBlob> errors;
|
||||
|
||||
// Build defines for overridable constants
|
||||
const D3D_SHADER_MACRO* pDefines = nullptr;
|
||||
std::vector<D3D_SHADER_MACRO> fxcDefines;
|
||||
if (r.defineStrings.size() > 0) {
|
||||
fxcDefines.reserve(r.defineStrings.size() + 1);
|
||||
for (const auto& [name, value] : r.defineStrings) {
|
||||
fxcDefines.push_back({name.c_str(), value.c_str()});
|
||||
}
|
||||
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
|
||||
fxcDefines.push_back({nullptr, nullptr});
|
||||
pDefines = fxcDefines.data();
|
||||
}
|
||||
|
||||
DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines,
|
||||
DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
|
||||
nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(),
|
||||
r.compileFlags, 0, &compiledShader, &errors)),
|
||||
"D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer()));
|
||||
|
@ -420,6 +327,14 @@ ResultOrError<std::string> TranslateToHLSL(
|
|||
tint::transform::Renamer::Target::kHlslKeywords);
|
||||
}
|
||||
|
||||
if (r.substituteOverrideConfig) {
|
||||
// This needs to run after SingleEntryPoint transform to get rid of overrides not used for
|
||||
// the current entry point.
|
||||
transformManager.Add<tint::transform::SubstituteOverride>();
|
||||
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
|
||||
std::move(r.substituteOverrideConfig).value());
|
||||
}
|
||||
|
||||
// D3D12 registers like `t3` and `c3` have the same bindingOffset number in
|
||||
// the remapping but should not be considered a collision because they have
|
||||
// different types.
|
||||
|
@ -450,6 +365,13 @@ ResultOrError<std::string> TranslateToHLSL(
|
|||
return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
|
||||
}
|
||||
|
||||
if (r.stage == SingleShaderStage::Compute) {
|
||||
// Validate workgroup size after program runs transforms.
|
||||
Extent3D _;
|
||||
DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
|
||||
transformedProgram, remappedEntryPointName->data(), r.limits));
|
||||
}
|
||||
|
||||
if (r.stage == SingleShaderStage::Vertex) {
|
||||
if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
|
||||
*usesVertexOrInstanceIndex = data->has_vertex_or_instance_index;
|
||||
|
@ -555,8 +477,7 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
|
|||
|
||||
req.bytecode.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
|
||||
req.bytecode.compileFlags = compileFlags;
|
||||
req.bytecode.defineStrings =
|
||||
GetOverridableConstantsDefines(programmableStage.constants, entryPoint.overrides);
|
||||
|
||||
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
||||
req.bytecode.compiler = Compiler::DXC;
|
||||
req.bytecode.dxcLibrary = device->GetDxcLibrary().Get();
|
||||
|
@ -645,6 +566,11 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
|
|||
}
|
||||
}
|
||||
|
||||
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
|
||||
if (!programmableStage.metadata->overrides.empty()) {
|
||||
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
|
||||
}
|
||||
|
||||
req.hlsl.inputProgram = GetTintProgram();
|
||||
req.hlsl.entryPointName = programmableStage.entryPoint.c_str();
|
||||
req.hlsl.stage = stage;
|
||||
|
@ -657,6 +583,10 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
|
|||
req.hlsl.remappedAccessControls = std::move(remappedAccessControls);
|
||||
req.hlsl.newBindingsMap = BuildExternalTextureTransformBindings(layout);
|
||||
req.hlsl.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
|
||||
req.hlsl.substituteOverrideConfig = std::move(substituteOverrideConfig);
|
||||
|
||||
const CombinedLimits& limits = device->GetLimits();
|
||||
req.hlsl.limits = LimitsForCompilationRequest::Create(limits.v1);
|
||||
|
||||
CacheResult<CompiledShader> compiledShader;
|
||||
DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob,
|
||||
|
|
|
@ -40,8 +40,9 @@ MaybeError ComputePipeline::Initialize() {
|
|||
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
|
||||
ShaderModule::MetalFunctionData computeData;
|
||||
|
||||
DAWN_TRY(CreateMTLFunction(computeStage, SingleShaderStage::Compute, ToBackend(GetLayout()),
|
||||
&computeData));
|
||||
DAWN_TRY(ToBackend(computeStage.module.Get())
|
||||
->CreateFunction(SingleShaderStage::Compute, computeStage, ToBackend(GetLayout()),
|
||||
&computeData));
|
||||
|
||||
NSError* error = nullptr;
|
||||
mMtlComputePipelineState.Acquire(
|
||||
|
@ -53,8 +54,7 @@ MaybeError ComputePipeline::Initialize() {
|
|||
ASSERT(mMtlComputePipelineState != nil);
|
||||
|
||||
// Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
|
||||
Origin3D localSize = GetStage(SingleShaderStage::Compute).metadata->localWorkgroupSize;
|
||||
mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z);
|
||||
mLocalWorkgroupSize = computeData.localWorkgroupSize;
|
||||
|
||||
mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
|
||||
mWorkgroupAllocations = std::move(computeData.workgroupAllocations);
|
||||
|
|
|
@ -340,8 +340,9 @@ MaybeError RenderPipeline::Initialize() {
|
|||
const PerStage<ProgrammableStage>& allStages = GetAllStages();
|
||||
const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex];
|
||||
ShaderModule::MetalFunctionData vertexData;
|
||||
DAWN_TRY(CreateMTLFunction(vertexStage, SingleShaderStage::Vertex, ToBackend(GetLayout()),
|
||||
&vertexData, 0xFFFFFFFF, this));
|
||||
DAWN_TRY(ToBackend(vertexStage.module.Get())
|
||||
->CreateFunction(SingleShaderStage::Vertex, vertexStage, ToBackend(GetLayout()),
|
||||
&vertexData, 0xFFFFFFFF, this));
|
||||
|
||||
descriptorMTL.vertexFunction = vertexData.function.Get();
|
||||
if (vertexData.needsStorageBufferLength) {
|
||||
|
@ -351,8 +352,9 @@ MaybeError RenderPipeline::Initialize() {
|
|||
if (GetStageMask() & wgpu::ShaderStage::Fragment) {
|
||||
const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment];
|
||||
ShaderModule::MetalFunctionData fragmentData;
|
||||
DAWN_TRY(CreateMTLFunction(fragmentStage, SingleShaderStage::Fragment,
|
||||
ToBackend(GetLayout()), &fragmentData, GetSampleMask()));
|
||||
DAWN_TRY(ToBackend(fragmentStage.module.Get())
|
||||
->CreateFunction(SingleShaderStage::Fragment, fragmentStage,
|
||||
ToBackend(GetLayout()), &fragmentData, GetSampleMask()));
|
||||
|
||||
descriptorMTL.fragmentFunction = fragmentData.function.Get();
|
||||
if (fragmentData.needsStorageBufferLength) {
|
||||
|
|
|
@ -25,6 +25,10 @@
|
|||
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
namespace dawn::native {
|
||||
struct ProgrammableStage;
|
||||
}
|
||||
|
||||
namespace dawn::native::metal {
|
||||
|
||||
class Device;
|
||||
|
@ -42,15 +46,13 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
NSPRef<id<MTLFunction>> function;
|
||||
bool needsStorageBufferLength;
|
||||
std::vector<uint32_t> workgroupAllocations;
|
||||
MTLSize localWorkgroupSize;
|
||||
};
|
||||
|
||||
// MTLFunctionConstantValues needs @available tag to compile
|
||||
// Use id (like void*) in function signature as workaround and do static cast inside
|
||||
MaybeError CreateFunction(const char* entryPointName,
|
||||
SingleShaderStage stage,
|
||||
MaybeError CreateFunction(SingleShaderStage stage,
|
||||
const ProgrammableStage& programmableStage,
|
||||
const PipelineLayout* layout,
|
||||
MetalFunctionData* out,
|
||||
id constantValues = nil,
|
||||
uint32_t sampleMask = 0xFFFFFFFF,
|
||||
const RenderPipeline* renderPipeline = nullptr);
|
||||
|
||||
|
|
|
@ -35,17 +35,20 @@ namespace {
|
|||
|
||||
using OptionalVertexPullingTransformConfig = std::optional<tint::transform::VertexPulling::Config>;
|
||||
|
||||
#define MSL_COMPILATION_REQUEST_MEMBERS(X) \
|
||||
X(const tint::Program*, inputProgram) \
|
||||
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
|
||||
X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \
|
||||
X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \
|
||||
X(std::string, entryPointName) \
|
||||
X(uint32_t, sampleMask) \
|
||||
X(bool, emitVertexPointSize) \
|
||||
X(bool, isRobustnessEnabled) \
|
||||
X(bool, disableSymbolRenaming) \
|
||||
X(bool, disableWorkgroupInit) \
|
||||
#define MSL_COMPILATION_REQUEST_MEMBERS(X) \
|
||||
X(SingleShaderStage, stage) \
|
||||
X(const tint::Program*, inputProgram) \
|
||||
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
|
||||
X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \
|
||||
X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \
|
||||
X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
|
||||
X(LimitsForCompilationRequest, limits) \
|
||||
X(std::string, entryPointName) \
|
||||
X(uint32_t, sampleMask) \
|
||||
X(bool, emitVertexPointSize) \
|
||||
X(bool, isRobustnessEnabled) \
|
||||
X(bool, disableSymbolRenaming) \
|
||||
X(bool, disableWorkgroupInit) \
|
||||
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
|
||||
|
||||
DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS);
|
||||
|
@ -53,12 +56,13 @@ DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS);
|
|||
|
||||
using WorkgroupAllocations = std::vector<uint32_t>;
|
||||
|
||||
#define MSL_COMPILATION_MEMBERS(X) \
|
||||
X(std::string, msl) \
|
||||
X(std::string, remappedEntryPointName) \
|
||||
X(bool, needsStorageBufferLength) \
|
||||
X(bool, hasInvariantAttribute) \
|
||||
X(WorkgroupAllocations, workgroupAllocations)
|
||||
#define MSL_COMPILATION_MEMBERS(X) \
|
||||
X(std::string, msl) \
|
||||
X(std::string, remappedEntryPointName) \
|
||||
X(bool, needsStorageBufferLength) \
|
||||
X(bool, hasInvariantAttribute) \
|
||||
X(WorkgroupAllocations, workgroupAllocations) \
|
||||
X(Extent3D, localWorkgroupSize)
|
||||
|
||||
DAWN_SERIALIZABLE(struct, MslCompilation, MSL_COMPILATION_MEMBERS){};
|
||||
#undef MSL_COMPILATION_MEMBERS
|
||||
|
@ -92,13 +96,14 @@ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
|
|||
|
||||
namespace {
|
||||
|
||||
ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
||||
const tint::Program* inputProgram,
|
||||
const char* entryPointName,
|
||||
SingleShaderStage stage,
|
||||
const PipelineLayout* layout,
|
||||
uint32_t sampleMask,
|
||||
const RenderPipeline* renderPipeline) {
|
||||
ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(
|
||||
DeviceBase* device,
|
||||
const ProgrammableStage& programmableStage,
|
||||
SingleShaderStage stage,
|
||||
const PipelineLayout* layout,
|
||||
ShaderModule::MetalFunctionData* out,
|
||||
uint32_t sampleMask,
|
||||
const RenderPipeline* renderPipeline) {
|
||||
ScopedTintICEHandler scopedICEHandler(device);
|
||||
|
||||
std::ostringstream errorStream;
|
||||
|
@ -137,7 +142,7 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
|||
if (stage == SingleShaderStage::Vertex &&
|
||||
device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
|
||||
vertexPullingTransformConfig = BuildVertexPullingTransformConfig(
|
||||
*renderPipeline, entryPointName, kPullingBufferBindingSet);
|
||||
*renderPipeline, programmableStage.entryPoint.c_str(), kPullingBufferBindingSet);
|
||||
|
||||
for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
|
||||
uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot);
|
||||
|
@ -152,12 +157,19 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
|||
}
|
||||
}
|
||||
|
||||
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
|
||||
if (!programmableStage.metadata->overrides.empty()) {
|
||||
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
|
||||
}
|
||||
|
||||
MslCompilationRequest req = {};
|
||||
req.inputProgram = inputProgram;
|
||||
req.stage = stage;
|
||||
req.inputProgram = programmableStage.module->GetTintProgram();
|
||||
req.bindingPoints = std::move(bindingPoints);
|
||||
req.externalTextureBindings = std::move(externalTextureBindings);
|
||||
req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig);
|
||||
req.entryPointName = entryPointName;
|
||||
req.substituteOverrideConfig = std::move(substituteOverrideConfig);
|
||||
req.entryPointName = programmableStage.entryPoint.c_str();
|
||||
req.sampleMask = sampleMask;
|
||||
req.emitVertexPointSize =
|
||||
stage == SingleShaderStage::Vertex &&
|
||||
|
@ -166,6 +178,9 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
|||
req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
|
||||
req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform());
|
||||
|
||||
const CombinedLimits& limits = device->GetLimits();
|
||||
req.limits = LimitsForCompilationRequest::Create(limits.v1);
|
||||
|
||||
CacheResult<MslCompilation> mslCompilation;
|
||||
DAWN_TRY_LOAD_OR_RUN(
|
||||
mslCompilation, device, std::move(req), MslCompilation::FromBlob,
|
||||
|
@ -190,6 +205,14 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
|||
std::move(r.vertexPullingTransformConfig).value());
|
||||
}
|
||||
|
||||
if (r.substituteOverrideConfig) {
|
||||
// This needs to run after SingleEntryPoint transform to get rid of overrides not
|
||||
// used for the current entry point.
|
||||
transformManager.Add<tint::transform::SubstituteOverride>();
|
||||
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
|
||||
std::move(r.substituteOverrideConfig).value());
|
||||
}
|
||||
|
||||
if (r.isRobustnessEnabled) {
|
||||
transformManager.Add<tint::transform::Robustness>();
|
||||
}
|
||||
|
@ -230,6 +253,13 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
|||
return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
|
||||
}
|
||||
|
||||
Extent3D localSize{0, 0, 0};
|
||||
if (r.stage == SingleShaderStage::Compute) {
|
||||
// Validate workgroup size after program runs transforms.
|
||||
DAWN_TRY_ASSIGN(localSize, ValidateComputeStageWorkgroupSize(
|
||||
program, remappedEntryPointName.data(), r.limits));
|
||||
}
|
||||
|
||||
tint::writer::msl::Options options;
|
||||
options.buffer_size_ubo_index = kBufferLengthBufferSlot;
|
||||
options.fixed_sample_mask = r.sampleMask;
|
||||
|
@ -258,6 +288,7 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
|||
result.needs_storage_buffer_sizes,
|
||||
result.has_invariant_attribute,
|
||||
std::move(workgroupAllocations),
|
||||
localSize,
|
||||
}};
|
||||
});
|
||||
|
||||
|
@ -272,11 +303,10 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
|
|||
|
||||
} // namespace
|
||||
|
||||
MaybeError ShaderModule::CreateFunction(const char* entryPointName,
|
||||
SingleShaderStage stage,
|
||||
MaybeError ShaderModule::CreateFunction(SingleShaderStage stage,
|
||||
const ProgrammableStage& programmableStage,
|
||||
const PipelineLayout* layout,
|
||||
ShaderModule::MetalFunctionData* out,
|
||||
id constantValuesPointer,
|
||||
uint32_t sampleMask,
|
||||
const RenderPipeline* renderPipeline) {
|
||||
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleMTL::CreateFunction");
|
||||
|
@ -284,16 +314,21 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
|
|||
ASSERT(!IsError());
|
||||
ASSERT(out);
|
||||
|
||||
const char* entryPointName = programmableStage.entryPoint.c_str();
|
||||
|
||||
// Vertex stages must specify a renderPipeline
|
||||
if (stage == SingleShaderStage::Vertex) {
|
||||
ASSERT(renderPipeline != nullptr);
|
||||
}
|
||||
|
||||
CacheResult<MslCompilation> mslCompilation;
|
||||
DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), GetTintProgram(), entryPointName,
|
||||
stage, layout, sampleMask, renderPipeline));
|
||||
DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), programmableStage, stage, layout,
|
||||
out, sampleMask, renderPipeline));
|
||||
out->needsStorageBufferLength = mslCompilation->needsStorageBufferLength;
|
||||
out->workgroupAllocations = std::move(mslCompilation->workgroupAllocations);
|
||||
out->localWorkgroupSize = MTLSizeMake(mslCompilation->localWorkgroupSize.width,
|
||||
mslCompilation->localWorkgroupSize.height,
|
||||
mslCompilation->localWorkgroupSize.depthOrArrayLayers);
|
||||
|
||||
NSRef<NSString> mslSource =
|
||||
AcquireNSRef([[NSString alloc] initWithUTF8String:mslCompilation->msl.c_str()]);
|
||||
|
@ -327,25 +362,7 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
|
|||
|
||||
{
|
||||
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "MTLLibrary::newFunctionWithName");
|
||||
if (constantValuesPointer != nil) {
|
||||
if (@available(macOS 10.12, *)) {
|
||||
MTLFunctionConstantValues* constantValues = constantValuesPointer;
|
||||
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()
|
||||
constantValues:constantValues
|
||||
error:&error]);
|
||||
if (error != nullptr) {
|
||||
if (error.code != MTLLibraryErrorCompileWarning) {
|
||||
return DAWN_VALIDATION_ERROR("Function compile error: %s",
|
||||
[error.localizedDescription UTF8String]);
|
||||
}
|
||||
}
|
||||
ASSERT(out->function != nil);
|
||||
} else {
|
||||
UNREACHABLE();
|
||||
}
|
||||
} else {
|
||||
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
|
||||
}
|
||||
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
|
||||
}
|
||||
|
||||
if (BlobCache* cache = GetDevice()->GetBlobCache()) {
|
||||
|
|
|
@ -68,15 +68,6 @@ void EnsureDestinationTextureInitialized(CommandRecordingContext* commandContext
|
|||
|
||||
MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect);
|
||||
|
||||
// Helper function to create function with constant values wrapped in
|
||||
// if available branch
|
||||
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
|
||||
SingleShaderStage singleShaderStage,
|
||||
PipelineLayout* pipelineLayout,
|
||||
ShaderModule::MetalFunctionData* functionData,
|
||||
uint32_t sampleMask = 0xFFFFFFFF,
|
||||
const RenderPipeline* renderPipeline = nullptr);
|
||||
|
||||
// Allow use MTLStoreActionStoreAndMultismapleResolve because the logic in the backend is
|
||||
// first to compute what the "best" Metal render pass descriptor is, then fix it up if we
|
||||
// are not on macOS 10.12 (i.e. the EmulateStoreAndMSAAResolve toggle is on).
|
||||
|
|
|
@ -323,104 +323,6 @@ MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect) {
|
|||
return MTLBlitOptionNone;
|
||||
}
|
||||
|
||||
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
|
||||
SingleShaderStage singleShaderStage,
|
||||
PipelineLayout* pipelineLayout,
|
||||
ShaderModule::MetalFunctionData* functionData,
|
||||
uint32_t sampleMask,
|
||||
const RenderPipeline* renderPipeline) {
|
||||
ShaderModule* shaderModule = ToBackend(programmableStage.module.Get());
|
||||
const char* shaderEntryPoint = programmableStage.entryPoint.c_str();
|
||||
const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint);
|
||||
if (entryPointMetadata.overrides.size() == 0) {
|
||||
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout,
|
||||
functionData, nil, sampleMask, renderPipeline));
|
||||
return {};
|
||||
}
|
||||
|
||||
if (@available(macOS 10.12, *)) {
|
||||
// MTLFunctionConstantValues can only be created within the if available branch
|
||||
NSRef<MTLFunctionConstantValues> constantValues =
|
||||
AcquireNSRef([MTLFunctionConstantValues new]);
|
||||
|
||||
std::unordered_set<std::string> overriddenConstants;
|
||||
|
||||
auto switchType = [&](EntryPointMetadata::Override::Type dawnType,
|
||||
MTLDataType* type, OverrideScalar* entry,
|
||||
double value = 0) {
|
||||
switch (dawnType) {
|
||||
case EntryPointMetadata::Override::Type::Boolean:
|
||||
*type = MTLDataTypeBool;
|
||||
if (entry) {
|
||||
entry->b = static_cast<int32_t>(value);
|
||||
}
|
||||
break;
|
||||
case EntryPointMetadata::Override::Type::Float32:
|
||||
*type = MTLDataTypeFloat;
|
||||
if (entry) {
|
||||
entry->f32 = static_cast<float>(value);
|
||||
}
|
||||
break;
|
||||
case EntryPointMetadata::Override::Type::Int32:
|
||||
*type = MTLDataTypeInt;
|
||||
if (entry) {
|
||||
entry->i32 = static_cast<int32_t>(value);
|
||||
}
|
||||
break;
|
||||
case EntryPointMetadata::Override::Type::Uint32:
|
||||
*type = MTLDataTypeUInt;
|
||||
if (entry) {
|
||||
entry->u32 = static_cast<uint32_t>(value);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
UNREACHABLE();
|
||||
}
|
||||
};
|
||||
|
||||
for (const auto& [name, value] : programmableStage.constants) {
|
||||
overriddenConstants.insert(name);
|
||||
|
||||
// This is already validated so `name` must exist
|
||||
const auto& moduleConstant = entryPointMetadata.overrides.at(name);
|
||||
|
||||
MTLDataType type;
|
||||
OverrideScalar entry{};
|
||||
|
||||
switchType(moduleConstant.type, &type, &entry, value);
|
||||
|
||||
[constantValues.Get() setConstantValue:&entry type:type atIndex:moduleConstant.id];
|
||||
}
|
||||
|
||||
// Set shader initialized default values because MSL function_constant
|
||||
// has no default value
|
||||
for (const std::string& name : entryPointMetadata.initializedOverrides) {
|
||||
if (overriddenConstants.count(name) != 0) {
|
||||
// This constant already has overridden value
|
||||
continue;
|
||||
}
|
||||
|
||||
// Must exist because it is validated
|
||||
const auto& moduleConstant = entryPointMetadata.overrides.at(name);
|
||||
ASSERT(moduleConstant.isInitialized);
|
||||
MTLDataType type;
|
||||
|
||||
switchType(moduleConstant.type, &type, nullptr);
|
||||
|
||||
[constantValues.Get() setConstantValue:&moduleConstant.defaultValue
|
||||
type:type
|
||||
atIndex:moduleConstant.id];
|
||||
}
|
||||
|
||||
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout,
|
||||
functionData, constantValues.Get(), sampleMask,
|
||||
renderPipeline));
|
||||
} else {
|
||||
UNREACHABLE();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
MaybeError EncodeMetalRenderPass(Device* device,
|
||||
CommandRecordingContext* commandContext,
|
||||
MTLRenderPassDescriptor* mtlRenderPass,
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include <string>
|
||||
|
||||
#include "dawn/native/Limits.h"
|
||||
|
||||
namespace dawn::native::stream {
|
||||
|
||||
template <>
|
||||
|
|
|
@ -61,16 +61,12 @@ MaybeError ComputePipeline::Initialize() {
|
|||
|
||||
ShaderModule::ModuleAndSpirv moduleAndSpirv;
|
||||
DAWN_TRY_ASSIGN(moduleAndSpirv,
|
||||
module->GetHandleAndSpirv(computeStage.entryPoint.c_str(), layout));
|
||||
module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout));
|
||||
|
||||
createInfo.stage.module = moduleAndSpirv.module;
|
||||
createInfo.stage.pName = computeStage.entryPoint.c_str();
|
||||
|
||||
std::vector<OverrideScalar> specializationDataEntries;
|
||||
std::vector<VkSpecializationMapEntry> specializationMapEntries;
|
||||
VkSpecializationInfo specializationInfo{};
|
||||
createInfo.stage.pSpecializationInfo = GetVkSpecializationInfo(
|
||||
computeStage, &specializationInfo, &specializationDataEntries, &specializationMapEntries);
|
||||
createInfo.stage.pSpecializationInfo = nullptr;
|
||||
|
||||
PNextChainBuilder stageExtChain(&createInfo.stage);
|
||||
|
||||
|
|
|
@ -341,9 +341,6 @@ MaybeError RenderPipeline::Initialize() {
|
|||
|
||||
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
|
||||
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
|
||||
std::array<std::vector<OverrideScalar>, 2> specializationDataEntriesPerStages;
|
||||
std::array<std::vector<VkSpecializationMapEntry>, 2> specializationMapEntriesPerStages;
|
||||
std::array<VkSpecializationInfo, 2> specializationInfoPerStages;
|
||||
uint32_t stageCount = 0;
|
||||
|
||||
for (auto stage : IterateStages(this->GetStageMask())) {
|
||||
|
@ -354,7 +351,7 @@ MaybeError RenderPipeline::Initialize() {
|
|||
|
||||
ShaderModule::ModuleAndSpirv moduleAndSpirv;
|
||||
DAWN_TRY_ASSIGN(moduleAndSpirv,
|
||||
module->GetHandleAndSpirv(programmableStage.entryPoint.c_str(), layout));
|
||||
module->GetHandleAndSpirv(stage, programmableStage, layout));
|
||||
|
||||
shaderStage.module = moduleAndSpirv.module;
|
||||
shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
|
@ -379,11 +376,6 @@ MaybeError RenderPipeline::Initialize() {
|
|||
}
|
||||
}
|
||||
|
||||
shaderStage.pSpecializationInfo =
|
||||
GetVkSpecializationInfo(programmableStage, &specializationInfoPerStages[stageCount],
|
||||
&specializationDataEntriesPerStages[stageCount],
|
||||
&specializationMapEntriesPerStages[stageCount]);
|
||||
|
||||
DAWN_ASSERT(stageCount < 2);
|
||||
shaderStages[stageCount] = shaderStage;
|
||||
stageCount++;
|
||||
|
|
|
@ -60,13 +60,35 @@ class ShaderModule::Spirv : private Blob {
|
|||
|
||||
namespace dawn::native::vulkan {
|
||||
|
||||
bool TransformedShaderModuleCacheKey::operator==(
|
||||
const TransformedShaderModuleCacheKey& other) const {
|
||||
if (layout != other.layout || entryPoint != other.entryPoint ||
|
||||
constants.size() != other.constants.size()) {
|
||||
return false;
|
||||
}
|
||||
if (!std::equal(constants.begin(), constants.end(), other.constants.begin())) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t TransformedShaderModuleCacheKeyHashFunc::operator()(
|
||||
const TransformedShaderModuleCacheKey& key) const {
|
||||
size_t hash = 0;
|
||||
HashCombine(&hash, key.layout, key.entryPoint);
|
||||
for (const auto& entry : key.constants) {
|
||||
HashCombine(&hash, entry.first, entry.second);
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
class ShaderModule::ConcurrentTransformedShaderModuleCache {
|
||||
public:
|
||||
explicit ConcurrentTransformedShaderModuleCache(Device* device);
|
||||
~ConcurrentTransformedShaderModuleCache();
|
||||
|
||||
std::optional<ModuleAndSpirv> Find(const PipelineLayoutEntryPointPair& key);
|
||||
ModuleAndSpirv AddOrGet(const PipelineLayoutEntryPointPair& key,
|
||||
std::optional<ModuleAndSpirv> Find(const TransformedShaderModuleCacheKey& key);
|
||||
ModuleAndSpirv AddOrGet(const TransformedShaderModuleCacheKey& key,
|
||||
VkShaderModule module,
|
||||
Spirv&& spirv);
|
||||
|
||||
|
@ -75,7 +97,9 @@ class ShaderModule::ConcurrentTransformedShaderModuleCache {
|
|||
|
||||
Device* mDevice;
|
||||
std::mutex mMutex;
|
||||
std::unordered_map<PipelineLayoutEntryPointPair, Entry, PipelineLayoutEntryPointPairHashFunc>
|
||||
std::unordered_map<TransformedShaderModuleCacheKey,
|
||||
Entry,
|
||||
TransformedShaderModuleCacheKeyHashFunc>
|
||||
mTransformedShaderModuleCache;
|
||||
};
|
||||
|
||||
|
@ -92,7 +116,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::~ConcurrentTransformedShad
|
|||
|
||||
std::optional<ShaderModule::ModuleAndSpirv>
|
||||
ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
|
||||
const PipelineLayoutEntryPointPair& key) {
|
||||
const TransformedShaderModuleCacheKey& key) {
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
auto iter = mTransformedShaderModuleCache.find(key);
|
||||
if (iter != mTransformedShaderModuleCache.end()) {
|
||||
|
@ -106,7 +130,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
|
|||
}
|
||||
|
||||
ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet(
|
||||
const PipelineLayoutEntryPointPair& key,
|
||||
const TransformedShaderModuleCacheKey& key,
|
||||
VkShaderModule module,
|
||||
Spirv&& spirv) {
|
||||
ASSERT(module != VK_NULL_HANDLE);
|
||||
|
@ -168,20 +192,24 @@ void ShaderModule::DestroyImpl() {
|
|||
|
||||
ShaderModule::~ShaderModule() = default;
|
||||
|
||||
#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \
|
||||
X(const tint::Program*, inputProgram) \
|
||||
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
|
||||
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
|
||||
X(std::string_view, entryPointName) \
|
||||
X(bool, disableWorkgroupInit) \
|
||||
X(bool, useZeroInitializeWorkgroupMemoryExtension) \
|
||||
#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \
|
||||
X(SingleShaderStage, stage) \
|
||||
X(const tint::Program*, inputProgram) \
|
||||
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
|
||||
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
|
||||
X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
|
||||
X(LimitsForCompilationRequest, limits) \
|
||||
X(std::string_view, entryPointName) \
|
||||
X(bool, disableWorkgroupInit) \
|
||||
X(bool, useZeroInitializeWorkgroupMemoryExtension) \
|
||||
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
|
||||
|
||||
DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS);
|
||||
#undef SPIRV_COMPILATION_REQUEST_MEMBERS
|
||||
|
||||
ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
|
||||
const char* entryPointName,
|
||||
SingleShaderStage stage,
|
||||
const ProgrammableStage& programmableStage,
|
||||
const PipelineLayout* layout) {
|
||||
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv");
|
||||
|
||||
|
@ -191,7 +219,8 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
|
|||
ScopedTintICEHandler scopedICEHandler(GetDevice());
|
||||
|
||||
// Check to see if we have the handle and spirv cached already.
|
||||
auto cacheKey = std::make_pair(layout, entryPointName);
|
||||
auto cacheKey = TransformedShaderModuleCacheKey{layout, programmableStage.entryPoint.c_str(),
|
||||
programmableStage.constants};
|
||||
auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey);
|
||||
if (handleAndSpirv.has_value()) {
|
||||
return std::move(*handleAndSpirv);
|
||||
|
@ -204,7 +233,8 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
|
|||
using BindingPoint = tint::transform::BindingPoint;
|
||||
BindingRemapper::BindingPoints bindingPoints;
|
||||
|
||||
const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings;
|
||||
const BindingInfoArray& moduleBindingInfo =
|
||||
GetEntryPoint(programmableStage.entryPoint.c_str()).bindings;
|
||||
|
||||
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
|
||||
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
|
||||
|
@ -238,16 +268,26 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
|
|||
}
|
||||
}
|
||||
|
||||
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
|
||||
if (!programmableStage.metadata->overrides.empty()) {
|
||||
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
|
||||
}
|
||||
|
||||
#if TINT_BUILD_SPV_WRITER
|
||||
SpirvCompilationRequest req = {};
|
||||
req.stage = stage;
|
||||
req.inputProgram = GetTintProgram();
|
||||
req.bindingPoints = std::move(bindingPoints);
|
||||
req.newBindingsMap = std::move(newBindingsMap);
|
||||
req.entryPointName = entryPointName;
|
||||
req.entryPointName = programmableStage.entryPoint;
|
||||
req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
|
||||
req.useZeroInitializeWorkgroupMemoryExtension =
|
||||
GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension);
|
||||
req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform());
|
||||
req.substituteOverrideConfig = std::move(substituteOverrideConfig);
|
||||
|
||||
const CombinedLimits& limits = GetDevice()->GetLimits();
|
||||
req.limits = LimitsForCompilationRequest::Create(limits.v1);
|
||||
|
||||
CacheResult<Spirv> spirv;
|
||||
DAWN_TRY_LOAD_OR_RUN(
|
||||
|
@ -270,12 +310,27 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
|
|||
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
|
||||
r.newBindingsMap);
|
||||
}
|
||||
if (r.substituteOverrideConfig) {
|
||||
// This needs to run after SingleEntryPoint transform to get rid of overrides not
|
||||
// used for the current entry point.
|
||||
transformManager.Add<tint::transform::SubstituteOverride>();
|
||||
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
|
||||
std::move(r.substituteOverrideConfig).value());
|
||||
}
|
||||
tint::Program program;
|
||||
{
|
||||
TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms");
|
||||
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram,
|
||||
transformInputs, nullptr, nullptr));
|
||||
}
|
||||
|
||||
if (r.stage == SingleShaderStage::Compute) {
|
||||
// Validate workgroup size after program runs transforms.
|
||||
Extent3D _;
|
||||
DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
|
||||
program, r.entryPointName.data(), r.limits));
|
||||
}
|
||||
|
||||
tint::writer::spirv::Options options;
|
||||
options.emit_vertex_point_size = true;
|
||||
options.disable_workgroup_init = r.disableWorkgroupInit;
|
||||
|
|
|
@ -18,14 +18,32 @@
|
|||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "dawn/common/HashUtils.h"
|
||||
#include "dawn/common/vulkan_platform.h"
|
||||
#include "dawn/native/Error.h"
|
||||
#include "dawn/native/ShaderModule.h"
|
||||
|
||||
namespace dawn::native::vulkan {
|
||||
namespace dawn::native {
|
||||
|
||||
struct ProgrammableStage;
|
||||
|
||||
namespace vulkan {
|
||||
|
||||
struct TransformedShaderModuleCacheKey {
|
||||
const PipelineLayoutBase* layout;
|
||||
std::string entryPoint;
|
||||
PipelineConstantEntries constants;
|
||||
|
||||
bool operator==(const TransformedShaderModuleCacheKey& other) const;
|
||||
};
|
||||
|
||||
struct TransformedShaderModuleCacheKeyHashFunc {
|
||||
size_t operator()(const TransformedShaderModuleCacheKey& key) const;
|
||||
};
|
||||
|
||||
class Device;
|
||||
class PipelineLayout;
|
||||
|
@ -44,7 +62,8 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages);
|
||||
|
||||
ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const char* entryPointName,
|
||||
ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(SingleShaderStage stage,
|
||||
const ProgrammableStage& programmableStage,
|
||||
const PipelineLayout* layout);
|
||||
|
||||
private:
|
||||
|
@ -59,6 +78,8 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
std::unique_ptr<ConcurrentTransformedShaderModuleCache> mTransformedShaderModuleCache;
|
||||
};
|
||||
|
||||
} // namespace dawn::native::vulkan
|
||||
} // namespace vulkan
|
||||
|
||||
} // namespace dawn::native
|
||||
|
||||
#endif // SRC_DAWN_NATIVE_VULKAN_SHADERMODULEVK_H_
|
||||
|
|
|
@ -258,60 +258,4 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName) {
|
|||
return std::string(debugName, length);
|
||||
}
|
||||
|
||||
VkSpecializationInfo* GetVkSpecializationInfo(
|
||||
const ProgrammableStage& programmableStage,
|
||||
VkSpecializationInfo* specializationInfo,
|
||||
std::vector<OverrideScalar>* specializationDataEntries,
|
||||
std::vector<VkSpecializationMapEntry>* specializationMapEntries) {
|
||||
ASSERT(specializationInfo);
|
||||
ASSERT(specializationDataEntries);
|
||||
ASSERT(specializationMapEntries);
|
||||
|
||||
if (programmableStage.constants.size() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const EntryPointMetadata& entryPointMetaData =
|
||||
programmableStage.module->GetEntryPoint(programmableStage.entryPoint);
|
||||
|
||||
for (const auto& pipelineConstant : programmableStage.constants) {
|
||||
const std::string& identifier = pipelineConstant.first;
|
||||
double value = pipelineConstant.second;
|
||||
|
||||
// This is already validated so `identifier` must exist
|
||||
const auto& moduleConstant = entryPointMetaData.overrides.at(identifier);
|
||||
|
||||
specializationMapEntries->push_back(VkSpecializationMapEntry{
|
||||
moduleConstant.id,
|
||||
static_cast<uint32_t>(specializationDataEntries->size() * sizeof(OverrideScalar)),
|
||||
sizeof(OverrideScalar)});
|
||||
|
||||
OverrideScalar entry{};
|
||||
switch (moduleConstant.type) {
|
||||
case EntryPointMetadata::Override::Type::Boolean:
|
||||
entry.b = static_cast<int32_t>(value);
|
||||
break;
|
||||
case EntryPointMetadata::Override::Type::Float32:
|
||||
entry.f32 = static_cast<float>(value);
|
||||
break;
|
||||
case EntryPointMetadata::Override::Type::Int32:
|
||||
entry.i32 = static_cast<int32_t>(value);
|
||||
break;
|
||||
case EntryPointMetadata::Override::Type::Uint32:
|
||||
entry.u32 = static_cast<uint32_t>(value);
|
||||
break;
|
||||
default:
|
||||
UNREACHABLE();
|
||||
}
|
||||
specializationDataEntries->push_back(entry);
|
||||
}
|
||||
|
||||
specializationInfo->mapEntryCount = static_cast<uint32_t>(specializationMapEntries->size());
|
||||
specializationInfo->pMapEntries = specializationMapEntries->data();
|
||||
specializationInfo->dataSize = specializationDataEntries->size() * sizeof(OverrideScalar);
|
||||
specializationInfo->pData = specializationDataEntries->data();
|
||||
|
||||
return specializationInfo;
|
||||
}
|
||||
|
||||
} // namespace dawn::native::vulkan
|
||||
|
|
|
@ -144,15 +144,6 @@ void SetDebugName(Device* device,
|
|||
std::string GetNextDeviceDebugPrefix();
|
||||
std::string GetDeviceDebugPrefixFromDebugName(const char* debugName);
|
||||
|
||||
// Returns nullptr or &specializationInfo
|
||||
// specializationInfo, specializationDataEntries, specializationMapEntries needs to
|
||||
// be alive at least until VkSpecializationInfo is passed into Vulkan Create*Pipelines
|
||||
VkSpecializationInfo* GetVkSpecializationInfo(
|
||||
const ProgrammableStage& programmableStage,
|
||||
VkSpecializationInfo* specializationInfo,
|
||||
std::vector<OverrideScalar>* specializationDataEntries,
|
||||
std::vector<VkSpecializationMapEntry>* specializationMapEntries);
|
||||
|
||||
} // namespace dawn::native::vulkan
|
||||
|
||||
#endif // SRC_DAWN_NATIVE_VULKAN_UTILSVULKAN_H_
|
||||
|
|
|
@ -265,6 +265,7 @@ dawn_test("dawn_unittests") {
|
|||
"unittests/native/CreatePipelineAsyncTaskTests.cpp",
|
||||
"unittests/native/DestroyObjectTests.cpp",
|
||||
"unittests/native/DeviceCreationTests.cpp",
|
||||
"unittests/native/ObjectContentHasherTests.cpp",
|
||||
"unittests/native/StreamTests.cpp",
|
||||
"unittests/validation/BindGroupValidationTests.cpp",
|
||||
"unittests/validation/BufferValidationTests.cpp",
|
||||
|
@ -489,6 +490,7 @@ source_set("end2end_tests_sources") {
|
|||
"end2end/ScissorTests.cpp",
|
||||
"end2end/ShaderFloat16Tests.cpp",
|
||||
"end2end/ShaderTests.cpp",
|
||||
"end2end/ShaderValidationTests.cpp",
|
||||
"end2end/StorageTextureTests.cpp",
|
||||
"end2end/SubresourceRenderAttachmentTests.cpp",
|
||||
"end2end/Texture3DTests.cpp",
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "dawn/tests/DawnTest.h"
|
||||
|
||||
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
|
||||
|
@ -158,6 +160,46 @@ TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) {
|
|||
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||
}
|
||||
|
||||
// Test that ComputePipeline are correctly deduplicated wrt. their constants override values
|
||||
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnOverrides) {
|
||||
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
|
||||
override x: u32 = 1u;
|
||||
var<workgroup> i : u32;
|
||||
@compute @workgroup_size(x) fn main() {
|
||||
i = 0u;
|
||||
})");
|
||||
|
||||
wgpu::PipelineLayout layout = utils::MakeBasicPipelineLayout(device, nullptr);
|
||||
|
||||
wgpu::ComputePipelineDescriptor desc;
|
||||
desc.compute.entryPoint = "main";
|
||||
desc.layout = layout;
|
||||
desc.compute.module = module;
|
||||
|
||||
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 16}};
|
||||
desc.compute.constantCount = constants.size();
|
||||
desc.compute.constants = constants.data();
|
||||
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
|
||||
|
||||
std::vector<wgpu::ConstantEntry> sameConstants{{nullptr, "x", 16}};
|
||||
desc.compute.constantCount = sameConstants.size();
|
||||
desc.compute.constants = sameConstants.data();
|
||||
wgpu::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
|
||||
|
||||
desc.compute.constantCount = 0;
|
||||
desc.compute.constants = nullptr;
|
||||
wgpu::ComputePipeline otherPipeline1 = device.CreateComputePipeline(&desc);
|
||||
|
||||
std::vector<wgpu::ConstantEntry> otherConstants{{nullptr, "x", 4}};
|
||||
desc.compute.constantCount = otherConstants.size();
|
||||
desc.compute.constants = otherConstants.data();
|
||||
wgpu::ComputePipeline otherPipeline2 = device.CreateComputePipeline(&desc);
|
||||
|
||||
EXPECT_NE(pipeline.Get(), otherPipeline1.Get());
|
||||
EXPECT_NE(pipeline.Get(), otherPipeline2.Get());
|
||||
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||
}
|
||||
|
||||
// Test that ComputePipeline are correctly deduplicated wrt. their layout
|
||||
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) {
|
||||
wgpu::BindGroupLayout bgl = utils::MakeBindGroupLayout(
|
||||
|
@ -303,6 +345,48 @@ TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnFragmentModule) {
|
|||
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||
}
|
||||
|
||||
// Test that Renderpipelines are correctly deduplicated wrt. their constants override values
|
||||
TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnOverrides) {
|
||||
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
|
||||
override a: f32 = 1.0;
|
||||
@vertex fn vertexMain() -> @builtin(position) vec4<f32> {
|
||||
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
}
|
||||
@fragment fn fragmentMain() -> @location(0) vec4<f32> {
|
||||
return vec4<f32>(0.0, 0.0, 0.0, a);
|
||||
})");
|
||||
|
||||
utils::ComboRenderPipelineDescriptor desc;
|
||||
desc.vertex.module = module;
|
||||
desc.vertex.entryPoint = "vertexMain";
|
||||
desc.cFragment.module = module;
|
||||
desc.cFragment.entryPoint = "fragmentMain";
|
||||
desc.cTargets[0].writeMask = wgpu::ColorWriteMask::None;
|
||||
|
||||
std::vector<wgpu::ConstantEntry> constants{{nullptr, "a", 0.5}};
|
||||
desc.cFragment.constantCount = constants.size();
|
||||
desc.cFragment.constants = constants.data();
|
||||
wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&desc);
|
||||
|
||||
std::vector<wgpu::ConstantEntry> sameConstants{{nullptr, "a", 0.5}};
|
||||
desc.cFragment.constantCount = sameConstants.size();
|
||||
desc.cFragment.constants = sameConstants.data();
|
||||
wgpu::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc);
|
||||
|
||||
std::vector<wgpu::ConstantEntry> otherConstants{{nullptr, "a", 1.0}};
|
||||
desc.cFragment.constantCount = otherConstants.size();
|
||||
desc.cFragment.constants = otherConstants.data();
|
||||
wgpu::RenderPipeline otherPipeline1 = device.CreateRenderPipeline(&desc);
|
||||
|
||||
desc.cFragment.constantCount = 0;
|
||||
desc.cFragment.constants = nullptr;
|
||||
wgpu::RenderPipeline otherPipeline2 = device.CreateRenderPipeline(&desc);
|
||||
|
||||
EXPECT_NE(pipeline.Get(), otherPipeline1.Get());
|
||||
EXPECT_NE(pipeline.Get(), otherPipeline2.Get());
|
||||
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||
}
|
||||
|
||||
// Test that Samplers are correctly deduplicated.
|
||||
TEST_P(ObjectCachingTest, SamplerDeduplication) {
|
||||
wgpu::SamplerDescriptor samplerDesc;
|
||||
|
|
|
@ -471,6 +471,127 @@ struct Buf {
|
|||
EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount);
|
||||
}
|
||||
|
||||
// Test one shader shared by two pipelines with different constants overridden
|
||||
TEST_P(ShaderTests, OverridableConstantsSharedShader) {
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||
|
||||
std::vector<uint32_t> expected1{1};
|
||||
wgpu::Buffer buffer1 = CreateBuffer(expected1.size());
|
||||
std::vector<uint32_t> expected2{2};
|
||||
wgpu::Buffer buffer2 = CreateBuffer(expected2.size());
|
||||
|
||||
std::string shader = R"(
|
||||
override a: u32;
|
||||
|
||||
struct Buf {
|
||||
data : array<u32, 1>
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> buf : Buf;
|
||||
|
||||
@compute @workgroup_size(1) fn main() {
|
||||
buf.data[0] = a;
|
||||
})";
|
||||
|
||||
std::vector<wgpu::ConstantEntry> constants1;
|
||||
constants1.push_back({nullptr, "a", 1});
|
||||
std::vector<wgpu::ConstantEntry> constants2;
|
||||
constants2.push_back({nullptr, "a", 2});
|
||||
|
||||
wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1);
|
||||
wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2);
|
||||
|
||||
wgpu::BindGroup bindGroup1 =
|
||||
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
|
||||
wgpu::BindGroup bindGroup2 =
|
||||
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
|
||||
|
||||
wgpu::CommandBuffer commands;
|
||||
{
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
pass.SetPipeline(pipeline1);
|
||||
pass.SetBindGroup(0, bindGroup1);
|
||||
pass.DispatchWorkgroups(1);
|
||||
pass.SetPipeline(pipeline2);
|
||||
pass.SetBindGroup(0, bindGroup2);
|
||||
pass.DispatchWorkgroups(1);
|
||||
pass.End();
|
||||
|
||||
commands = encoder.Finish();
|
||||
}
|
||||
|
||||
queue.Submit(1, &commands);
|
||||
|
||||
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size());
|
||||
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size());
|
||||
}
|
||||
|
||||
// Test overridable constants work with workgroup size
|
||||
TEST_P(ShaderTests, OverridableConstantsWorkgroupSize) {
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||
|
||||
std::string shader = R"(
|
||||
override x: u32;
|
||||
|
||||
struct Buf {
|
||||
data : array<u32, 1>
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> buf : Buf;
|
||||
|
||||
@compute @workgroup_size(x) fn main(
|
||||
@builtin(local_invocation_id) local_invocation_id : vec3<u32>
|
||||
) {
|
||||
if (local_invocation_id.x >= x - 1) {
|
||||
buf.data[0] = local_invocation_id.x + 1;
|
||||
}
|
||||
})";
|
||||
|
||||
const uint32_t workgroup_size_x_1 = 16u;
|
||||
const uint32_t workgroup_size_x_2 = 64u;
|
||||
|
||||
std::vector<uint32_t> expected1{workgroup_size_x_1};
|
||||
wgpu::Buffer buffer1 = CreateBuffer(expected1.size());
|
||||
std::vector<uint32_t> expected2{workgroup_size_x_2};
|
||||
wgpu::Buffer buffer2 = CreateBuffer(expected2.size());
|
||||
|
||||
std::vector<wgpu::ConstantEntry> constants1;
|
||||
constants1.push_back({nullptr, "x", static_cast<double>(workgroup_size_x_1)});
|
||||
std::vector<wgpu::ConstantEntry> constants2;
|
||||
constants2.push_back({nullptr, "x", static_cast<double>(workgroup_size_x_2)});
|
||||
|
||||
wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1);
|
||||
wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2);
|
||||
|
||||
wgpu::BindGroup bindGroup1 =
|
||||
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
|
||||
wgpu::BindGroup bindGroup2 =
|
||||
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
|
||||
|
||||
wgpu::CommandBuffer commands;
|
||||
{
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
pass.SetPipeline(pipeline1);
|
||||
pass.SetBindGroup(0, bindGroup1);
|
||||
pass.DispatchWorkgroups(1);
|
||||
pass.SetPipeline(pipeline2);
|
||||
pass.SetBindGroup(0, bindGroup2);
|
||||
pass.DispatchWorkgroups(1);
|
||||
pass.End();
|
||||
|
||||
commands = encoder.Finish();
|
||||
}
|
||||
|
||||
queue.Submit(1, &commands);
|
||||
|
||||
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size());
|
||||
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size());
|
||||
}
|
||||
|
||||
// Test overridable constants with numeric identifiers
|
||||
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||
|
@ -596,6 +717,7 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
|||
std::string shader = R"(
|
||||
@id(1001) override c1: u32;
|
||||
@id(1002) override c2: u32;
|
||||
@id(1003) override c3: u32;
|
||||
|
||||
struct Buf {
|
||||
data : array<u32, 1>
|
||||
|
@ -611,7 +733,7 @@ struct Buf {
|
|||
buf.data[0] = c2;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(1) fn main3() {
|
||||
@compute @workgroup_size(c3) fn main3() {
|
||||
buf.data[0] = 3u;
|
||||
}
|
||||
)";
|
||||
|
@ -620,6 +742,8 @@ struct Buf {
|
|||
constants1.push_back({nullptr, "1001", 1});
|
||||
std::vector<wgpu::ConstantEntry> constants2;
|
||||
constants2.push_back({nullptr, "1002", 2});
|
||||
std::vector<wgpu::ConstantEntry> constants3;
|
||||
constants3.push_back({nullptr, "1003", 1});
|
||||
|
||||
wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader.c_str());
|
||||
|
||||
|
@ -640,6 +764,8 @@ struct Buf {
|
|||
wgpu::ComputePipelineDescriptor csDesc3;
|
||||
csDesc3.compute.module = shaderModule;
|
||||
csDesc3.compute.entryPoint = "main3";
|
||||
csDesc3.compute.constants = constants3.data();
|
||||
csDesc3.compute.constantCount = constants3.size();
|
||||
wgpu::ComputePipeline pipeline3 = device.CreateComputePipeline(&csDesc3);
|
||||
|
||||
wgpu::BindGroup bindGroup1 =
|
||||
|
@ -765,8 +891,6 @@ TEST_P(ShaderTests, ConflictingBindingsDueToTransformOrder) {
|
|||
device.CreateRenderPipeline(&desc);
|
||||
}
|
||||
|
||||
// TODO(tint:1155): Test overridable constants used for workgroup size
|
||||
|
||||
DAWN_INSTANTIATE_TEST(ShaderTests,
|
||||
D3D12Backend(),
|
||||
MetalBackend(),
|
||||
|
|
|
@ -0,0 +1,395 @@
|
|||
// Copyright 2022 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "dawn/tests/DawnTest.h"
|
||||
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
|
||||
#include "dawn/utils/WGPUHelpers.h"
|
||||
|
||||
// The compute shader workgroup size is settled at compute pipeline creation time.
|
||||
// The validation code in dawn is in each backend (not including Null backend) thus this test needs
|
||||
// to be as part of a dawn_end2end_tests instead of the dawn_unittests
|
||||
// TODO(dawn:1504): Add support for GL backend.
|
||||
class WorkgroupSizeValidationTest : public DawnTest {
|
||||
public:
|
||||
wgpu::ShaderModule SetUpShaderWithValidDefaultValueConstants() {
|
||||
return utils::CreateShaderModule(device, R"(
|
||||
override x: u32 = 1u;
|
||||
override y: u32 = 1u;
|
||||
override z: u32 = 1u;
|
||||
|
||||
@compute @workgroup_size(x, y, z) fn main() {
|
||||
_ = 0u;
|
||||
})");
|
||||
}
|
||||
|
||||
wgpu::ShaderModule SetUpShaderWithZeroDefaultValueConstants() {
|
||||
return utils::CreateShaderModule(device, R"(
|
||||
override x: u32 = 0u;
|
||||
override y: u32 = 0u;
|
||||
override z: u32 = 0u;
|
||||
|
||||
@compute @workgroup_size(x, y, z) fn main() {
|
||||
_ = 0u;
|
||||
})");
|
||||
}
|
||||
|
||||
wgpu::ShaderModule SetUpShaderWithOutOfLimitsDefaultValueConstants() {
|
||||
return utils::CreateShaderModule(device, R"(
|
||||
override x: u32 = 1u;
|
||||
override y: u32 = 1u;
|
||||
override z: u32 = 9999u;
|
||||
|
||||
@compute @workgroup_size(x, y, z) fn main() {
|
||||
_ = 0u;
|
||||
})");
|
||||
}
|
||||
|
||||
wgpu::ShaderModule SetUpShaderWithUninitializedConstants() {
|
||||
return utils::CreateShaderModule(device, R"(
|
||||
override x: u32;
|
||||
override y: u32;
|
||||
override z: u32;
|
||||
|
||||
@compute @workgroup_size(x, y, z) fn main() {
|
||||
_ = 0u;
|
||||
})");
|
||||
}
|
||||
|
||||
wgpu::ShaderModule SetUpShaderWithPartialConstants() {
|
||||
return utils::CreateShaderModule(device, R"(
|
||||
override x: u32;
|
||||
|
||||
@compute @workgroup_size(x, 1, 1) fn main() {
|
||||
_ = 0u;
|
||||
})");
|
||||
}
|
||||
|
||||
void TestCreatePipeline(const wgpu::ShaderModule& module) {
|
||||
wgpu::ComputePipelineDescriptor csDesc;
|
||||
csDesc.compute.module = module;
|
||||
csDesc.compute.entryPoint = "main";
|
||||
csDesc.compute.constants = nullptr;
|
||||
csDesc.compute.constantCount = 0;
|
||||
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
|
||||
}
|
||||
|
||||
void TestCreatePipeline(const wgpu::ShaderModule& module,
|
||||
const std::vector<wgpu::ConstantEntry>& constants) {
|
||||
wgpu::ComputePipelineDescriptor csDesc;
|
||||
csDesc.compute.module = module;
|
||||
csDesc.compute.entryPoint = "main";
|
||||
csDesc.compute.constants = constants.data();
|
||||
csDesc.compute.constantCount = constants.size();
|
||||
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
|
||||
}
|
||||
|
||||
void TestInitializedWithZero(const wgpu::ShaderModule& module) {
|
||||
std::vector<wgpu::ConstantEntry> constants{
|
||||
{nullptr, "x", 0}, {nullptr, "y", 0}, {nullptr, "z", 0}};
|
||||
TestCreatePipeline(module, constants);
|
||||
}
|
||||
|
||||
void TestInitializedWithOutOfLimitValue(const wgpu::ShaderModule& module) {
|
||||
std::vector<wgpu::ConstantEntry> constants{
|
||||
{nullptr, "x",
|
||||
static_cast<double>(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)},
|
||||
{nullptr, "y", 1},
|
||||
{nullptr, "z", 1}};
|
||||
TestCreatePipeline(module, constants);
|
||||
}
|
||||
|
||||
void TestInitializedWithValidValue(const wgpu::ShaderModule& module) {
|
||||
std::vector<wgpu::ConstantEntry> constants{
|
||||
{nullptr, "x", 4}, {nullptr, "y", 4}, {nullptr, "z", 4}};
|
||||
TestCreatePipeline(module, constants);
|
||||
}
|
||||
|
||||
void TestInitializedPartially(const wgpu::ShaderModule& module) {
|
||||
std::vector<wgpu::ConstantEntry> constants{{nullptr, "y", 4}};
|
||||
TestCreatePipeline(module, constants);
|
||||
}
|
||||
|
||||
wgpu::Buffer buffer;
|
||||
};
|
||||
|
||||
// Test workgroup size validation with fixed values.
|
||||
TEST_P(WorkgroupSizeValidationTest, WithFixedValues) {
|
||||
auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) {
|
||||
std::ostringstream ss;
|
||||
ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}";
|
||||
|
||||
wgpu::ComputePipelineDescriptor desc;
|
||||
desc.compute.entryPoint = "main";
|
||||
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
|
||||
|
||||
if (success) {
|
||||
device.CreateComputePipeline(&desc);
|
||||
} else {
|
||||
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
|
||||
}
|
||||
};
|
||||
|
||||
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
||||
|
||||
CheckShaderWithWorkgroupSize(true, 1, 1, 1);
|
||||
CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1);
|
||||
CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1);
|
||||
CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ);
|
||||
|
||||
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1);
|
||||
CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1);
|
||||
CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1);
|
||||
|
||||
// No individual dimension exceeds its limit, but the combined size should definitely exceed the
|
||||
// total invocation limit.
|
||||
DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX *
|
||||
supportedLimits.maxComputeWorkgroupSizeY *
|
||||
supportedLimits.maxComputeWorkgroupSizeZ >
|
||||
supportedLimits.maxComputeInvocationsPerWorkgroup);
|
||||
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX,
|
||||
supportedLimits.maxComputeWorkgroupSizeY,
|
||||
supportedLimits.maxComputeWorkgroupSizeZ);
|
||||
}
|
||||
|
||||
// Test workgroup size validation with fixed values (storage size limits validation).
|
||||
TEST_P(WorkgroupSizeValidationTest, WithFixedValuesStorageSizeLimits) {
|
||||
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
||||
|
||||
constexpr uint32_t kVec4Size = 16;
|
||||
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
|
||||
constexpr uint32_t kMat4Size = 64;
|
||||
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
|
||||
|
||||
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
|
||||
uint32_t mat4_count) {
|
||||
std::ostringstream ss;
|
||||
std::ostringstream body;
|
||||
if (vec4_count > 0) {
|
||||
ss << "var<workgroup> vec4_data: array<vec4<f32>, " << vec4_count << ">;";
|
||||
body << "_ = vec4_data;";
|
||||
}
|
||||
if (mat4_count > 0) {
|
||||
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, " << mat4_count << ">;";
|
||||
body << "_ = mat4_data;";
|
||||
}
|
||||
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
|
||||
|
||||
wgpu::ComputePipelineDescriptor desc;
|
||||
desc.compute.entryPoint = "main";
|
||||
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
|
||||
|
||||
if (success) {
|
||||
device.CreateComputePipeline(&desc);
|
||||
} else {
|
||||
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
|
||||
}
|
||||
};
|
||||
|
||||
CheckPipelineWithWorkgroupStorage(true, 1, 1);
|
||||
CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0);
|
||||
CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count);
|
||||
CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1);
|
||||
CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1);
|
||||
|
||||
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
|
||||
CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1);
|
||||
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
|
||||
CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count);
|
||||
}
|
||||
|
||||
// Test workgroup size validation with valid overrides default values.
|
||||
TEST_P(WorkgroupSizeValidationTest, OverridesWithValidDefault) {
|
||||
wgpu::ShaderModule module = SetUpShaderWithValidDefaultValueConstants();
|
||||
{
|
||||
// Valid default
|
||||
TestCreatePipeline(module);
|
||||
}
|
||||
{
|
||||
// Error: invalid value (zero)
|
||||
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
|
||||
}
|
||||
{
|
||||
// Error: invalid value (out of device limits)
|
||||
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
|
||||
}
|
||||
{
|
||||
// Valid: initialized partially
|
||||
TestInitializedPartially(module);
|
||||
}
|
||||
{
|
||||
// Valid
|
||||
TestInitializedWithValidValue(module);
|
||||
}
|
||||
}
|
||||
|
||||
// Test workgroup size validation with zero as the overrides default values.
|
||||
TEST_P(WorkgroupSizeValidationTest, OverridesWithZeroDefault) {
|
||||
// Error: zero is detected as invalid at shader creation time
|
||||
ASSERT_DEVICE_ERROR(SetUpShaderWithZeroDefaultValueConstants());
|
||||
}
|
||||
|
||||
// Test workgroup size validation with out-of-limits overrides default values.
|
||||
TEST_P(WorkgroupSizeValidationTest, OverridesWithOutOfLimitsDefault) {
|
||||
wgpu::ShaderModule module = SetUpShaderWithOutOfLimitsDefaultValueConstants();
|
||||
{
|
||||
// Error: invalid default
|
||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
|
||||
}
|
||||
{
|
||||
// Error: invalid value (zero)
|
||||
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
|
||||
}
|
||||
{
|
||||
// Error: invalid value (out of device limits)
|
||||
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
|
||||
}
|
||||
{
|
||||
// Error: initialized partially
|
||||
ASSERT_DEVICE_ERROR(TestInitializedPartially(module));
|
||||
}
|
||||
{
|
||||
// Valid
|
||||
TestInitializedWithValidValue(module);
|
||||
}
|
||||
}
|
||||
|
||||
// Test workgroup size validation without overrides default values specified.
|
||||
TEST_P(WorkgroupSizeValidationTest, OverridesWithUninitialized) {
|
||||
wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants();
|
||||
{
|
||||
// Error: uninitialized
|
||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
|
||||
}
|
||||
{
|
||||
// Error: invalid value (zero)
|
||||
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
|
||||
}
|
||||
{
|
||||
// Error: invalid value (out of device limits)
|
||||
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
|
||||
}
|
||||
{
|
||||
// Error: initialized partially
|
||||
ASSERT_DEVICE_ERROR(TestInitializedPartially(module));
|
||||
}
|
||||
{
|
||||
// Valid
|
||||
TestInitializedWithValidValue(module);
|
||||
}
|
||||
}
|
||||
|
||||
// Test workgroup size validation with only partial dimensions are overrides.
|
||||
TEST_P(WorkgroupSizeValidationTest, PartialOverrides) {
|
||||
wgpu::ShaderModule module = SetUpShaderWithPartialConstants();
|
||||
{
|
||||
// Error: uninitialized
|
||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
|
||||
}
|
||||
{
|
||||
// Error: invalid value (zero)
|
||||
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 0}};
|
||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
|
||||
}
|
||||
{
|
||||
// Error: invalid value (out of device limits)
|
||||
std::vector<wgpu::ConstantEntry> constants{
|
||||
{nullptr, "x",
|
||||
static_cast<double>(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)}};
|
||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
|
||||
}
|
||||
{
|
||||
// Valid
|
||||
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 16}};
|
||||
TestCreatePipeline(module, constants);
|
||||
}
|
||||
}
|
||||
|
||||
// Test workgroup size validation after being overrided with invalid values.
|
||||
TEST_P(WorkgroupSizeValidationTest, ValidationAfterOverride) {
|
||||
wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants();
|
||||
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
||||
{
|
||||
// Error: exceed maxComputeWorkgroupSizeZ
|
||||
std::vector<wgpu::ConstantEntry> constants{
|
||||
{nullptr, "x", 1},
|
||||
{nullptr, "y", 1},
|
||||
{nullptr, "z", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeZ + 1)}};
|
||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
|
||||
}
|
||||
{
|
||||
// Error: exceed maxComputeInvocationsPerWorkgroup
|
||||
DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX *
|
||||
supportedLimits.maxComputeWorkgroupSizeY *
|
||||
supportedLimits.maxComputeWorkgroupSizeZ >
|
||||
supportedLimits.maxComputeInvocationsPerWorkgroup);
|
||||
std::vector<wgpu::ConstantEntry> constants{
|
||||
{nullptr, "x", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeX)},
|
||||
{nullptr, "y", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeY)},
|
||||
{nullptr, "z", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeZ)}};
|
||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
|
||||
}
|
||||
}
|
||||
|
||||
// Test workgroup size validation after being overrided with invalid values (storage size limits
|
||||
// validation).
|
||||
// TODO(tint:1660): re-enable after override can be used as array size.
|
||||
TEST_P(WorkgroupSizeValidationTest, DISABLED_ValidationAfterOverrideStorageSize) {
|
||||
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
||||
|
||||
constexpr uint32_t kVec4Size = 16;
|
||||
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
|
||||
constexpr uint32_t kMat4Size = 64;
|
||||
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
|
||||
|
||||
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
|
||||
uint32_t mat4_count) {
|
||||
std::ostringstream ss;
|
||||
std::ostringstream body;
|
||||
ss << "override a: u32;";
|
||||
ss << "override b: u32;";
|
||||
if (vec4_count > 0) {
|
||||
ss << "var<workgroup> vec4_data: array<vec4<f32>, a>;";
|
||||
body << "_ = vec4_data;";
|
||||
}
|
||||
if (mat4_count > 0) {
|
||||
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, b>;";
|
||||
body << "_ = mat4_data;";
|
||||
}
|
||||
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
|
||||
|
||||
wgpu::ComputePipelineDescriptor desc;
|
||||
desc.compute.entryPoint = "main";
|
||||
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
|
||||
|
||||
std::vector<wgpu::ConstantEntry> constants{{nullptr, "a", static_cast<double>(vec4_count)},
|
||||
{nullptr, "b", static_cast<double>(mat4_count)}};
|
||||
desc.compute.constants = constants.data();
|
||||
desc.compute.constantCount = constants.size();
|
||||
|
||||
if (success) {
|
||||
device.CreateComputePipeline(&desc);
|
||||
} else {
|
||||
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
|
||||
}
|
||||
};
|
||||
|
||||
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
|
||||
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
|
||||
}
|
||||
|
||||
DAWN_INSTANTIATE_TEST(WorkgroupSizeValidationTest, D3D12Backend(), MetalBackend(), VulkanBackend());
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright 2022 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "dawn/native/ObjectContentHasher.h"
|
||||
#include "dawn/tests/DawnNativeTest.h"
|
||||
|
||||
namespace dawn::native {
|
||||
|
||||
namespace {
|
||||
|
||||
class ObjectContentHasherTests : public DawnNativeTest {};
|
||||
|
||||
#define EXPECT_IF_HASH_EQ(eq, a, b) \
|
||||
do { \
|
||||
ObjectContentHasher ra, rb; \
|
||||
ra.Record(a); \
|
||||
rb.Record(b); \
|
||||
EXPECT_EQ(eq, ra.GetContentHash() == rb.GetContentHash()); \
|
||||
} while (0)
|
||||
|
||||
TEST(ObjectContentHasherTests, Pair) {
|
||||
EXPECT_IF_HASH_EQ(true, (std::pair<std::string, uint8_t>{"a", 1}),
|
||||
(std::pair<std::string, uint8_t>{"a", 1}));
|
||||
EXPECT_IF_HASH_EQ(false, (std::pair<uint8_t, std::string>{1, "a"}),
|
||||
(std::pair<std::string, uint8_t>{"a", 1}));
|
||||
EXPECT_IF_HASH_EQ(false, (std::pair<std::string, uint8_t>{"a", 1}),
|
||||
(std::pair<std::string, uint8_t>{"a", 2}));
|
||||
EXPECT_IF_HASH_EQ(false, (std::pair<std::string, uint8_t>{"a", 1}),
|
||||
(std::pair<std::string, uint8_t>{"b", 1}));
|
||||
}
|
||||
|
||||
TEST(ObjectContentHasherTests, Vector) {
|
||||
EXPECT_IF_HASH_EQ(true, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{0, 1}));
|
||||
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{0, 1, 2}));
|
||||
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{1, 0}));
|
||||
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{}));
|
||||
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<float>{0, 1}));
|
||||
}
|
||||
|
||||
TEST(ObjectContentHasherTests, Map) {
|
||||
EXPECT_IF_HASH_EQ(true, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
|
||||
(std::map<std::string, uint8_t>{{"b", 2}, {"a", 1}}));
|
||||
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
|
||||
(std::map<std::string, uint8_t>{{"a", 2}, {"b", 1}}));
|
||||
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
|
||||
(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}, {"c", 1}}));
|
||||
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
|
||||
(std::map<std::string, uint8_t>{}));
|
||||
}
|
||||
|
||||
TEST(ObjectContentHasherTests, HashCombine) {
|
||||
ObjectContentHasher ra, rb;
|
||||
|
||||
ra.Record(std::vector<uint8_t>{0, 1});
|
||||
ra.Record(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}});
|
||||
|
||||
rb.Record(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}});
|
||||
rb.Record(std::vector<uint8_t>{0, 1});
|
||||
|
||||
EXPECT_NE(ra.GetContentHash(), rb.GetContentHash());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace dawn::native
|
|
@ -450,87 +450,6 @@ TEST_F(ShaderModuleValidationTest, MaximumInterStageShaderComponents) {
|
|||
}
|
||||
}
|
||||
|
||||
// Tests that we validate workgroup size limits.
|
||||
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupSizeLimits) {
|
||||
auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) {
|
||||
std::ostringstream ss;
|
||||
ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}";
|
||||
|
||||
wgpu::ComputePipelineDescriptor desc;
|
||||
desc.compute.entryPoint = "main";
|
||||
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
|
||||
|
||||
if (success) {
|
||||
device.CreateComputePipeline(&desc);
|
||||
} else {
|
||||
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
|
||||
}
|
||||
};
|
||||
|
||||
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
||||
|
||||
CheckShaderWithWorkgroupSize(true, 1, 1, 1);
|
||||
CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1);
|
||||
CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1);
|
||||
CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ);
|
||||
|
||||
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1);
|
||||
CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1);
|
||||
CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1);
|
||||
|
||||
// No individual dimension exceeds its limit, but the combined size should definitely exceed the
|
||||
// total invocation limit.
|
||||
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX,
|
||||
supportedLimits.maxComputeWorkgroupSizeY,
|
||||
supportedLimits.maxComputeWorkgroupSizeZ);
|
||||
}
|
||||
|
||||
// Tests that we validate workgroup storage size limits.
|
||||
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupStorageSizeLimits) {
|
||||
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
|
||||
|
||||
constexpr uint32_t kVec4Size = 16;
|
||||
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
|
||||
constexpr uint32_t kMat4Size = 64;
|
||||
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
|
||||
|
||||
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
|
||||
uint32_t mat4_count) {
|
||||
std::ostringstream ss;
|
||||
std::ostringstream body;
|
||||
if (vec4_count > 0) {
|
||||
ss << "var<workgroup> vec4_data: array<vec4<f32>, " << vec4_count << ">;";
|
||||
body << "_ = vec4_data;";
|
||||
}
|
||||
if (mat4_count > 0) {
|
||||
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, " << mat4_count << ">;";
|
||||
body << "_ = mat4_data;";
|
||||
}
|
||||
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
|
||||
|
||||
wgpu::ComputePipelineDescriptor desc;
|
||||
desc.compute.entryPoint = "main";
|
||||
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
|
||||
|
||||
if (success) {
|
||||
device.CreateComputePipeline(&desc);
|
||||
} else {
|
||||
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
|
||||
}
|
||||
};
|
||||
|
||||
CheckPipelineWithWorkgroupStorage(true, 1, 1);
|
||||
CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0);
|
||||
CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count);
|
||||
CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1);
|
||||
CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1);
|
||||
|
||||
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
|
||||
CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1);
|
||||
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
|
||||
CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count);
|
||||
}
|
||||
|
||||
// Test that numeric ID must be unique
|
||||
TEST_F(ShaderModuleValidationTest, OverridableConstantsNumericIDConflicts) {
|
||||
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(
|
||||
|
|
|
@ -37,46 +37,6 @@ class UnsafeAPIValidationTest : public ValidationTest {
|
|||
}
|
||||
};
|
||||
|
||||
// Check that pipeline overridable constants are disallowed as part of unsafe APIs.
|
||||
// TODO(dawn:1041) Remove when implementation for all backend is added
|
||||
TEST_F(UnsafeAPIValidationTest, PipelineOverridableConstants) {
|
||||
// Create the placeholder compute pipeline.
|
||||
wgpu::ComputePipelineDescriptor pipelineDescBase;
|
||||
pipelineDescBase.compute.entryPoint = "main";
|
||||
|
||||
// Control case: shader without overridable constant is allowed.
|
||||
{
|
||||
wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase;
|
||||
pipelineDesc.compute.module =
|
||||
utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}");
|
||||
|
||||
device.CreateComputePipeline(&pipelineDesc);
|
||||
}
|
||||
|
||||
// Error case: shader with overridable constant with default value
|
||||
{
|
||||
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(
|
||||
@id(1000) override c0: u32 = 1u;
|
||||
@id(1000) override c1: u32;
|
||||
|
||||
@compute @workgroup_size(1) fn main() {
|
||||
_ = c0;
|
||||
_ = c1;
|
||||
})"));
|
||||
}
|
||||
|
||||
// Error case: pipeline stage with constant entry is disallowed
|
||||
{
|
||||
wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase;
|
||||
pipelineDesc.compute.module =
|
||||
utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}");
|
||||
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c", 1u}};
|
||||
pipelineDesc.compute.constants = constants.data();
|
||||
pipelineDesc.compute.constantCount = constants.size();
|
||||
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&pipelineDesc));
|
||||
}
|
||||
}
|
||||
|
||||
class UnsafeQueryAPIValidationTest : public ValidationTest {
|
||||
protected:
|
||||
WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override {
|
||||
|
|
|
@ -133,6 +133,110 @@ Inspector::Inspector(const Program* program) : program_(program) {}
|
|||
|
||||
Inspector::~Inspector() = default;
|
||||
|
||||
EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) {
|
||||
EntryPoint entry_point;
|
||||
TINT_ASSERT(Inspector, func != nullptr);
|
||||
TINT_ASSERT(Inspector, func->IsEntryPoint());
|
||||
|
||||
auto* sem = program_->Sem().Get(func);
|
||||
|
||||
entry_point.name = program_->Symbols().NameFor(func->symbol);
|
||||
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol);
|
||||
|
||||
switch (func->PipelineStage()) {
|
||||
case ast::PipelineStage::kCompute: {
|
||||
entry_point.stage = PipelineStage::kCompute;
|
||||
|
||||
auto wgsize = sem->WorkgroupSize();
|
||||
if (!wgsize[0].overridable_const && !wgsize[1].overridable_const &&
|
||||
!wgsize[2].overridable_const) {
|
||||
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value, wgsize[2].value};
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ast::PipelineStage::kFragment: {
|
||||
entry_point.stage = PipelineStage::kFragment;
|
||||
break;
|
||||
}
|
||||
case ast::PipelineStage::kVertex: {
|
||||
entry_point.stage = PipelineStage::kVertex;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
TINT_UNREACHABLE(Inspector, diagnostics_)
|
||||
<< "invalid pipeline stage for entry point '" << entry_point.name << "'";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* param : sem->Parameters()) {
|
||||
AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol),
|
||||
param->Type(), param->Declaration()->attributes,
|
||||
entry_point.input_variables);
|
||||
|
||||
entry_point.input_position_used |= ContainsBuiltin(
|
||||
ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes);
|
||||
entry_point.front_facing_used |= ContainsBuiltin(
|
||||
ast::BuiltinValue::kFrontFacing, param->Type(), param->Declaration()->attributes);
|
||||
entry_point.sample_index_used |= ContainsBuiltin(
|
||||
ast::BuiltinValue::kSampleIndex, param->Type(), param->Declaration()->attributes);
|
||||
entry_point.input_sample_mask_used |= ContainsBuiltin(
|
||||
ast::BuiltinValue::kSampleMask, param->Type(), param->Declaration()->attributes);
|
||||
entry_point.num_workgroups_used |= ContainsBuiltin(
|
||||
ast::BuiltinValue::kNumWorkgroups, param->Type(), param->Declaration()->attributes);
|
||||
}
|
||||
|
||||
if (!sem->ReturnType()->Is<sem::Void>()) {
|
||||
AddEntryPointInOutVariables("<retval>", sem->ReturnType(), func->return_type_attributes,
|
||||
entry_point.output_variables);
|
||||
|
||||
entry_point.output_sample_mask_used = ContainsBuiltin(
|
||||
ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes);
|
||||
}
|
||||
|
||||
for (auto* var : sem->TransitivelyReferencedGlobals()) {
|
||||
auto* decl = var->Declaration();
|
||||
|
||||
auto name = program_->Symbols().NameFor(decl->symbol);
|
||||
|
||||
auto* global = var->As<sem::GlobalVariable>();
|
||||
if (global && global->Declaration()->Is<ast::Override>()) {
|
||||
Override override;
|
||||
override.name = name;
|
||||
override.id = global->OverrideId();
|
||||
auto* type = var->Type();
|
||||
TINT_ASSERT(Inspector, type->is_scalar());
|
||||
if (type->is_bool_scalar_or_vector()) {
|
||||
override.type = Override::Type::kBool;
|
||||
} else if (type->is_float_scalar()) {
|
||||
override.type = Override::Type::kFloat32;
|
||||
} else if (type->is_signed_integer_scalar()) {
|
||||
override.type = Override::Type::kInt32;
|
||||
} else if (type->is_unsigned_integer_scalar()) {
|
||||
override.type = Override::Type::kUint32;
|
||||
} else {
|
||||
TINT_UNREACHABLE(Inspector, diagnostics_);
|
||||
}
|
||||
|
||||
override.is_initialized = global->Declaration()->constructor;
|
||||
override.is_id_specified =
|
||||
ast::HasAttribute<ast::IdAttribute>(global->Declaration()->attributes);
|
||||
|
||||
entry_point.overrides.push_back(override);
|
||||
}
|
||||
}
|
||||
|
||||
return entry_point;
|
||||
}
|
||||
|
||||
EntryPoint Inspector::GetEntryPoint(const std::string& entry_point_name) {
|
||||
auto* func = FindEntryPointByName(entry_point_name);
|
||||
if (!func) {
|
||||
return EntryPoint();
|
||||
}
|
||||
return GetEntryPoint(func);
|
||||
}
|
||||
|
||||
std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
||||
std::vector<EntryPoint> result;
|
||||
|
||||
|
@ -141,97 +245,7 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
|||
continue;
|
||||
}
|
||||
|
||||
auto* sem = program_->Sem().Get(func);
|
||||
|
||||
EntryPoint entry_point;
|
||||
entry_point.name = program_->Symbols().NameFor(func->symbol);
|
||||
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol);
|
||||
|
||||
switch (func->PipelineStage()) {
|
||||
case ast::PipelineStage::kCompute: {
|
||||
entry_point.stage = PipelineStage::kCompute;
|
||||
|
||||
auto wgsize = sem->WorkgroupSize();
|
||||
if (!wgsize[0].overridable_const && !wgsize[1].overridable_const &&
|
||||
!wgsize[2].overridable_const) {
|
||||
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value,
|
||||
wgsize[2].value};
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ast::PipelineStage::kFragment: {
|
||||
entry_point.stage = PipelineStage::kFragment;
|
||||
break;
|
||||
}
|
||||
case ast::PipelineStage::kVertex: {
|
||||
entry_point.stage = PipelineStage::kVertex;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
TINT_UNREACHABLE(Inspector, diagnostics_)
|
||||
<< "invalid pipeline stage for entry point '" << entry_point.name << "'";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* param : sem->Parameters()) {
|
||||
AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol),
|
||||
param->Type(), param->Declaration()->attributes,
|
||||
entry_point.input_variables);
|
||||
|
||||
entry_point.input_position_used |= ContainsBuiltin(
|
||||
ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes);
|
||||
entry_point.front_facing_used |= ContainsBuiltin(
|
||||
ast::BuiltinValue::kFrontFacing, param->Type(), param->Declaration()->attributes);
|
||||
entry_point.sample_index_used |= ContainsBuiltin(
|
||||
ast::BuiltinValue::kSampleIndex, param->Type(), param->Declaration()->attributes);
|
||||
entry_point.input_sample_mask_used |= ContainsBuiltin(
|
||||
ast::BuiltinValue::kSampleMask, param->Type(), param->Declaration()->attributes);
|
||||
entry_point.num_workgroups_used |= ContainsBuiltin(
|
||||
ast::BuiltinValue::kNumWorkgroups, param->Type(), param->Declaration()->attributes);
|
||||
}
|
||||
|
||||
if (!sem->ReturnType()->Is<sem::Void>()) {
|
||||
AddEntryPointInOutVariables("<retval>", sem->ReturnType(), func->return_type_attributes,
|
||||
entry_point.output_variables);
|
||||
|
||||
entry_point.output_sample_mask_used = ContainsBuiltin(
|
||||
ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes);
|
||||
}
|
||||
|
||||
for (auto* var : sem->TransitivelyReferencedGlobals()) {
|
||||
auto* decl = var->Declaration();
|
||||
|
||||
auto name = program_->Symbols().NameFor(decl->symbol);
|
||||
|
||||
auto* global = var->As<sem::GlobalVariable>();
|
||||
if (global && global->Declaration()->Is<ast::Override>()) {
|
||||
Override override;
|
||||
override.name = name;
|
||||
override.id = global->OverrideId();
|
||||
auto* type = var->Type();
|
||||
TINT_ASSERT(Inspector, type->is_scalar());
|
||||
if (type->is_bool_scalar_or_vector()) {
|
||||
override.type = Override::Type::kBool;
|
||||
} else if (type->is_float_scalar()) {
|
||||
override.type = Override::Type::kFloat32;
|
||||
} else if (type->is_signed_integer_scalar()) {
|
||||
override.type = Override::Type::kInt32;
|
||||
} else if (type->is_unsigned_integer_scalar()) {
|
||||
override.type = Override::Type::kUint32;
|
||||
} else {
|
||||
TINT_UNREACHABLE(Inspector, diagnostics_);
|
||||
}
|
||||
|
||||
override.is_initialized = global->Declaration()->constructor;
|
||||
override.is_id_specified =
|
||||
ast::HasAttribute<ast::IdAttribute>(global->Declaration()->attributes);
|
||||
|
||||
entry_point.overrides.push_back(override);
|
||||
}
|
||||
}
|
||||
|
||||
result.push_back(std::move(entry_point));
|
||||
result.push_back(GetEntryPoint(func));
|
||||
}
|
||||
|
||||
return result;
|
||||
|
|
|
@ -55,6 +55,10 @@ class Inspector {
|
|||
/// @returns vector of entry point information
|
||||
std::vector<EntryPoint> GetEntryPoints();
|
||||
|
||||
/// @param entry_point name of the entry point to get information about
|
||||
/// @returns the entry point information
|
||||
EntryPoint GetEntryPoint(const std::string& entry_point);
|
||||
|
||||
/// @returns map of override identifier to initial value
|
||||
std::map<OverrideId, Scalar> GetOverrideDefaultValues();
|
||||
|
||||
|
@ -230,6 +234,10 @@ class Inspector {
|
|||
/// whenever a set of expressions are resolved to globals.
|
||||
template <size_t N, typename F>
|
||||
void GetOriginatingResources(std::array<const ast::Expression*, N> exprs, F&& cb);
|
||||
|
||||
/// @param func the function of the entry point. Must be non-nullptr and true for IsEntryPoint()
|
||||
/// @returns the entry point information
|
||||
EntryPoint GetEntryPoint(const tint::ast::Function* func);
|
||||
};
|
||||
|
||||
} // namespace tint::inspector
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "tint/override_id.h"
|
||||
|
||||
#include "src/tint/reflection.h"
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint::transform {
|
||||
|
@ -63,6 +64,9 @@ class SubstituteOverride final : public Castable<SubstituteOverride, Transform>
|
|||
/// The value is always a double coming into the transform and will be
|
||||
/// converted to the correct type through and initializer.
|
||||
std::unordered_map<OverrideId, double> map;
|
||||
|
||||
/// Reflect the fields of this class so that it can be used by tint::ForeachField()
|
||||
TINT_REFLECT(map);
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
|
|
Loading…
Reference in New Issue