diff --git a/BUILD.gn b/BUILD.gn index 5407a7e6e1..a49c65cae6 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -820,6 +820,7 @@ source_set("tint_unittests_core_src") { "src/inspector/inspector_test.cc", "src/namer_test.cc", "src/program_test.cc", + "src/program_builder_test.cc", "src/scope_stack_test.cc", "src/symbol_table_test.cc", "src/symbol_test.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 207fe4af28..0f34ad9499 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -450,6 +450,7 @@ if(${TINT_BUILD_TESTS}) inspector/inspector_test.cc namer_test.cc program_test.cc + program_builder_test.cc scope_stack_test.cc symbol_table_test.cc symbol_test.cc diff --git a/src/program_builder.cc b/src/program_builder.cc index d4d08fe5ca..146d644e5d 100644 --- a/src/program_builder.cc +++ b/src/program_builder.cc @@ -51,6 +51,18 @@ ProgramBuilder& ProgramBuilder::operator=(ProgramBuilder&& rhs) { return *this; } +ProgramBuilder ProgramBuilder::Wrap(const Program* program) { + ProgramBuilder builder; + builder.types_ = type::Manager::Wrap(program->Types()); + builder.ast_ = builder.create( + program->AST().source(), program->AST().ConstructedTypes(), + program->AST().Functions(), program->AST().GlobalVariables()); + builder.sem_ = semantic::Info::Wrap(program->Sem()); + builder.symbols_ = program->Symbols(); + builder.diagnostics_ = program->Diagnostics(); + return builder; +} + bool ProgramBuilder::IsValid() const { return !diagnostics_.contains_errors() && ast_->IsValid(); } diff --git a/src/program_builder.h b/src/program_builder.h index e5071379b8..5355e1e6bb 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -96,6 +96,18 @@ class ProgramBuilder { /// @return this builder ProgramBuilder& operator=(ProgramBuilder&& rhs); + /// Wrap returns a new ProgramBuilder wrapping the Program `program` without + /// making a deep clone of the Program contents. + /// ProgramBuilder returned by Wrap() is intended to temporarily extend an + /// existing immutable program. + /// As the returned ProgramBuilder wraps `program`, `program` must not be + /// destructed or assigned while using the returned ProgramBuilder. + /// TODO(bclayton) - Evaluate whether there are safer alternatives to this + /// function. See crbug.com/tint/460. + /// @param program the immutable Program to wrap + /// @return the ProgramBuilder that wraps `program` + static ProgramBuilder Wrap(const Program* program); + /// @returns a reference to the program's types type::Manager& Types() { AssertNotMoved(); diff --git a/src/program_builder_test.cc b/src/program_builder_test.cc new file mode 100644 index 0000000000..b5dbaa7876 --- /dev/null +++ b/src/program_builder_test.cc @@ -0,0 +1,65 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/program_builder.h" + +#include "gtest/gtest.h" + +namespace tint { +namespace { + +using ProgramBuilderTest = testing::Test; + +TEST_F(ProgramBuilderTest, WrapDoesntAffectInner) { + Program inner([] { + ProgramBuilder builder; + auto* ty = builder.ty.f32(); + auto* func = builder.Func("a", {}, ty, {}, {}); + builder.AST().Functions().Add(func); + return builder; + }()); + + ASSERT_EQ(inner.AST().Functions().size(), 1u); + ASSERT_TRUE(inner.Symbols().Get("a").IsValid()); + ASSERT_FALSE(inner.Symbols().Get("b").IsValid()); + + ProgramBuilder outer = ProgramBuilder::Wrap(&inner); + + ASSERT_EQ(inner.AST().Functions().size(), 1u); + ASSERT_EQ(outer.AST().Functions().size(), 1u); + EXPECT_EQ(inner.AST().Functions()[0], outer.AST().Functions()[0]); + EXPECT_TRUE(inner.Symbols().Get("a").IsValid()); + EXPECT_EQ(inner.Symbols().Get("a"), outer.Symbols().Get("a")); + EXPECT_TRUE(inner.Symbols().Get("a").IsValid()); + EXPECT_TRUE(outer.Symbols().Get("a").IsValid()); + EXPECT_FALSE(inner.Symbols().Get("b").IsValid()); + EXPECT_FALSE(outer.Symbols().Get("b").IsValid()); + + auto* ty = outer.ty.f32(); + auto* func = outer.Func("b", {}, ty, {}, {}); + outer.AST().Functions().Add(func); + + ASSERT_EQ(inner.AST().Functions().size(), 1u); + ASSERT_EQ(outer.AST().Functions().size(), 2u); + EXPECT_EQ(inner.AST().Functions()[0], outer.AST().Functions()[0]); + EXPECT_EQ(outer.AST().Functions()[1]->symbol(), outer.Symbols().Get("b")); + EXPECT_EQ(inner.Symbols().Get("a"), outer.Symbols().Get("a")); + EXPECT_TRUE(inner.Symbols().Get("a").IsValid()); + EXPECT_TRUE(outer.Symbols().Get("a").IsValid()); + EXPECT_FALSE(inner.Symbols().Get("b").IsValid()); + EXPECT_TRUE(outer.Symbols().Get("b").IsValid()); +} + +} // namespace +} // namespace tint diff --git a/src/semantic/info.h b/src/semantic/info.h index fff6558944..4b36b2fca9 100644 --- a/src/semantic/info.h +++ b/src/semantic/info.h @@ -35,6 +35,18 @@ class Info { /// @param rhs the Program to move /// @return this Program Info& operator=(Info&& rhs); + + /// Wrap returns a new Info created with the contents of `inner`. + /// The Info returned by Wrap is intended to temporarily extend the contents + /// of an existing immutable Info. + /// As the copied contents are owned by `inner`, `inner` must not be + /// destructed or assigned while using the returned Info. + /// @param inner the immutable Info to extend + /// @return the Info that wraps `inner` + static Info Wrap(const Info& inner) { + (void)inner; + return Info(); + } }; } // namespace semantic diff --git a/src/type/type_manager.h b/src/type/type_manager.h index 779521b3b4..a3ece09971 100644 --- a/src/type/type_manager.h +++ b/src/type/type_manager.h @@ -68,6 +68,8 @@ class Manager { /// of an existing immutable Manager. /// As the copied types are owned by `inner`, `inner` must not be destructed /// or assigned while using the returned Manager. + /// TODO(bclayton) - Evaluate whether there are safer alternatives to this + /// function. See crbug.com/tint/460. /// @param inner the immutable Manager to extend /// @return the Manager that wraps `inner` static Manager Wrap(const Manager& inner) { diff --git a/src/writer/append_vector.cc b/src/writer/append_vector.cc index 4873723760..0ac2447a0e 100644 --- a/src/writer/append_vector.cc +++ b/src/writer/append_vector.cc @@ -18,6 +18,7 @@ #include "src/ast/expression.h" #include "src/ast/type_constructor_expression.h" +#include "src/semantic/info.h" #include "src/type/vector_type.h" namespace tint { @@ -36,10 +37,9 @@ ast::TypeConstructorExpression* AsVectorConstructor(ast::Expression* expr) { } // namespace -bool AppendVector( - ast::Expression* vector, - ast::Expression* scalar, - std::function callback) { +ast::TypeConstructorExpression* AppendVector(ProgramBuilder* b, + ast::Expression* vector, + ast::Expression* scalar) { uint32_t packed_size; type::Type* packed_el_ty; // Currently must be f32. if (auto* vec = vector->result_type()->As()) { @@ -51,14 +51,14 @@ bool AppendVector( } if (!packed_el_ty) { - return false; // missing type info + return nullptr; // missing type info } // Cast scalar to the vector element type - ast::TypeConstructorExpression scalar_cast(Source{}, packed_el_ty, {scalar}); - scalar_cast.set_result_type(packed_el_ty); + auto* scalar_cast = b->Construct(packed_el_ty, scalar); + scalar_cast->set_result_type(packed_el_ty); - type::Vector packed_ty(packed_el_ty, packed_size); + auto* packed_ty = b->create(packed_el_ty, packed_size); // If the coordinates are already passed in a vector constructor, extract // the elements into the new vector instead of nesting a vector-in-vector. @@ -69,16 +69,15 @@ bool AppendVector( packed.emplace_back(vector); } if (packed_el_ty != scalar->result_type()) { - packed.emplace_back(&scalar_cast); + packed.emplace_back(scalar_cast); } else { packed.emplace_back(scalar); } - ast::TypeConstructorExpression constructor{Source{}, &packed_ty, - std::move(packed)}; - constructor.set_result_type(&packed_ty); + auto* constructor = b->Construct(packed_ty, std::move(packed)); + constructor->set_result_type(packed_ty); - return callback(&constructor); + return constructor; } } // namespace writer diff --git a/src/writer/append_vector.h b/src/writer/append_vector.h index d431336f19..818b7fd084 100644 --- a/src/writer/append_vector.h +++ b/src/writer/append_vector.h @@ -15,9 +15,7 @@ #ifndef SRC_WRITER_APPEND_VECTOR_H_ #define SRC_WRITER_APPEND_VECTOR_H_ -#include - -#include "src/source.h" +#include "src/program_builder.h" namespace tint { @@ -32,20 +30,16 @@ namespace writer { /// AppendVector is used to generate texture intrinsic function calls for /// backends that expect the texture coordinates to be packed with an additional /// mip-level or array-index parameter. -/// AppendVector() calls the `callback` function with a vector -/// expression containing the elements of `vector` followed by the single -/// element of `scalar` cast to the `vector` element type. /// All types must have been assigned to the expressions and their child nodes /// before calling. +/// @param builder the program builder. /// @param vector the vector to be appended. May be a scalar, `vec2` or `vec3`. /// @param scalar the scalar to append to the vector. Must be a scalar. -/// @param callback the function called with the packed result. Note that the -/// pointer argument is only valid for the duration of the call. -/// @returns the value returned by `callback` to indicate success -bool AppendVector( - ast::Expression* vector, - ast::Expression* scalar, - std::function callback); +/// @returns a vector expression containing the elements of `vector` followed by +/// the single element of `scalar` cast to the `vector` element type. +ast::TypeConstructorExpression* AppendVector(ProgramBuilder* builder, + ast::Expression* vector, + ast::Expression* scalar); } // namespace writer } // namespace tint diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index e667955d06..5e3caeed0d 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -128,7 +128,7 @@ const char* image_format_to_rwtexture_type(type::ImageFormat image_format) { } // namespace GeneratorImpl::GeneratorImpl(const Program* program) - : program_(program), types_(type::Manager::Wrap(program->Types())) {} + : builder_(ProgramBuilder::Wrap(program)) {} GeneratorImpl::~GeneratorImpl() = default; @@ -139,20 +139,20 @@ void GeneratorImpl::make_indent(std::ostream& out) { } bool GeneratorImpl::Generate(std::ostream& out) { - for (auto* global : program_->AST().GlobalVariables()) { + for (auto* global : builder_.AST().GlobalVariables()) { register_global(global); } - for (auto* const ty : program_->AST().ConstructedTypes()) { + for (auto* const ty : builder_.AST().ConstructedTypes()) { if (!EmitConstructedType(out, ty)) { return false; } } - if (!program_->AST().ConstructedTypes().empty()) { + if (!builder_.AST().ConstructedTypes().empty()) { out << std::endl; } - for (auto* var : program_->AST().GlobalVariables()) { + for (auto* var : builder_.AST().GlobalVariables()) { if (!var->is_const()) { continue; } @@ -163,7 +163,7 @@ bool GeneratorImpl::Generate(std::ostream& out) { std::unordered_set emitted_globals; // Make sure all entry point data is emitted before the entry point functions - for (auto* func : program_->AST().Functions()) { + for (auto* func : builder_.AST().Functions()) { if (!func->IsEntryPoint()) { continue; } @@ -173,13 +173,13 @@ bool GeneratorImpl::Generate(std::ostream& out) { } } - for (auto* func : program_->AST().Functions()) { + for (auto* func : builder_.AST().Functions()) { if (!EmitFunction(out, func)) { return false; } } - for (auto* func : program_->AST().Functions()) { + for (auto* func : builder_.AST().Functions()) { if (!func->IsEntryPoint()) { continue; } @@ -236,7 +236,7 @@ bool GeneratorImpl::EmitConstructedType(std::ostream& out, // generate a secondary struct with the new name. if (auto* str = alias->type()->As()) { if (!EmitStructType(out, str, - program_->Symbols().NameFor(alias->symbol()))) { + builder_.Symbols().NameFor(alias->symbol()))) { return false; } return true; @@ -245,10 +245,10 @@ bool GeneratorImpl::EmitConstructedType(std::ostream& out, if (!EmitType(out, alias->type(), "")) { return false; } - out << " " << namer_.NameFor(program_->Symbols().NameFor(alias->symbol())) + out << " " << namer_.NameFor(builder_.Symbols().NameFor(alias->symbol())) << ";" << std::endl; } else if (auto* str = ty->As()) { - if (!EmitStructType(out, str, program_->Symbols().NameFor(str->symbol()))) { + if (!EmitStructType(out, str, builder_.Symbols().NameFor(str->symbol()))) { return false; } } else { @@ -627,7 +627,7 @@ bool GeneratorImpl::EmitCall(std::ostream& pre, return true; } - auto name = program_->Symbols().NameFor(ident->symbol()); + auto name = builder_.Symbols().NameFor(ident->symbol()); auto caller_sym = ident->symbol(); auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" + caller_sym.to_str()); @@ -635,10 +635,10 @@ bool GeneratorImpl::EmitCall(std::ostream& pre, name = it->second; } - auto* func = program_->AST().Functions().Find(ident->symbol()); + auto* func = builder_.AST().Functions().Find(ident->symbol()); if (func == nullptr) { error_ = "Unable to find function: " + - program_->Symbols().NameFor(ident->symbol()); + builder_.Symbols().NameFor(ident->symbol()); return false; } @@ -872,7 +872,7 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& pre, break; default: error_ = "Internal compiler error: Unhandled texture intrinsic '" + - program_->Symbols().NameFor(ident->symbol()) + "'"; + builder_.Symbols().NameFor(ident->symbol()) + "'"; return false; } @@ -885,28 +885,25 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& pre, auto* param_coords = params[pidx.coords]; auto emit_vector_appended_with_i32_zero = [&](tint::ast::Expression* vector) { - auto* i32 = types_.Get(); - ast::SintLiteral zero_lit(Source{}, i32, 0); - ast::ScalarConstructorExpression zero(Source{}, &zero_lit); - zero.set_result_type(i32); - return AppendVector(vector, &zero, - [&](ast::TypeConstructorExpression* packed) { - return EmitExpression(pre, out, packed); - }); + auto* i32 = builder_.create(); + auto* zero = builder_.Expr(0); + zero->set_result_type(i32); + auto* packed = AppendVector(&builder_, vector, zero); + return EmitExpression(pre, out, packed); }; if (pidx.array_index != kNotUsed) { // Array index needs to be appended to the coordinates. auto* param_array_index = params[pidx.array_index]; - if (!AppendVector(param_coords, param_array_index, - [&](ast::TypeConstructorExpression* packed) { - if (pack_mip_in_coords) { - return emit_vector_appended_with_i32_zero(packed); - } else { - return EmitExpression(pre, out, packed); - } - })) { - return false; + auto* packed = AppendVector(&builder_, param_coords, param_array_index); + if (pack_mip_in_coords) { + if (!emit_vector_appended_with_i32_zero(packed)) { + return false; + } + } else { + if (!EmitExpression(pre, out, packed)) { + return false; + } } } else if (pack_mip_in_coords) { // Mip level needs to be appended to the coordinates, but is always zero. @@ -935,7 +932,7 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& pre, } return true; -} +} // namespace hlsl std::string GeneratorImpl::generate_builtin_name(ast::CallExpression* expr) { std::string out; @@ -975,7 +972,7 @@ std::string GeneratorImpl::generate_builtin_name(ast::CallExpression* expr) { case ast::Intrinsic::kMax: case ast::Intrinsic::kMin: case ast::Intrinsic::kClamp: - out = program_->Symbols().NameFor(ident->symbol()); + out = builder_.Symbols().NameFor(ident->symbol()); break; case ast::Intrinsic::kFaceForward: out = "faceforward"; @@ -991,7 +988,7 @@ std::string GeneratorImpl::generate_builtin_name(ast::CallExpression* expr) { break; default: error_ = "Unknown builtin method: " + - program_->Symbols().NameFor(ident->symbol()); + builder_.Symbols().NameFor(ident->symbol()); return ""; } @@ -1141,7 +1138,7 @@ bool GeneratorImpl::EmitExpression(std::ostream& pre, return EmitUnaryOp(pre, out, u); } - error_ = "unknown expression type: " + program_->str(expr); + error_ = "unknown expression type: " + builder_.str(expr); return false; } @@ -1174,9 +1171,9 @@ bool GeneratorImpl::EmitIdentifier(std::ostream&, // Swizzles output the name directly if (ident->IsSwizzle()) { - out << program_->Symbols().NameFor(ident->symbol()); + out << builder_.Symbols().NameFor(ident->symbol()); } else { - out << namer_.NameFor(program_->Symbols().NameFor(ident->symbol())); + out << namer_.NameFor(builder_.Symbols().NameFor(ident->symbol())); } return true; @@ -1339,12 +1336,12 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, auto ep_name = ep_sym.to_str(); // TODO(dsinclair): The SymbolToName should go away and just use // to_str() here when the conversion is complete. - name = generate_name(program_->Symbols().NameFor(func->symbol()) + "_" + - program_->Symbols().NameFor(ep_sym)); + name = generate_name(builder_.Symbols().NameFor(func->symbol()) + "_" + + builder_.Symbols().NameFor(ep_sym)); ep_func_name_remapped_[ep_name + "_" + func_name] = name; } else { // TODO(dsinclair): this should be updated to a remapped name - name = namer_.NameFor(program_->Symbols().NameFor(func->symbol())); + name = namer_.NameFor(builder_.Symbols().NameFor(func->symbol())); } out << name << "("; @@ -1380,12 +1377,12 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, } first = false; - if (!EmitType(out, v->type(), program_->Symbols().NameFor(v->symbol()))) { + if (!EmitType(out, v->type(), builder_.Symbols().NameFor(v->symbol()))) { return false; } // Array name is output as part of the type if (!v->type()->Is()) { - out << " " << program_->Symbols().NameFor(v->symbol()); + out << " " << builder_.Symbols().NameFor(v->symbol()); } } @@ -1439,7 +1436,7 @@ bool GeneratorImpl::EmitEntryPointData( auto* binding = data.second.binding; if (binding == nullptr) { error_ = "unable to find binding information for uniform: " + - program_->Symbols().NameFor(var->symbol()); + builder_.Symbols().NameFor(var->symbol()); return false; } // auto* set = data.second.set; @@ -1453,8 +1450,8 @@ bool GeneratorImpl::EmitEntryPointData( auto* type = var->type()->UnwrapIfNeeded(); if (auto* strct = type->As()) { - out << "ConstantBuffer<" << program_->Symbols().NameFor(strct->symbol()) - << "> " << program_->Symbols().NameFor(var->symbol()) + out << "ConstantBuffer<" << builder_.Symbols().NameFor(strct->symbol()) + << "> " << builder_.Symbols().NameFor(var->symbol()) << " : register(b" << binding->value() << ");" << std::endl; } else { // TODO(dsinclair): There is outstanding spec work to require all uniform @@ -1463,7 +1460,7 @@ bool GeneratorImpl::EmitEntryPointData( // is not a block. // Relevant: https://github.com/gpuweb/gpuweb/issues/1004 // https://github.com/gpuweb/gpuweb/issues/1008 - auto name = "cbuffer_" + program_->Symbols().NameFor(var->symbol()); + auto name = "cbuffer_" + builder_.Symbols().NameFor(var->symbol()); out << "cbuffer " << name << " : register(b" << binding->value() << ") {" << std::endl; @@ -1472,7 +1469,7 @@ bool GeneratorImpl::EmitEntryPointData( if (!EmitType(out, type, "")) { return false; } - out << " " << program_->Symbols().NameFor(var->symbol()) << ";" + out << " " << builder_.Symbols().NameFor(var->symbol()) << ";" << std::endl; decrement_indent(); out << "};" << std::endl; @@ -1505,7 +1502,7 @@ bool GeneratorImpl::EmitEntryPointData( if (ac->IsReadWrite()) { out << "RW"; } - out << "ByteAddressBuffer " << program_->Symbols().NameFor(var->symbol()) + out << "ByteAddressBuffer " << builder_.Symbols().NameFor(var->symbol()) << " : register(u" << binding->value() << ");" << std::endl; emitted_storagebuffer = true; } @@ -1514,9 +1511,8 @@ bool GeneratorImpl::EmitEntryPointData( } if (!in_variables.empty()) { - auto in_struct_name = - generate_name(program_->Symbols().NameFor(func->symbol()) + "_" + - kInStructNameSuffix); + auto in_struct_name = generate_name( + builder_.Symbols().NameFor(func->symbol()) + "_" + kInStructNameSuffix); auto in_var_name = generate_name(kTintStructInVarPrefix); ep_sym_to_in_data_[func->symbol()] = {in_struct_name, in_var_name}; @@ -1531,11 +1527,11 @@ bool GeneratorImpl::EmitEntryPointData( make_indent(out); if (!EmitType(out, var->type(), - program_->Symbols().NameFor(var->symbol()))) { + builder_.Symbols().NameFor(var->symbol()))) { return false; } - out << " " << program_->Symbols().NameFor(var->symbol()) << " : "; + out << " " << builder_.Symbols().NameFor(var->symbol()) << " : "; if (auto* location = deco->As()) { if (func->pipeline_stage() == ast::PipelineStage::kCompute) { error_ = "invalid location variable for pipeline stage"; @@ -1563,7 +1559,7 @@ bool GeneratorImpl::EmitEntryPointData( if (!outvariables.empty()) { auto outstruct_name = - generate_name(program_->Symbols().NameFor(func->symbol()) + "_" + + generate_name(builder_.Symbols().NameFor(func->symbol()) + "_" + kOutStructNameSuffix); auto outvar_name = generate_name(kTintStructOutVarPrefix); ep_sym_to_out_data_[func->symbol()] = {outstruct_name, outvar_name}; @@ -1578,11 +1574,11 @@ bool GeneratorImpl::EmitEntryPointData( make_indent(out); if (!EmitType(out, var->type(), - program_->Symbols().NameFor(var->symbol()))) { + builder_.Symbols().NameFor(var->symbol()))) { return false; } - out << " " << program_->Symbols().NameFor(var->symbol()) << " : "; + out << " " << builder_.Symbols().NameFor(var->symbol()) << " : "; if (auto* location = deco->As()) { auto loc = location->value(); @@ -1665,7 +1661,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, out << "void"; } // TODO(dsinclair): This should output the remapped name - out << " " << namer_.NameFor(program_->Symbols().NameFor(current_ep_sym_)) + out << " " << namer_.NameFor(builder_.Symbols().NameFor(current_ep_sym_)) << "("; auto in_data = ep_sym_to_in_data_.find(current_ep_sym_); @@ -1814,7 +1810,7 @@ bool GeneratorImpl::EmitLoop(std::ostream& out, ast::LoopStatement* stmt) { } out << pre.str(); - out << program_->Symbols().NameFor(var->symbol()) << " = "; + out << builder_.Symbols().NameFor(var->symbol()) << " = "; if (var->constructor() != nullptr) { out << constructor_out.str(); } else { @@ -1877,7 +1873,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( // // This must be a single element swizzle if we've got a vector at this // point. - if (program_->Symbols().NameFor(mem->member()->symbol()).size() != 1) { + if (builder_.Symbols().NameFor(mem->member()->symbol()).size() != 1) { error_ = "Encountered multi-element swizzle when should have only one " "level"; @@ -1889,7 +1885,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( // f64 types. out << "(4 * " << convert_swizzle_to_index( - program_->Symbols().NameFor(mem->member()->symbol())) + builder_.Symbols().NameFor(mem->member()->symbol())) << ")"; } else { error_ = @@ -2068,7 +2064,7 @@ bool GeneratorImpl::is_storage_buffer_access( // If the data is a multi-element swizzle then we will not load the swizzle // portion through the Load command. if (data_type->Is() && - program_->Symbols().NameFor(expr->member()->symbol()).size() > 1) { + builder_.Symbols().NameFor(expr->member()->symbol()).size() > 1) { return false; } @@ -2179,7 +2175,7 @@ bool GeneratorImpl::EmitStatement(std::ostream& out, ast::Statement* stmt) { return EmitVariable(out, v->variable(), false); } - error_ = "unknown statement type: " + program_->str(stmt); + error_ = "unknown statement type: " + builder_.str(stmt); return false; } @@ -2220,7 +2216,7 @@ bool GeneratorImpl::EmitType(std::ostream& out, } if (auto* alias = type->As()) { - out << namer_.NameFor(program_->Symbols().NameFor(alias->symbol())); + out << namer_.NameFor(builder_.Symbols().NameFor(alias->symbol())); } else if (auto* ary = type->As()) { type::Type* base_type = ary; std::vector sizes; @@ -2267,7 +2263,7 @@ bool GeneratorImpl::EmitType(std::ostream& out, } out << "State"; } else if (auto* str = type->As()) { - out << program_->Symbols().NameFor(str->symbol()); + out << builder_.Symbols().NameFor(str->symbol()); } else if (auto* tex = type->As()) { if (tex->Is()) { out << "RW"; @@ -2352,12 +2348,12 @@ bool GeneratorImpl::EmitStructType(std::ostream& out, // https://bugs.chromium.org/p/tint/issues/detail?id=184 if (!EmitType(out, mem->type(), - program_->Symbols().NameFor(mem->symbol()))) { + builder_.Symbols().NameFor(mem->symbol()))) { return false; } // Array member name will be output with the type if (!mem->type()->Is()) { - out << " " << namer_.NameFor(program_->Symbols().NameFor(mem->symbol())); + out << " " << namer_.NameFor(builder_.Symbols().NameFor(mem->symbol())); } out << ";" << std::endl; } @@ -2416,11 +2412,11 @@ bool GeneratorImpl::EmitVariable(std::ostream& out, if (var->is_const()) { out << "const "; } - if (!EmitType(out, var->type(), program_->Symbols().NameFor(var->symbol()))) { + if (!EmitType(out, var->type(), builder_.Symbols().NameFor(var->symbol()))) { return false; } if (!var->type()->Is()) { - out << " " << program_->Symbols().NameFor(var->symbol()); + out << " " << builder_.Symbols().NameFor(var->symbol()); } out << constructor_out.str() << ";" << std::endl; @@ -2466,20 +2462,20 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out, out << "#endif" << std::endl; out << "static const "; if (!EmitType(out, var->type(), - program_->Symbols().NameFor(var->symbol()))) { + builder_.Symbols().NameFor(var->symbol()))) { return false; } - out << " " << program_->Symbols().NameFor(var->symbol()) + out << " " << builder_.Symbols().NameFor(var->symbol()) << " = WGSL_SPEC_CONSTANT_" << const_id << ";" << std::endl; out << "#undef WGSL_SPEC_CONSTANT_" << const_id << std::endl; } else { out << "static const "; if (!EmitType(out, var->type(), - program_->Symbols().NameFor(var->symbol()))) { + builder_.Symbols().NameFor(var->symbol()))) { return false; } if (!var->type()->Is()) { - out << " " << program_->Symbols().NameFor(var->symbol()); + out << " " << builder_.Symbols().NameFor(var->symbol()); } if (var->constructor() != nullptr) { @@ -2494,7 +2490,7 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out, std::string GeneratorImpl::get_buffer_name(ast::Expression* expr) { for (;;) { if (auto* ident = expr->As()) { - return program_->Symbols().NameFor(ident->symbol()); + return builder_.Symbols().NameFor(ident->symbol()); } else if (auto* member = expr->As()) { expr = member->structure(); } else if (auto* array = expr->As()) { diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index d1ca0ab045..4a10975a23 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -39,7 +39,7 @@ #include "src/ast/switch_statement.h" #include "src/ast/type_constructor_expression.h" #include "src/ast/unary_op_expression.h" -#include "src/program.h" +#include "src/program_builder.h" #include "src/scope_stack.h" #include "src/type/struct_type.h" #include "src/writer/hlsl/namer.h" @@ -394,8 +394,7 @@ class GeneratorImpl { size_t indent_ = 0; Namer namer_; - const Program* program_ = nullptr; - type::Manager types_; + ProgramBuilder builder_; Symbol current_ep_sym_; bool generating_entry_point_ = false; uint32_t loop_emission_counter_ = 0; diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index b6c826649a..577a8420e3 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -289,9 +289,7 @@ Builder::AccessorInfo::AccessorInfo() : source_id(0), source_type(nullptr) {} Builder::AccessorInfo::~AccessorInfo() {} Builder::Builder(const Program* program) - : program_(program), - type_mgr_(type::Manager::Wrap(program->Types())), - scope_stack_({}) {} + : builder_(ProgramBuilder::Wrap(program)), scope_stack_({}) {} Builder::~Builder() = default; @@ -302,13 +300,13 @@ bool Builder::Build() { {Operand::Int(SpvAddressingModelLogical), Operand::Int(SpvMemoryModelGLSL450)}); - for (auto* var : program_->AST().GlobalVariables()) { + for (auto* var : builder_.AST().GlobalVariables()) { if (!GenerateGlobalVariable(var)) { return false; } } - for (auto* func : program_->AST().Functions()) { + for (auto* func : builder_.AST().Functions()) { if (!GenerateFunction(func)) { return false; } @@ -455,7 +453,7 @@ bool Builder::GenerateEntryPoint(ast::Function* func, uint32_t id) { OperandList operands = { Operand::Int(stage), Operand::Int(id), - Operand::String(program_->Symbols().NameFor(func->symbol()))}; + Operand::String(builder_.Symbols().NameFor(func->symbol()))}; for (const auto* var : func->referenced_module_variables()) { // For SPIR-V 1.3 we only output Input/output variables. If we update to @@ -468,7 +466,7 @@ bool Builder::GenerateEntryPoint(ast::Function* func, uint32_t id) { uint32_t var_id; if (!scope_stack_.get(var->symbol(), &var_id)) { error_ = "unable to find ID for global variable: " + - program_->Symbols().NameFor(var->symbol()); + builder_.Symbols().NameFor(var->symbol()); return false; } @@ -533,7 +531,7 @@ uint32_t Builder::GenerateExpression(ast::Expression* expr) { return GenerateUnaryOpExpression(u); } - error_ = "unknown expression type: " + program_->str(expr); + error_ = "unknown expression type: " + builder_.str(expr); return 0; } @@ -548,7 +546,7 @@ bool Builder::GenerateFunction(ast::Function* func) { push_debug(spv::Op::OpName, {Operand::Int(func_id), - Operand::String(program_->Symbols().NameFor(func->symbol()))}); + Operand::String(builder_.Symbols().NameFor(func->symbol()))}); auto ret_id = GenerateTypeIfNeeded(func->return_type()); if (ret_id == 0) { @@ -574,7 +572,7 @@ bool Builder::GenerateFunction(ast::Function* func) { push_debug(spv::Op::OpName, {Operand::Int(param_id), - Operand::String(program_->Symbols().NameFor(param->symbol()))}); + Operand::String(builder_.Symbols().NameFor(param->symbol()))}); params.push_back(Instruction{spv::Op::OpFunctionParameter, {Operand::Int(param_type_id), param_op}}); @@ -665,7 +663,7 @@ bool Builder::GenerateFunctionVariable(ast::Variable* var) { push_debug(spv::Op::OpName, {Operand::Int(var_id), - Operand::String(program_->Symbols().NameFor(var->symbol()))}); + Operand::String(builder_.Symbols().NameFor(var->symbol()))}); // TODO(dsinclair) We could detect if the constructor is fully const and emit // an initializer value for the variable instead of doing the OpLoad. @@ -717,7 +715,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { } push_debug(spv::Op::OpName, {Operand::Int(init_id), - Operand::String(program_->Symbols().NameFor(var->symbol()))}); + Operand::String(builder_.Symbols().NameFor(var->symbol()))}); scope_stack_.set_global(var->symbol(), init_id); spirv_id_to_variable_[init_id] = var; @@ -739,7 +737,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { push_debug(spv::Op::OpName, {Operand::Int(var_id), - Operand::String(program_->Symbols().NameFor(var->symbol()))}); + Operand::String(builder_.Symbols().NameFor(var->symbol()))}); OperandList ops = {Operand::Int(type_id), result, Operand::Int(ConvertStorageClass(sc))}; @@ -920,7 +918,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, } // TODO(dsinclair): Swizzle stuff - auto swiz = program_->Symbols().NameFor(expr->member()->symbol()); + auto swiz = builder_.Symbols().NameFor(expr->member()->symbol()); // Single element swizzle is either an access chain or a composite extract if (swiz.size() == 1) { auto val = IndexFromName(swiz[0]); @@ -1091,7 +1089,7 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { } } else { - error_ = "invalid accessor in list: " + program_->str(accessor); + error_ = "invalid accessor in list: " + builder_.str(accessor); return 0; } } @@ -1128,7 +1126,7 @@ uint32_t Builder::GenerateIdentifierExpression( } error_ = "unable to find variable with identifier: " + - program_->Symbols().NameFor(expr->symbol()); + builder_.Symbols().NameFor(expr->symbol()); return 0; } @@ -1821,7 +1819,7 @@ uint32_t Builder::GenerateCallExpression(ast::CallExpression* expr) { auto func_id = func_symbol_to_id_[ident->symbol()]; if (func_id == 0) { error_ = "unable to find called function: " + - program_->Symbols().NameFor(ident->symbol()); + builder_.Symbols().NameFor(ident->symbol()); return 0; } ops.push_back(Operand::Int(func_id)); @@ -1953,7 +1951,7 @@ uint32_t Builder::GenerateIntrinsic(ast::IdentifierExpression* ident, auto inst_id = intrinsic_to_glsl_method(ident->result_type(), ident->intrinsic()); if (inst_id == 0) { - error_ = "unknown method " + program_->Symbols().NameFor(ident->symbol()); + error_ = "unknown method " + builder_.Symbols().NameFor(ident->symbol()); return 0; } @@ -1965,7 +1963,7 @@ uint32_t Builder::GenerateIntrinsic(ast::IdentifierExpression* ident, if (op == spv::Op::OpNop) { error_ = "unable to determine operator for: " + - program_->Symbols().NameFor(ident->symbol()); + builder_.Symbols().NameFor(ident->symbol()); return 0; } @@ -2048,8 +2046,8 @@ bool Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident, // to calling append_result_type_and_id_to_spirv_params(). auto append_result_type_and_id_to_spirv_params_for_read = [&]() { if (texture_type->Is()) { - auto* f32 = type_mgr_.Get(); - auto* spirv_result_type = type_mgr_.Get(f32, 4); + auto* f32 = builder_.create(); + auto* spirv_result_type = builder_.create(f32, 4); auto spirv_result = result_op(); post_emission = [=] { return push_function_inst( @@ -2081,7 +2079,7 @@ bool Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident, auto* element_type = ElementTypeOf(call->result_type()); auto spirv_result = result_op(); auto* spirv_result_type = - type_mgr_.Get(element_type, spirv_result_width); + builder_.create(element_type, spirv_result_width); if (swizzle.size() > 1) { post_emission = [=] { OperandList operands{ @@ -2118,18 +2116,12 @@ bool Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident, auto* param_coords = call->params()[pidx.coords]; auto* param_array_index = call->params()[pidx.array_index]; - if (!AppendVector(param_coords, param_array_index, - [&](ast::TypeConstructorExpression* packed) { - auto param = - GenerateTypeConstructorExpression(packed, false); - if (param == 0) { - return false; - } - spirv_params.emplace_back(Operand::Int(param)); - return true; - })) { + auto* packed = AppendVector(&builder_, param_coords, param_array_index); + auto param = GenerateTypeConstructorExpression(packed, false); + if (param == 0) { return false; } + spirv_params.emplace_back(Operand::Int(param)); } else { spirv_params.emplace_back(gen_param(pidx.coords)); // coordinates } @@ -2199,7 +2191,7 @@ bool Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident, op = spv::Op::OpImageQuerySizeLod; spirv_params.emplace_back(gen_param(pidx.level)); } else { - ast::SintLiteral i32_0(Source{}, type_mgr_.Get(), 0); + ast::SintLiteral i32_0(Source{}, builder_.create(), 0); op = spv::Op::OpImageQuerySizeLod; spirv_params.emplace_back( Operand::Int(GenerateLiteralIfNeeded(nullptr, &i32_0))); @@ -2234,7 +2226,7 @@ bool Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident, texture_type->Is()) { op = spv::Op::OpImageQuerySize; } else { - ast::SintLiteral i32_0(Source{}, type_mgr_.Get(), 0); + ast::SintLiteral i32_0(Source{}, builder_.create(), 0); op = spv::Op::OpImageQuerySizeLod; spirv_params.emplace_back( Operand::Int(GenerateLiteralIfNeeded(nullptr, &i32_0))); @@ -2313,7 +2305,7 @@ bool Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident, if (call->params()[pidx.level]->result_type()->Is()) { // Depth textures have i32 parameters for the level, but SPIR-V expects // F32. Cast. - auto* f32 = type_mgr_.Get(); + auto* f32 = builder_.create(); ast::TypeConstructorExpression cast(Source{}, f32, {call->params()[pidx.level]}); level = Operand::Int(GenerateExpression(&cast)); @@ -2380,7 +2372,7 @@ bool Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident, if (op == spv::Op::OpNop) { error_ = "unable to determine operator for: " + - program_->Symbols().NameFor(ident->symbol()); + builder_.Symbols().NameFor(ident->symbol()); return false; } @@ -2763,7 +2755,7 @@ bool Builder::GenerateStatement(ast::Statement* stmt) { return GenerateVariableDeclStatement(v); } - error_ = "Unknown statement: " + program_->str(stmt); + error_ = "Unknown statement: " + builder_.str(stmt); return false; } @@ -3002,7 +2994,7 @@ bool Builder::GenerateStructType(type::Struct* struct_type, push_debug( spv::Op::OpName, {Operand::Int(struct_id), - Operand::String(program_->Symbols().NameFor(struct_type->symbol()))}); + Operand::String(builder_.Symbols().NameFor(struct_type->symbol()))}); } OperandList ops; @@ -3045,7 +3037,7 @@ uint32_t Builder::GenerateStructMember(uint32_t struct_id, ast::StructMember* member) { push_debug(spv::Op::OpMemberName, {Operand::Int(struct_id), Operand::Int(idx), - Operand::String(program_->Symbols().NameFor(member->symbol()))}); + Operand::String(builder_.Symbols().NameFor(member->symbol()))}); bool has_layout = false; for (auto* deco : member->decorations()) { diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index f58ebcd081..2c73098091 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -43,7 +43,7 @@ #include "src/ast/type_constructor_expression.h" #include "src/ast/unary_op_expression.h" #include "src/ast/variable_decl_statement.h" -#include "src/program.h" +#include "src/program_builder.h" #include "src/scope_stack.h" #include "src/type/access_control_type.h" #include "src/type/array_type.h" @@ -488,8 +488,7 @@ class Builder { /// automatically. Operand result_op(); - const Program* program_; - type::Manager type_mgr_; + ProgramBuilder builder_; std::string error_; uint32_t next_id_ = 1; uint32_t current_label_id_ = 0;