Cleanup transform usage

Use tint::transform::DataMap for inputs as well as outputs.

This allows tint to nest transforms inside each other (e.g. embedding
transforms inside sanitizers), and still having a consistent way to pass
data in and out of these transforms, regardless of nesting depth.

Transforms can also now be fully pre-built and used multiple times as
there is no state held by the transform itself.

Bug: tint:389

Change-Id: If1616c77f2776be449021a32f4a6b0b89159aa2a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/48060
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
Auto-Submit: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-04-19 19:42:19 +00:00 committed by Commit Bot service account
parent eae70b75ae
commit 8091c68450
8 changed files with 91 additions and 81 deletions

View File

@ -1110,11 +1110,13 @@ namespace dawn_native {
parseResult->tintSource = std::move(tintSource);
} else {
tint::transform::Manager transformManager;
transformManager.append(
std::make_unique<tint::transform::EmitVertexPointSize>());
transformManager.append(std::make_unique<tint::transform::Spirv>());
DAWN_TRY_ASSIGN(program,
RunTransforms(&transformManager, &program, outMessages));
transformManager.Add<tint::transform::EmitVertexPointSize>();
transformManager.Add<tint::transform::Spirv>();
tint::transform::DataMap transformInputs;
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, &program,
transformInputs, nullptr, outMessages));
std::vector<uint32_t> spirv;
DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(&program));
@ -1144,8 +1146,10 @@ namespace dawn_native {
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
const tint::Program* program,
const tint::transform::DataMap& inputs,
tint::transform::DataMap* outputs,
OwnedCompilationMessages* outMessages) {
tint::transform::Transform::Output output = transform->Run(program);
tint::transform::Output output = transform->Run(program, inputs);
if (outMessages != nullptr) {
outMessages->AddMessages(output.program.Diagnostics());
}
@ -1153,13 +1157,16 @@ namespace dawn_native {
std::string err = "Tint program failure: " + output.program.Diagnostics().str();
return DAWN_VALIDATION_ERROR(err.c_str());
}
if (outputs != nullptr) {
*outputs = std::move(output.data);
}
return std::move(output.program);
}
std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform(
const VertexState& vertexState,
void AddVertexPullingTransformConfig(const VertexState& vertexState,
const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet) {
BindGroupIndex pullingBufferBindingSet,
tint::transform::DataMap* transformInputs) {
tint::transform::VertexPulling::Config cfg;
cfg.entry_point_name = entryPoint;
cfg.pulling_group = static_cast<uint32_t>(pullingBufferBindingSet);
@ -1181,7 +1188,7 @@ namespace dawn_native {
cfg.vertex_state.push_back(std::move(layout));
}
return std::make_unique<tint::transform::VertexPulling>(cfg);
transformInputs->Add<tint::transform::VertexPulling::Config>(cfg);
}
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
@ -1314,19 +1321,23 @@ namespace dawn_native {
errorStream << "Tint vertex pulling failure:" << std::endl;
tint::transform::Manager transformManager;
transformManager.append(
MakeVertexPullingTransform(vertexState, entryPoint, pullingBufferBindingSet));
transformManager.append(std::make_unique<tint::transform::EmitVertexPointSize>());
transformManager.append(std::make_unique<tint::transform::Spirv>());
transformManager.Add<tint::transform::VertexPulling>();
transformManager.Add<tint::transform::EmitVertexPointSize>();
transformManager.Add<tint::transform::Spirv>();
if (GetDevice()->IsRobustnessEnabled()) {
transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
transformManager.Add<tint::transform::BoundArrayAccessors>();
}
tint::transform::DataMap transformInputs;
AddVertexPullingTransformConfig(vertexState, entryPoint, pullingBufferBindingSet,
&transformInputs);
// A nullptr is passed in for the CompilationMessages here since this method is called
// during RenderPipeline creation, by which point the shader module's CompilationInfo may
// have already been queried.
// during RenderPipeline creation, by which point the shader module's CompilationInfo
// may have already been queried.
tint::Program program;
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, programIn, nullptr));
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, programIn, transformInputs,
nullptr, nullptr));
tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) {

View File

@ -37,6 +37,7 @@ namespace tint {
class Program;
namespace transform {
class DataMap;
class Transform;
class VertexPulling;
} // namespace transform
@ -88,12 +89,15 @@ namespace dawn_native {
const PipelineLayoutBase* layout);
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
const tint::Program* program,
const tint::transform::DataMap& inputs,
tint::transform::DataMap* outputs,
OwnedCompilationMessages* messages);
std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform(
const VertexState& vertexState,
/// Creates and adds the tint::transform::VertexPulling::Config to transformInputs.
void AddVertexPullingTransformConfig(const VertexState& vertexState,
const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet);
BindGroupIndex pullingBufferBindingSet,
tint::transform::DataMap* transformInputs);
// Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
// stored in the ShaderModuleBase and destroyed only when the shader program is destroyed so
@ -173,7 +177,7 @@ namespace dawn_native {
const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet) const;
OwnedCompilationMessages* CompilationMessages() {
OwnedCompilationMessages* GetCompilationMessages() {
return mCompilationMessages.get();
}

View File

@ -194,7 +194,7 @@ namespace dawn_native { namespace d3d12 {
SingleShaderStage stage,
PipelineLayout* layout,
std::string* remappedEntryPointName,
FirstOffsetInfo* firstOffsetInfo) const {
FirstOffsetInfo* firstOffsetInfo) {
ASSERT(!IsError());
ScopedTintICEHandler scopedICEHandler(GetDevice());
@ -245,30 +245,28 @@ namespace dawn_native { namespace d3d12 {
errorStream << "Tint HLSL failure:" << std::endl;
tint::transform::Manager transformManager;
transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
if (stage == SingleShaderStage::Vertex) {
transformManager.append(std::make_unique<tint::transform::FirstIndexOffset>(
layout->GetFirstIndexOffsetShaderRegister(),
layout->GetFirstIndexOffsetRegisterSpace()));
}
transformManager.append(std::make_unique<tint::transform::BindingRemapper>());
transformManager.append(std::make_unique<tint::transform::Renamer>());
transformManager.append(std::make_unique<tint::transform::Hlsl>());
tint::transform::DataMap transformInputs;
transformManager.Add<tint::transform::BoundArrayAccessors>();
if (stage == SingleShaderStage::Vertex) {
transformManager.Add<tint::transform::FirstIndexOffset>();
transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
layout->GetFirstIndexOffsetShaderRegister(),
layout->GetFirstIndexOffsetRegisterSpace());
}
transformManager.Add<tint::transform::BindingRemapper>();
transformManager.Add<tint::transform::Renamer>();
transformManager.Add<tint::transform::Hlsl>();
transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
std::move(accessControls));
tint::transform::Transform::Output output =
transformManager.Run(GetTintProgram(), transformInputs);
const tint::Program& program = output.program;
if (!program.IsValid()) {
errorStream << "Tint program transform error: " << program.Diagnostics().str()
<< std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::Program program;
tint::transform::DataMap transformOutputs;
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
&transformOutputs, nullptr));
if (auto* data = output.data.Get<tint::transform::FirstIndexOffset::Data>()) {
if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
firstOffsetInfo->usesVertexIndex = data->has_vertex_index;
if (firstOffsetInfo->usesVertexIndex) {
firstOffsetInfo->vertexIndexOffset = data->first_vertex_offset;
@ -279,7 +277,7 @@ namespace dawn_native { namespace d3d12 {
}
}
if (auto* data = output.data.Get<tint::transform::Renamer::Data>()) {
if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
auto it = data->remappings.find(entryPointName);
if (it == data->remappings.end()) {
return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point.");

View File

@ -63,7 +63,7 @@ namespace dawn_native { namespace d3d12 {
SingleShaderStage stage,
PipelineLayout* layout,
std::string* remappedEntryPointName,
FirstOffsetInfo* firstOffsetInfo) const;
FirstOffsetInfo* firstOffsetInfo);
ResultOrError<std::string> TranslateToHLSLWithSPIRVCross(const char* entryPointName,
SingleShaderStage stage,

View File

@ -70,10 +70,12 @@ namespace dawn_native { namespace metal {
errorStream << "Tint MSL failure:" << std::endl;
tint::transform::Manager transformManager;
tint::transform::DataMap transformInputs;
if (stage == SingleShaderStage::Vertex &&
GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
transformManager.append(
MakeVertexPullingTransform(*vertexState, entryPointName, kPullingBufferBindingSet));
AddVertexPullingTransformConfig(*vertexState, entryPointName, kPullingBufferBindingSet,
&transformInputs);
for (VertexBufferSlot slot :
IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
@ -83,20 +85,16 @@ namespace dawn_native { namespace metal {
// this MSL buffer index.
}
}
transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
transformManager.append(std::make_unique<tint::transform::Renamer>());
transformManager.append(std::make_unique<tint::transform::Msl>());
transformManager.Add<tint::transform::BoundArrayAccessors>();
transformManager.Add<tint::transform::Renamer>();
transformManager.Add<tint::transform::Msl>();
tint::transform::Transform::Output output = transformManager.Run(GetTintProgram());
tint::Program program;
tint::transform::DataMap transformOutputs;
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
&transformOutputs, nullptr));
tint::Program& program = output.program;
if (!program.IsValid()) {
errorStream << "Tint program transform error: " << program.Diagnostics().str()
<< std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
if (auto* data = output.data.Get<tint::transform::Renamer::Data>()) {
if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
auto it = data->remappings.find(entryPointName);
if (it == data->remappings.end()) {
return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point.");

View File

@ -87,9 +87,12 @@ namespace dawn_native { namespace opengl {
tint::transform::Manager transformManager;
transformManager.append(std::make_unique<tint::transform::Spirv>());
tint::transform::DataMap transformInputs;
tint::Program program;
DAWN_TRY_ASSIGN(
program, RunTransforms(&transformManager, GetTintProgram(), CompilationMessages()));
DAWN_TRY_ASSIGN(program,
RunTransforms(&transformManager, GetTintProgram(), transformInputs,
nullptr, GetCompilationMessages()));
tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) {

View File

@ -55,14 +55,16 @@ namespace dawn_native { namespace vulkan {
errorStream << "Tint SPIR-V writer failure:" << std::endl;
tint::transform::Manager transformManager;
transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
transformManager.append(std::make_unique<tint::transform::EmitVertexPointSize>());
transformManager.append(std::make_unique<tint::transform::Spirv>());
transformManager.Add<tint::transform::BoundArrayAccessors>();
transformManager.Add<tint::transform::EmitVertexPointSize>();
transformManager.Add<tint::transform::Spirv>();
tint::transform::DataMap transformInputs;
tint::Program program;
DAWN_TRY_ASSIGN(program,
RunTransforms(&transformManager, parseResult->tintProgram.get(),
CompilationMessages()));
transformInputs, nullptr, GetCompilationMessages()));
tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) {
@ -166,15 +168,10 @@ namespace dawn_native { namespace vulkan {
tint::transform::DataMap transformInputs;
transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
std::move(accessControls));
tint::transform::Transform::Output output =
transformManager.Run(GetTintProgram(), transformInputs);
const tint::Program& program = output.program;
if (!program.IsValid()) {
errorStream << "Tint program transform error: " << program.Diagnostics().str()
<< std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::Program program;
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
nullptr, nullptr));
tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) {

View File

@ -160,7 +160,7 @@ TEST_F(ShaderModuleValidationTest, MultisampledArrayTexture) {
}
// Tests that shader module compilation messages can be queried.
TEST_F(ShaderModuleValidationTest, CompilationMessages) {
TEST_F(ShaderModuleValidationTest, GetCompilationMessages) {
// This test works assuming ShaderModule is backed by a dawn_native::ShaderModuleBase, which
// is not the case on the wire.
DAWN_SKIP_TEST_IF(UsesWire());
@ -172,12 +172,11 @@ TEST_F(ShaderModuleValidationTest, CompilationMessages) {
dawn_native::ShaderModuleBase* shaderModuleBase =
reinterpret_cast<dawn_native::ShaderModuleBase*>(shaderModule.Get());
shaderModuleBase->CompilationMessages()->ClearMessages();
shaderModuleBase->CompilationMessages()->AddMessage("Info Message");
shaderModuleBase->CompilationMessages()->AddMessage("Warning Message",
wgpu::CompilationMessageType::Warning);
shaderModuleBase->CompilationMessages()->AddMessage("Error Message",
wgpu::CompilationMessageType::Error, 3, 4);
dawn_native::OwnedCompilationMessages* messages = shaderModuleBase->GetCompilationMessages();
messages->ClearMessages();
messages->AddMessage("Info Message");
messages->AddMessage("Warning Message", wgpu::CompilationMessageType::Warning);
messages->AddMessage("Error Message", wgpu::CompilationMessageType::Error, 3, 4);
auto callback = [](WGPUCompilationInfoRequestStatus status, const WGPUCompilationInfo* info,
void* userdata) {