Use SubstituteOverride transform to implement overrides

Remove the old backend specific implementation for
overrides. Use tint SubstituteOverride transform to replace
overrides with const expressions and use the updated program
at pipeline creation time.

This CL also adds support for overrides used as workgroup size
and related tests. Workgroup size validation now happens
in backend code and at compute pipeline creation time.

Bug: dawn:1504
Change-Id: I7df1fe9c3e358caa23235eacd6d13ba0b2998aec
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99821
Commit-Queue: Shrek Shao <shrekshao@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
shrekshao 2022-09-07 20:09:54 +00:00 committed by Dawn LUCI CQ
parent 23cf74c30e
commit 145337f309
36 changed files with 1246 additions and 698 deletions

View File

@ -25,7 +25,15 @@
namespace {{native_namespace}} {
//
// Cache key writers for wgpu structures used in caching.
// Streaming readers for wgpu structures.
//
{% macro render_reader(member) %}
{%- set name = member.name.camelCase() -%}
DAWN_TRY(StreamOut(source, &t->{{name}}));
{% endmacro %}
//
// Streaming writers for wgpu structures.
//
{% macro render_writer(member) %}
{%- set name = member.name.camelCase() -%}
@ -38,31 +46,50 @@ namespace {{native_namespace}} {
{% endif %}
{% endmacro %}
{# Helper macro to render writers. Should be used in a call block to provide additional custom
{# Helper macro to render readers and writers. Should be used in a call block to provide additional custom
handling when necessary. The optional `omit` field can be used to omit fields that are either
handled in the custom code, or unnecessary in the serialized output.
Example:
{% call render_cache_key_writer("struct name", omits=["omit field"]) %}
{% call render_streaming_impl("struct name", writer=true, reader=false, omits=["omit field"]) %}
// Custom C++ code to handle special types/members that are hard to generate code for
{% endcall %}
One day we should probably make the generator smart enough to generate everything it can
instead of manually adding streaming implementations here.
#}
{% macro render_cache_key_writer(json_type, omits=[]) %}
{% macro render_streaming_impl(json_type, writer, reader, omits=[]) %}
{%- set cpp_type = types[json_type].name.CamelCase() -%}
{% if reader %}
template <>
MaybeError stream::Stream<{{cpp_type}}>::Read(stream::Source* source, {{cpp_type}}* t) {
{{ caller() }}
{% for member in types[json_type].members %}
{% if not member.name.get() in omits %}
{{render_reader(member)}}
{% endif %}
{% endfor %}
return {};
}
{% endif %}
{% if writer %}
template <>
void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) {
{{ caller() }}
{% for member in types[json_type].members %}
{%- if not member.name.get() in omits %}
{% if not member.name.get() in omits %}
{{render_writer(member)}}
{%- endif %}
{% endif %}
{% endfor %}
}
{% endif %}
{% endmacro %}
{% call render_cache_key_writer("adapter properties") %}
{% call render_streaming_impl("adapter properties", true, false) %}
{% endcall %}
{% call render_cache_key_writer("dawn cache device descriptor") %}
{% call render_streaming_impl("dawn cache device descriptor", true, false) %}
{% endcall %}
{% call render_streaming_impl("extent 3D", true, true) %}
{% endcall %}
} // namespace {{native_namespace}}

View File

@ -215,4 +215,22 @@ Limits ApplyLimitTiers(Limits limits) {
return limits;
}
#define DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT(type, name) \
{ result.name = limits.name; }
#define DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(MEMBERS) \
MEMBERS(DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT)
LimitsForCompilationRequest LimitsForCompilationRequest::Create(const Limits& limits) {
LimitsForCompilationRequest result;
DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS)
return result;
}
#undef DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT
#undef DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT
template <>
void stream::Stream<LimitsForCompilationRequest>::Write(Sink* s,
const LimitsForCompilationRequest& t) {
t.VisitAll([&](const auto&... members) { StreamIn(s, members...); });
}
} // namespace dawn::native

View File

@ -16,6 +16,7 @@
#define SRC_DAWN_NATIVE_LIMITS_H_
#include "dawn/native/Error.h"
#include "dawn/native/VisitableMembers.h"
#include "dawn/native/dawn_platform.h"
namespace dawn::native {
@ -38,6 +39,20 @@ MaybeError ValidateLimits(const Limits& supportedLimits, const Limits& requiredL
// Returns a copy of |limits| where limit tiers are applied.
Limits ApplyLimitTiers(Limits limits);
// If there are new limit member needed at shader compilation time
// Simply append a new X(type, name) here.
#define LIMITS_FOR_COMPILATION_REQUEST_MEMBERS(X) \
X(uint32_t, maxComputeWorkgroupSizeX) \
X(uint32_t, maxComputeWorkgroupSizeY) \
X(uint32_t, maxComputeWorkgroupSizeZ) \
X(uint32_t, maxComputeInvocationsPerWorkgroup) \
X(uint32_t, maxComputeWorkgroupStorageSize)
struct LimitsForCompilationRequest {
static LimitsForCompilationRequest Create(const Limits& limits);
DAWN_VISITABLE_MEMBERS(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS)
};
} // namespace dawn::native
#endif // SRC_DAWN_NATIVE_LIMITS_H_

View File

@ -15,7 +15,9 @@
#ifndef SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_
#define SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "dawn/common/HashUtils.h"
@ -60,6 +62,13 @@ class ObjectContentHasher {
}
};
template <typename T, typename E>
struct RecordImpl<std::map<T, E>> {
static constexpr void Call(ObjectContentHasher* recorder, const std::map<T, E>& map) {
recorder->RecordIterable<std::map<T, E>>(map);
}
};
template <typename IteratorT>
constexpr void RecordIterable(const IteratorT& iterable) {
for (auto it = iterable.begin(); it != iterable.end(); ++it) {
@ -67,6 +76,14 @@ class ObjectContentHasher {
}
}
template <typename T, typename E>
struct RecordImpl<std::pair<T, E>> {
static constexpr void Call(ObjectContentHasher* recorder, const std::pair<T, E>& pair) {
recorder->Record(pair.first);
recorder->Record(pair.second);
}
};
size_t mContentHash = 0;
};

View File

