diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index 4c23874d4a..c02e809637 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -888,6 +888,8 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); } + constexpr uint32_t kMaxInterStageShaderLocation = + kMaxInterStageShaderComponents / 4 - 1; for (auto& entryPoint : entryPoints) { ASSERT(result.count(entryPoint.name) == 0); @@ -928,6 +930,12 @@ namespace dawn_native { << output_var.name; return DAWN_VALIDATION_ERROR(ss.str()); } + uint32_t location = output_var.location_decoration; + if (DAWN_UNLIKELY(location > kMaxInterStageShaderLocation)) { + std::stringstream ss; + ss << "Vertex output location (" << location << ") over limits"; + return DAWN_VALIDATION_ERROR(ss.str()); + } } } @@ -937,6 +945,12 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR( "Need location decoration on fragment input"); } + uint32_t location = input_var.location_decoration; + if (DAWN_UNLIKELY(location > kMaxInterStageShaderLocation)) { + std::stringstream ss; + ss << "Fragment input location (" << location << ") over limits"; + return DAWN_VALIDATION_ERROR(ss.str()); + } } for (const auto& output_var : entryPoint.output_variables) { diff --git a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp index 16b943a148..a83ab88603 100644 --- a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp +++ b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp @@ -212,3 +212,73 @@ TEST_F(ShaderModuleValidationTest, GetCompilationMessages) { shaderModule.GetCompilationInfo(callback, nullptr); } + +// Validate the maximum location of effective inter-stage variables cannot be greater than 14 +// (kMaxInterStageShaderComponents / 4 - 1). +TEST_F(ShaderModuleValidationTest, MaximumShaderIOLocations) { + auto generateShaderForTest = [](uint32_t maximumOutputLocation, wgpu::ShaderStage shaderStage) { + std::ostringstream stream; + stream << "struct ShaderIO {" << std::endl; + for (uint32_t location = 0; location <= maximumOutputLocation; ++location) { + stream << "[[location(" << location << ")]] var" << location << ": f32;" << std::endl; + } + switch (shaderStage) { + case wgpu::ShaderStage::Vertex: { + stream << R"( + [[builtin(position)]] pos: vec4; + }; + [[stage(vertex)]] fn main() -> ShaderIO { + var shaderIO : ShaderIO; + shaderIO.pos = vec4(0.0, 0.0, 0.0, 1.0); + return shaderIO; + })"; + } break; + + case wgpu::ShaderStage::Fragment: { + stream << R"( + }; + [[stage(fragment)]] fn main(shaderIO: ShaderIO) -> [[location(0)]] vec4 { + return vec4(0.0, 0.0, 0.0, 1.0); + })"; + } break; + + case wgpu::ShaderStage::Compute: + default: + UNREACHABLE(); + } + + return stream.str(); + }; + + constexpr uint32_t kMaxInterShaderIOLocation = kMaxInterStageShaderComponents / 4 - 1; + + // It is allowed to create a shader module with the maximum active vertex output location == 14; + { + std::string vertexShader = + generateShaderForTest(kMaxInterShaderIOLocation, wgpu::ShaderStage::Vertex); + utils::CreateShaderModule(device, vertexShader.c_str()); + } + + // It isn't allowed to create a shader module with the maximum active vertex output location > + // 14; + { + std::string vertexShader = + generateShaderForTest(kMaxInterShaderIOLocation + 1, wgpu::ShaderStage::Vertex); + ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, vertexShader.c_str())); + } + + // It is allowed to create a shader module with the maximum active fragment input location == + // 14; + { + std::string fragmentShader = + generateShaderForTest(kMaxInterShaderIOLocation, wgpu::ShaderStage::Fragment); + utils::CreateShaderModule(device, fragmentShader.c_str()); + } + + // It is allowed to create a shader module with the maximum active vertex output location > 14; + { + std::string fragmentShader = + generateShaderForTest(kMaxInterShaderIOLocation + 1, wgpu::ShaderStage::Fragment); + ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, fragmentShader.c_str())); + } +}