transform: Fix multiple mutability issues with VertexPulling

ConvertVertexInputVariablesToPrivate() mutated the source program global variables, and copied them into the destination program.

Symbols and types were assigned across the program boundary without cloning.

Bug: tint:390
Change-Id: I03c8924e6ba94b745e74de0ab57f8a489e85cc50
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/38554
Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2021-01-26 16:57:10 +00:00
parent 5d9b1c38c2
commit deb02019d5
2 changed files with 121 additions and 111 deletions

View File

@ -103,21 +103,24 @@ Transform::Output VertexPulling::Run(const Program* in) {
// following stages will pass // following stages will pass
Output out; Output out;
State state{in, &out.program, cfg}; CloneContext ctx(&out.program, in);
State state{ctx, cfg};
state.FindOrInsertVertexIndexIfUsed(); state.FindOrInsertVertexIndexIfUsed();
state.FindOrInsertInstanceIndexIfUsed(); state.FindOrInsertInstanceIndexIfUsed();
state.ConvertVertexInputVariablesToPrivate(); state.ConvertVertexInputVariablesToPrivate();
state.AddVertexStorageBuffers(); state.AddVertexStorageBuffers();
CloneContext(&out.program, in) for (auto& replacement : state.location_replacements) {
.ReplaceAll([&](CloneContext* ctx, ast::Function* f) -> ast::Function* { ctx.Replace(replacement.from, replacement.to);
if (f == func) { }
return CloneWithStatementsAtStart( ctx.ReplaceAll([&](CloneContext*, ast::Function* f) -> ast::Function* {
ctx, f, {state.CreateVertexPullingPreamble()}); if (f == func) {
} return CloneWithStatementsAtStart(&ctx, f,
return nullptr; // Just clone func {state.CreateVertexPullingPreamble()});
}) }
.Clone(); return nullptr; // Just clone func
});
ctx.Clone();
return out; return out;
} }
@ -126,8 +129,8 @@ VertexPulling::Config::Config() = default;
VertexPulling::Config::Config(const Config&) = default; VertexPulling::Config::Config(const Config&) = default;
VertexPulling::Config::~Config() = default; VertexPulling::Config::~Config() = default;
VertexPulling::State::State(const Program* i, Program* o, const Config& c) VertexPulling::State::State(CloneContext& context, const Config& c)
: in(i), out(o), cfg(c) {} : ctx(context), cfg(c) {}
VertexPulling::State::State(const State&) = default; VertexPulling::State::State(const State&) = default;
@ -150,7 +153,7 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
} }
// Look for an existing vertex index builtin // Look for an existing vertex index builtin
for (auto* v : in->AST().GlobalVariables()) { for (auto* v : ctx.src->AST().GlobalVariables()) {
if (v->storage_class() != ast::StorageClass::kInput) { if (v->storage_class() != ast::StorageClass::kInput) {
continue; continue;
} }
@ -158,7 +161,7 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
for (auto* d : v->decorations()) { for (auto* d : v->decorations()) {
if (auto* builtin = d->As<ast::BuiltinDecoration>()) { if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
if (builtin->value() == ast::Builtin::kVertexIndex) { if (builtin->value() == ast::Builtin::kVertexIndex) {
vertex_index_name = in->Symbols().NameFor(v->symbol()); vertex_index_name = ctx.src->Symbols().NameFor(v->symbol());
return; return;
} }
} }
@ -168,20 +171,19 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
// We didn't find a vertex index builtin, so create one // We didn't find a vertex index builtin, so create one
vertex_index_name = kDefaultVertexIndexName; vertex_index_name = kDefaultVertexIndexName;
auto* var = out->create<ast::Variable>( auto* var = ctx.dst->create<ast::Variable>(
Source{}, // source Source{}, // source
out->Symbols().Register(vertex_index_name), // symbol ctx.dst->Symbols().Register(vertex_index_name), // symbol
ast::StorageClass::kInput, // storage_class ast::StorageClass::kInput, // storage_class
GetI32Type(), // type GetI32Type(), // type
false, // is_const false, // is_const
nullptr, // constructor nullptr, // constructor
ast::VariableDecorationList{ ast::VariableDecorationList{
// decorations ctx.dst->create<ast::BuiltinDecoration>(Source{},
out->create<ast::BuiltinDecoration>(Source{}, ast::Builtin::kVertexIndex),
ast::Builtin::kVertexIndex),
}); });
out->AST().AddGlobalVariable(var); ctx.dst->AST().AddGlobalVariable(var);
} }
void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() { void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
@ -197,7 +199,7 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
} }
// Look for an existing instance index builtin // Look for an existing instance index builtin
for (auto* v : in->AST().GlobalVariables()) { for (auto* v : ctx.src->AST().GlobalVariables()) {
if (v->storage_class() != ast::StorageClass::kInput) { if (v->storage_class() != ast::StorageClass::kInput) {
continue; continue;
} }
@ -205,7 +207,7 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
for (auto* d : v->decorations()) { for (auto* d : v->decorations()) {
if (auto* builtin = d->As<ast::BuiltinDecoration>()) { if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
if (builtin->value() == ast::Builtin::kInstanceIndex) { if (builtin->value() == ast::Builtin::kInstanceIndex) {
instance_index_name = in->Symbols().NameFor(v->symbol()); instance_index_name = ctx.src->Symbols().NameFor(v->symbol());
return; return;
} }
} }
@ -215,24 +217,22 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
// We didn't find an instance index builtin, so create one // We didn't find an instance index builtin, so create one
instance_index_name = kDefaultInstanceIndexName; instance_index_name = kDefaultInstanceIndexName;
auto* var = out->create<ast::Variable>( auto* var = ctx.dst->create<ast::Variable>(
Source{}, // source Source{}, // source
out->Symbols().Register(instance_index_name), // symbol ctx.dst->Symbols().Register(instance_index_name), // symbol
ast::StorageClass::kInput, // storage_class ast::StorageClass::kInput, // storage_class
GetI32Type(), // type GetI32Type(), // type
false, // is_const false, // is_const
nullptr, // constructor nullptr, // constructor
ast::VariableDecorationList{ ast::VariableDecorationList{
// decorations ctx.dst->create<ast::BuiltinDecoration>(Source{},
out->create<ast::BuiltinDecoration>(Source{}, ast::Builtin::kInstanceIndex),
ast::Builtin::kInstanceIndex),
}); });
out->AST().AddGlobalVariable(var); ctx.dst->AST().AddGlobalVariable(var);
} }
void VertexPulling::State::ConvertVertexInputVariablesToPrivate() { void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
// TODO(https://crbug.com/tint/390): Remove this const_cast hack! for (auto* v : ctx.src->AST().GlobalVariables()) {
for (auto*& v : const_cast<Program*>(in)->AST().GlobalVariables()) {
if (v->storage_class() != ast::StorageClass::kInput) { if (v->storage_class() != ast::StorageClass::kInput) {
continue; continue;
} }
@ -240,18 +240,20 @@ void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
for (auto* d : v->decorations()) { for (auto* d : v->decorations()) {
if (auto* l = d->As<ast::LocationDecoration>()) { if (auto* l = d->As<ast::LocationDecoration>()) {
uint32_t location = l->value(); uint32_t location = l->value();
// This is where the replacement happens. Expressions use identifier // This is where the replacement is created. Expressions use identifier
// strings instead of pointers, so we don't need to update any other // strings instead of pointers, so we don't need to update any other
// place in the AST. // place in the AST.
v = out->create<ast::Variable>( auto name = ctx.src->Symbols().NameFor(v->symbol());
Source{}, // source auto* replacement = ctx.dst->create<ast::Variable>(
v->symbol(), // symbol Source{}, // source
ast::StorageClass::kPrivate, // storage_class ctx.dst->Symbols().Register(name), // symbol
v->type(), // type ast::StorageClass::kPrivate, // storage_class
false, // is_const ctx.Clone(v->type()), // type
nullptr, // constructor false, // is_const
ast::VariableDecorationList{}); // decorations nullptr, // constructor
location_to_var[location] = v; ast::VariableDecorationList{}); // decorations
location_to_var[location] = replacement;
location_replacements.emplace_back(LocationReplacement{v, replacement});
break; break;
} }
} }
@ -261,47 +263,47 @@ void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
void VertexPulling::State::AddVertexStorageBuffers() { void VertexPulling::State::AddVertexStorageBuffers() {
// TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935 // TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
// The array inside the struct definition // The array inside the struct definition
auto* internal_array_type = out->create<type::Array>( auto* internal_array_type = ctx.dst->create<type::Array>(
GetU32Type(), 0, GetU32Type(), 0,
ast::ArrayDecorationList{ ast::ArrayDecorationList{
out->create<ast::StrideDecoration>(Source{}, 4u), ctx.dst->create<ast::StrideDecoration>(Source{}, 4u),
}); });
// Creating the struct type // Creating the struct type
ast::StructMemberList members; ast::StructMemberList members;
ast::StructMemberDecorationList member_dec; ast::StructMemberDecorationList member_dec;
member_dec.push_back( member_dec.push_back(
out->create<ast::StructMemberOffsetDecoration>(Source{}, 0u)); ctx.dst->create<ast::StructMemberOffsetDecoration>(Source{}, 0u));
members.push_back(out->create<ast::StructMember>( members.push_back(ctx.dst->create<ast::StructMember>(
Source{}, out->Symbols().Register(kStructBufferName), internal_array_type, Source{}, ctx.dst->Symbols().Register(kStructBufferName),
std::move(member_dec))); internal_array_type, std::move(member_dec)));
ast::StructDecorationList decos; ast::StructDecorationList decos;
decos.push_back(out->create<ast::StructBlockDecoration>(Source{})); decos.push_back(ctx.dst->create<ast::StructBlockDecoration>(Source{}));
auto* struct_type = out->create<type::Struct>( auto* struct_type = ctx.dst->create<type::Struct>(
out->Symbols().Register(kStructName), ctx.dst->Symbols().Register(kStructName),
out->create<ast::Struct>(Source{}, std::move(members), std::move(decos))); ctx.dst->create<ast::Struct>(Source{}, std::move(members),
std::move(decos)));
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
// The decorated variable with struct type // The decorated variable with struct type
std::string name = GetVertexBufferName(i); std::string name = GetVertexBufferName(i);
auto* var = out->create<ast::Variable>( auto* var = ctx.dst->create<ast::Variable>(
Source{}, // source Source{}, // source
out->Symbols().Register(name), // symbol ctx.dst->Symbols().Register(name), // symbol
ast::StorageClass::kStorage, // storage_class ast::StorageClass::kStorage, // storage_class
struct_type, // type struct_type, // type
false, // is_const false, // is_const
nullptr, // constructor nullptr, // constructor
ast::VariableDecorationList{ ast::VariableDecorationList{
// decorations ctx.dst->create<ast::BindingDecoration>(Source{}, i),
out->create<ast::BindingDecoration>(Source{}, i), ctx.dst->create<ast::GroupDecoration>(Source{}, cfg.pulling_group),
out->create<ast::GroupDecoration>(Source{}, cfg.pulling_group),
}); });
out->AST().AddGlobalVariable(var); ctx.dst->AST().AddGlobalVariable(var);
} }
out->AST().AddConstructedType(struct_type); ctx.dst->AST().AddConstructedType(struct_type);
} }
ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const { ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
@ -311,10 +313,10 @@ ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
ast::StatementList stmts; ast::StatementList stmts;
// Declare the |kPullingPosVarName| variable in the shader // Declare the |kPullingPosVarName| variable in the shader
auto* pos_declaration = out->create<ast::VariableDeclStatement>( auto* pos_declaration = ctx.dst->create<ast::VariableDeclStatement>(
Source{}, out->create<ast::Variable>( Source{}, ctx.dst->create<ast::Variable>(
Source{}, // source Source{}, // source
out->Symbols().Register(kPullingPosVarName), // symbol ctx.dst->Symbols().Register(kPullingPosVarName), // symbol
ast::StorageClass::kFunction, // storage_class ast::StorageClass::kFunction, // storage_class
GetI32Type(), // type GetI32Type(), // type
false, // is_const false, // is_const
@ -341,42 +343,41 @@ ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
? vertex_index_name ? vertex_index_name
: instance_index_name; : instance_index_name;
// Identifier to index by // Identifier to index by
auto* index_identifier = out->create<ast::IdentifierExpression>( auto* index_identifier = ctx.dst->create<ast::IdentifierExpression>(
Source{}, out->Symbols().Register(name)); Source{}, ctx.dst->Symbols().Register(name));
// An expression for the start of the read in the buffer in bytes // An expression for the start of the read in the buffer in bytes
auto* pos_value = out->create<ast::BinaryExpression>( auto* pos_value = ctx.dst->create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kAdd, Source{}, ast::BinaryOp::kAdd,
out->create<ast::BinaryExpression>( ctx.dst->create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kMultiply, index_identifier, Source{}, ast::BinaryOp::kMultiply, index_identifier,
GenUint(static_cast<uint32_t>(buffer_layout.array_stride))), GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
GenUint(static_cast<uint32_t>(attribute_desc.offset))); GenUint(static_cast<uint32_t>(attribute_desc.offset)));
// Update position of the read // Update position of the read
auto* set_pos_expr = out->create<ast::AssignmentStatement>( auto* set_pos_expr = ctx.dst->create<ast::AssignmentStatement>(
Source{}, CreatePullingPositionIdent(), pos_value); Source{}, CreatePullingPositionIdent(), pos_value);
stmts.emplace_back(set_pos_expr); stmts.emplace_back(set_pos_expr);
auto ident_name = in->Symbols().NameFor(v->symbol()); stmts.emplace_back(ctx.dst->create<ast::AssignmentStatement>(
stmts.emplace_back(out->create<ast::AssignmentStatement>(
Source{}, Source{},
out->create<ast::IdentifierExpression>( ctx.dst->create<ast::IdentifierExpression>(Source{}, v->symbol()),
Source{}, out->Symbols().Register(ident_name)),
AccessByFormat(i, attribute_desc.format))); AccessByFormat(i, attribute_desc.format)));
} }
} }
return out->create<ast::BlockStatement>(Source{}, stmts); return ctx.dst->create<ast::BlockStatement>(Source{}, stmts);
} }
ast::Expression* VertexPulling::State::GenUint(uint32_t value) const { ast::Expression* VertexPulling::State::GenUint(uint32_t value) const {
return out->create<ast::ScalarConstructorExpression>( return ctx.dst->create<ast::ScalarConstructorExpression>(
Source{}, out->create<ast::UintLiteral>(Source{}, GetU32Type(), value)); Source{},
ctx.dst->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
} }
ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const { ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const {
return out->create<ast::IdentifierExpression>( return ctx.dst->create<ast::IdentifierExpression>(
Source{}, out->Symbols().Register(kPullingPosVarName)); Source{}, ctx.dst->Symbols().Register(kPullingPosVarName));
} }
ast::Expression* VertexPulling::State::AccessByFormat( ast::Expression* VertexPulling::State::AccessByFormat(
@ -415,30 +416,30 @@ ast::Expression* VertexPulling::State::AccessU32(uint32_t buffer,
// unpacked into an appropriate variable. All reads should end up here as a // unpacked into an appropriate variable. All reads should end up here as a
// base case. // base case.
auto vbuf_name = GetVertexBufferName(buffer); auto vbuf_name = GetVertexBufferName(buffer);
return out->create<ast::ArrayAccessorExpression>( return ctx.dst->create<ast::ArrayAccessorExpression>(
Source{}, Source{},
out->create<ast::MemberAccessorExpression>( ctx.dst->create<ast::MemberAccessorExpression>(
Source{}, Source{},
out->create<ast::IdentifierExpression>( ctx.dst->create<ast::IdentifierExpression>(
Source{}, out->Symbols().Register(vbuf_name)), Source{}, ctx.dst->Symbols().Register(vbuf_name)),
out->create<ast::IdentifierExpression>( ctx.dst->create<ast::IdentifierExpression>(
Source{}, out->Symbols().Register(kStructBufferName))), Source{}, ctx.dst->Symbols().Register(kStructBufferName))),
out->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide, pos, ctx.dst->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide,
GenUint(4))); pos, GenUint(4)));
} }
ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer, ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer,
ast::Expression* pos) const { ast::Expression* pos) const {
// as<T> reinterprets bits // as<T> reinterprets bits
return out->create<ast::BitcastExpression>(Source{}, GetI32Type(), return ctx.dst->create<ast::BitcastExpression>(Source{}, GetI32Type(),
AccessU32(buffer, pos)); AccessU32(buffer, pos));
} }
ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer, ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer,
ast::Expression* pos) const { ast::Expression* pos) const {
// as<T> reinterprets bits // as<T> reinterprets bits
return out->create<ast::BitcastExpression>(Source{}, GetF32Type(), return ctx.dst->create<ast::BitcastExpression>(Source{}, GetF32Type(),
AccessU32(buffer, pos)); AccessU32(buffer, pos));
} }
ast::Expression* VertexPulling::State::AccessPrimitive( ast::Expression* VertexPulling::State::AccessPrimitive(
@ -469,27 +470,27 @@ ast::Expression* VertexPulling::State::AccessVec(uint32_t buffer,
ast::ExpressionList expr_list; ast::ExpressionList expr_list;
for (uint32_t i = 0; i < count; ++i) { for (uint32_t i = 0; i < count; ++i) {
// Offset read position by element_stride for each component // Offset read position by element_stride for each component
auto* cur_pos = out->create<ast::BinaryExpression>( auto* cur_pos = ctx.dst->create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(), Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(),
GenUint(element_stride * i)); GenUint(element_stride * i));
expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format)); expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
} }
return out->create<ast::TypeConstructorExpression>( return ctx.dst->create<ast::TypeConstructorExpression>(
Source{}, out->create<type::Vector>(base_type, count), Source{}, ctx.dst->create<type::Vector>(base_type, count),
std::move(expr_list)); std::move(expr_list));
} }
type::Type* VertexPulling::State::GetU32Type() const { type::Type* VertexPulling::State::GetU32Type() const {
return out->create<type::U32>(); return ctx.dst->create<type::U32>();
} }
type::Type* VertexPulling::State::GetI32Type() const { type::Type* VertexPulling::State::GetI32Type() const {
return out->create<type::I32>(); return ctx.dst->create<type::I32>();
} }
type::Type* VertexPulling::State::GetF32Type() const { type::Type* VertexPulling::State::GetF32Type() const {
return out->create<type::F32>(); return ctx.dst->create<type::F32>();
} }
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default; VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;

View File

@ -183,7 +183,7 @@ class VertexPulling : public Transform {
Config cfg; Config cfg;
struct State { struct State {
State(const Program* in, Program* out, const Config& c); State(CloneContext& ctx, const Config& c);
explicit State(const State&); explicit State(const State&);
~State(); ~State();
@ -263,11 +263,20 @@ class VertexPulling : public Transform {
type::Type* GetI32Type() const; type::Type* GetI32Type() const;
type::Type* GetF32Type() const; type::Type* GetF32Type() const;
const Program* const in; CloneContext& ctx;
Program* const out;
Config const cfg; Config const cfg;
/// LocationReplacement describes an ast::Variable replacement for a
/// location input.
struct LocationReplacement {
/// The variable to replace in the source Program
ast::Variable* from;
/// The replacement to use in the target ProgramBuilder
ast::Variable* to;
};
std::unordered_map<uint32_t, ast::Variable*> location_to_var; std::unordered_map<uint32_t, ast::Variable*> location_to_var;
std::vector<LocationReplacement> location_replacements;
std::string vertex_index_name; std::string vertex_index_name;
std::string instance_index_name; std::string instance_index_name;
}; };