@ -58,12 +58,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout));
}
if (constantCount > 0u && device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs)) {
return DAWN_VALIDATION_ERROR(
"Pipeline overridable constants are disallowed because they are partially "
"implemented.");
}
// Validate if overridable constants exist in shader module
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor
size_t numUninitializedConstants = metadata.uninitializedOverrides.size();
@ -233,6 +227,7 @@ size_t PipelineBase::ComputeContentHash() {
for (SingleShaderStage stage : IterateStages(mStageMask)) {
recorder.Record(mStages[stage].module->GetContentHash());
recorder.Record(mStages[stage].entryPoint);
recorder.Record(mStages[stage].constants);
}
return recorder.GetContentHash();
@ -248,7 +243,14 @@ bool PipelineBase::EqualForCache(const PipelineBase* a, const PipelineBase* b) {
for (SingleShaderStage stage : IterateStages(a->mStageMask)) {
// The module is deduplicated so it can be compared by pointer.
if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() ||
a->mStages[stage].entryPoint != b->mStages[stage].entryPoint) {
a->mStages[stage].entryPoint != b->mStages[stage].entryPoint ||
a->mStages[stage].constants.size() != b->mStages[stage].constants.size()) {
return false;
}
// If the constants.size are the same, we still need to compare the key and value.
if (!std::equal(a->mStages[stage].constants.begin(), a->mStages[stage].constants.end(),
b->mStages[stage].constants.begin())) {
return false;
}
}

View File

@ -40,9 +40,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
const PipelineLayoutBase* layout,
SingleShaderStage stage);
// Use map to make sure constant keys are sorted for creating shader cache keys
using PipelineConstantEntries = std::map<std::string, double>;
struct ProgrammableStage {
Ref<ShaderModuleBase> module;
std::string entryPoint;

View File

@ -20,7 +20,6 @@
#include "absl/strings/str_format.h"
#include "dawn/common/BitSetIterator.h"
#include "dawn/common/Constants.h"
#include "dawn/common/HashUtils.h"
#include "dawn/native/BindGroupLayout.h"
#include "dawn/native/ChainUtils_autogen.h"
#include "dawn/native/CompilationMessages.h"
@ -511,7 +510,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
const DeviceBase* device,
tint::inspector::Inspector* inspector,
const tint::inspector::EntryPoint& entryPoint) {
const CombinedLimits& limits = device->GetLimits();
constexpr uint32_t kMaxInterStageShaderLocation = kMaxInterStageShaderVariables - 1;
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
@ -528,10 +526,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
})()
if (!entryPoint.overrides.empty()) {
DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs),
"Pipeline overridable constants are disallowed because they "
"are partially implemented.");
const auto& name2Id = inspector->GetNamedOverrideIds();
const auto& id2Scalar = inspector->GetOverrideDefaultValues();
@ -553,10 +547,10 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
UNREACHABLE();
}
}
EntryPointMetadata::Override override = {id.value, FromTintOverrideType(c.type),
EntryPointMetadata::Override override = {id, FromTintOverrideType(c.type),
c.is_initialized, defaultValue};
std::string identifier = c.is_id_specified ? std::to_string(override.id) : c.name;
std::string identifier = c.is_id_specified ? std::to_string(override.id.value) : c.name;
metadata->overrides[identifier] = override;
if (!c.is_initialized) {
@ -575,39 +569,6 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
if (metadata->stage == SingleShaderStage::Compute) {
auto workgroup_size = entryPoint.workgroup_size;
DAWN_INVALID_IF(
!workgroup_size.has_value(),
"TODO(crbug.com/dawn/1504): Dawn does not currently support @workgroup_size "
"attributes using override-expressions");
DelayedInvalidIf(workgroup_size->x > limits.v1.maxComputeWorkgroupSizeX ||
workgroup_size->y > limits.v1.maxComputeWorkgroupSizeY ||
workgroup_size->z > limits.v1.maxComputeWorkgroupSizeZ,
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
"maximum allowed (%u, %u, %u).",
workgroup_size->x, workgroup_size->y, workgroup_size->z,
limits.v1.maxComputeWorkgroupSizeX, limits.v1.maxComputeWorkgroupSizeY,
limits.v1.maxComputeWorkgroupSizeZ);
// Dimensions have already been validated against their individual limits above.
// Cast to uint64_t to avoid overflow in this multiplication.
uint64_t numInvocations =
static_cast<uint64_t>(workgroup_size->x) * workgroup_size->y * workgroup_size->z;
DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup,
"The total number of workgroup invocations (%u) exceeds the "
"maximum allowed (%u).",
numInvocations, limits.v1.maxComputeInvocationsPerWorkgroup);
const size_t workgroupStorageSize = inspector->GetWorkgroupStorageSize(entryPoint.name);
DelayedInvalidIf(workgroupStorageSize > limits.v1.maxComputeWorkgroupStorageSize,
"The total use of workgroup storage (%u bytes) is larger than "
"the maximum allowed (%u bytes).",
workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize);
metadata->localWorkgroupSize.x = workgroup_size->x;
metadata->localWorkgroupSize.y = workgroup_size->y;
metadata->localWorkgroupSize.z = workgroup_size->z;
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
}
@ -883,6 +844,46 @@ MaybeError ReflectShaderUsingTint(const DeviceBase* device,
}
} // anonymous namespace
ResultOrError<Extent3D> ValidateComputeStageWorkgroupSize(
const tint::Program& program,
const char* entryPointName,
const LimitsForCompilationRequest& limits) {
tint::inspector::Inspector inspector(&program);
// At this point the entry point must exist and must have workgroup size values.
tint::inspector::EntryPoint entryPoint = inspector.GetEntryPoint(entryPointName);
ASSERT(entryPoint.workgroup_size.has_value());
const tint::inspector::WorkgroupSize& workgroup_size = entryPoint.workgroup_size.value();
DAWN_INVALID_IF(workgroup_size.x < 1 || workgroup_size.y < 1 || workgroup_size.z < 1,
"Entry-point uses workgroup_size(%u, %u, %u) that are below the "
"minimum allowed (1, 1, 1).",
workgroup_size.x, workgroup_size.y, workgroup_size.z);
DAWN_INVALID_IF(workgroup_size.x > limits.maxComputeWorkgroupSizeX ||
workgroup_size.y > limits.maxComputeWorkgroupSizeY ||
workgroup_size.z > limits.maxComputeWorkgroupSizeZ,
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
"maximum allowed (%u, %u, %u).",
workgroup_size.x, workgroup_size.y, workgroup_size.z,
limits.maxComputeWorkgroupSizeX, limits.maxComputeWorkgroupSizeY,
limits.maxComputeWorkgroupSizeZ);
uint64_t numInvocations =
static_cast<uint64_t>(workgroup_size.x) * workgroup_size.y * workgroup_size.z;
DAWN_INVALID_IF(numInvocations > limits.maxComputeInvocationsPerWorkgroup,
"The total number of workgroup invocations (%u) exceeds the "
"maximum allowed (%u).",
numInvocations, limits.maxComputeInvocationsPerWorkgroup);
const size_t workgroupStorageSize = inspector.GetWorkgroupStorageSize(entryPointName);
DAWN_INVALID_IF(workgroupStorageSize > limits.maxComputeWorkgroupStorageSize,
"The total use of workgroup storage (%u bytes) is larger than "
"the maximum allowed (%u bytes).",
workgroupStorageSize, limits.maxComputeWorkgroupStorageSize);
return Extent3D{workgroup_size.x, workgroup_size.y, workgroup_size.z};
}
ShaderModuleParseResult::ShaderModuleParseResult() = default;
ShaderModuleParseResult::~ShaderModuleParseResult() = default;
@ -1200,11 +1201,4 @@ MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult
return {};
}
size_t PipelineLayoutEntryPointPairHashFunc::operator()(
const PipelineLayoutEntryPointPair& pair) const {
size_t hash = 0;
HashCombine(&hash, pair.first, pair.second);
return hash;
}
} // namespace dawn::native

View File

