diff --git a/src/program_builder.h b/src/program_builder.h index facd53976f..8627c9afdd 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -36,6 +36,7 @@ #include "src/ast/float_literal.h" #include "src/ast/i32.h" #include "src/ast/if_statement.h" +#include "src/ast/interpolate_decoration.h" #include "src/ast/loop_statement.h" #include "src/ast/matrix.h" #include "src/ast/member_accessor_expression.h" @@ -1925,6 +1926,26 @@ class ProgramBuilder { return create(source_, builtin); } + /// Creates an ast::InterpolateDecoration + /// @param source the source information + /// @param type the interpolation type + /// @param sampling the interpolation sampling + /// @returns the interpolate decoration pointer + ast::InterpolateDecoration* Interpolate(const Source& source, + ast::InterpolationType type, + ast::InterpolationSampling sampling) { + return create(source, type, sampling); + } + + /// Creates an ast::InterpolateDecoration + /// @param type the interpolation type + /// @param sampling the interpolation sampling + /// @returns the interpolate decoration pointer + ast::InterpolateDecoration* Interpolate(ast::InterpolationType type, + ast::InterpolationSampling sampling) { + return create(source_, type, sampling); + } + /// Creates an ast::LocationDecoration /// @param source the source information /// @param location the location value diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc index 8702d2d876..c40c70ada9 100644 --- a/src/resolver/decoration_validation_test.cc +++ b/src/resolver/decoration_validation_test.cc @@ -61,6 +61,7 @@ enum class DecorationKind { kBinding, kBuiltin, kGroup, + kInterpolate, kLocation, kOverride, kOffset, @@ -102,6 +103,10 @@ static ast::DecorationList createDecorations(const Source& source, return {builder.Builtin(source, ast::Builtin::kPosition)}; case DecorationKind::kGroup: return {builder.create(source, 1u)}; + case DecorationKind::kInterpolate: + return {builder.Interpolate(source, ast::InterpolationType::kLinear, + ast::InterpolationSampling::kCenter), + builder.Location(0)}; case DecorationKind::kLocation: return {builder.Location(source, 1)}; case DecorationKind::kOverride: @@ -150,6 +155,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -185,6 +191,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, true}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, true}, TestParams{DecorationKind::kLocation, true}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -247,6 +254,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -282,6 +290,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, true}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, true}, TestParams{DecorationKind::kLocation, true}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -335,6 +344,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -369,6 +379,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -408,7 +419,7 @@ TEST_P(StructMemberDecorationTest, IsValid) { createDecorations(Source{{12, 34}}, *this, params.kind))}); } else { members.push_back( - {Member("a", ty.i32(), + {Member("a", ty.f32(), createDecorations(Source{{12, 34}}, *this, params.kind))}); } @@ -431,6 +442,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, true}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, true}, TestParams{DecorationKind::kLocation, true}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, true}, @@ -492,6 +504,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -542,6 +555,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, true}, TestParams{DecorationKind::kOffset, false}, @@ -590,6 +604,7 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, false}, TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kInterpolate, false}, TestParams{DecorationKind::kLocation, false}, TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, @@ -957,5 +972,101 @@ TEST_F(WorkgroupDecoration, DuplicateDecoration) { } // namespace } // namespace WorkgroupDecorationTests +namespace InterpolateTests { +namespace { + +using InterpolateTest = ResolverTest; + +struct Params { + ast::InterpolationType type; + ast::InterpolationSampling sampling; + bool should_pass; +}; + +struct TestWithParams : ResolverTestWithParam {}; + +using InterpolateParameterTest = TestWithParams; +TEST_P(InterpolateParameterTest, All) { + auto& params = GetParam(); + + Func("main", + ast::VariableList{Param( + "a", ty.f32(), + {Location(0), + Interpolate(Source{{12, 34}}, params.type, params.sampling)})}, + ty.void_(), {}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}); + + if (params.should_pass) { + EXPECT_TRUE(r()->Resolve()) << r()->error(); + } else { + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: flat interpolation attribute must not have a " + "sampling parameter"); + } +} + +INSTANTIATE_TEST_SUITE_P( + ResolverDecorationValidationTest, + InterpolateParameterTest, + testing::Values(Params{ast::InterpolationType::kPerspective, + ast::InterpolationSampling::kNone, true}, + Params{ast::InterpolationType::kPerspective, + ast::InterpolationSampling::kCenter, true}, + Params{ast::InterpolationType::kPerspective, + ast::InterpolationSampling::kCentroid, true}, + Params{ast::InterpolationType::kPerspective, + ast::InterpolationSampling::kSample, true}, + Params{ast::InterpolationType::kLinear, + ast::InterpolationSampling::kNone, true}, + Params{ast::InterpolationType::kLinear, + ast::InterpolationSampling::kCenter, true}, + Params{ast::InterpolationType::kLinear, + ast::InterpolationSampling::kCentroid, true}, + Params{ast::InterpolationType::kLinear, + ast::InterpolationSampling::kSample, true}, + // flat interpolation must not have a sampling type + Params{ast::InterpolationType::kFlat, + ast::InterpolationSampling::kNone, true}, + Params{ast::InterpolationType::kFlat, + ast::InterpolationSampling::kCenter, false}, + Params{ast::InterpolationType::kFlat, + ast::InterpolationSampling::kCentroid, false}, + Params{ast::InterpolationType::kFlat, + ast::InterpolationSampling::kSample, false})); + +TEST_F(InterpolateTest, Parameter_NotFloatingPoint) { + Func("main", + ast::VariableList{ + Param("a", ty.i32(), + {Location(0), + Interpolate(Source{{12, 34}}, ast::InterpolationType::kFlat, + ast::InterpolationSampling::kNone)})}, + ty.void_(), {}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: store type of interpolate attribute must be floating " + "point scalar or vector"); +} + +TEST_F(InterpolateTest, ReturnType_NotFloatingPoint) { + Func( + "main", {}, ty.i32(), {Return(1)}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, + {Location(0), Interpolate(Source{{12, 34}}, ast::InterpolationType::kFlat, + ast::InterpolationSampling::kNone)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: store type of interpolate attribute must be floating " + "point scalar or vector"); +} + +} // namespace +} // namespace InterpolateTests + } // namespace resolver } // namespace tint diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 65b34bd0a0..c6462df3ed 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -30,6 +30,7 @@ #include "src/ast/fallthrough_statement.h" #include "src/ast/if_statement.h" #include "src/ast/internal_decoration.h" +#include "src/ast/interpolate_decoration.h" #include "src/ast/loop_statement.h" #include "src/ast/matrix.h" #include "src/ast/override_decoration.h" @@ -719,7 +720,8 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { } } else { bool is_shader_io_decoration = - deco->IsAnyOf(); + deco->IsAnyOf(); bool has_io_storage_class = info->storage_class == ast::StorageClass::kInput || info->storage_class == ast::StorageClass::kOutput; @@ -947,6 +949,10 @@ bool Resolver::ValidateParameter(const ast::Function* func, if (!ValidateBuiltinDecoration(builtin, info->type)) { return false; } + } else if (auto* interpolate = deco->As()) { + if (!ValidateInterpolateDecoration(interpolate, info->type)) { + return false; + } } else if (!deco->IsAnyOf() && !(IsValidationDisabled( @@ -1015,6 +1021,29 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, return true; } +bool Resolver::ValidateInterpolateDecoration( + const ast::InterpolateDecoration* deco, + const sem::Type* storage_type) { + auto* type = storage_type->UnwrapRef(); + + if (!type->is_float_scalar_or_vector()) { + AddError( + "store type of interpolate attribute must be floating point scalar or " + "vector", + deco->source()); + return false; + } + + if (deco->type() == ast::InterpolationType::kFlat && + deco->sampling() != ast::InterpolationSampling::kNone) { + AddError("flat interpolation attribute must not have a sampling parameter", + deco->source()); + return false; + } + + return true; +} + bool Resolver::ValidateFunction(const ast::Function* func, const FunctionInfo* info) { auto func_it = symbol_to_function_.find(func->symbol()); @@ -1101,6 +1130,10 @@ bool Resolver::ValidateFunction(const ast::Function* func, if (!ValidateBuiltinDecoration(builtin, info->return_type)) { return false; } + } else if (auto* interpolate = deco->As()) { + if (!ValidateInterpolateDecoration(interpolate, info->return_type)) { + return false; + } } else if (!deco->Is()) { AddError("decoration is not valid for entry point return types", deco->source()); @@ -3364,6 +3397,7 @@ bool Resolver::ValidateStructure(const sem::Struct* str) { for (auto* deco : member->Declaration()->decorations()) { if (!(deco->Is() || + deco->Is() || deco->Is() || deco->Is() || deco->Is() || @@ -3376,6 +3410,10 @@ bool Resolver::ValidateStructure(const sem::Struct* str) { if (!ValidateBuiltinDecoration(builtin, member->Type())) { return false; } + } else if (auto* interpolate = deco->As()) { + if (!ValidateInterpolateDecoration(interpolate, member->Type())) { + return false; + } } } diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 336b44c128..16758bfbdc 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -276,6 +276,8 @@ class Resolver { bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info); bool ValidateFunction(const ast::Function* func, const FunctionInfo* info); bool ValidateGlobalVariable(const VariableInfo* var); + bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco, + const sem::Type* storage_type); bool ValidateMatrix(const sem::Matrix* matirx_type, const Source& source); bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor, const sem::Matrix* matrix_type);