From 989a8e4d623fc2abd44867fb5b374bbcfa9dd455 Mon Sep 17 00:00:00 2001 From: James Price Date: Mon, 28 Jun 2021 23:04:43 +0000 Subject: [PATCH] validation: Validate interpolation attributes They are only valid on entry point parameters and return types, and struct members. They must only be used on floating point scalar and vector types. If the interpolation type is flat, the sampling type must not be specified. Bug: tint:746 Change-Id: Iab17816bc9947a74593a5937bdf513ac9ec664f1 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56241 Auto-Submit: James Price Reviewed-by: Ben Clayton Kokoro: Kokoro --- src/program_builder.h | 21 ++++ src/resolver/decoration_validation_test.cc | 113 ++++++++++++++++++++- src/resolver/resolver.cc | 40 +++++++- src/resolver/resolver.h | 2 + 4 files changed, 174 insertions(+), 2 deletions(-) diff --git a/src/program_builder.h b/src/program_builder.h index facd53976f..8627c9afdd 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -36,6 +36,7 @@ #include "src/ast/float_literal.h" #include "src/ast/i32.h" #include "src/ast/if_statement.h" +#include "src/ast/interpolate_decoration.h" #include "src/ast/loop_statement.h" #include "src/ast/matrix.h" #include "src/ast/member_accessor_expression.h" @@ -1925,6 +1926,26 @@ class ProgramBuilder { return create(source_, builtin); } + /// Creates an ast::InterpolateDecoration + /// @param source the source information + /// @param type the interpolation type + /// @param sampling the interpolation sampling + /// @returns the interpolate decoration pointer + ast::InterpolateDecoration* Interpolate(const Source& source, + ast::InterpolationType type, + ast::InterpolationSampling sampling) { + return create(source, type, sampling); + } + + /// Creates an ast::InterpolateDecoration + /// @param type the interpolation type + /// @param sampling the interpolation sampling + /// @returns the interpolate decoration pointer + ast::InterpolateDecoration* Interpolate(ast::InterpolationType type, + ast::InterpolationSampling sampling) { + return create(source_, type, sampling); + } + /// Creates an ast::LocationDecoration /// @param source the source information /// @param location the location value diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc index 8702d2d876..c40c70ada9 100644 --- a/src/resolver/decoration_validation_test.cc +++ b/src/resolver/decoration_validation_test.cc @@ -61,6 +61,7 @@ enum class DecorationKind { kBinding, kBuiltin, kGroup, + kInterpolate, kLocation, kOverride, kOffset, @@ -102,6 +103,10 @@ static ast::DecorationList createDecorations(const Source& source, return {builder.Builtin(source, ast::Builtin::kPosition)}; case DecorationKind::kGroup: return {builder.create(source, 1u)}; + case DecorationKind::kInterpolate: + return {builder.Interpolate(source, ast::InterpolationType::kLinear, + ast::InterpolationSampling::kCenter), + builder.Location(0)}; case DecorationKind::kLocation: return {builder.Location(source, 1)}; case DecorationKind::kOverride: @@ -150,6 +155,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -185,6 +191,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, true}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, true}, TestParams{DecorationKind::kLocation, true}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -247,6 +254,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -282,6 +290,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, true}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, true}, TestParams{DecorationKind::kLocation, true}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -335,6 +344,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -369,6 +379,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -408,7 +419,7 @@ TEST_P(StructMemberDecorationTest, IsValid) { createDecorations(Source{{12, 34}}, *this, params.kind))}); } else { members.push_back( - {Member("a", ty.i32(), + {Member("a", ty.f32(), createDecorations(Source{{12, 34}}, *this, params.kind))}); } @@ -431,6 +442,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, true}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, true}, TestParams{DecorationKind::kLocation, true}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, true}, @@ -492,6 +504,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -542,6 +555,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, true}, TestParams{DecorationKind::kOffset, false}, @@ -590,6 +604,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -957,5 +972,101 @@ TEST_F(WorkgroupDecoration, DuplicateDecoration) { } // namespace } // namespace WorkgroupDecorationTests +namespace InterpolateTests { +namespace { + +using InterpolateTest = ResolverTest; + +struct Params { + ast::InterpolationType type; + ast::InterpolationSampling sampling; + bool should_pass; +}; + +struct TestWithParams : ResolverTestWithParam {}; + +using InterpolateParameterTest = TestWithParams; +TEST_P(InterpolateParameterTest, All) { + auto& params = GetParam(); + + Func("main", + ast::VariableList{Param( + "a", ty.f32(), + {Location(0), + Interpolate(Source{{12, 34}}, params.type, params.sampling)})}, + ty.void_(), {}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}); + + if (params.should_pass) { + EXPECT_TRUE(r()->Resolve()) << r()->error(); + } else { + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: flat interpolation attribute must not have a " + "sampling parameter"); + } +} + +INSTANTIATE_TEST_SUITE_P( + ResolverDecorationValidationTest, + InterpolateParameterTest, + testing::Values(Params{ast::InterpolationType::kPerspective, + ast::InterpolationSampling::kNone, true}, + Params{ast::InterpolationType::kPerspective, + ast::InterpolationSampling::kCenter, true}, + Params{ast::InterpolationType::kPerspective, + ast::InterpolationSampling::kCentroid, true}, + Params{ast::InterpolationType::kPerspective, + ast::InterpolationSampling::kSample, true}, + Params{ast::InterpolationType::kLinear, + ast::InterpolationSampling::kNone, true}, + Params{ast::InterpolationType::kLinear, + ast::InterpolationSampling::kCenter, true}, + Params{ast::InterpolationType::kLinear, + ast::InterpolationSampling::kCentroid, true}, + Params{ast::InterpolationType::kLinear, + ast::InterpolationSampling::kSample, true}, + // flat interpolation must not have a sampling type + Params{ast::InterpolationType::kFlat, + ast::InterpolationSampling::kNone, true}, + Params{ast::InterpolationType::kFlat, + ast::InterpolationSampling::kCenter, false}, + Params{ast::InterpolationType::kFlat, + ast::InterpolationSampling::kCentroid, false}, + Params{ast::InterpolationType::kFlat, + ast::InterpolationSampling::kSample, false})); + +TEST_F(InterpolateTest, Parameter_NotFloatingPoint) { + Func("main", + ast::VariableList{ + Param("a", ty.i32(), + {Location(0), + Interpolate(Source{{12, 34}}, ast::InterpolationType::kFlat, + ast::InterpolationSampling::kNone)})}, + ty.void_(), {}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: store type of interpolate attribute must be floating " + "point scalar or vector"); +} + +TEST_F(InterpolateTest, ReturnType_NotFloatingPoint) { + Func( + "main", {}, ty.i32(), {Return(1)}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, + {Location(0), Interpolate(Source{{12, 34}}, ast::InterpolationType::kFlat, + ast::InterpolationSampling::kNone)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: store type of interpolate attribute must be floating " + "point scalar or vector"); +} + +} // namespace +} // namespace InterpolateTests + } // namespace resolver } // namespace tint diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 65b34bd0a0..c6462df3ed 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -30,6 +30,7 @@ #include "src/ast/fallthrough_statement.h" #include "src/ast/if_statement.h" #include "src/ast/internal_decoration.h" +#include "src/ast/interpolate_decoration.h" #include "src/ast/loop_statement.h" #include "src/ast/matrix.h" #include "src/ast/override_decoration.h" @@ -719,7 +720,8 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { } } else { bool is_shader_io_decoration = - deco->IsAnyOf(); + deco->IsAnyOf(); bool has_io_storage_class = info->storage_class == ast::StorageClass::kInput || info->storage_class == ast::StorageClass::kOutput; @@ -947,6 +949,10 @@ bool Resolver::ValidateParameter(const ast::Function* func, if (!ValidateBuiltinDecoration(builtin, info->type)) { return false; } + } else if (auto* interpolate = deco->As()) { + if (!ValidateInterpolateDecoration(interpolate, info->type)) { + return false; + } } else if (!deco->IsAnyOf() && !(IsValidationDisabled( @@ -1015,6 +1021,29 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, return true; } +bool Resolver::ValidateInterpolateDecoration( + const ast::InterpolateDecoration* deco, + const sem::Type* storage_type) { + auto* type = storage_type->UnwrapRef(); + + if (!type->is_float_scalar_or_vector()) { + AddError( + "store type of interpolate attribute must be floating point scalar or " + "vector", + deco->source()); + return false; + } + + if (deco->type() == ast::InterpolationType::kFlat && + deco->sampling() != ast::InterpolationSampling::kNone) { + AddError("flat interpolation attribute must not have a sampling parameter", + deco->source()); + return false; + } + + return true; +} + bool Resolver::ValidateFunction(const ast::Function* func, const FunctionInfo* info) { auto func_it = symbol_to_function_.find(func->symbol()); @@ -1101,6 +1130,10 @@ bool Resolver::ValidateFunction(const ast::Function* func, if (!ValidateBuiltinDecoration(builtin, info->return_type)) { return false; } + } else if (auto* interpolate = deco->As()) { + if (!ValidateInterpolateDecoration(interpolate, info->return_type)) { + return false; + } } else if (!deco->Is()) { AddError("decoration is not valid for entry point return types", deco->source()); @@ -3364,6 +3397,7 @@ bool Resolver::ValidateStructure(const sem::Struct* str) { for (auto* deco : member->Declaration()->decorations()) { if (!(deco->Is() || + deco->Is() || deco->Is() || deco->Is() || deco->Is() || @@ -3376,6 +3410,10 @@ bool Resolver::ValidateStructure(const sem::Struct* str) { if (!ValidateBuiltinDecoration(builtin, member->Type())) { return false; } + } else if (auto* interpolate = deco->As()) { + if (!ValidateInterpolateDecoration(interpolate, member->Type())) { + return false; + } } } diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 336b44c128..16758bfbdc 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -276,6 +276,8 @@ class Resolver { bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info); bool ValidateFunction(const ast::Function* func, const FunctionInfo* info); bool ValidateGlobalVariable(const VariableInfo* var); + bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco, + const sem::Type* storage_type); bool ValidateMatrix(const sem::Matrix* matirx_type, const Source& source); bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor, const sem::Matrix* matrix_type);