@ -33,10 +33,12 @@
#include "dawn/native/Format.h"
#include "dawn/native/Forward.h"
#include "dawn/native/IntegerTypes.h"
#include "dawn/native/Limits.h"
#include "dawn/native/ObjectBase.h"
#include "dawn/native/PerStage.h"
#include "dawn/native/VertexFormat.h"
#include "dawn/native/dawn_platform.h"
#include "tint/override_id.h"
namespace tint {
@ -76,10 +78,8 @@ enum class InterpolationSampling {
Sample,
};
using PipelineLayoutEntryPointPair = std::pair<const PipelineLayoutBase*, std::string>;
struct PipelineLayoutEntryPointPairHashFunc {
size_t operator()(const PipelineLayoutEntryPointPair& pair) const;
};
// Use map to make sure constant keys are sorted for creating shader cache keys
using PipelineConstantEntries = std::map<std::string, double>;
// A map from name to EntryPointMetadata.
using EntryPointMetadataTable =
@ -108,6 +108,13 @@ MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout);
// Return extent3D with workgroup size dimension info if it is valid
// width = x, height = y, depthOrArrayLength = z
ResultOrError<Extent3D> ValidateComputeStageWorkgroupSize(
const tint::Program& program,
const char* entryPointName,
const LimitsForCompilationRequest& limits);
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout);
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
@ -204,14 +211,12 @@ struct EntryPointMetadata {
std::bitset<kMaxInterStageShaderVariables> usedInterStageVariables;
std::array<InterStageVariableInfo, kMaxInterStageShaderVariables> interStageVariables;
// The local workgroup size declared for a compute entry point (or 0s otehrwise).
Origin3D localWorkgroupSize;
// The shader stage for this binding.
SingleShaderStage stage;
struct Override {
uint32_t id;
tint::OverrideId id;
// Match tint::inspector::Override::Type
// Bool is defined as a macro on linux X11 and cannot compile
enum class Type { Boolean, Float32, Uint32, Int32 } type;
@ -273,6 +278,7 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const;
};
// This returns tint program before running transforms.
const tint::Program* GetTintProgram() const;
void APIGetCompilationInfo(wgpu::CompilationInfoCallback callback, void* userdata);

View File

@ -65,6 +65,23 @@ void stream::Stream<tint::transform::VertexPulling::Config>::Write(
StreamInTintObject(cfg, sink);
}
// static
template <>
void stream::Stream<tint::transform::SubstituteOverride::Config>::Write(
stream::Sink* sink,
const tint::transform::SubstituteOverride::Config& cfg) {
StreamInTintObject(cfg, sink);
}
// static
template <>
void stream::Stream<tint::OverrideId>::Write(stream::Sink* sink, const tint::OverrideId& id) {
// TODO(tint:1640): fix the include build issues and use StreamInTintObject instead.
static_assert(offsetof(tint::OverrideId, value) == 0,
"Please update serialization for tint::OverrideId");
StreamIn(sink, id.value);
}
// static
template <>
void stream::Stream<tint::transform::VertexBufferLayoutDescriptor>::Write(

View File

@ -16,6 +16,7 @@
#include "dawn/native/BindGroupLayout.h"
#include "dawn/native/Device.h"
#include "dawn/native/Pipeline.h"
#include "dawn/native/PipelineLayout.h"
#include "dawn/native/RenderPipeline.h"
@ -183,6 +184,21 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
return cfg;
}
tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig(
const ProgrammableStage& stage) {
const EntryPointMetadata& metadata = *stage.metadata;
const auto& constants = stage.constants;
tint::transform::SubstituteOverride::Config cfg;
for (const auto& [key, value] : constants) {
const auto& o = metadata.overrides.at(key);
cfg.map.insert({o.id, value});
}
return cfg;
}
} // namespace dawn::native
namespace tint::sem {

View File

@ -26,6 +26,7 @@ namespace dawn::native {
class DeviceBase;
class PipelineLayoutBase;
struct ProgrammableStage;
class RenderPipelineBase;
// Indicates that for the lifetime of this object tint internal compiler errors should be
@ -47,6 +48,9 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
const std::string_view& entryPoint,
BindGroupIndex pullingBufferBindingSet);
tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig(
const ProgrammableStage& stage);
} // namespace dawn::native
namespace tint::sem {

View File

@ -351,8 +351,6 @@ MaybeError RenderPipeline::Initialize() {
D3D12_GRAPHICS_PIPELINE_STATE_DESC descriptorD3D12 = {};
PerStage<ProgrammableStage> pipelineStages = GetAllStages();
PerStage<D3D12_SHADER_BYTECODE*> shaders;
shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS;
shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS;
@ -360,8 +358,9 @@ MaybeError RenderPipeline::Initialize() {
PerStage<CompiledShader> compiledShader;
for (auto stage : IterateStages(GetStageMask())) {
DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(pipelineStages[stage].module)
->Compile(pipelineStages[stage], stage,
const ProgrammableStage& programmableStage = GetStage(stage);
DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(programmableStage.module)
->Compile(programmableStage, stage,
ToBackend(GetLayout()), compileFlags));
*shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
}

View File

@ -65,72 +65,6 @@ namespace dawn::native::d3d12 {
namespace {
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
std::ostringstream out;
out.precision(n);
out << std::fixed << v;
return out.str();
}
std::string GetHLSLValueString(EntryPointMetadata::Override::Type dawnType,
const OverrideScalar* entry,
double value = 0) {
switch (dawnType) {
case EntryPointMetadata::Override::Type::Boolean:
return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
case EntryPointMetadata::Override::Type::Float32:
return FloatToStringWithPrecision(entry ? entry->f32 : static_cast<float>(value));
case EntryPointMetadata::Override::Type::Int32:
return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
case EntryPointMetadata::Override::Type::Uint32:
return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
default:
UNREACHABLE();
}
}
constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
using DefineStrings = std::vector<std::pair<std::string, std::string>>;
DefineStrings GetOverridableConstantsDefines(
const PipelineConstantEntries& pipelineConstantEntries,
const EntryPointMetadata::OverridesMap& shaderEntryPointConstants) {
DefineStrings defineStrings;
std::unordered_set<std::string> overriddenConstants;
// Set pipeline overridden values
for (const auto& [name, value] : pipelineConstantEntries) {
overriddenConstants.insert(name);
// This is already validated so `name` must exist
const auto& moduleConstant = shaderEntryPointConstants.at(name);
defineStrings.emplace_back(
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString(moduleConstant.type, nullptr, value));
}
// Set shader initialized default values
for (const auto& iter : shaderEntryPointConstants) {
const std::string& name = iter.first;
if (overriddenConstants.count(name) != 0) {
// This constant already has overridden value
continue;
}
const auto& moduleConstant = shaderEntryPointConstants.at(name);
// Uninitialized default values are okay since they ar only defined to pass
// compilation but not used
defineStrings.emplace_back(
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
}
return defineStrings;
}
enum class Compiler { FXC, DXC };
#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
@ -151,11 +85,12 @@ enum class Compiler { FXC, DXC };
X(bool, usesNumWorkgroups) \
X(uint32_t, numWorkgroupsShaderRegister) \
X(uint32_t, numWorkgroupsRegisterSpace) \
X(DefineStrings, defineStrings) \
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \
X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \
X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \
X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
X(LimitsForCompilationRequest, limits) \
X(bool, disableSymbolRenaming) \
X(bool, isRobustnessEnabled) \
X(bool, disableWorkgroupInit) \
@ -170,8 +105,7 @@ enum class Compiler { FXC, DXC };
X(std::string_view, fxcShaderProfile) \
X(pD3DCompile, d3dCompile) \
X(IDxcLibrary*, dxcLibrary) \
X(IDxcCompiler*, dxcCompiler) \
X(DefineStrings, defineStrings)
X(IDxcCompiler*, dxcCompiler)
DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){};
#undef HLSL_COMPILATION_REQUEST_MEMBERS
@ -255,24 +189,10 @@ ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(const D3DBytecodeCompilationReq
std::vector<const wchar_t*> arguments =
GetDXCArguments(r.compileFlags, r.hasShaderFloat16Feature);
// Build defines for overridable constants
std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
defineStrings.reserve(r.defineStrings.size());
for (const auto& [name, value] : r.defineStrings) {
defineStrings.emplace_back(UTF8ToWStr(name.c_str()), UTF8ToWStr(value.c_str()));
}
std::vector<DxcDefine> dxcDefines;
dxcDefines.reserve(defineStrings.size());
for (const auto& [name, value] : defineStrings) {
dxcDefines.push_back({name.c_str(), value.c_str()});
}
ComPtr<IDxcOperationResult> result;
DAWN_TRY(CheckHRESULT(
r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
r.dxcShaderProfile.data(), arguments.data(), arguments.size(),
dxcDefines.data(), dxcDefines.size(), nullptr, &result),
DAWN_TRY(CheckHRESULT(r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
r.dxcShaderProfile.data(), arguments.data(),
arguments.size(), nullptr, 0, nullptr, &result),
"DXC compile"));
HRESULT hr;
@ -359,20 +279,7 @@ ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const D3DBytecodeCompilationReq
ComPtr<ID3DBlob> compiledShader;
ComPtr<ID3DBlob> errors;
// Build defines for overridable constants
const D3D_SHADER_MACRO* pDefines = nullptr;
std::vector<D3D_SHADER_MACRO> fxcDefines;
if (r.defineStrings.size() > 0) {
fxcDefines.reserve(r.defineStrings.size() + 1);
for (const auto& [name, value] : r.defineStrings) {
fxcDefines.push_back({name.c_str(), value.c_str()});
}
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
fxcDefines.push_back({nullptr, nullptr});
pDefines = fxcDefines.data();
}
DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines,
DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(),
r.compileFlags, 0, &compiledShader, &errors)),
"D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer()));
@ -420,6 +327,14 @@ ResultOrError<std::string> TranslateToHLSL(
tint::transform::Renamer::Target::kHlslKeywords);
}
if (r.substituteOverrideConfig) {
// This needs to run after SingleEntryPoint transform to get rid of overrides not used for
// the current entry point.
transformManager.Add<tint::transform::SubstituteOverride>();
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
std::move(r.substituteOverrideConfig).value());
}
// D3D12 registers like `t3` and `c3` have the same bindingOffset number in
// the remapping but should not be considered a collision because they have
// different types.
@ -450,6 +365,13 @@ ResultOrError<std::string> TranslateToHLSL(
return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
}
if (r.stage == SingleShaderStage::Compute) {
// Validate workgroup size after program runs transforms.
Extent3D _;
DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
transformedProgram, remappedEntryPointName->data(), r.limits));
}
if (r.stage == SingleShaderStage::Vertex) {
if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
*usesVertexOrInstanceIndex = data->has_vertex_or_instance_index;
@ -555,8 +477,7 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
req.bytecode.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
req.bytecode.compileFlags = compileFlags;
req.bytecode.defineStrings =
GetOverridableConstantsDefines(programmableStage.constants, entryPoint.overrides);
if (device->IsToggleEnabled(Toggle::UseDXC)) {
req.bytecode.compiler = Compiler::DXC;
req.bytecode.dxcLibrary = device->GetDxcLibrary().Get();
@ -645,6 +566,11 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
}
}
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
if (!programmableStage.metadata->overrides.empty()) {
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
}
req.hlsl.inputProgram = GetTintProgram();
req.hlsl.entryPointName = programmableStage.entryPoint.c_str();
req.hlsl.stage = stage;
@ -657,6 +583,10 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
req.hlsl.remappedAccessControls = std::move(remappedAccessControls);
req.hlsl.newBindingsMap = BuildExternalTextureTransformBindings(layout);
req.hlsl.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
req.hlsl.substituteOverrideConfig = std::move(substituteOverrideConfig);
const CombinedLimits& limits = device->GetLimits();
req.hlsl.limits = LimitsForCompilationRequest::Create(limits.v1);
CacheResult<CompiledShader> compiledShader;
DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob,

View File

