validation: validate builtin pipeline stage and Input/Output

Bug: tint:957
Change-Id: I5f509e61501b39f2a0b3bc10a204ae1f39a0d460
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57105
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Sarah 2021-07-08 14:02:56 +00:00 committed by Sarah Mashayekhi
parent a7392fbd8a
commit 99a78ad72f
8 changed files with 334 additions and 77 deletions

View File

@ -1059,11 +1059,11 @@ TEST_F(InspectorGetEntryPointTest, MultipleEntryPointsInOutVariables) {
TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables) { TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables) {
auto* in_var0 = auto* in_var0 =
Param("in_var0", ty.u32(), {Builtin(ast::Builtin::kInstanceIndex)}); Param("in_var0", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)});
auto* in_var1 = Param("in_var1", ty.u32(), {Location(0u)}); auto* in_var1 = Param("in_var1", ty.f32(), {Location(0u)});
Func("foo", {in_var0, in_var1}, ty.u32(), {Return("in_var1")}, Func("foo", {in_var0, in_var1}, ty.f32(), {Return("in_var1")},
{Stage(ast::PipelineStage::kFragment)}, {Stage(ast::PipelineStage::kFragment)},
{Builtin(ast::Builtin::kSampleMask)}); {Builtin(ast::Builtin::kFragDepth)});
Inspector& inspector = Build(); Inspector& inspector = Build();
auto result = inspector.GetEntryPoints(); auto result = inspector.GetEntryPoints();
@ -1075,7 +1075,7 @@ TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables) {
EXPECT_EQ("in_var1", result[0].input_variables[0].name); EXPECT_EQ("in_var1", result[0].input_variables[0].name);
EXPECT_TRUE(result[0].input_variables[0].has_location_decoration); EXPECT_TRUE(result[0].input_variables[0].has_location_decoration);
EXPECT_EQ(0u, result[0].input_variables[0].location_decoration); EXPECT_EQ(0u, result[0].input_variables[0].location_decoration);
EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[0].component_type); EXPECT_EQ(ComponentType::kFloat, result[0].input_variables[0].component_type);
ASSERT_EQ(0u, result[0].output_variables.size()); ASSERT_EQ(0u, result[0].output_variables.size());
} }

View File

