diff --git a/src/ast/function.cc b/src/ast/function.cc index 8f8ab730cd..d921e893e5 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -38,7 +38,7 @@ Function::Function(const Source& source, decorations_(std::move(decorations)), return_type_decorations_(std::move(return_type_decorations)) { for (auto* param : params_) { - TINT_ASSERT(param); + TINT_ASSERT(param && param->is_const()); } TINT_ASSERT(symbol_.IsValid()); TINT_ASSERT(return_type_); diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc index ccb9071fd9..a9bdf0e74b 100644 --- a/src/ast/function_test.cc +++ b/src/ast/function_test.cc @@ -26,7 +26,7 @@ using FunctionTest = TestHelper; TEST_F(FunctionTest, Creation) { VariableList params; - params.push_back(Var("var", ty.i32(), StorageClass::kNone)); + params.push_back(Param("var", ty.i32())); auto* var = params[0]; auto* f = Func("func", params, ty.void_(), StatementList{}, DecorationList{}); @@ -38,7 +38,7 @@ TEST_F(FunctionTest, Creation) { TEST_F(FunctionTest, Creation_WithSource) { VariableList params; - params.push_back(Var("var", ty.i32(), StorageClass::kNone)); + params.push_back(Param("var", ty.i32())); auto* f = Func(Source{Source::Location{20, 2}}, "func", params, ty.void_(), StatementList{}, DecorationList{}); @@ -71,7 +71,7 @@ TEST_F(FunctionTest, Assert_NullParam) { { ProgramBuilder b; VariableList params; - params.push_back(b.Var("var", b.ty.i32(), StorageClass::kNone)); + params.push_back(b.Param("var", b.ty.i32())); params.push_back(nullptr); b.Func("f", params, b.ty.void_(), StatementList{}, DecorationList{}); @@ -79,6 +79,18 @@ TEST_F(FunctionTest, Assert_NullParam) { "internal compiler error"); } +TEST_F(FunctionTest, Assert_NonConstParam) { + EXPECT_FATAL_FAILURE( + { + ProgramBuilder b; + VariableList params; + params.push_back(b.Var("var", b.ty.i32(), ast::StorageClass::kNone)); + + b.Func("f", params, b.ty.void_(), StatementList{}, DecorationList{}); + }, + "internal compiler error"); +} + TEST_F(FunctionTest, ToStr) { auto* f = Func("func", VariableList{}, ty.void_(), StatementList{ @@ -112,7 +124,7 @@ WorkgroupDecoration{2 4 6} TEST_F(FunctionTest, ToStr_WithParams) { VariableList params; - params.push_back(Var("var", ty.i32(), StorageClass::kNone)); + params.push_back(Param("var", ty.i32())); auto* f = Func("func", params, ty.void_(), StatementList{ @@ -122,7 +134,7 @@ TEST_F(FunctionTest, ToStr_WithParams) { EXPECT_EQ(str(f), R"(Function func -> __void ( - Variable{ + VariableConst{ var none __i32 @@ -142,8 +154,8 @@ TEST_F(FunctionTest, TypeName) { TEST_F(FunctionTest, TypeName_WithParams) { VariableList params; - params.push_back(Var("var1", ty.i32(), StorageClass::kNone)); - params.push_back(Var("var2", ty.f32(), StorageClass::kNone)); + params.push_back(Param("var1", ty.i32())); + params.push_back(Param("var2", ty.f32())); auto* f = Func("func", params, ty.void_(), StatementList{}, DecorationList{}); EXPECT_EQ(f->type_name(), "__func__void__i32__f32"); diff --git a/src/program_builder.h b/src/program_builder.h index ae197ccac4..6c728cc4b5 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -815,7 +815,7 @@ class ProgramBuilder { /// @param type the variable type /// @param constructor optional constructor expression /// @param decorations optional variable decorations - /// @returns a constant `ast::Variable` with the given name, storage and type + /// @returns a constant `ast::Variable` with the given name and type template ast::Variable* Const(NAME&& name, type::Type* type, @@ -831,7 +831,7 @@ class ProgramBuilder { /// @param type the variable type /// @param constructor optional constructor expression /// @param decorations optional variable decorations - /// @returns a constant `ast::Variable` with the given name, storage and type + /// @returns a constant `ast::Variable` with the given name and type template ast::Variable* Const(const Source& source, NAME&& name, @@ -843,6 +843,34 @@ class ProgramBuilder { constructor, decorations); } + /// @param name the parameter name + /// @param type the parameter type + /// @param decorations optional parameter decorations + /// @returns a constant `ast::Variable` with the given name and type + template + ast::Variable* Param(NAME&& name, + type::Type* type, + ast::DecorationList decorations = {}) { + return create(Sym(std::forward(name)), + ast::StorageClass::kNone, type, true, nullptr, + decorations); + } + + /// @param source the parameter source + /// @param name the parameter name + /// @param type the parameter type + /// @param decorations optional parameter decorations + /// @returns a constant `ast::Variable` with the given name and type + template + ast::Variable* Param(const Source& source, + NAME&& name, + type::Type* type, + ast::DecorationList decorations = {}) { + return create(source, Sym(std::forward(name)), + ast::StorageClass::kNone, type, true, nullptr, + decorations); + } + /// @param args the arguments to pass to Var() /// @returns a `ast::Variable` constructed by calling Var() with the arguments /// of `args`, which is automatically registered as a global variable with the diff --git a/src/resolver/entry_point_validation_test.cc b/src/resolver/entry_point_validation_test.cc index 067186b51f..4b31d1a314 100644 --- a/src/resolver/entry_point_validation_test.cc +++ b/src/resolver/entry_point_validation_test.cc @@ -253,7 +253,7 @@ TEST_F(ResolverEntryPointValidationTest, ReturnType_Struct_DuplicateLocation) { TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Location) { // [[stage(fragment)]] // fn main([[location(0)]] param : f32) {} - auto* param = Const("param", ty.f32(), nullptr, {Location(0)}); + auto* param = Param("param", ty.f32(), {Location(0)}); Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -263,8 +263,8 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Location) { TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Builtin) { // [[stage(fragment)]] // fn main([[builtin(frag_depth)]] param : f32) {} - auto* param = Const("param", ty.vec4(), nullptr, - {Builtin(ast::Builtin::kFragDepth)}); + auto* param = + Param("param", ty.vec4(), {Builtin(ast::Builtin::kFragDepth)}); Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -274,7 +274,7 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Builtin) { TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Missing) { // [[stage(fragment)]] // fn main(param : f32) {} - auto* param = Const(Source{{13, 43}}, "param", ty.vec4(), nullptr); + auto* param = Param(Source{{13, 43}}, "param", ty.vec4()); Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -286,7 +286,7 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Missing) { TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Multiple) { // [[stage(fragment)]] // fn main([[location(0)]] [[builtin(vertex_index)]] param : u32) {} - auto* param = Const("param", ty.u32(), nullptr, + auto* param = Param("param", ty.u32(), {Location(Source{{13, 43}}, 0), Builtin(Source{{14, 52}}, ast::Builtin::kVertexIndex)}); Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, @@ -303,7 +303,7 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Struct) { // [[stage(fragment)]] // fn main([[location(0)]] param : Input) {} auto* input = Structure("Input", {}); - auto* param = Const("param", input, nullptr, {Location(Source{{13, 43}}, 0)}); + auto* param = Param("param", input, {Location(Source{{13, 43}}, 0)}); Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -324,7 +324,7 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_Valid) { auto* input = Structure( "Input", {Member("a", ty.f32(), {Location(0)}), Member("b", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})}); - auto* param = Const("param", input); + auto* param = Param("param", input); Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -343,7 +343,7 @@ TEST_F(ResolverEntryPointValidationTest, {Member("a", ty.f32(), {Location(Source{{13, 43}}, 0), Builtin(Source{{14, 52}}, ast::Builtin::kSampleIndex)})}); - auto* param = Const("param", input); + auto* param = Param("param", input); Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -364,7 +364,7 @@ TEST_F(ResolverEntryPointValidationTest, auto* input = Structure( "Input", {Member(Source{{13, 43}}, "a", ty.f32(), {Location(0)}), Member(Source{{14, 52}}, "b", ty.f32(), {})}); - auto* param = Const("param", input); + auto* param = Param("param", input); Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -386,7 +386,7 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_NestedStruct) { "Inner", {Member(Source{{13, 43}}, "a", ty.f32(), {Location(0)})}); auto* input = Structure("Input", {Member(Source{{14, 52}}, "a", inner, {Location(0)})}); - auto* param = Const("param", input); + auto* param = Param("param", input); Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -408,7 +408,7 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_RuntimeArray) { "Input", {Member(Source{{13, 43}}, "a", ty.array(), {Location(0)})}, {create()}); - auto* param = Const("param", input); + auto* param = Param("param", input); Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -423,10 +423,10 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_DuplicateBuiltins) { // [[stage(fragment)]] // fn main([[builtin(sample_index)]] param_a : u32, // [[builtin(sample_index)]] param_b : u32) {} - auto* param_a = Const("param_a", ty.u32(), nullptr, - {Builtin(ast::Builtin::kSampleIndex)}); - auto* param_b = Const("param_b", ty.u32(), nullptr, - {Builtin(ast::Builtin::kSampleIndex)}); + auto* param_a = + Param("param_a", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)}); + auto* param_b = + Param("param_b", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)}); Func(Source{{12, 34}}, "main", {param_a, param_b}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -450,8 +450,8 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_DuplicateBuiltins) { "InputA", {Member("a", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})}); auto* input_b = Structure( "InputB", {Member("a", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})}); - auto* param_a = Const("param_a", input_a); - auto* param_b = Const("param_b", input_b); + auto* param_a = Param("param_a", input_a); + auto* param_b = Param("param_b", input_b); Func(Source{{12, 34}}, "main", {param_a, param_b}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -466,8 +466,8 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_DuplicateLocation) { // [[stage(fragment)]] // fn main([[location(1)]] param_a : f32, // [[location(1)]] param_b : f32) {} - auto* param_a = Const("param_a", ty.u32(), nullptr, {Location(1)}); - auto* param_b = Const("param_b", ty.u32(), nullptr, {Location(1)}); + auto* param_a = Param("param_a", ty.u32(), {Location(1)}); + auto* param_b = Param("param_b", ty.u32(), {Location(1)}); Func(Source{{12, 34}}, "main", {param_a, param_b}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -488,8 +488,8 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_DuplicateLocation) { // fn main(param_a : InputA, param_b : InputB) {} auto* input_a = Structure("InputA", {Member("a", ty.f32(), {Location(1)})}); auto* input_b = Structure("InputB", {Member("a", ty.f32(), {Location(1)})}); - auto* param_a = Const("param_a", input_a); - auto* param_b = Const("param_b", input_b); + auto* param_a = Param("param_a", input_a); + auto* param_b = Param("param_b", input_b); Func(Source{{12, 34}}, "main", {param_a, param_b}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc index 76ae8f54e1..daf14ad66b 100644 --- a/src/resolver/function_validation_test.cc +++ b/src/resolver/function_validation_test.cc @@ -248,7 +248,7 @@ TEST_F(ResolverFunctionValidationTest, FunctionVarInitWithParam) { // var baz : f32 = bar; // } - auto* bar = Var("bar", ty.f32(), ast::StorageClass::kFunction); + auto* bar = Param("bar", ty.f32()); auto* baz = Var("baz", ty.f32(), ast::StorageClass::kFunction, Expr("bar")); Func("foo", ast::VariableList{bar}, ty.void_(), ast::StatementList{Decl(baz)}, @@ -262,7 +262,7 @@ TEST_F(ResolverFunctionValidationTest, FunctionConstInitWithParam) { // let baz : f32 = bar; // } - auto* bar = Var("bar", ty.f32(), ast::StorageClass::kFunction); + auto* bar = Param("bar", ty.f32()); auto* baz = Const("baz", ty.f32(), Expr("bar")); Func("foo", ast::VariableList{bar}, ty.void_(), ast::StatementList{Decl(baz)}, diff --git a/src/resolver/struct_pipeline_stage_use_test.cc b/src/resolver/struct_pipeline_stage_use_test.cc index 6804e47589..d4c615399f 100644 --- a/src/resolver/struct_pipeline_stage_use_test.cc +++ b/src/resolver/struct_pipeline_stage_use_test.cc @@ -42,7 +42,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointParam) { auto* s = Structure( "S", {Member("a", ty.f32(), {create(0)})}); - Func("foo", {Const("param", s)}, ty.void_(), {}, {}); + Func("foo", {Param("param", s)}, ty.void_(), {}, {}); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -68,7 +68,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderParam) { auto* s = Structure( "S", {Member("a", ty.f32(), {create(0)})}); - Func("main", {Const("param", s)}, ty.void_(), {}, + Func("main", {Param("param", s)}, ty.void_(), {}, {create(ast::PipelineStage::kVertex)}); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -99,7 +99,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderParam) { auto* s = Structure( "S", {Member("a", ty.f32(), {create(0)})}); - Func("main", {Const("param", s)}, ty.void_(), {}, + Func("main", {Param("param", s)}, ty.void_(), {}, {create(ast::PipelineStage::kFragment)}); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -132,7 +132,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsComputeShaderParam) { {create( ast::Builtin::kLocalInvocationIndex)})}); - Func("main", {Const("param", s)}, ty.void_(), {}, + Func("main", {Param("param", s)}, ty.void_(), {}, {create(ast::PipelineStage::kCompute)}); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -148,10 +148,10 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedMultipleStages) { auto* s = Structure( "S", {Member("a", ty.f32(), {create(0)})}); - Func("vert_main", {Const("param", s)}, s, {Return(Construct(s, Expr(0.f)))}, + Func("vert_main", {Param("param", s)}, s, {Return(Construct(s, Expr(0.f)))}, {create(ast::PipelineStage::kVertex)}); - Func("frag_main", {Const("param", s)}, ty.void_(), {}, + Func("frag_main", {Param("param", s)}, ty.void_(), {}, {create(ast::PipelineStage::kFragment)}); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -170,7 +170,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamViaAlias) { "S", {Member("a", ty.f32(), {create(0)})}); auto* s_alias = ty.alias("S_alias", s); - Func("main", {Const("param", s_alias)}, ty.void_(), {}, + Func("main", {Param("param", s_alias)}, ty.void_(), {}, {create(ast::PipelineStage::kFragment)}); ASSERT_TRUE(r()->Resolve()) << r()->error(); diff --git a/src/resolver/struct_storage_class_use_test.cc b/src/resolver/struct_storage_class_use_test.cc index 8213758100..ac1b8f5927 100644 --- a/src/resolver/struct_storage_class_use_test.cc +++ b/src/resolver/struct_storage_class_use_test.cc @@ -39,7 +39,7 @@ TEST_F(ResolverStorageClassUseTest, UnreachableStruct) { TEST_F(ResolverStorageClassUseTest, StructReachableFromParameter) { auto* s = Structure("S", {Member("a", ty.f32())}); - Func("f", {Var("param", s, ast::StorageClass::kNone)}, ty.void_(), {}, {}); + Func("f", {Param("param", s)}, ty.void_(), {}, {}); ASSERT_TRUE(r()->Resolve()) << r()->error(); diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc index a5dd6bbb1a..ec2113c033 100644 --- a/src/resolver/type_validation_test.cc +++ b/src/resolver/type_validation_test.cc @@ -388,8 +388,7 @@ TEST_F(ResolverTypeValidationTest, RuntimeArrayAsParameter_Fail) { // fn func(a : array) {} // [[stage(vertex)]] fn main() {} - auto* param = - Var(Source{{12, 34}}, "a", ty.array(), ast::StorageClass::kNone); + auto* param = Param(Source{{12, 34}}, "a", ty.array()); Func("func", ast::VariableList{param}, ty.void_(), ast::StatementList{ diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc index b263767d09..db1b5a7dfe 100644 --- a/src/transform/canonicalize_entry_point_io.cc +++ b/src/transform/canonicalize_entry_point_io.cc @@ -137,8 +137,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in, ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, in_struct); // Create a new function parameter using this struct type. - auto* struct_param = ctx.dst->Var(new_struct_param_symbol, in_struct, - ast::StorageClass::kNone); + auto* struct_param = ctx.dst->Param(new_struct_param_symbol, in_struct); new_parameters.push_back(struct_param); } diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc index bb4f7ae865..6945cbfd19 100644 --- a/src/transform/spirv.cc +++ b/src/transform/spirv.cc @@ -159,7 +159,7 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const { // Create a function that writes a return value to all output variables. auto* store_value = - ctx.dst->Const(store_value_symbol, ctx.Clone(func->return_type())); + ctx.dst->Param(store_value_symbol, ctx.Clone(func->return_type())); auto return_func_symbol = ctx.dst->Symbols().New(); auto* return_func = ctx.dst->create( return_func_symbol, ast::VariableList{store_value}, diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index bfd7f8c474..fcf000adf1 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -68,9 +68,7 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Name_Collision) { } TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithParams) { - Func("my_func", - ast::VariableList{Var("a", ty.f32(), ast::StorageClass::kNone), - Var("b", ty.i32(), ast::StorageClass::kNone)}, + Func("my_func", ast::VariableList{Param("a", ty.f32()), Param("b", ty.i32())}, ty.void_(), ast::StatementList{ create(), @@ -114,8 +112,7 @@ TEST_F(HlslGeneratorImplTest_Function, // fn frag_main([[location(0)]] foo : f32) -> [[location(1)]] f32 { // return foo; // } - auto* foo_in = - Const("foo", ty.f32(), nullptr, {create(0)}); + auto* foo_in = Param("foo", ty.f32(), {create(0)}); Func("frag_main", ast::VariableList{foo_in}, ty.f32(), {create(Expr("foo"))}, {create(ast::PipelineStage::kFragment)}, @@ -148,7 +145,7 @@ TEST_F(HlslGeneratorImplTest_Function, // return coord.x; // } auto* coord_in = - Const("coord", ty.vec4(), nullptr, + Param("coord", ty.vec4(), {create(ast::Builtin::kFragCoord)}); Func("frag_main", ast::VariableList{coord_in}, ty.f32(), {create(MemberAccessor("coord", "x"))}, @@ -199,7 +196,7 @@ TEST_F(HlslGeneratorImplTest_Function, Construct(interface_struct, Expr(0.5f), Expr(0.25f)))}, {create(ast::PipelineStage::kVertex)}); - Func("frag_main", {Const("colors", interface_struct)}, ty.void_(), + Func("frag_main", {Param("colors", interface_struct)}, ty.void_(), { WrapInStatement( Const("r", ty.f32(), MemberAccessor(Expr("colors"), "col1"))), @@ -261,7 +258,7 @@ TEST_F(HlslGeneratorImplTest_Function, {Member("pos", ty.vec4(), {create(ast::Builtin::kPosition)})}); - Func("foo", {Const("x", ty.f32())}, vertex_output_struct, + Func("foo", {Param("x", ty.f32())}, vertex_output_struct, {create(Construct( vertex_output_struct, Construct(ty.vec4(), Expr("x"), Expr("x"), Expr("x"), Expr(1.f))))}, @@ -570,9 +567,7 @@ TEST_F( create(0), }); - Func("sub_func", - ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kNone)}, - ty.f32(), + Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(), ast::StatementList{ create(Expr("bar"), Expr("foo")), create(Expr("val"), Expr("param")), @@ -626,9 +621,7 @@ TEST_F(HlslGeneratorImplTest_Function, create(ast::Builtin::kFragDepth), }); - Func("sub_func", - ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kFunction)}, - ty.f32(), + Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(), ast::StatementList{ create(Expr("param")), }, @@ -680,9 +673,7 @@ TEST_F( create(ast::Builtin::kFragDepth), }); - Func("sub_func", - ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kNone)}, - ty.f32(), + Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(), ast::StatementList{ create(Expr("depth"), MemberAccessor("coord", "x")), @@ -735,9 +726,7 @@ TEST_F(HlslGeneratorImplTest_Function, create(1), }); - Func("sub_func", - ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kFunction)}, - ty.f32(), + Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(), ast::StatementList{ create(MemberAccessor("coord", "x")), }, @@ -785,9 +774,7 @@ TEST_F(HlslGeneratorImplTest_Function, create(1), }); - Func("sub_func", - ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kFunction)}, - ty.f32(), + Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(), ast::StatementList{ create(MemberAccessor("coord", "x")), }, @@ -932,13 +919,10 @@ void main() { } TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { - Func( - "my_func", - ast::VariableList{Var("a", ty.array(), ast::StorageClass::kNone)}, - ty.void_(), - ast::StatementList{ - create(), - }); + Func("my_func", ast::VariableList{Param("a", ty.array())}, ty.void_(), + ast::StatementList{ + create(), + }); GeneratorImpl& gen = Build(); diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index c29db72745..a7660616fa 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -49,8 +49,8 @@ using namespace metal; TEST_F(MslGeneratorImplTest, Emit_Function_WithParams) { ast::VariableList params; - params.push_back(Var("a", ty.f32(), ast::StorageClass::kNone)); - params.push_back(Var("b", ty.i32(), ast::StorageClass::kNone)); + params.push_back(Param("a", ty.f32())); + params.push_back(Param("b", ty.i32())); Func("my_func", params, ty.void_(), ast::StatementList{ @@ -96,8 +96,7 @@ TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOutVars) { // fn frag_main([[location(0)]] foo : f32) -> [[location(1)]] f32 { // return foo; // } - auto* foo_in = - Const("foo", ty.f32(), nullptr, {create(0)}); + auto* foo_in = Param("foo", ty.f32(), {create(0)}); Func("frag_main", ast::VariableList{foo_in}, ty.f32(), {create(Expr("foo"))}, {create(ast::PipelineStage::kFragment)}, @@ -129,7 +128,7 @@ TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOut_Builtins) { // return coord.x; // } auto* coord_in = - Const("coord", ty.vec4(), nullptr, + Param("coord", ty.vec4(), {create(ast::Builtin::kFragCoord)}); Func("frag_main", ast::VariableList{coord_in}, ty.f32(), {create(MemberAccessor("coord", "x"))}, @@ -180,7 +179,7 @@ TEST_F(MslGeneratorImplTest, Construct(interface_struct, Expr(0.5f), Expr(0.25f)))}, {create(ast::PipelineStage::kVertex)}); - Func("frag_main", {Const("colors", interface_struct)}, ty.void_(), + Func("frag_main", {Param("colors", interface_struct)}, ty.void_(), { WrapInStatement( Const("r", ty.f32(), MemberAccessor(Expr("colors"), "col1"))), @@ -242,7 +241,7 @@ TEST_F(MslGeneratorImplTest, {Member("pos", ty.vec4(), {create(ast::Builtin::kPosition)})}); - Func("foo", {Const("x", ty.f32())}, vertex_output_struct, + Func("foo", {Param("x", ty.f32())}, vertex_output_struct, {create(Construct( vertex_output_struct, Construct(ty.vec4(), Expr("x"), Expr("x"), Expr("x"), Expr(1.f))))}, @@ -393,7 +392,7 @@ TEST_F( ast::DecorationList{create(0)}); ast::VariableList params; - params.push_back(Var("param", ty.f32(), ast::StorageClass::kNone)); + params.push_back(Param("param", ty.f32())); auto body = ast::StatementList{ create(Expr("bar"), Expr("foo")), @@ -450,7 +449,7 @@ TEST_F(MslGeneratorImplTest, create(ast::Builtin::kFragDepth)}); ast::VariableList params; - params.push_back(Var("param", ty.f32(), ast::StorageClass::kFunction)); + params.push_back(Param("param", ty.f32())); Func("sub_func", params, ty.f32(), ast::StatementList{ @@ -504,7 +503,7 @@ TEST_F( create(ast::Builtin::kFragDepth)}); ast::VariableList params; - params.push_back(Var("param", ty.f32(), ast::StorageClass::kNone)); + params.push_back(Param("param", ty.f32())); auto body = ast::StatementList{ create(Expr("depth"), @@ -555,7 +554,7 @@ TEST_F(MslGeneratorImplTest, create(1)}); ast::VariableList params; - params.push_back(Var("param", ty.f32(), ast::StorageClass::kFunction)); + params.push_back(Param("param", ty.f32())); auto body = ast::StatementList{ create(MemberAccessor("coord", "x")), @@ -610,7 +609,7 @@ TEST_F(MslGeneratorImplTest, create(1)}); ast::VariableList params; - params.push_back(Var("param", ty.f32(), ast::StorageClass::kFunction)); + params.push_back(Param("param", ty.f32())); auto body = ast::StatementList{ create(MemberAccessor("coord", "b"))}; @@ -666,7 +665,7 @@ TEST_F(MslGeneratorImplTest, create(1)}); ast::VariableList params; - params.push_back(Var("param", ty.f32(), ast::StorageClass::kFunction)); + params.push_back(Param("param", ty.f32())); auto body = ast::StatementList{ create(MemberAccessor("coord", "b"))}; @@ -758,7 +757,7 @@ fragment ep_1_out ep_1() { TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayParams) { ast::VariableList params; - params.push_back(Var("a", ty.array(), ast::StorageClass::kNone)); + params.push_back(Param("a", ty.array())); Func("my_func", params, ty.void_(), ast::StatementList{ diff --git a/src/writer/spirv/builder_call_test.cc b/src/writer/spirv/builder_call_test.cc index 3f5311003e..253eef90e4 100644 --- a/src/writer/spirv/builder_call_test.cc +++ b/src/writer/spirv/builder_call_test.cc @@ -26,8 +26,8 @@ using BuilderTest = TestHelper; TEST_F(BuilderTest, Expression_Call) { ast::VariableList func_params; - func_params.push_back(Var("a", ty.f32(), ast::StorageClass::kFunction)); - func_params.push_back(Var("b", ty.f32(), ast::StorageClass::kFunction)); + func_params.push_back(Param("a", ty.f32())); + func_params.push_back(Param("b", ty.f32())); auto* a_func = Func("a_func", func_params, ty.f32(), @@ -46,28 +46,26 @@ TEST_F(BuilderTest, Expression_Call) { ASSERT_TRUE(b.GenerateFunction(a_func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(func)) << b.error(); - EXPECT_EQ(b.GenerateCallExpression(expr), 14u) << b.error(); + EXPECT_EQ(b.GenerateCallExpression(expr), 12u) << b.error(); EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" OpName %4 "a" OpName %5 "b" -OpName %12 "main" +OpName %10 "main" %2 = OpTypeFloat 32 %1 = OpTypeFunction %2 %2 %2 -%11 = OpTypeVoid -%10 = OpTypeFunction %11 -%15 = OpConstant %2 1 +%9 = OpTypeVoid +%8 = OpTypeFunction %9 +%13 = OpConstant %2 1 %3 = OpFunction %2 None %1 %4 = OpFunctionParameter %2 %5 = OpFunctionParameter %2 %6 = OpLabel -%7 = OpLoad %2 %4 -%8 = OpLoad %2 %5 -%9 = OpFAdd %2 %7 %8 -OpReturnValue %9 +%7 = OpFAdd %2 %4 %5 +OpReturnValue %7 OpFunctionEnd -%12 = OpFunction %11 None %10 -%13 = OpLabel -%14 = OpFunctionCall %2 %3 %15 %15 +%10 = OpFunction %9 None %8 +%11 = OpLabel +%12 = OpFunctionCall %2 %3 %13 %13 OpReturn OpFunctionEnd )"); @@ -75,8 +73,8 @@ OpFunctionEnd TEST_F(BuilderTest, Statement_Call) { ast::VariableList func_params; - func_params.push_back(Var("a", ty.f32(), ast::StorageClass::kFunction)); - func_params.push_back(Var("b", ty.f32(), ast::StorageClass::kFunction)); + func_params.push_back(Param("a", ty.f32())); + func_params.push_back(Param("b", ty.f32())); auto* a_func = Func("a_func", func_params, ty.f32(), @@ -99,24 +97,22 @@ TEST_F(BuilderTest, Statement_Call) { EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" OpName %4 "a" OpName %5 "b" -OpName %12 "main" +OpName %10 "main" %2 = OpTypeFloat 32 %1 = OpTypeFunction %2 %2 %2 -%11 = OpTypeVoid -%10 = OpTypeFunction %11 -%15 = OpConstant %2 1 +%9 = OpTypeVoid +%8 = OpTypeFunction %9 +%13 = OpConstant %2 1 %3 = OpFunction %2 None %1 %4 = OpFunctionParameter %2 %5 = OpFunctionParameter %2 %6 = OpLabel -%7 = OpLoad %2 %4 -%8 = OpLoad %2 %5 -%9 = OpFAdd %2 %7 %8 -OpReturnValue %9 +%7 = OpFAdd %2 %4 %5 +OpReturnValue %7 OpFunctionEnd -%12 = OpFunction %11 None %10 -%13 = OpLabel -%14 = OpFunctionCall %2 %3 %15 %15 +%10 = OpFunction %9 None %8 +%11 = OpLabel +%12 = OpFunctionCall %2 %3 %13 %13 OpReturn OpFunctionEnd )"); diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc index a4792e7b68..ea7bad6f4b 100644 --- a/src/writer/spirv/builder_entry_point_test.cc +++ b/src/writer/spirv/builder_entry_point_test.cc @@ -44,10 +44,10 @@ TEST_F(BuilderTest, EntryPoint_Parameters) { // } auto* f32 = ty.f32(); auto* vec4 = ty.vec4(); - auto* coord = Var("coord", vec4, ast::StorageClass::kInput, nullptr, - {create(ast::Builtin::kFragCoord)}); - auto* loc1 = Var("loc1", f32, ast::StorageClass::kInput, nullptr, - {create(1u)}); + auto* coord = + Param("coord", vec4, + {create(ast::Builtin::kFragCoord)}); + auto* loc1 = Param("loc1", f32, {create(1u)}); auto* mul = Mul(Expr(MemberAccessor("coord", "x")), Expr("loc1")); auto* col = Var("col", f32, ast::StorageClass::kFunction, mul, {}); Func("frag_main", ast::VariableList{coord, loc1}, ty.void_(), @@ -109,8 +109,7 @@ TEST_F(BuilderTest, EntryPoint_ReturnValue) { // } auto* f32 = ty.f32(); auto* u32 = ty.u32(); - auto* loc_in = Var("loc_in", u32, ast::StorageClass::kFunction, nullptr, - {create(0)}); + auto* loc_in = Param("loc_in", u32, {create(0)}); auto* cond = create(ast::BinaryOp::kGreaterThan, Expr("loc_in"), Expr(10u)); Func("frag_main", ast::VariableList{loc_in}, f32, @@ -203,8 +202,7 @@ TEST_F(BuilderTest, EntryPoint_SharedStruct) { {create(vert_retval)}, {create(ast::PipelineStage::kVertex)}); - auto* frag_inputs = - Var("inputs", interface, ast::StorageClass::kFunction, nullptr); + auto* frag_inputs = Param("inputs", interface); Func("frag_main", ast::VariableList{frag_inputs}, ty.f32(), {create(MemberAccessor(Expr("inputs"), "value"))}, {create(ast::PipelineStage::kFragment)}, diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc index 66d93dcd8b..be78a581e6 100644 --- a/src/writer/spirv/builder_function_test.cc +++ b/src/writer/spirv/builder_function_test.cc @@ -113,8 +113,7 @@ OpFunctionEnd } TEST_F(BuilderTest, Function_WithParams) { - ast::VariableList params = {Var("a", ty.f32(), ast::StorageClass::kFunction), - Var("b", ty.i32(), ast::StorageClass::kFunction)}; + ast::VariableList params = {Param("a", ty.f32()), Param("b", ty.i32())}; Func("a_func", params, ty.f32(), ast::StatementList{create(Expr("a"))}, @@ -134,8 +133,7 @@ OpName %6 "b" %5 = OpFunctionParameter %2 %6 = OpFunctionParameter %3 %7 = OpLabel -%8 = OpLoad %2 %5 -OpReturnValue %8 +OpReturnValue %5 OpFunctionEnd )") << DumpBuilder(b); } diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc index f7f6d6a790..98f3a8da21 100644 --- a/src/writer/wgsl/generator_impl_function_test.cc +++ b/src/writer/wgsl/generator_impl_function_test.cc @@ -47,16 +47,14 @@ TEST_F(WgslGeneratorImplTest, Emit_Function) { } TEST_F(WgslGeneratorImplTest, Emit_Function_WithParams) { - auto* func = - Func("my_func", - ast::VariableList{Var("a", ty.f32(), ast::StorageClass::kNone), - Var("b", ty.i32(), ast::StorageClass::kNone)}, - ty.void_(), - ast::StatementList{ - create(), - create(), - }, - ast::DecorationList{}); + auto* func = Func( + "my_func", ast::VariableList{Param("a", ty.f32()), Param("b", ty.i32())}, + ty.void_(), + ast::StatementList{ + create(), + create(), + }, + ast::DecorationList{}); GeneratorImpl& gen = Build(); @@ -145,10 +143,10 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Multiple) { TEST_F(WgslGeneratorImplTest, Emit_Function_EntryPoint_Parameters) { auto* vec4 = ty.vec4(); - auto* coord = Var("coord", vec4, ast::StorageClass::kInput, nullptr, - {create(ast::Builtin::kFragCoord)}); - auto* loc1 = Var("loc1", ty.f32(), ast::StorageClass::kInput, nullptr, - {create(1u)}); + auto* coord = + Param("coord", vec4, + {create(ast::Builtin::kFragCoord)}); + auto* loc1 = Param("loc1", ty.f32(), {create(1u)}); auto* func = Func("frag_main", ast::VariableList{coord, loc1}, ty.void_(), ast::StatementList{},