@ -40,7 +40,8 @@ MaybeError ComputePipeline::Initialize() {
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
ShaderModule::MetalFunctionData computeData;
DAWN_TRY(CreateMTLFunction(computeStage, SingleShaderStage::Compute, ToBackend(GetLayout()),
DAWN_TRY(ToBackend(computeStage.module.Get())
->CreateFunction(SingleShaderStage::Compute, computeStage, ToBackend(GetLayout()),
&computeData));
NSError* error = nullptr;
@ -53,8 +54,7 @@ MaybeError ComputePipeline::Initialize() {
ASSERT(mMtlComputePipelineState != nil);
// Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
Origin3D localSize = GetStage(SingleShaderStage::Compute).metadata->localWorkgroupSize;
mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z);
mLocalWorkgroupSize = computeData.localWorkgroupSize;
mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
mWorkgroupAllocations = std::move(computeData.workgroupAllocations);

View File

@ -340,7 +340,8 @@ MaybeError RenderPipeline::Initialize() {
const PerStage<ProgrammableStage>& allStages = GetAllStages();
const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex];
ShaderModule::MetalFunctionData vertexData;
DAWN_TRY(CreateMTLFunction(vertexStage, SingleShaderStage::Vertex, ToBackend(GetLayout()),
DAWN_TRY(ToBackend(vertexStage.module.Get())
->CreateFunction(SingleShaderStage::Vertex, vertexStage, ToBackend(GetLayout()),
&vertexData, 0xFFFFFFFF, this));
descriptorMTL.vertexFunction = vertexData.function.Get();
@ -351,7 +352,8 @@ MaybeError RenderPipeline::Initialize() {
if (GetStageMask() & wgpu::ShaderStage::Fragment) {
const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment];
ShaderModule::MetalFunctionData fragmentData;
DAWN_TRY(CreateMTLFunction(fragmentStage, SingleShaderStage::Fragment,
DAWN_TRY(ToBackend(fragmentStage.module.Get())
->CreateFunction(SingleShaderStage::Fragment, fragmentStage,
ToBackend(GetLayout()), &fragmentData, GetSampleMask()));
descriptorMTL.fragmentFunction = fragmentData.function.Get();

View File

@ -25,6 +25,10 @@
#import <Metal/Metal.h>
namespace dawn::native {
struct ProgrammableStage;
}
namespace dawn::native::metal {
class Device;
@ -42,15 +46,13 @@ class ShaderModule final : public ShaderModuleBase {
NSPRef<id<MTLFunction>> function;
bool needsStorageBufferLength;
std::vector<uint32_t> workgroupAllocations;
MTLSize localWorkgroupSize;
};
// MTLFunctionConstantValues needs @available tag to compile
// Use id (like void*) in function signature as workaround and do static cast inside
MaybeError CreateFunction(const char* entryPointName,
SingleShaderStage stage,
MaybeError CreateFunction(SingleShaderStage stage,
const ProgrammableStage& programmableStage,
const PipelineLayout* layout,
MetalFunctionData* out,
id constantValues = nil,
uint32_t sampleMask = 0xFFFFFFFF,
const RenderPipeline* renderPipeline = nullptr);

View File

@ -36,10 +36,13 @@ namespace {
using OptionalVertexPullingTransformConfig = std::optional<tint::transform::VertexPulling::Config>;
#define MSL_COMPILATION_REQUEST_MEMBERS(X) \
X(SingleShaderStage, stage) \
X(const tint::Program*, inputProgram) \
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \
X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \
X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
X(LimitsForCompilationRequest, limits) \
X(std::string, entryPointName) \
X(uint32_t, sampleMask) \
X(bool, emitVertexPointSize) \
@ -58,7 +61,8 @@ using WorkgroupAllocations = std::vector<uint32_t>;
X(std::string, remappedEntryPointName) \
X(bool, needsStorageBufferLength) \
X(bool, hasInvariantAttribute) \
X(WorkgroupAllocations, workgroupAllocations)
X(WorkgroupAllocations, workgroupAllocations) \
X(Extent3D, localWorkgroupSize)
DAWN_SERIALIZABLE(struct, MslCompilation, MSL_COMPILATION_MEMBERS){};
#undef MSL_COMPILATION_MEMBERS
@ -92,11 +96,12 @@ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
namespace {
ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
const tint::Program* inputProgram,
const char* entryPointName,
ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(
DeviceBase* device,
const ProgrammableStage& programmableStage,
SingleShaderStage stage,
const PipelineLayout* layout,
ShaderModule::MetalFunctionData* out,
uint32_t sampleMask,
const RenderPipeline* renderPipeline) {
ScopedTintICEHandler scopedICEHandler(device);
@ -137,7 +142,7 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
if (stage == SingleShaderStage::Vertex &&
device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
vertexPullingTransformConfig = BuildVertexPullingTransformConfig(
*renderPipeline, entryPointName, kPullingBufferBindingSet);
*renderPipeline, programmableStage.entryPoint.c_str(), kPullingBufferBindingSet);
for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot);
@ -152,12 +157,19 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
}
}
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
if (!programmableStage.metadata->overrides.empty()) {
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
}
MslCompilationRequest req = {};
req.inputProgram = inputProgram;
req.stage = stage;
req.inputProgram = programmableStage.module->GetTintProgram();
req.bindingPoints = std::move(bindingPoints);
req.externalTextureBindings = std::move(externalTextureBindings);
req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig);
req.entryPointName = entryPointName;
req.substituteOverrideConfig = std::move(substituteOverrideConfig);
req.entryPointName = programmableStage.entryPoint.c_str();
req.sampleMask = sampleMask;
req.emitVertexPointSize =
stage == SingleShaderStage::Vertex &&
@ -166,6 +178,9 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform());
const CombinedLimits& limits = device->GetLimits();
req.limits = LimitsForCompilationRequest::Create(limits.v1);
CacheResult<MslCompilation> mslCompilation;
DAWN_TRY_LOAD_OR_RUN(
mslCompilation, device, std::move(req), MslCompilation::FromBlob,
@ -190,6 +205,14 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
std::move(r.vertexPullingTransformConfig).value());
}
if (r.substituteOverrideConfig) {
// This needs to run after SingleEntryPoint transform to get rid of overrides not
// used for the current entry point.
transformManager.Add<tint::transform::SubstituteOverride>();
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
std::move(r.substituteOverrideConfig).value());
}
if (r.isRobustnessEnabled) {
transformManager.Add<tint::transform::Robustness>();
}
@ -230,6 +253,13 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
}
Extent3D localSize{0, 0, 0};
if (r.stage == SingleShaderStage::Compute) {
// Validate workgroup size after program runs transforms.
DAWN_TRY_ASSIGN(localSize, ValidateComputeStageWorkgroupSize(
program, remappedEntryPointName.data(), r.limits));
}
tint::writer::msl::Options options;
options.buffer_size_ubo_index = kBufferLengthBufferSlot;
options.fixed_sample_mask = r.sampleMask;
@ -258,6 +288,7 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
result.needs_storage_buffer_sizes,
result.has_invariant_attribute,
std::move(workgroupAllocations),
localSize,
}};
});
@ -272,11 +303,10 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
} // namespace
MaybeError ShaderModule::CreateFunction(const char* entryPointName,
SingleShaderStage stage,
MaybeError ShaderModule::CreateFunction(SingleShaderStage stage,
const ProgrammableStage& programmableStage,
const PipelineLayout* layout,
ShaderModule::MetalFunctionData* out,
id constantValuesPointer,
uint32_t sampleMask,
const RenderPipeline* renderPipeline) {
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleMTL::CreateFunction");
@ -284,16 +314,21 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
ASSERT(!IsError());
ASSERT(out);
const char* entryPointName = programmableStage.entryPoint.c_str();
// Vertex stages must specify a renderPipeline
if (stage == SingleShaderStage::Vertex) {
ASSERT(renderPipeline != nullptr);
}
CacheResult<MslCompilation> mslCompilation;
DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), GetTintProgram(), entryPointName,
stage, layout, sampleMask, renderPipeline));
DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), programmableStage, stage, layout,
out, sampleMask, renderPipeline));
out->needsStorageBufferLength = mslCompilation->needsStorageBufferLength;
out->workgroupAllocations = std::move(mslCompilation->workgroupAllocations);
out->localWorkgroupSize = MTLSizeMake(mslCompilation->localWorkgroupSize.width,
mslCompilation->localWorkgroupSize.height,
mslCompilation->localWorkgroupSize.depthOrArrayLayers);
NSRef<NSString> mslSource =
AcquireNSRef([[NSString alloc] initWithUTF8String:mslCompilation->msl.c_str()]);
@ -327,26 +362,8 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
{
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "MTLLibrary::newFunctionWithName");
if (constantValuesPointer != nil) {
if (@available(macOS 10.12, *)) {
MTLFunctionConstantValues* constantValues = constantValuesPointer;
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()
constantValues:constantValues
error:&error]);
if (error != nullptr) {
if (error.code != MTLLibraryErrorCompileWarning) {
return DAWN_VALIDATION_ERROR("Function compile error: %s",
[error.localizedDescription UTF8String]);
}
}
ASSERT(out->function != nil);
} else {
UNREACHABLE();
}
} else {
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
}
}
if (BlobCache* cache = GetDevice()->GetBlobCache()) {
cache->EnsureStored(mslCompilation);

View File

@ -68,15 +68,6 @@ void EnsureDestinationTextureInitialized(CommandRecordingContext* commandContext
MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect);
// Helper function to create function with constant values wrapped in
// if available branch
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
SingleShaderStage singleShaderStage,
PipelineLayout* pipelineLayout,
ShaderModule::MetalFunctionData* functionData,
uint32_t sampleMask = 0xFFFFFFFF,
const RenderPipeline* renderPipeline = nullptr);
// Allow use MTLStoreActionStoreAndMultismapleResolve because the logic in the backend is
// first to compute what the "best" Metal render pass descriptor is, then fix it up if we
// are not on macOS 10.12 (i.e. the EmulateStoreAndMSAAResolve toggle is on).

View File