@ -16,9 +16,209 @@
#include "src/resolver/resolver_test_helper.h" #include "src/resolver/resolver_test_helper.h"
namespace tint { namespace tint {
namespace resolver {
namespace { namespace {
template <typename T>
using DataType = builder::DataType<T>;
template <typename T>
using vec2 = builder::vec2<T>;
template <typename T>
using vec3 = builder::vec3<T>;
template <typename T>
using vec4 = builder::vec4<T>;
template <typename T>
using f32 = builder::f32;
using i32 = builder::i32;
using u32 = builder::u32;
class ResolverBuiltinsValidationTest : public resolver::TestHelper, class ResolverBuiltinsValidationTest : public resolver::TestHelper,
public testing::Test {}; public testing::Test {};
namespace TypeTemp {
struct Params {
builder::ast_type_func_ptr type;
ast::Builtin builtin;
ast::PipelineStage stage;
bool is_valid;
};
template <typename T>
constexpr Params ParamsFor(ast::Builtin builtin,
ast::PipelineStage stage,
bool is_valid) {
return Params{DataType<T>::AST, builtin, stage, is_valid};
}
static constexpr Params cases[] = {
ParamsFor<u32>(ast::Builtin::kVertexIndex,
ast::PipelineStage::kVertex,
true),
ParamsFor<u32>(ast::Builtin::kVertexIndex,
ast::PipelineStage::kFragment,
false),
ParamsFor<u32>(ast::Builtin::kVertexIndex,
ast::PipelineStage::kCompute,
false),
ParamsFor<u32>(ast::Builtin::kInstanceIndex,
ast::PipelineStage::kVertex,
true),
ParamsFor<u32>(ast::Builtin::kInstanceIndex,
ast::PipelineStage::kFragment,
false),
ParamsFor<u32>(ast::Builtin::kInstanceIndex,
ast::PipelineStage::kCompute,
false),
ParamsFor<bool>(ast::Builtin::kFrontFacing,
ast::PipelineStage::kVertex,
false),
ParamsFor<bool>(ast::Builtin::kFrontFacing,
ast::PipelineStage::kFragment,
true),
ParamsFor<bool>(ast::Builtin::kFrontFacing,
ast::PipelineStage::kCompute,
false),
ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
ast::PipelineStage::kVertex,
false),
ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
ast::PipelineStage::kFragment,
false),
ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
ast::PipelineStage::kCompute,
true),
ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
ast::PipelineStage::kVertex,
false),
ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
ast::PipelineStage::kFragment,
false),
ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
ast::PipelineStage::kCompute,
true),
ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
ast::PipelineStage::kVertex,
false),
ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
ast::PipelineStage::kFragment,
false),
ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
ast::PipelineStage::kCompute,
true),
ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
ast::PipelineStage::kVertex,
false),
ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
ast::PipelineStage::kFragment,
false),
ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
ast::PipelineStage::kCompute,
true),
ParamsFor<u32>(ast::Builtin::kSampleIndex,
ast::PipelineStage::kVertex,
false),
ParamsFor<u32>(ast::Builtin::kSampleIndex,
ast::PipelineStage::kFragment,
true),
ParamsFor<u32>(ast::Builtin::kSampleIndex,
ast::PipelineStage::kCompute,
false),
ParamsFor<u32>(ast::Builtin::kSampleMask,
ast::PipelineStage::kVertex,
false),
ParamsFor<u32>(ast::Builtin::kSampleMask,
ast::PipelineStage::kFragment,
true),
ParamsFor<u32>(ast::Builtin::kSampleMask,
ast::PipelineStage::kCompute,
false),
};
using ResolverBuiltinsStageTest = ResolverTestWithParam<Params>;
TEST_P(ResolverBuiltinsStageTest, All_input) {
const Params& params = GetParam();
auto* p = Global("p", ty.vec4<f32>(), ast::StorageClass::kPrivate);
auto* input =
Param("input", params.type(*this),
ast::DecorationList{Builtin(Source{{12, 34}}, params.builtin)});
switch (params.stage) {
case ast::PipelineStage::kVertex:
Func("main", {input}, ty.vec4<f32>(), {Return(p)},
{Stage(ast::PipelineStage::kVertex)},
{Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
break;
case ast::PipelineStage::kFragment:
Func("main", {input}, ty.void_(), {},
{Stage(ast::PipelineStage::kFragment)}, {});
break;
case ast::PipelineStage::kCompute:
Func("main", {input}, ty.void_(), {},
ast::DecorationList{Stage(ast::PipelineStage::kCompute),
WorkgroupSize(1)});
break;
default:
break;
}
if (params.is_valid) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
std::stringstream err;
err << "12:34 error: builtin(" << params.builtin << ")";
err << " cannot be used in input of " << params.stage << " pipeline stage";
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), err.str());
}
}
INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
ResolverBuiltinsStageTest,
testing::ValuesIn(cases));
TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInput_Fail) {
// [[stage(fragment)]]
// fn fs_main(
// [[builtin(kFragDepth)]] fd: f32,
// ) -> [[location(0)]] f32 { return 1.0; }
auto* fd = Param(
"fd", ty.f32(),
ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
Func("fs_main", ast::VariableList{fd}, ty.f32(), {Return(1.0f)},
ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
{Location(0)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: builtin(frag_depth) cannot be used in input of "
"fragment pipeline stage");
}
TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInputStruct_Fail) {
// Struct MyInputs {
// [[builtin(front_facing)]] ff: bool;
// };
// [[stage(fragment)]]
// fn fragShader(arg: MyInputs) -> [[location(0)]] f32 { return 1.0; }
auto* s = Structure(
"MyInputs", {Member("frag_depth", ty.f32(),
ast::DecorationList{Builtin(
Source{{12, 34}}, ast::Builtin::kFragDepth)})});
Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
{Stage(ast::PipelineStage::kFragment)}, {Location(0)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: builtin(frag_depth) cannot be used in input of fragment "
"pipeline stage\nnote: while analysing entry point fragShader");
}
} // namespace TypeTemp
TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) { TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) {
// struct MyInputs { // struct MyInputs {
@ -170,15 +370,12 @@ TEST_F(ResolverBuiltinsValidationTest, PositionIsNotF32_Fail) {
TEST_F(ResolverBuiltinsValidationTest, FragDepthIsNotF32_Fail) { TEST_F(ResolverBuiltinsValidationTest, FragDepthIsNotF32_Fail) {
// [[stage(fragment)]] // [[stage(fragment)]]
// fn fs_main( // fn fs_main() -> [[builtin(kFragDepth)]] f32 { var fd: i32; return fd; }
// [[builtin(kFragDepth)]] fd: f32, auto* fd = Var("fd", ty.i32());
// ) -> [[location(0)]] f32 { return 1.0; } Func(
auto* fd = Param( "fs_main", {}, ty.i32(), {Decl(fd), Return(fd)},
"fd", ty.i32(),
ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
Func("fs_main", ast::VariableList{fd}, ty.f32(), {Return(1.0f)},
ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
{Location(0)}); ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: store type of builtin(frag_depth) must be 'f32'"); "12:34 error: store type of builtin(frag_depth) must be 'f32'");
@ -227,44 +424,43 @@ TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltin_Pass) {
// fn fs_main( // fn fs_main(
// [[builtin(kPosition)]] p: vec4<f32>, // [[builtin(kPosition)]] p: vec4<f32>,
// [[builtin(front_facing)]] ff: bool, // [[builtin(front_facing)]] ff: bool,
// [[builtin(frag_depth)]] fd: f32,
// [[builtin(sample_index)]] si: u32, // [[builtin(sample_index)]] si: u32,
// [[builtin(sample_mask)]] sm : u32 // [[builtin(sample_mask)]] sm : u32
// ) -> [[location(0)]] f32 { return 1.0; } // ) -> [[builtin(frag_depth)]] f32 { var fd: f32; return fd; }
auto* p = Param("p", ty.vec4<f32>(), auto* p = Param("p", ty.vec4<f32>(),
ast::DecorationList{Builtin(ast::Builtin::kPosition)}); ast::DecorationList{Builtin(ast::Builtin::kPosition)});
auto* ff = Param("ff", ty.bool_(), auto* ff = Param("ff", ty.bool_(),
ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)}); ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)});
auto* fd = Param("fd", ty.f32(),
ast::DecorationList{Builtin(ast::Builtin::kFragDepth)});
auto* si = Param("si", ty.u32(), auto* si = Param("si", ty.u32(),
ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)}); ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)});
auto* sm = Param("sm", ty.u32(), auto* sm = Param("sm", ty.u32(),
ast::DecorationList{Builtin(ast::Builtin::kSampleMask)}); ast::DecorationList{Builtin(ast::Builtin::kSampleMask)});
Func( auto* var_fd = Var("fd", ty.f32());
"fs_main", ast::VariableList{p, ff, fd, si, sm}, ty.f32(), {Return(1.0f)}, Func("fs_main", ast::VariableList{p, ff, si, sm}, ty.f32(),
ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, {Location(0)}); {Decl(var_fd), Return(var_fd)},
ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
ast::DecorationList{Builtin(ast::Builtin::kFragDepth)});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_F(ResolverBuiltinsValidationTest, VertexBuiltin_Pass) { TEST_F(ResolverBuiltinsValidationTest, VertexBuiltin_Pass) {
// [[stage(vertex)]] // [[stage(vertex)]]
// fn main( // fn main(
// [[builtin(kVertexIndex)]] vi : u32, // [[builtin(vertex_index)]] vi : u32,
// [[builtin(kInstanceIndex)]] ii : u32, // [[builtin(instance_index)]] ii : u32,
// [[builtin(kPosition)]] p :vec4<f32> // ) -> [[builtin(position)]] vec4<f32> { var p :vec4<f32>; return p; }
// ) {}
auto* vi = Param("vi", ty.u32(), auto* vi = Param("vi", ty.u32(),
ast::DecorationList{ ast::DecorationList{
Builtin(Source{{12, 34}}, ast::Builtin::kVertexIndex)}); Builtin(Source{{12, 34}}, ast::Builtin::kVertexIndex)});
auto* p = Param("p", ty.vec4<f32>(),
ast::DecorationList{Builtin(ast::Builtin::kPosition)});
auto* ii = Param("ii", ty.u32(), auto* ii = Param("ii", ty.u32(),
ast::DecorationList{Builtin(Source{{12, 34}}, ast::DecorationList{Builtin(Source{{12, 34}},
ast::Builtin::kInstanceIndex)}); ast::Builtin::kInstanceIndex)});
Func("main", ast::VariableList{vi, ii, p}, ty.vec4<f32>(), auto* p = Var("p", ty.vec4<f32>());
Func("main", ast::VariableList{vi, ii}, ty.vec4<f32>(),
{ {
Return(Expr(p)), Decl(p),
Return(p),
}, },
ast::DecorationList{Stage(ast::PipelineStage::kVertex)}, ast::DecorationList{Stage(ast::PipelineStage::kVertex)},
ast::DecorationList{Builtin(ast::Builtin::kPosition)}); ast::DecorationList{Builtin(ast::Builtin::kPosition)});
@ -369,7 +565,6 @@ TEST_F(ResolverBuiltinsValidationTest,
TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass) { TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass) {
// Struct MyInputs { // Struct MyInputs {
// [[builtin(kPosition)]] p: vec4<f32>; // [[builtin(kPosition)]] p: vec4<f32>;
// [[builtin(front_facing)]] ff: bool;
// [[builtin(frag_depth)]] fd: f32; // [[builtin(frag_depth)]] fd: f32;
// [[builtin(sample_index)]] si: u32; // [[builtin(sample_index)]] si: u32;
// [[builtin(sample_mask)]] sm : u32;; // [[builtin(sample_mask)]] sm : u32;;
@ -383,8 +578,6 @@ TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass) {
ast::DecorationList{Builtin(ast::Builtin::kPosition)}), ast::DecorationList{Builtin(ast::Builtin::kPosition)}),
Member("front_facing", ty.bool_(), Member("front_facing", ty.bool_(),
ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)}), ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)}),
Member("frag_depth", ty.f32(),
ast::DecorationList{Builtin(ast::Builtin::kFragDepth)}),
Member("sample_index", ty.u32(), Member("sample_index", ty.u32(),
ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)}), ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)}),
Member("sample_mask", ty.u32(), Member("sample_mask", ty.u32(),
@ -1006,4 +1199,5 @@ INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
"pack2x16float")); "pack2x16float"));
} // namespace } // namespace
} // namespace resolver
} // namespace tint } // namespace tint

