Cleanup transform usage

Use tint::transform::DataMap for inputs as well as outputs.

This allows tint to nest transforms inside each other (e.g. embedding
transforms inside sanitizers), and still having a consistent way to pass
data in and out of these transforms, regardless of nesting depth.

Transforms can also now be fully pre-built and used multiple times as
there is no state held by the transform itself.

Bug: tint:389

Change-Id: If1616c77f2776be449021a32f4a6b0b89159aa2a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/48060
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
Auto-Submit: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-04-19 19:42:19 +00:00 committed by Commit Bot service account
parent eae70b75ae
commit 8091c68450
8 changed files with 91 additions and 81 deletions

View File

@ -1110,11 +1110,13 @@ namespace dawn_native {
parseResult->tintSource = std::move(tintSource); parseResult->tintSource = std::move(tintSource);
} else { } else {
tint::transform::Manager transformManager; tint::transform::Manager transformManager;
transformManager.append( transformManager.Add<tint::transform::EmitVertexPointSize>();
std::make_unique<tint::transform::EmitVertexPointSize>()); transformManager.Add<tint::transform::Spirv>();
transformManager.append(std::make_unique<tint::transform::Spirv>());
DAWN_TRY_ASSIGN(program, tint::transform::DataMap transformInputs;
RunTransforms(&transformManager, &program, outMessages));
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, &program,
transformInputs, nullptr, outMessages));
std::vector<uint32_t> spirv; std::vector<uint32_t> spirv;
DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(&program)); DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(&program));
@ -1144,8 +1146,10 @@ namespace dawn_native {
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform, ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
const tint::Program* program, const tint::Program* program,
const tint::transform::DataMap& inputs,
tint::transform::DataMap* outputs,
OwnedCompilationMessages* outMessages) { OwnedCompilationMessages* outMessages) {
tint::transform::Transform::Output output = transform->Run(program); tint::transform::Output output = transform->Run(program, inputs);
if (outMessages != nullptr) { if (outMessages != nullptr) {
outMessages->AddMessages(output.program.Diagnostics()); outMessages->AddMessages(output.program.Diagnostics());
} }
@ -1153,13 +1157,16 @@ namespace dawn_native {
std::string err = "Tint program failure: " + output.program.Diagnostics().str(); std::string err = "Tint program failure: " + output.program.Diagnostics().str();
return DAWN_VALIDATION_ERROR(err.c_str()); return DAWN_VALIDATION_ERROR(err.c_str());
} }
if (outputs != nullptr) {
*outputs = std::move(output.data);
}
return std::move(output.program); return std::move(output.program);
} }
std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform( void AddVertexPullingTransformConfig(const VertexState& vertexState,
const VertexState& vertexState, const std::string& entryPoint,
const std::string& entryPoint, BindGroupIndex pullingBufferBindingSet,
BindGroupIndex pullingBufferBindingSet) { tint::transform::DataMap* transformInputs) {
tint::transform::VertexPulling::Config cfg; tint::transform::VertexPulling::Config cfg;
cfg.entry_point_name = entryPoint; cfg.entry_point_name = entryPoint;
cfg.pulling_group = static_cast<uint32_t>(pullingBufferBindingSet); cfg.pulling_group = static_cast<uint32_t>(pullingBufferBindingSet);
@ -1181,7 +1188,7 @@ namespace dawn_native {
cfg.vertex_state.push_back(std::move(layout)); cfg.vertex_state.push_back(std::move(layout));
} }
return std::make_unique<tint::transform::VertexPulling>(cfg); transformInputs->Add<tint::transform::VertexPulling::Config>(cfg);
} }
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device, MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
@ -1314,19 +1321,23 @@ namespace dawn_native {
errorStream << "Tint vertex pulling failure:" << std::endl; errorStream << "Tint vertex pulling failure:" << std::endl;
tint::transform::Manager transformManager; tint::transform::Manager transformManager;
transformManager.append( transformManager.Add<tint::transform::VertexPulling>();
MakeVertexPullingTransform(vertexState, entryPoint, pullingBufferBindingSet)); transformManager.Add<tint::transform::EmitVertexPointSize>();
transformManager.append(std::make_unique<tint::transform::EmitVertexPointSize>()); transformManager.Add<tint::transform::Spirv>();
transformManager.append(std::make_unique<tint::transform::Spirv>());
if (GetDevice()->IsRobustnessEnabled()) { if (GetDevice()->IsRobustnessEnabled()) {
transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>()); transformManager.Add<tint::transform::BoundArrayAccessors>();
} }
tint::transform::DataMap transformInputs;
AddVertexPullingTransformConfig(vertexState, entryPoint, pullingBufferBindingSet,
&transformInputs);
// A nullptr is passed in for the CompilationMessages here since this method is called // A nullptr is passed in for the CompilationMessages here since this method is called
// during RenderPipeline creation, by which point the shader module's CompilationInfo may // during RenderPipeline creation, by which point the shader module's CompilationInfo
// have already been queried. // may have already been queried.
tint::Program program; tint::Program program;
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, programIn, nullptr)); DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, programIn, transformInputs,
nullptr, nullptr));
tint::writer::spirv::Generator generator(&program); tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) { if (!generator.Generate()) {

View File

@ -37,6 +37,7 @@ namespace tint {
class Program; class Program;
namespace transform { namespace transform {
class DataMap;
class Transform; class Transform;
class VertexPulling; class VertexPulling;
} // namespace transform } // namespace transform
@ -88,12 +89,15 @@ namespace dawn_native {
const PipelineLayoutBase* layout); const PipelineLayoutBase* layout);
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform, ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
const tint::Program* program, const tint::Program* program,
const tint::transform::DataMap& inputs,
tint::transform::DataMap* outputs,
OwnedCompilationMessages* messages); OwnedCompilationMessages* messages);
std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform( /// Creates and adds the tint::transform::VertexPulling::Config to transformInputs.
const VertexState& vertexState, void AddVertexPullingTransformConfig(const VertexState& vertexState,
const std::string& entryPoint, const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet); BindGroupIndex pullingBufferBindingSet,
tint::transform::DataMap* transformInputs);
// Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are // Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
// stored in the ShaderModuleBase and destroyed only when the shader program is destroyed so // stored in the ShaderModuleBase and destroyed only when the shader program is destroyed so
@ -173,7 +177,7 @@ namespace dawn_native {
const std::string& entryPoint, const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet) const; BindGroupIndex pullingBufferBindingSet) const;
OwnedCompilationMessages* CompilationMessages() { OwnedCompilationMessages* GetCompilationMessages() {
return mCompilationMessages.get(); return mCompilationMessages.get();
} }