@ -323,104 +323,6 @@ MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect) {
return MTLBlitOptionNone;
}
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
SingleShaderStage singleShaderStage,
PipelineLayout* pipelineLayout,
ShaderModule::MetalFunctionData* functionData,
uint32_t sampleMask,
const RenderPipeline* renderPipeline) {
ShaderModule* shaderModule = ToBackend(programmableStage.module.Get());
const char* shaderEntryPoint = programmableStage.entryPoint.c_str();
const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint);
if (entryPointMetadata.overrides.size() == 0) {
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout,
functionData, nil, sampleMask, renderPipeline));
return {};
}
if (@available(macOS 10.12, *)) {
// MTLFunctionConstantValues can only be created within the if available branch
NSRef<MTLFunctionConstantValues> constantValues =
AcquireNSRef([MTLFunctionConstantValues new]);
std::unordered_set<std::string> overriddenConstants;
auto switchType = [&](EntryPointMetadata::Override::Type dawnType,
MTLDataType* type, OverrideScalar* entry,
double value = 0) {
switch (dawnType) {
case EntryPointMetadata::Override::Type::Boolean:
*type = MTLDataTypeBool;
if (entry) {
entry->b = static_cast<int32_t>(value);
}
break;
case EntryPointMetadata::Override::Type::Float32:
*type = MTLDataTypeFloat;
if (entry) {
entry->f32 = static_cast<float>(value);
}
break;
case EntryPointMetadata::Override::Type::Int32:
*type = MTLDataTypeInt;
if (entry) {
entry->i32 = static_cast<int32_t>(value);
}
break;
case EntryPointMetadata::Override::Type::Uint32:
*type = MTLDataTypeUInt;
if (entry) {
entry->u32 = static_cast<uint32_t>(value);
}
break;
default:
UNREACHABLE();
}
};
for (const auto& [name, value] : programmableStage.constants) {
overriddenConstants.insert(name);
// This is already validated so `name` must exist
const auto& moduleConstant = entryPointMetadata.overrides.at(name);
MTLDataType type;
OverrideScalar entry{};
switchType(moduleConstant.type, &type, &entry, value);
[constantValues.Get() setConstantValue:&entry type:type atIndex:moduleConstant.id];
}
// Set shader initialized default values because MSL function_constant
// has no default value
for (const std::string& name : entryPointMetadata.initializedOverrides) {
if (overriddenConstants.count(name) != 0) {
// This constant already has overridden value
continue;
}
// Must exist because it is validated
const auto& moduleConstant = entryPointMetadata.overrides.at(name);
ASSERT(moduleConstant.isInitialized);
MTLDataType type;
switchType(moduleConstant.type, &type, nullptr);
[constantValues.Get() setConstantValue:&moduleConstant.defaultValue
type:type
atIndex:moduleConstant.id];
}
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout,
functionData, constantValues.Get(), sampleMask,
renderPipeline));
} else {
UNREACHABLE();
}
return {};
}
MaybeError EncodeMetalRenderPass(Device* device,
CommandRecordingContext* commandContext,
MTLRenderPassDescriptor* mtlRenderPass,

View File

@ -16,6 +16,8 @@
#include <string>
#include "dawn/native/Limits.h"
namespace dawn::native::stream {
template <>

View File

@ -61,16 +61,12 @@ MaybeError ComputePipeline::Initialize() {
ShaderModule::ModuleAndSpirv moduleAndSpirv;
DAWN_TRY_ASSIGN(moduleAndSpirv,
module->GetHandleAndSpirv(computeStage.entryPoint.c_str(), layout));
module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout));
createInfo.stage.module = moduleAndSpirv.module;
createInfo.stage.pName = computeStage.entryPoint.c_str();
std::vector<OverrideScalar> specializationDataEntries;
std::vector<VkSpecializationMapEntry> specializationMapEntries;
VkSpecializationInfo specializationInfo{};
createInfo.stage.pSpecializationInfo = GetVkSpecializationInfo(
computeStage, &specializationInfo, &specializationDataEntries, &specializationMapEntries);
createInfo.stage.pSpecializationInfo = nullptr;
PNextChainBuilder stageExtChain(&createInfo.stage);

View File

@ -341,9 +341,6 @@ MaybeError RenderPipeline::Initialize() {
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
std::array<std::vector<OverrideScalar>, 2> specializationDataEntriesPerStages;
std::array<std::vector<VkSpecializationMapEntry>, 2> specializationMapEntriesPerStages;
std::array<VkSpecializationInfo, 2> specializationInfoPerStages;
uint32_t stageCount = 0;
for (auto stage : IterateStages(this->GetStageMask())) {
@ -354,7 +351,7 @@ MaybeError RenderPipeline::Initialize() {
ShaderModule::ModuleAndSpirv moduleAndSpirv;
DAWN_TRY_ASSIGN(moduleAndSpirv,
module->GetHandleAndSpirv(programmableStage.entryPoint.c_str(), layout));
module->GetHandleAndSpirv(stage, programmableStage, layout));
shaderStage.module = moduleAndSpirv.module;
shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
@ -379,11 +376,6 @@ MaybeError RenderPipeline::Initialize() {
}
}
shaderStage.pSpecializationInfo =
GetVkSpecializationInfo(programmableStage, &specializationInfoPerStages[stageCount],
&specializationDataEntriesPerStages[stageCount],
&specializationMapEntriesPerStages[stageCount]);
DAWN_ASSERT(stageCount < 2);
shaderStages[stageCount] = shaderStage;
stageCount++;

View File

@ -60,13 +60,35 @@ class ShaderModule::Spirv : private Blob {
namespace dawn::native::vulkan {
bool TransformedShaderModuleCacheKey::operator==(
const TransformedShaderModuleCacheKey& other) const {
if (layout != other.layout || entryPoint != other.entryPoint ||
constants.size() != other.constants.size()) {
return false;
}
if (!std::equal(constants.begin(), constants.end(), other.constants.begin())) {
return false;
}
return true;
}
size_t TransformedShaderModuleCacheKeyHashFunc::operator()(
const TransformedShaderModuleCacheKey& key) const {
size_t hash = 0;
HashCombine(&hash, key.layout, key.entryPoint);
for (const auto& entry : key.constants) {
HashCombine(&hash, entry.first, entry.second);
}
return hash;
}
class ShaderModule::ConcurrentTransformedShaderModuleCache {
public:
explicit ConcurrentTransformedShaderModuleCache(Device* device);
~ConcurrentTransformedShaderModuleCache();
std::optional<ModuleAndSpirv> Find(const PipelineLayoutEntryPointPair& key);
ModuleAndSpirv AddOrGet(const PipelineLayoutEntryPointPair& key,
std::optional<ModuleAndSpirv> Find(const TransformedShaderModuleCacheKey& key);
ModuleAndSpirv AddOrGet(const TransformedShaderModuleCacheKey& key,
VkShaderModule module,
Spirv&& spirv);
@ -75,7 +97,9 @@ class ShaderModule::ConcurrentTransformedShaderModuleCache {
Device* mDevice;
std::mutex mMutex;
std::unordered_map<PipelineLayoutEntryPointPair, Entry, PipelineLayoutEntryPointPairHashFunc>
std::unordered_map<TransformedShaderModuleCacheKey,
Entry,
TransformedShaderModuleCacheKeyHashFunc>
mTransformedShaderModuleCache;
};
@ -92,7 +116,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::~ConcurrentTransformedShad
std::optional<ShaderModule::ModuleAndSpirv>
ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
const PipelineLayoutEntryPointPair& key) {
const TransformedShaderModuleCacheKey& key) {
std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key);
if (iter != mTransformedShaderModuleCache.end()) {
@ -106,7 +130,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
}
ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet(
const PipelineLayoutEntryPointPair& key,
const TransformedShaderModuleCacheKey& key,
VkShaderModule module,
Spirv&& spirv) {
ASSERT(module != VK_NULL_HANDLE);
@ -169,9 +193,12 @@ void ShaderModule::DestroyImpl() {
ShaderModule::~ShaderModule() = default;
#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \
X(SingleShaderStage, stage) \
X(const tint::Program*, inputProgram) \
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
X(std::optional<tint::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
X(LimitsForCompilationRequest, limits) \
X(std::string_view, entryPointName) \
X(bool, disableWorkgroupInit) \
X(bool, useZeroInitializeWorkgroupMemoryExtension) \
@ -181,7 +208,8 @@ DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBE
#undef SPIRV_COMPILATION_REQUEST_MEMBERS
ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
const char* entryPointName,
SingleShaderStage stage,
const ProgrammableStage& programmableStage,
const PipelineLayout* layout) {
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv");
@ -191,7 +219,8 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
ScopedTintICEHandler scopedICEHandler(GetDevice());
// Check to see if we have the handle and spirv cached already.
auto cacheKey = std::make_pair(layout, entryPointName);
auto cacheKey = TransformedShaderModuleCacheKey{layout, programmableStage.entryPoint.c_str(),
programmableStage.constants};
auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey);
if (handleAndSpirv.has_value()) {
return std::move(*handleAndSpirv);
@ -204,7 +233,8 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
using BindingPoint = tint::transform::BindingPoint;
BindingRemapper::BindingPoints bindingPoints;
const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings;
const BindingInfoArray& moduleBindingInfo =
GetEntryPoint(programmableStage.entryPoint.c_str()).bindings;
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
@ -238,16 +268,26 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
}
}
std::optional<tint::transform::SubstituteOverride::Config> substituteOverrideConfig;
if (!programmableStage.metadata->overrides.empty()) {
substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
}
#if TINT_BUILD_SPV_WRITER
SpirvCompilationRequest req = {};
req.stage = stage;
req.inputProgram = GetTintProgram();
req.bindingPoints = std::move(bindingPoints);
req.newBindingsMap = std::move(newBindingsMap);
req.entryPointName = entryPointName;
req.entryPointName = programmableStage.entryPoint;
req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
req.useZeroInitializeWorkgroupMemoryExtension =
GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension);
req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform());
req.substituteOverrideConfig = std::move(substituteOverrideConfig);
const CombinedLimits& limits = GetDevice()->GetLimits();
req.limits = LimitsForCompilationRequest::Create(limits.v1);
CacheResult<Spirv> spirv;
DAWN_TRY_LOAD_OR_RUN(
@ -270,12 +310,27 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
r.newBindingsMap);
}
if (r.substituteOverrideConfig) {
// This needs to run after SingleEntryPoint transform to get rid of overrides not
// used for the current entry point.
transformManager.Add<tint::transform::SubstituteOverride>();
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
std::move(r.substituteOverrideConfig).value());
}
tint::Program program;
{
TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms");
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram,
transformInputs, nullptr, nullptr));
}
if (r.stage == SingleShaderStage::Compute) {
// Validate workgroup size after program runs transforms.
Extent3D _;
DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
program, r.entryPointName.data(), r.limits));
}
tint::writer::spirv::Options options;
options.emit_vertex_point_size = true;
options.disable_workgroup_init = r.disableWorkgroupInit;

View File

@ -18,14 +18,32 @@
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <unordered_map>
#include <utility>
#include "dawn/common/HashUtils.h"
#include "dawn/common/vulkan_platform.h"
#include "dawn/native/Error.h"
#include "dawn/native/ShaderModule.h"
namespace dawn::native::vulkan {
namespace dawn::native {
struct ProgrammableStage;
namespace vulkan {
struct TransformedShaderModuleCacheKey {
const PipelineLayoutBase* layout;
std::string entryPoint;
PipelineConstantEntries constants;
bool operator==(const TransformedShaderModuleCacheKey& other) const;
};
struct TransformedShaderModuleCacheKeyHashFunc {
size_t operator()(const TransformedShaderModuleCacheKey& key) const;
};
class Device;
class PipelineLayout;
@ -44,7 +62,8 @@ class ShaderModule final : public ShaderModuleBase {
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const char* entryPointName,
ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(SingleShaderStage stage,
const ProgrammableStage& programmableStage,
const PipelineLayout* layout);
private:
@ -59,6 +78,8 @@ class ShaderModule final : public ShaderModuleBase {
std::unique_ptr<ConcurrentTransformedShaderModuleCache> mTransformedShaderModuleCache;
};
} // namespace dawn::native::vulkan
} // namespace vulkan
} // namespace dawn::native
#endif // SRC_DAWN_NATIVE_VULKAN_SHADERMODULEVK_H_

View File

@ -258,60 +258,4 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName) {
return std::string(debugName, length);
}
VkSpecializationInfo* GetVkSpecializationInfo(
const ProgrammableStage& programmableStage,
VkSpecializationInfo* specializationInfo,
std::vector<OverrideScalar>* specializationDataEntries,
std::vector<VkSpecializationMapEntry>* specializationMapEntries) {
ASSERT(specializationInfo);
ASSERT(specializationDataEntries);
ASSERT(specializationMapEntries);
if (programmableStage.constants.size() == 0) {
return nullptr;
}
const EntryPointMetadata& entryPointMetaData =
programmableStage.module->GetEntryPoint(programmableStage.entryPoint);
for (const auto& pipelineConstant : programmableStage.constants) {
const std::string& identifier = pipelineConstant.first;
double value = pipelineConstant.second;
// This is already validated so `identifier` must exist
const auto& moduleConstant = entryPointMetaData.overrides.at(identifier);
specializationMapEntries->push_back(VkSpecializationMapEntry{
moduleConstant.id,
static_cast<uint32_t>(specializationDataEntries->size() * sizeof(OverrideScalar)),
sizeof(OverrideScalar)});
OverrideScalar entry{};
switch (moduleConstant.type) {
case EntryPointMetadata::Override::Type::Boolean:
entry.b = static_cast<int32_t>(value);
break;
case EntryPointMetadata::Override::Type::Float32:
entry.f32 = static_cast<float>(value);
break;
case EntryPointMetadata::Override::Type::Int32:
entry.i32 = static_cast<int32_t>(value);
break;
case EntryPointMetadata::Override::Type::Uint32:
entry.u32 = static_cast<uint32_t>(value);
break;
default:
UNREACHABLE();
}
specializationDataEntries->push_back(entry);
}
specializationInfo->mapEntryCount = static_cast<uint32_t>(specializationMapEntries->size());
specializationInfo->pMapEntries = specializationMapEntries->data();
specializationInfo->dataSize = specializationDataEntries->size() * sizeof(OverrideScalar);
specializationInfo->pData = specializationDataEntries->data();
return specializationInfo;
}
} // namespace dawn::native::vulkan