View File

@ -145,7 +145,8 @@ TEST_P(FunctionParameterDecorationTest, IsValid) {
} else { } else {
EXPECT_FALSE(r()->Resolve()) << r()->error(); EXPECT_FALSE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"error: decoration is not valid for function parameters"); "error: decoration is not valid for non-entry point function "
"parameters");
} }
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
@ -244,7 +245,8 @@ TEST_P(FunctionReturnTypeDecorationTest, IsValid) {
} else { } else {
EXPECT_FALSE(r()->Resolve()) << r()->error(); EXPECT_FALSE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"error: decoration is not valid for function return types"); "error: decoration is not valid for non-entry point function "
"return types");
} }
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(

View File

@ -289,16 +289,6 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Location) {
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Builtin) {
// [[stage(fragment)]]
// fn main([[builtin(frag_depth)]] param : f32) {}
auto* param = Param("param", ty.f32(), {Builtin(ast::Builtin::kFragDepth)});
Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
{Stage(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Missing) { TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Missing) {
// [[stage(fragment)]] // [[stage(fragment)]]
// fn main(param : f32) {} // fn main(param : f32) {}
@ -313,10 +303,10 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Missing) {
TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Multiple) { TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Multiple) {
// [[stage(fragment)]] // [[stage(fragment)]]
// fn main([[location(0)]] [[builtin(vertex_index)]] param : u32) {} // fn main([[location(0)]] [[builtin(sample_index)]] param : u32) {}
auto* param = Param("param", ty.u32(), auto* param = Param("param", ty.u32(),
{Location(Source{{13, 43}}, 0), {Location(Source{{13, 43}}, 0),
Builtin(Source{{14, 52}}, ast::Builtin::kVertexIndex)}); Builtin(Source{{14, 52}}, ast::Builtin::kSampleIndex)});
Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
{Stage(ast::PipelineStage::kFragment)}); {Stage(ast::PipelineStage::kFragment)});

View File

@ -928,20 +928,15 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func,
for (auto* deco : info->declaration->decorations()) { for (auto* deco : info->declaration->decorations()) {
if (!func->IsEntryPoint() && !deco->Is<ast::InternalDecoration>()) { if (!func->IsEntryPoint() && !deco->Is<ast::InternalDecoration>()) {
AddError("decoration is not valid for function parameters", AddError(
"decoration is not valid for non-entry point function parameters",
deco->source()); deco->source());
return false; return false;
}
if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
if (!ValidateBuiltinDecoration(builtin, info->type)) {
return false;
}
} else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) { } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
if (!ValidateInterpolateDecoration(interpolate, info->type)) { if (!ValidateInterpolateDecoration(interpolate, info->type)) {
return false; return false;
} }
} else if (!deco->IsAnyOf<ast::LocationDecoration, } else if (!deco->IsAnyOf<ast::LocationDecoration, ast::BuiltinDecoration,
ast::InternalDecoration>() && ast::InternalDecoration>() &&
(IsValidationEnabled( (IsValidationEnabled(
info->declaration->decorations(), info->declaration->decorations(),
@ -989,10 +984,25 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func,
} }
bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type) { const sem::Type* storage_type,
const bool is_input) {
auto* type = storage_type->UnwrapRef(); auto* type = storage_type->UnwrapRef();
const auto stage = current_function_
? current_function_->declaration->pipeline_stage()
: ast::PipelineStage::kNone;
std::stringstream stage_name;
stage_name << stage;
bool is_stage_mismatch = false;
switch (deco->value()) { switch (deco->value()) {
case ast::Builtin::kPosition: case ast::Builtin::kPosition:
if (stage != ast::PipelineStage::kNone &&
!(stage == ast::PipelineStage::kFragment && is_input) &&
!(stage == ast::PipelineStage::kVertex && !is_input)) {
AddError(deco_to_str(deco) + " cannot be used in " +
(is_input ? "input of " : "output of ") +
stage_name.str() + " pipeline stage",
deco->source());
}
if (!(type->is_float_vector() && type->As<sem::Vector>()->size() == 4)) { if (!(type->is_float_vector() && type->As<sem::Vector>()->size() == 4)) {
AddError("store type of " + deco_to_str(deco) + " must be 'vec4<f32>'", AddError("store type of " + deco_to_str(deco) + " must be 'vec4<f32>'",
deco->source()); deco->source());
@ -1002,6 +1012,10 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
case ast::Builtin::kGlobalInvocationId: case ast::Builtin::kGlobalInvocationId:
case ast::Builtin::kLocalInvocationId: case ast::Builtin::kLocalInvocationId:
case ast::Builtin::kWorkgroupId: case ast::Builtin::kWorkgroupId:
if (stage != ast::PipelineStage::kNone &&
!(stage == ast::PipelineStage::kCompute && is_input)) {
is_stage_mismatch = true;
}
if (!(type->is_unsigned_integer_vector() && if (!(type->is_unsigned_integer_vector() &&
type->As<sem::Vector>()->size() == 3)) { type->As<sem::Vector>()->size() == 3)) {
AddError("store type of " + deco_to_str(deco) + " must be 'vec3<u32>'", AddError("store type of " + deco_to_str(deco) + " must be 'vec3<u32>'",
@ -1010,6 +1024,10 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
} }
break; break;
case ast::Builtin::kFragDepth: case ast::Builtin::kFragDepth:
if (stage != ast::PipelineStage::kNone &&
!(stage == ast::PipelineStage::kFragment && !is_input)) {
is_stage_mismatch = true;
}
if (!type->Is<sem::F32>()) { if (!type->Is<sem::F32>()) {
AddError("store type of " + deco_to_str(deco) + " must be 'f32'", AddError("store type of " + deco_to_str(deco) + " must be 'f32'",
deco->source()); deco->source());
@ -1017,6 +1035,10 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
} }
break; break;
case ast::Builtin::kFrontFacing: case ast::Builtin::kFrontFacing:
if (stage != ast::PipelineStage::kNone &&
!(stage == ast::PipelineStage::kFragment && is_input)) {
is_stage_mismatch = true;
}
if (!type->Is<sem::Bool>()) { if (!type->Is<sem::Bool>()) {
AddError("store type of " + deco_to_str(deco) + " must be 'bool'", AddError("store type of " + deco_to_str(deco) + " must be 'bool'",
deco->source()); deco->source());
@ -1024,10 +1046,44 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
} }
break; break;
case ast::Builtin::kLocalInvocationIndex: case ast::Builtin::kLocalInvocationIndex:
if (stage != ast::PipelineStage::kNone &&
!(stage == ast::PipelineStage::kCompute && is_input)) {
is_stage_mismatch = true;
}
if (!type->Is<sem::U32>()) {
AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
deco->source());
return false;
}
break;
case ast::Builtin::kVertexIndex: case ast::Builtin::kVertexIndex:
case ast::Builtin::kInstanceIndex: case ast::Builtin::kInstanceIndex:
if (stage != ast::PipelineStage::kNone &&
!(stage == ast::PipelineStage::kVertex && is_input)) {
is_stage_mismatch = true;
}
if (!type->Is<sem::U32>()) {
AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
deco->source());
return false;
}
break;
case ast::Builtin::kSampleMask: case ast::Builtin::kSampleMask:
if (stage != ast::PipelineStage::kNone &&
!(stage == ast::PipelineStage::kFragment)) {
is_stage_mismatch = true;
}
if (!type->Is<sem::U32>()) {
AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
deco->source());
return false;
}
break;
case ast::Builtin::kSampleIndex: case ast::Builtin::kSampleIndex:
if (stage != ast::PipelineStage::kNone &&
!(stage == ast::PipelineStage::kFragment && is_input)) {
is_stage_mismatch = true;
}
if (!type->Is<sem::U32>()) { if (!type->Is<sem::U32>()) {
AddError("store type of " + deco_to_str(deco) + " must be 'u32'", AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
deco->source()); deco->source());
@ -1037,6 +1093,15 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
default: default:
break; break;
} }
if (is_stage_mismatch) {
AddError(deco_to_str(deco) + " cannot be used in " +
(is_input ? "input of " : "output of ") + stage_name.str() +
" pipeline stage",
deco->source());
return false;
}
return true; return true;
} }
@ -1070,12 +1135,9 @@ bool Resolver::ValidateFunction(const ast::Function* func,
return false; return false;
} }
auto stage_deco_count = 0;
auto workgroup_deco_count = 0; auto workgroup_deco_count = 0;
for (auto* deco : func->decorations()) { for (auto* deco : func->decorations()) {
if (deco->Is<ast::StageDecoration>()) { if (deco->Is<ast::WorkgroupDecoration>()) {
stage_deco_count++;
} else if (deco->Is<ast::WorkgroupDecoration>()) {
workgroup_deco_count++; workgroup_deco_count++;
if (func->pipeline_stage() != ast::PipelineStage::kCompute) { if (func->pipeline_stage() != ast::PipelineStage::kCompute) {
AddError( AddError(
@ -1083,7 +1145,8 @@ bool Resolver::ValidateFunction(const ast::Function* func,
deco->source()); deco->source());
return false; return false;
} }
} else if (!deco->Is<ast::InternalDecoration>()) { } else if (!deco->IsAnyOf<ast::StageDecoration,
ast::InternalDecoration>()) {
AddError("decoration is not valid for functions", deco->source()); AddError("decoration is not valid for functions", deco->source());
return false; return false;
} }
@ -1119,20 +1182,24 @@ bool Resolver::ValidateFunction(const ast::Function* func,
for (auto* deco : func->return_type_decorations()) { for (auto* deco : func->return_type_decorations()) {
if (!func->IsEntryPoint()) { if (!func->IsEntryPoint()) {
AddError("decoration is not valid for function return types", AddError(
"decoration is not valid for non-entry point function return types",
deco->source()); deco->source());
return false; return false;
} }
if (auto* builtin = deco->As<ast::BuiltinDecoration>()) { if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
if (!ValidateBuiltinDecoration(builtin, info->return_type)) {
return false;
}
} else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
if (!ValidateInterpolateDecoration(interpolate, info->return_type)) { if (!ValidateInterpolateDecoration(interpolate, info->return_type)) {
return false; return false;
} }
} else if (!deco->Is<ast::LocationDecoration>()) { } else if (!deco->IsAnyOf<ast::LocationDecoration, ast::BuiltinDecoration,
ast::InternalDecoration>() &&
(IsValidationEnabled(
info->declaration->decorations(),
ast::DisabledValidation::kEntryPointParameter) &&
IsValidationEnabled(info->declaration->decorations(),
ast::DisabledValidation::
kIgnoreAtomicFunctionParameter))) {
AddError("decoration is not valid for entry point return types", AddError("decoration is not valid for entry point return types",
deco->source()); deco->source());
return false; return false;
@ -1192,6 +1259,12 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
} }
builtins.emplace(builtin->value()); builtins.emplace(builtin->value());
if (!ValidateBuiltinDecoration(builtin, ty,
/* is_input */ param_or_ret ==
ParamOrRetType::kParameter)) {
return false;
}
} else if (auto* location = deco->As<ast::LocationDecoration>()) { } else if (auto* location = deco->As<ast::LocationDecoration>()) {
if (pipeline_io_attribute) { if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", deco->source()); AddError("multiple entry point IO attributes", deco->source());
@ -1409,7 +1482,6 @@ bool Resolver::Function(ast::Function* func) {
return false; return false;
} }
// TODO(amaiorano): Validate parameter decorations
for (auto* deco : param->decorations()) { for (auto* deco : param->decorations()) {
Mark(deco); Mark(deco);
} }

View File

@ -273,7 +273,8 @@ class Resolver {
bool ValidateAtomicUses(); bool ValidateAtomicUses();
bool ValidateAssignment(const ast::AssignmentStatement* a); bool ValidateAssignment(const ast::AssignmentStatement* a);
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type); const sem::Type* storage_type,
const bool is_input = true);
bool ValidateCallStatement(ast::CallStatement* stmt); bool ValidateCallStatement(ast::CallStatement* stmt);
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info); bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info); bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);

View File

@ -85,9 +85,8 @@ TEST_F(ResolverValidationTest, WorkgroupMemoryUsedInFragmentStage) {
// fn f2(){ dst = wg; } // fn f2(){ dst = wg; }
// fn f1() { f2(); } // fn f1() { f2(); }
// [[stage(fragment)]] // [[stage(fragment)]]
// fn f0() -> [[builtin(position)]] vec4<f32> { // fn f0() {
// f1(); // f1();
// return dst;
//} //}
Global(Source{{1, 2}}, "wg", ty.vec4<f32>(), ast::StorageClass::kWorkgroup); Global(Source{{1, 2}}, "wg", ty.vec4<f32>(), ast::StorageClass::kWorkgroup);
@ -97,10 +96,9 @@ TEST_F(ResolverValidationTest, WorkgroupMemoryUsedInFragmentStage) {
Func(Source{{5, 6}}, "f2", ast::VariableList{}, ty.void_(), {stmt}); Func(Source{{5, 6}}, "f2", ast::VariableList{}, ty.void_(), {stmt});
Func(Source{{7, 8}}, "f1", ast::VariableList{}, ty.void_(), Func(Source{{7, 8}}, "f1", ast::VariableList{}, ty.void_(),
{Ignore(Call("f2"))}); {Ignore(Call("f2"))});
Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.vec4<f32>(), Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.void_(),
{Ignore(Call("f1")), Return(Expr("dst"))}, {Ignore(Call("f1"))},
ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, ast::DecorationList{Stage(ast::PipelineStage::kFragment)});
ast::DecorationList{Builtin(ast::Builtin::kPosition)});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ( EXPECT_EQ(

View File

@ -296,7 +296,7 @@ OpFunctionEnd
TEST_F(BuilderTest, SampleIndex_SampleRateShadingCapability) { TEST_F(BuilderTest, SampleIndex_SampleRateShadingCapability) {
Func("main", Func("main",
{Param("sample_index", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})}, {Param("sample_index", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})},
ty.void_(), {}, {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)}); ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)});
spirv::Builder& b = SanitizeAndBuild(); spirv::Builder& b = SanitizeAndBuild();