View File

@ -194,7 +194,7 @@ namespace dawn_native { namespace d3d12 {
SingleShaderStage stage, SingleShaderStage stage,
PipelineLayout* layout, PipelineLayout* layout,
std::string* remappedEntryPointName, std::string* remappedEntryPointName,
FirstOffsetInfo* firstOffsetInfo) const { FirstOffsetInfo* firstOffsetInfo) {
ASSERT(!IsError()); ASSERT(!IsError());
ScopedTintICEHandler scopedICEHandler(GetDevice()); ScopedTintICEHandler scopedICEHandler(GetDevice());
@ -245,30 +245,28 @@ namespace dawn_native { namespace d3d12 {
errorStream << "Tint HLSL failure:" << std::endl; errorStream << "Tint HLSL failure:" << std::endl;
tint::transform::Manager transformManager; tint::transform::Manager transformManager;
transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
if (stage == SingleShaderStage::Vertex) {
transformManager.append(std::make_unique<tint::transform::FirstIndexOffset>(
layout->GetFirstIndexOffsetShaderRegister(),
layout->GetFirstIndexOffsetRegisterSpace()));
}
transformManager.append(std::make_unique<tint::transform::BindingRemapper>());
transformManager.append(std::make_unique<tint::transform::Renamer>());
transformManager.append(std::make_unique<tint::transform::Hlsl>());
tint::transform::DataMap transformInputs; tint::transform::DataMap transformInputs;
transformManager.Add<tint::transform::BoundArrayAccessors>();
if (stage == SingleShaderStage::Vertex) {
transformManager.Add<tint::transform::FirstIndexOffset>();
transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
layout->GetFirstIndexOffsetShaderRegister(),
layout->GetFirstIndexOffsetRegisterSpace());
}
transformManager.Add<tint::transform::BindingRemapper>();
transformManager.Add<tint::transform::Renamer>();
transformManager.Add<tint::transform::Hlsl>();
transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints), transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
std::move(accessControls)); std::move(accessControls));
tint::transform::Transform::Output output =
transformManager.Run(GetTintProgram(), transformInputs);
const tint::Program& program = output.program; tint::Program program;
if (!program.IsValid()) { tint::transform::DataMap transformOutputs;
errorStream << "Tint program transform error: " << program.Diagnostics().str() DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
<< std::endl; &transformOutputs, nullptr));
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
if (auto* data = output.data.Get<tint::transform::FirstIndexOffset::Data>()) { if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
firstOffsetInfo->usesVertexIndex = data->has_vertex_index; firstOffsetInfo->usesVertexIndex = data->has_vertex_index;
if (firstOffsetInfo->usesVertexIndex) { if (firstOffsetInfo->usesVertexIndex) {
firstOffsetInfo->vertexIndexOffset = data->first_vertex_offset; firstOffsetInfo->vertexIndexOffset = data->first_vertex_offset;
@ -279,7 +277,7 @@ namespace dawn_native { namespace d3d12 {
} }
} }
if (auto* data = output.data.Get<tint::transform::Renamer::Data>()) { if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
auto it = data->remappings.find(entryPointName); auto it = data->remappings.find(entryPointName);
if (it == data->remappings.end()) { if (it == data->remappings.end()) {
return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point."); return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point.");

View File

@ -63,7 +63,7 @@ namespace dawn_native { namespace d3d12 {
SingleShaderStage stage, SingleShaderStage stage,
PipelineLayout* layout, PipelineLayout* layout,
std::string* remappedEntryPointName, std::string* remappedEntryPointName,
FirstOffsetInfo* firstOffsetInfo) const; FirstOffsetInfo* firstOffsetInfo);
ResultOrError<std::string> TranslateToHLSLWithSPIRVCross(const char* entryPointName, ResultOrError<std::string> TranslateToHLSLWithSPIRVCross(const char* entryPointName,
SingleShaderStage stage, SingleShaderStage stage,

View File

@ -70,10 +70,12 @@ namespace dawn_native { namespace metal {
errorStream << "Tint MSL failure:" << std::endl; errorStream << "Tint MSL failure:" << std::endl;
tint::transform::Manager transformManager; tint::transform::Manager transformManager;
tint::transform::DataMap transformInputs;
if (stage == SingleShaderStage::Vertex && if (stage == SingleShaderStage::Vertex &&
GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) { GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
transformManager.append( AddVertexPullingTransformConfig(*vertexState, entryPointName, kPullingBufferBindingSet,
MakeVertexPullingTransform(*vertexState, entryPointName, kPullingBufferBindingSet)); &transformInputs);
for (VertexBufferSlot slot : for (VertexBufferSlot slot :
IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
@ -83,20 +85,16 @@ namespace dawn_native { namespace metal {
// this MSL buffer index. // this MSL buffer index.
} }
} }
transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>()); transformManager.Add<tint::transform::BoundArrayAccessors>();
transformManager.append(std::make_unique<tint::transform::Renamer>()); transformManager.Add<tint::transform::Renamer>();
transformManager.append(std::make_unique<tint::transform::Msl>()); transformManager.Add<tint::transform::Msl>();
tint::transform::Transform::Output output = transformManager.Run(GetTintProgram()); tint::Program program;
tint::transform::DataMap transformOutputs;
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
&transformOutputs, nullptr));
tint::Program& program = output.program; if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
if (!program.IsValid()) {
errorStream << "Tint program transform error: " << program.Diagnostics().str()
<< std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
if (auto* data = output.data.Get<tint::transform::Renamer::Data>()) {
auto it = data->remappings.find(entryPointName); auto it = data->remappings.find(entryPointName);
if (it == data->remappings.end()) { if (it == data->remappings.end()) {
return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point."); return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point.");

View File

@ -87,9 +87,12 @@ namespace dawn_native { namespace opengl {
tint::transform::Manager transformManager; tint::transform::Manager transformManager;
transformManager.append(std::make_unique<tint::transform::Spirv>()); transformManager.append(std::make_unique<tint::transform::Spirv>());
tint::transform::DataMap transformInputs;
tint::Program program; tint::Program program;
DAWN_TRY_ASSIGN( DAWN_TRY_ASSIGN(program,
program, RunTransforms(&transformManager, GetTintProgram(), CompilationMessages())); RunTransforms(&transformManager, GetTintProgram(), transformInputs,
nullptr, GetCompilationMessages()));
tint::writer::spirv::Generator generator(&program); tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) { if (!generator.Generate()) {

View File

@ -55,14 +55,16 @@ namespace dawn_native { namespace vulkan {
errorStream << "Tint SPIR-V writer failure:" << std::endl; errorStream << "Tint SPIR-V writer failure:" << std::endl;
tint::transform::Manager transformManager; tint::transform::Manager transformManager;
transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>()); transformManager.Add<tint::transform::BoundArrayAccessors>();
transformManager.append(std::make_unique<tint::transform::EmitVertexPointSize>()); transformManager.Add<tint::transform::EmitVertexPointSize>();
transformManager.append(std::make_unique<tint::transform::Spirv>()); transformManager.Add<tint::transform::Spirv>();
tint::transform::DataMap transformInputs;
tint::Program program; tint::Program program;
DAWN_TRY_ASSIGN(program, DAWN_TRY_ASSIGN(program,
RunTransforms(&transformManager, parseResult->tintProgram.get(), RunTransforms(&transformManager, parseResult->tintProgram.get(),
CompilationMessages())); transformInputs, nullptr, GetCompilationMessages()));
tint::writer::spirv::Generator generator(&program); tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) { if (!generator.Generate()) {
@ -166,15 +168,10 @@ namespace dawn_native { namespace vulkan {
tint::transform::DataMap transformInputs; tint::transform::DataMap transformInputs;
transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints), transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
std::move(accessControls)); std::move(accessControls));
tint::transform::Transform::Output output =
transformManager.Run(GetTintProgram(), transformInputs);
const tint::Program& program = output.program; tint::Program program;
if (!program.IsValid()) { DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
errorStream << "Tint program transform error: " << program.Diagnostics().str() nullptr, nullptr));
<< std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::writer::spirv::Generator generator(&program); tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) { if (!generator.Generate()) {

View File

@ -160,7 +160,7 @@ TEST_F(ShaderModuleValidationTest, MultisampledArrayTexture) {
} }
// Tests that shader module compilation messages can be queried. // Tests that shader module compilation messages can be queried.
TEST_F(ShaderModuleValidationTest, CompilationMessages) { TEST_F(ShaderModuleValidationTest, GetCompilationMessages) {
// This test works assuming ShaderModule is backed by a dawn_native::ShaderModuleBase, which // This test works assuming ShaderModule is backed by a dawn_native::ShaderModuleBase, which
// is not the case on the wire. // is not the case on the wire.
DAWN_SKIP_TEST_IF(UsesWire()); DAWN_SKIP_TEST_IF(UsesWire());
@ -172,12 +172,11 @@ TEST_F(ShaderModuleValidationTest, CompilationMessages) {
dawn_native::ShaderModuleBase* shaderModuleBase = dawn_native::ShaderModuleBase* shaderModuleBase =
reinterpret_cast<dawn_native::ShaderModuleBase*>(shaderModule.Get()); reinterpret_cast<dawn_native::ShaderModuleBase*>(shaderModule.Get());
shaderModuleBase->CompilationMessages()->ClearMessages(); dawn_native::OwnedCompilationMessages* messages = shaderModuleBase->GetCompilationMessages();
shaderModuleBase->CompilationMessages()->AddMessage("Info Message"); messages->ClearMessages();
shaderModuleBase->CompilationMessages()->AddMessage("Warning Message", messages->AddMessage("Info Message");
wgpu::CompilationMessageType::Warning); messages->AddMessage("Warning Message", wgpu::CompilationMessageType::Warning);
shaderModuleBase->CompilationMessages()->AddMessage("Error Message", messages->AddMessage("Error Message", wgpu::CompilationMessageType::Error, 3, 4);
wgpu::CompilationMessageType::Error, 3, 4);
auto callback = [](WGPUCompilationInfoRequestStatus status, const WGPUCompilationInfo* info, auto callback = [](WGPUCompilationInfoRequestStatus status, const WGPUCompilationInfo* info,
void* userdata) { void* userdata) {