View File

@ -144,15 +144,6 @@ void SetDebugName(Device* device,
std::string GetNextDeviceDebugPrefix();
std::string GetDeviceDebugPrefixFromDebugName(const char* debugName);
// Returns nullptr or &specializationInfo
// specializationInfo, specializationDataEntries, specializationMapEntries needs to
// be alive at least until VkSpecializationInfo is passed into Vulkan Create*Pipelines
VkSpecializationInfo* GetVkSpecializationInfo(
const ProgrammableStage& programmableStage,
VkSpecializationInfo* specializationInfo,
std::vector<OverrideScalar>* specializationDataEntries,
std::vector<VkSpecializationMapEntry>* specializationMapEntries);
} // namespace dawn::native::vulkan
#endif // SRC_DAWN_NATIVE_VULKAN_UTILSVULKAN_H_

View File

@ -265,6 +265,7 @@ dawn_test("dawn_unittests") {
"unittests/native/CreatePipelineAsyncTaskTests.cpp",
"unittests/native/DestroyObjectTests.cpp",
"unittests/native/DeviceCreationTests.cpp",
"unittests/native/ObjectContentHasherTests.cpp",
"unittests/native/StreamTests.cpp",
"unittests/validation/BindGroupValidationTests.cpp",
"unittests/validation/BufferValidationTests.cpp",
@ -489,6 +490,7 @@ source_set("end2end_tests_sources") {
"end2end/ScissorTests.cpp",
"end2end/ShaderFloat16Tests.cpp",
"end2end/ShaderTests.cpp",
"end2end/ShaderValidationTests.cpp",
"end2end/StorageTextureTests.cpp",
"end2end/SubresourceRenderAttachmentTests.cpp",
"end2end/Texture3DTests.cpp",

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "dawn/tests/DawnTest.h"
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
@ -158,6 +160,46 @@ TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) {
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
}
// Test that ComputePipeline are correctly deduplicated wrt. their constants override values
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnOverrides) {
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
override x: u32 = 1u;
var<workgroup> i : u32;
@compute @workgroup_size(x) fn main() {
i = 0u;
})");
wgpu::PipelineLayout layout = utils::MakeBasicPipelineLayout(device, nullptr);
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.layout = layout;
desc.compute.module = module;
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 16}};
desc.compute.constantCount = constants.size();
desc.compute.constants = constants.data();
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
std::vector<wgpu::ConstantEntry> sameConstants{{nullptr, "x", 16}};
desc.compute.constantCount = sameConstants.size();
desc.compute.constants = sameConstants.data();
wgpu::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
desc.compute.constantCount = 0;
desc.compute.constants = nullptr;
wgpu::ComputePipeline otherPipeline1 = device.CreateComputePipeline(&desc);
std::vector<wgpu::ConstantEntry> otherConstants{{nullptr, "x", 4}};
desc.compute.constantCount = otherConstants.size();
desc.compute.constants = otherConstants.data();
wgpu::ComputePipeline otherPipeline2 = device.CreateComputePipeline(&desc);
EXPECT_NE(pipeline.Get(), otherPipeline1.Get());
EXPECT_NE(pipeline.Get(), otherPipeline2.Get());
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
}
// Test that ComputePipeline are correctly deduplicated wrt. their layout
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) {
wgpu::BindGroupLayout bgl = utils::MakeBindGroupLayout(
@ -303,6 +345,48 @@ TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnFragmentModule) {
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
}
// Test that Renderpipelines are correctly deduplicated wrt. their constants override values
TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnOverrides) {
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
override a: f32 = 1.0;
@vertex fn vertexMain() -> @builtin(position) vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
@fragment fn fragmentMain() -> @location(0) vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, a);
})");
utils::ComboRenderPipelineDescriptor desc;
desc.vertex.module = module;
desc.vertex.entryPoint = "vertexMain";
desc.cFragment.module = module;
desc.cFragment.entryPoint = "fragmentMain";
desc.cTargets[0].writeMask = wgpu::ColorWriteMask::None;
std::vector<wgpu::ConstantEntry> constants{{nullptr, "a", 0.5}};
desc.cFragment.constantCount = constants.size();
desc.cFragment.constants = constants.data();
wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&desc);
std::vector<wgpu::ConstantEntry> sameConstants{{nullptr, "a", 0.5}};
desc.cFragment.constantCount = sameConstants.size();
desc.cFragment.constants = sameConstants.data();
wgpu::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc);
std::vector<wgpu::ConstantEntry> otherConstants{{nullptr, "a", 1.0}};
desc.cFragment.constantCount = otherConstants.size();
desc.cFragment.constants = otherConstants.data();
wgpu::RenderPipeline otherPipeline1 = device.CreateRenderPipeline(&desc);
desc.cFragment.constantCount = 0;
desc.cFragment.constants = nullptr;
wgpu::RenderPipeline otherPipeline2 = device.CreateRenderPipeline(&desc);
EXPECT_NE(pipeline.Get(), otherPipeline1.Get());
EXPECT_NE(pipeline.Get(), otherPipeline2.Get());
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
}
// Test that Samplers are correctly deduplicated.
TEST_P(ObjectCachingTest, SamplerDeduplication) {
wgpu::SamplerDescriptor samplerDesc;

View File

@ -471,6 +471,127 @@ struct Buf {
EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount);
}
// Test one shader shared by two pipelines with different constants overridden
TEST_P(ShaderTests, OverridableConstantsSharedShader) {
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
std::vector<uint32_t> expected1{1};
wgpu::Buffer buffer1 = CreateBuffer(expected1.size());
std::vector<uint32_t> expected2{2};
wgpu::Buffer buffer2 = CreateBuffer(expected2.size());
std::string shader = R"(
override a: u32;
struct Buf {
data : array<u32, 1>
}
@group(0) @binding(0) var<storage, read_write> buf : Buf;
@compute @workgroup_size(1) fn main() {
buf.data[0] = a;
})";
std::vector<wgpu::ConstantEntry> constants1;
constants1.push_back({nullptr, "a", 1});
std::vector<wgpu::ConstantEntry> constants2;
constants2.push_back({nullptr, "a", 2});
wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1);
wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2);
wgpu::BindGroup bindGroup1 =
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
wgpu::BindGroup bindGroup2 =
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
wgpu::CommandBuffer commands;
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipeline1);
pass.SetBindGroup(0, bindGroup1);
pass.DispatchWorkgroups(1);
pass.SetPipeline(pipeline2);
pass.SetBindGroup(0, bindGroup2);
pass.DispatchWorkgroups(1);
pass.End();
commands = encoder.Finish();
}
queue.Submit(1, &commands);
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size());
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size());
}
// Test overridable constants work with workgroup size
TEST_P(ShaderTests, OverridableConstantsWorkgroupSize) {
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
std::string shader = R"(
override x: u32;
struct Buf {
data : array<u32, 1>
}
@group(0) @binding(0) var<storage, read_write> buf : Buf;
@compute @workgroup_size(x) fn main(
@builtin(local_invocation_id) local_invocation_id : vec3<u32>
) {
if (local_invocation_id.x >= x - 1) {
buf.data[0] = local_invocation_id.x + 1;
}
})";
const uint32_t workgroup_size_x_1 = 16u;
const uint32_t workgroup_size_x_2 = 64u;
std::vector<uint32_t> expected1{workgroup_size_x_1};
wgpu::Buffer buffer1 = CreateBuffer(expected1.size());
std::vector<uint32_t> expected2{workgroup_size_x_2};
wgpu::Buffer buffer2 = CreateBuffer(expected2.size());
std::vector<wgpu::ConstantEntry> constants1;
constants1.push_back({nullptr, "x", static_cast<double>(workgroup_size_x_1)});
std::vector<wgpu::ConstantEntry> constants2;
constants2.push_back({nullptr, "x", static_cast<double>(workgroup_size_x_2)});
wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1);
wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2);
wgpu::BindGroup bindGroup1 =
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
wgpu::BindGroup bindGroup2 =
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
wgpu::CommandBuffer commands;
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipeline1);
pass.SetBindGroup(0, bindGroup1);
pass.DispatchWorkgroups(1);
pass.SetPipeline(pipeline2);
pass.SetBindGroup(0, bindGroup2);
pass.DispatchWorkgroups(1);
pass.End();
commands = encoder.Finish();
}
queue.Submit(1, &commands);
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size());
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size());
}
// Test overridable constants with numeric identifiers
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
@ -596,6 +717,7 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
std::string shader = R"(
@id(1001) override c1: u32;
@id(1002) override c2: u32;
@id(1003) override c3: u32;
struct Buf {
data : array<u32, 1>
@ -611,7 +733,7 @@ struct Buf {
buf.data[0] = c2;
}
@compute @workgroup_size(1) fn main3() {
@compute @workgroup_size(c3) fn main3() {
buf.data[0] = 3u;
}
)";
@ -620,6 +742,8 @@ struct Buf {
constants1.push_back({nullptr, "1001", 1});
std::vector<wgpu::ConstantEntry> constants2;
constants2.push_back({nullptr, "1002", 2});
std::vector<wgpu::ConstantEntry> constants3;
constants3.push_back({nullptr, "1003", 1});
wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader.c_str());
@ -640,6 +764,8 @@ struct Buf {
wgpu::ComputePipelineDescriptor csDesc3;
csDesc3.compute.module = shaderModule;
csDesc3.compute.entryPoint = "main3";
csDesc3.compute.constants = constants3.data();
csDesc3.compute.constantCount = constants3.size();
wgpu::ComputePipeline pipeline3 = device.CreateComputePipeline(&csDesc3);
wgpu::BindGroup bindGroup1 =
@ -765,8 +891,6 @@ TEST_P(ShaderTests, ConflictingBindingsDueToTransformOrder) {
device.CreateRenderPipeline(&desc);
}
// TODO(tint:1155): Test overridable constants used for workgroup size
DAWN_INSTANTIATE_TEST(ShaderTests,
D3D12Backend(),
MetalBackend(),

View File

@ -0,0 +1,395 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <numeric>
#include <string>
#include <vector>
#include "dawn/tests/DawnTest.h"
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
#include "dawn/utils/WGPUHelpers.h"
// The compute shader workgroup size is settled at compute pipeline creation time.
// The validation code in dawn is in each backend (not including Null backend) thus this test needs
// to be as part of a dawn_end2end_tests instead of the dawn_unittests
// TODO(dawn:1504): Add support for GL backend.
class WorkgroupSizeValidationTest : public DawnTest {
public:
wgpu::ShaderModule SetUpShaderWithValidDefaultValueConstants() {
return utils::CreateShaderModule(device, R"(
override x: u32 = 1u;
override y: u32 = 1u;
override z: u32 = 1u;
@compute @workgroup_size(x, y, z) fn main() {
_ = 0u;
})");
}
wgpu::ShaderModule SetUpShaderWithZeroDefaultValueConstants() {
return utils::CreateShaderModule(device, R"(
override x: u32 = 0u;
override y: u32 = 0u;
override z: u32 = 0u;
@compute @workgroup_size(x, y, z) fn main() {
_ = 0u;
})");
}
wgpu::ShaderModule SetUpShaderWithOutOfLimitsDefaultValueConstants() {
return utils::CreateShaderModule(device, R"(
override x: u32 = 1u;
override y: u32 = 1u;
override z: u32 = 9999u;
@compute @workgroup_size(x, y, z) fn main() {
_ = 0u;
})");
}
wgpu::ShaderModule SetUpShaderWithUninitializedConstants() {
return utils::CreateShaderModule(device, R"(
override x: u32;
override y: u32;
override z: u32;
@compute @workgroup_size(x, y, z) fn main() {
_ = 0u;
})");
}
wgpu::ShaderModule SetUpShaderWithPartialConstants() {
return utils::CreateShaderModule(device, R"(
override x: u32;
@compute @workgroup_size(x, 1, 1) fn main() {
_ = 0u;
})");
}
void TestCreatePipeline(const wgpu::ShaderModule& module) {
wgpu::ComputePipelineDescriptor csDesc;
csDesc.compute.module = module;
csDesc.compute.entryPoint = "main";
csDesc.compute.constants = nullptr;
csDesc.compute.constantCount = 0;
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
}
void TestCreatePipeline(const wgpu::ShaderModule& module,
const std::vector<wgpu::ConstantEntry>& constants) {
wgpu::ComputePipelineDescriptor csDesc;
csDesc.compute.module = module;
csDesc.compute.entryPoint = "main";
csDesc.compute.constants = constants.data();
csDesc.compute.constantCount = constants.size();
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
}
void TestInitializedWithZero(const wgpu::ShaderModule& module) {
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x", 0}, {nullptr, "y", 0}, {nullptr, "z", 0}};
TestCreatePipeline(module, constants);
}
void TestInitializedWithOutOfLimitValue(const wgpu::ShaderModule& module) {
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x",
static_cast<double>(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)},
{nullptr, "y", 1},
{nullptr, "z", 1}};
TestCreatePipeline(module, constants);
}
void TestInitializedWithValidValue(const wgpu::ShaderModule& module) {
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x", 4}, {nullptr, "y", 4}, {nullptr, "z", 4}};
TestCreatePipeline(module, constants);
}
void TestInitializedPartially(const wgpu::ShaderModule& module) {
std::vector<wgpu::ConstantEntry> constants{{nullptr, "y", 4}};
TestCreatePipeline(module, constants);
}
wgpu::Buffer buffer;
};
// Test workgroup size validation with fixed values.
TEST_P(WorkgroupSizeValidationTest, WithFixedValues) {
auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) {
std::ostringstream ss;
ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}";
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
if (success) {
device.CreateComputePipeline(&desc);
} else {
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
}
};
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
CheckShaderWithWorkgroupSize(true, 1, 1, 1);
CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1);
CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1);
CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ);
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1);
CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1);
CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1);
// No individual dimension exceeds its limit, but the combined size should definitely exceed the
// total invocation limit.
DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX *
supportedLimits.maxComputeWorkgroupSizeY *
supportedLimits.maxComputeWorkgroupSizeZ >
supportedLimits.maxComputeInvocationsPerWorkgroup);
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX,
supportedLimits.maxComputeWorkgroupSizeY,
supportedLimits.maxComputeWorkgroupSizeZ);
}
// Test workgroup size validation with fixed values (storage size limits validation).
TEST_P(WorkgroupSizeValidationTest, WithFixedValuesStorageSizeLimits) {
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
constexpr uint32_t kVec4Size = 16;
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
constexpr uint32_t kMat4Size = 64;
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
uint32_t mat4_count) {
std::ostringstream ss;
std::ostringstream body;
if (vec4_count > 0) {
ss << "var<workgroup> vec4_data: array<vec4<f32>, " << vec4_count << ">;";
body << "_ = vec4_data;";
}
if (mat4_count > 0) {
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, " << mat4_count << ">;";
body << "_ = mat4_data;";
}
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
if (success) {
device.CreateComputePipeline(&desc);
} else {
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
}
};
CheckPipelineWithWorkgroupStorage(true, 1, 1);
CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0);
CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count);
CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1);
CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1);
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1);
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count);
}
// Test workgroup size validation with valid overrides default values.
TEST_P(WorkgroupSizeValidationTest, OverridesWithValidDefault) {
wgpu::ShaderModule module = SetUpShaderWithValidDefaultValueConstants();
{
// Valid default
TestCreatePipeline(module);
}
{
// Error: invalid value (zero)
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
}
{
// Error: invalid value (out of device limits)
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
}
{
// Valid: initialized partially
TestInitializedPartially(module);
}
{
// Valid
TestInitializedWithValidValue(module);
}
}
// Test workgroup size validation with zero as the overrides default values.
TEST_P(WorkgroupSizeValidationTest, OverridesWithZeroDefault) {
// Error: zero is detected as invalid at shader creation time
ASSERT_DEVICE_ERROR(SetUpShaderWithZeroDefaultValueConstants());
}
// Test workgroup size validation with out-of-limits overrides default values.
TEST_P(WorkgroupSizeValidationTest, OverridesWithOutOfLimitsDefault) {
wgpu::ShaderModule module = SetUpShaderWithOutOfLimitsDefaultValueConstants();
{
// Error: invalid default
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
}
{
// Error: invalid value (zero)
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
}
{
// Error: invalid value (out of device limits)
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
}
{
// Error: initialized partially
ASSERT_DEVICE_ERROR(TestInitializedPartially(module));
}
{
// Valid
TestInitializedWithValidValue(module);
}
}
// Test workgroup size validation without overrides default values specified.
TEST_P(WorkgroupSizeValidationTest, OverridesWithUninitialized) {
wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants();
{
// Error: uninitialized
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
}
{
// Error: invalid value (zero)
ASSERT_DEVICE_ERROR(TestInitializedWithZero(module));
}
{
// Error: invalid value (out of device limits)
ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module));
}
{
// Error: initialized partially
ASSERT_DEVICE_ERROR(TestInitializedPartially(module));
}
{
// Valid
TestInitializedWithValidValue(module);
}
}
// Test workgroup size validation with only partial dimensions are overrides.
TEST_P(WorkgroupSizeValidationTest, PartialOverrides) {
wgpu::ShaderModule module = SetUpShaderWithPartialConstants();
{
// Error: uninitialized
ASSERT_DEVICE_ERROR(TestCreatePipeline(module));
}
{
// Error: invalid value (zero)
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 0}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
}
{
// Error: invalid value (out of device limits)
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x",
static_cast<double>(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
}
{
// Valid
std::vector<wgpu::ConstantEntry> constants{{nullptr, "x", 16}};
TestCreatePipeline(module, constants);
}
}
// Test workgroup size validation after being overrided with invalid values.
TEST_P(WorkgroupSizeValidationTest, ValidationAfterOverride) {
wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants();
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
{
// Error: exceed maxComputeWorkgroupSizeZ
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x", 1},
{nullptr, "y", 1},
{nullptr, "z", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeZ + 1)}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
}
{
// Error: exceed maxComputeInvocationsPerWorkgroup
DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX *
supportedLimits.maxComputeWorkgroupSizeY *
supportedLimits.maxComputeWorkgroupSizeZ >
supportedLimits.maxComputeInvocationsPerWorkgroup);
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "x", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeX)},
{nullptr, "y", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeY)},
{nullptr, "z", static_cast<double>(supportedLimits.maxComputeWorkgroupSizeZ)}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants));
}
}
// Test workgroup size validation after being overrided with invalid values (storage size limits
// validation).
// TODO(tint:1660): re-enable after override can be used as array size.
TEST_P(WorkgroupSizeValidationTest, DISABLED_ValidationAfterOverrideStorageSize) {
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
constexpr uint32_t kVec4Size = 16;
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
constexpr uint32_t kMat4Size = 64;
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
uint32_t mat4_count) {
std::ostringstream ss;
std::ostringstream body;
ss << "override a: u32;";
ss << "override b: u32;";
if (vec4_count > 0) {
ss << "var<workgroup> vec4_data: array<vec4<f32>, a>;";
body << "_ = vec4_data;";
}
if (mat4_count > 0) {
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, b>;";
body << "_ = mat4_data;";
}
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
std::vector<wgpu::ConstantEntry> constants{{nullptr, "a", static_cast<double>(vec4_count)},
{nullptr, "b", static_cast<double>(mat4_count)}};
desc.compute.constants = constants.data();
desc.compute.constantCount = constants.size();
if (success) {
device.CreateComputePipeline(&desc);
} else {
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
}
};
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
}
DAWN_INSTANTIATE_TEST(WorkgroupSizeValidationTest, D3D12Backend(), MetalBackend(), VulkanBackend());

View File

@ -0,0 +1,81 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "dawn/native/ObjectContentHasher.h"
#include "dawn/tests/DawnNativeTest.h"
namespace dawn::native {
namespace {
class ObjectContentHasherTests : public DawnNativeTest {};
#define EXPECT_IF_HASH_EQ(eq, a, b) \
do { \
ObjectContentHasher ra, rb; \
ra.Record(a); \
rb.Record(b); \
EXPECT_EQ(eq, ra.GetContentHash() == rb.GetContentHash()); \
} while (0)
TEST(ObjectContentHasherTests, Pair) {
EXPECT_IF_HASH_EQ(true, (std::pair<std::string, uint8_t>{"a", 1}),
(std::pair<std::string, uint8_t>{"a", 1}));
EXPECT_IF_HASH_EQ(false, (std::pair<uint8_t, std::string>{1, "a"}),
(std::pair<std::string, uint8_t>{"a", 1}));
EXPECT_IF_HASH_EQ(false, (std::pair<std::string, uint8_t>{"a", 1}),
(std::pair<std::string, uint8_t>{"a", 2}));
EXPECT_IF_HASH_EQ(false, (std::pair<std::string, uint8_t>{"a", 1}),
(std::pair<std::string, uint8_t>{"b", 1}));
}
TEST(ObjectContentHasherTests, Vector) {
EXPECT_IF_HASH_EQ(true, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{0, 1}));
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{0, 1, 2}));
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{1, 0}));
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<uint8_t>{}));
EXPECT_IF_HASH_EQ(false, (std::vector<uint8_t>{0, 1}), (std::vector<float>{0, 1}));
}
TEST(ObjectContentHasherTests, Map) {
EXPECT_IF_HASH_EQ(true, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
(std::map<std::string, uint8_t>{{"b", 2}, {"a", 1}}));
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
(std::map<std::string, uint8_t>{{"a", 2}, {"b", 1}}));
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}, {"c", 1}}));
EXPECT_IF_HASH_EQ(false, (std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}}),
(std::map<std::string, uint8_t>{}));
}
TEST(ObjectContentHasherTests, HashCombine) {
ObjectContentHasher ra, rb;
ra.Record(std::vector<uint8_t>{0, 1});
ra.Record(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}});
rb.Record(std::map<std::string, uint8_t>{{"a", 1}, {"b", 2}});
rb.Record(std::vector<uint8_t>{0, 1});
EXPECT_NE(ra.GetContentHash(), rb.GetContentHash());
}
} // namespace
} // namespace dawn::native

View File

@ -450,87 +450,6 @@ TEST_F(ShaderModuleValidationTest, MaximumInterStageShaderComponents) {
}
}
// Tests that we validate workgroup size limits.
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupSizeLimits) {
auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) {
std::ostringstream ss;
ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}";
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
if (success) {
device.CreateComputePipeline(&desc);
} else {
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
}
};
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
CheckShaderWithWorkgroupSize(true, 1, 1, 1);
CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1);
CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1);
CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ);
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1);
CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1);
CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1);
// No individual dimension exceeds its limit, but the combined size should definitely exceed the
// total invocation limit.
CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX,
supportedLimits.maxComputeWorkgroupSizeY,
supportedLimits.maxComputeWorkgroupSizeZ);
}
// Tests that we validate workgroup storage size limits.
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupStorageSizeLimits) {
wgpu::Limits supportedLimits = GetSupportedLimits().limits;
constexpr uint32_t kVec4Size = 16;
const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size;
constexpr uint32_t kMat4Size = 64;
const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size;
auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count,
uint32_t mat4_count) {
std::ostringstream ss;
std::ostringstream body;
if (vec4_count > 0) {
ss << "var<workgroup> vec4_data: array<vec4<f32>, " << vec4_count << ">;";
body << "_ = vec4_data;";
}
if (mat4_count > 0) {
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, " << mat4_count << ">;";
body << "_ = mat4_data;";
}
ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }";
wgpu::ComputePipelineDescriptor desc;
desc.compute.entryPoint = "main";
desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str());
if (success) {
device.CreateComputePipeline(&desc);
} else {
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
}
};
CheckPipelineWithWorkgroupStorage(true, 1, 1);
CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0);
CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count);
CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1);
CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1);
CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0);
CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1);
CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1);
CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count);
}
// Test that numeric ID must be unique
TEST_F(ShaderModuleValidationTest, OverridableConstantsNumericIDConflicts) {
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(

View File

@ -37,46 +37,6 @@ class UnsafeAPIValidationTest : public ValidationTest {
}
};
// Check that pipeline overridable constants are disallowed as part of unsafe APIs.
// TODO(dawn:1041) Remove when implementation for all backend is added
TEST_F(UnsafeAPIValidationTest, PipelineOverridableConstants) {
// Create the placeholder compute pipeline.
wgpu::ComputePipelineDescriptor pipelineDescBase;
pipelineDescBase.compute.entryPoint = "main";
// Control case: shader without overridable constant is allowed.
{
wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase;
pipelineDesc.compute.module =
utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}");
device.CreateComputePipeline(&pipelineDesc);
}
// Error case: shader with overridable constant with default value
{
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(
@id(1000) override c0: u32 = 1u;
@id(1000) override c1: u32;
@compute @workgroup_size(1) fn main() {
_ = c0;
_ = c1;
})"));
}
// Error case: pipeline stage with constant entry is disallowed
{
wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase;
pipelineDesc.compute.module =
utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}");
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c", 1u}};
pipelineDesc.compute.constants = constants.data();
pipelineDesc.compute.constantCount = constants.size();
ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&pipelineDesc));
}
}
class UnsafeQueryAPIValidationTest : public ValidationTest {
protected:
WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override {

View File

@ -133,17 +133,13 @@ Inspector::Inspector(const Program* program) : program_(program) {}
Inspector::~Inspector() = default;
std::vector<EntryPoint> Inspector::GetEntryPoints() {
std::vector<EntryPoint> result;
for (auto* func : program_->AST().Functions()) {
if (!func->IsEntryPoint()) {
continue;
}
EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) {
EntryPoint entry_point;
TINT_ASSERT(Inspector, func != nullptr);
TINT_ASSERT(Inspector, func->IsEntryPoint());
auto* sem = program_->Sem().Get(func);
EntryPoint entry_point;
entry_point.name = program_->Symbols().NameFor(func->symbol);
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol);
@ -154,8 +150,7 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
auto wgsize = sem->WorkgroupSize();
if (!wgsize[0].overridable_const && !wgsize[1].overridable_const &&
!wgsize[2].overridable_const) {
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value,
wgsize[2].value};
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value, wgsize[2].value};
}
break;
}
@ -231,7 +226,26 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
}
}
result.push_back(std::move(entry_point));
return entry_point;
}
EntryPoint Inspector::GetEntryPoint(const std::string& entry_point_name) {
auto* func = FindEntryPointByName(entry_point_name);
if (!func) {
return EntryPoint();
}
return GetEntryPoint(func);
}
std::vector<EntryPoint> Inspector::GetEntryPoints() {
std::vector<EntryPoint> result;
for (auto* func : program_->AST().Functions()) {
if (!func->IsEntryPoint()) {
continue;
}
result.push_back(GetEntryPoint(func));
}
return result;

View File

@ -55,6 +55,10 @@ class Inspector {
/// @returns vector of entry point information
std::vector<EntryPoint> GetEntryPoints();
/// @param entry_point name of the entry point to get information about
/// @returns the entry point information
EntryPoint GetEntryPoint(const std::string& entry_point);
/// @returns map of override identifier to initial value
std::map<OverrideId, Scalar> GetOverrideDefaultValues();
@ -230,6 +234,10 @@ class Inspector {
/// whenever a set of expressions are resolved to globals.
template <size_t N, typename F>
void GetOriginatingResources(std::array<const ast::Expression*, N> exprs, F&& cb);
/// @param func the function of the entry point. Must be non-nullptr and true for IsEntryPoint()
/// @returns the entry point information
EntryPoint GetEntryPoint(const tint::ast::Function* func);
};
} // namespace tint::inspector

View File

@ -20,6 +20,7 @@
#include "tint/override_id.h"
#include "src/tint/reflection.h"
#include "src/tint/transform/transform.h"
namespace tint::transform {
@ -63,6 +64,9 @@ class SubstituteOverride final : public Castable<SubstituteOverride, Transform>
/// The value is always a double coming into the transform and will be
/// converted to the correct type through and initializer.
std::unordered_map<OverrideId, double> map;
/// Reflect the fields of this class so that it can be used by tint::ForeachField()
TINT_REFLECT(map);
};
/// Constructor