transform: Fix multiple mutability issues with VertexPulling

ConvertVertexInputVariablesToPrivate() mutated the source program global variables, and copied them into the destination program.

Symbols and types were assigned across the program boundary without cloning.

Bug: tint:390
Change-Id: I03c8924e6ba94b745e74de0ab57f8a489e85cc50
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/38554
Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2021-01-26 16:57:10 +00:00
parent 5d9b1c38c2
commit deb02019d5
2 changed files with 121 additions and 111 deletions

View File

@ -103,21 +103,24 @@ Transform::Output VertexPulling::Run(const Program* in) {
// following stages will pass
Output out;
State state{in, &out.program, cfg};
CloneContext ctx(&out.program, in);
State state{ctx, cfg};
state.FindOrInsertVertexIndexIfUsed();
state.FindOrInsertInstanceIndexIfUsed();
state.ConvertVertexInputVariablesToPrivate();
state.AddVertexStorageBuffers();
CloneContext(&out.program, in)
.ReplaceAll([&](CloneContext* ctx, ast::Function* f) -> ast::Function* {
if (f == func) {
return CloneWithStatementsAtStart(
ctx, f, {state.CreateVertexPullingPreamble()});
}
return nullptr; // Just clone func
})
.Clone();
for (auto& replacement : state.location_replacements) {
ctx.Replace(replacement.from, replacement.to);
}
ctx.ReplaceAll([&](CloneContext*, ast::Function* f) -> ast::Function* {
if (f == func) {
return CloneWithStatementsAtStart(&ctx, f,
{state.CreateVertexPullingPreamble()});
}
return nullptr; // Just clone func
});
ctx.Clone();
return out;
}
@ -126,8 +129,8 @@ VertexPulling::Config::Config() = default;
VertexPulling::Config::Config(const Config&) = default;
VertexPulling::Config::~Config() = default;
VertexPulling::State::State(const Program* i, Program* o, const Config& c)
: in(i), out(o), cfg(c) {}
VertexPulling::State::State(CloneContext& context, const Config& c)
: ctx(context), cfg(c) {}
VertexPulling::State::State(const State&) = default;
@ -150,7 +153,7 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
}
// Look for an existing vertex index builtin
for (auto* v : in->AST().GlobalVariables()) {
for (auto* v : ctx.src->AST().GlobalVariables()) {
if (v->storage_class() != ast::StorageClass::kInput) {
continue;
}
@ -158,7 +161,7 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
for (auto* d : v->decorations()) {
if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
if (builtin->value() == ast::Builtin::kVertexIndex) {
vertex_index_name = in->Symbols().NameFor(v->symbol());
vertex_index_name = ctx.src->Symbols().NameFor(v->symbol());
return;
}
}
@ -168,20 +171,19 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
// We didn't find a vertex index builtin, so create one
vertex_index_name = kDefaultVertexIndexName;
auto* var = out->create<ast::Variable>(
Source{}, // source
out->Symbols().Register(vertex_index_name), // symbol
ast::StorageClass::kInput, // storage_class
GetI32Type(), // type
false, // is_const
nullptr, // constructor
auto* var = ctx.dst->create<ast::Variable>(
Source{}, // source
ctx.dst->Symbols().Register(vertex_index_name), // symbol
ast::StorageClass::kInput, // storage_class
GetI32Type(), // type
false, // is_const
nullptr, // constructor
ast::VariableDecorationList{
// decorations
out->create<ast::BuiltinDecoration>(Source{},
ast::Builtin::kVertexIndex),
ctx.dst->create<ast::BuiltinDecoration>(Source{},
ast::Builtin::kVertexIndex),
});
out->AST().AddGlobalVariable(var);
ctx.dst->AST().AddGlobalVariable(var);
}
void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
@ -197,7 +199,7 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
}
// Look for an existing instance index builtin
for (auto* v : in->AST().GlobalVariables()) {
for (auto* v : ctx.src->AST().GlobalVariables()) {
if (v->storage_class() != ast::StorageClass::kInput) {
continue;
}
@ -205,7 +207,7 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
for (auto* d : v->decorations()) {
if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
if (builtin->value() == ast::Builtin::kInstanceIndex) {
instance_index_name = in->Symbols().NameFor(v->symbol());
instance_index_name = ctx.src->Symbols().NameFor(v->symbol());
return;
}
}
@ -215,24 +217,22 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
// We didn't find an instance index builtin, so create one
instance_index_name = kDefaultInstanceIndexName;
auto* var = out->create<ast::Variable>(
Source{}, // source
out->Symbols().Register(instance_index_name), // symbol
ast::StorageClass::kInput, // storage_class
GetI32Type(), // type
false, // is_const
nullptr, // constructor
auto* var = ctx.dst->create<ast::Variable>(
Source{}, // source
ctx.dst->Symbols().Register(instance_index_name), // symbol
ast::StorageClass::kInput, // storage_class
GetI32Type(), // type
false, // is_const
nullptr, // constructor
ast::VariableDecorationList{
// decorations
out->create<ast::BuiltinDecoration>(Source{},
ast::Builtin::kInstanceIndex),
ctx.dst->create<ast::BuiltinDecoration>(Source{},
ast::Builtin::kInstanceIndex),
});
out->AST().AddGlobalVariable(var);
ctx.dst->AST().AddGlobalVariable(var);
}
void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
// TODO(https://crbug.com/tint/390): Remove this const_cast hack!
for (auto*& v : const_cast<Program*>(in)->AST().GlobalVariables()) {
for (auto* v : ctx.src->AST().GlobalVariables()) {
if (v->storage_class() != ast::StorageClass::kInput) {
continue;
}
@ -240,18 +240,20 @@ void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
for (auto* d : v->decorations()) {
if (auto* l = d->As<ast::LocationDecoration>()) {
uint32_t location = l->value();
// This is where the replacement happens. Expressions use identifier
// This is where the replacement is created. Expressions use identifier
// strings instead of pointers, so we don't need to update any other
// place in the AST.
v = out->create<ast::Variable>(
Source{}, // source
v->symbol(), // symbol
ast::StorageClass::kPrivate, // storage_class
v->type(), // type
false, // is_const
nullptr, // constructor
ast::VariableDecorationList{}); // decorations
location_to_var[location] = v;
auto name = ctx.src->Symbols().NameFor(v->symbol());
auto* replacement = ctx.dst->create<ast::Variable>(
Source{}, // source
ctx.dst->Symbols().Register(name), // symbol
ast::StorageClass::kPrivate, // storage_class
ctx.Clone(v->type()), // type
false, // is_const
nullptr, // constructor
ast::VariableDecorationList{}); // decorations
location_to_var[location] = replacement;
location_replacements.emplace_back(LocationReplacement{v, replacement});
break;
}
}
@ -261,47 +263,47 @@ void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
void VertexPulling::State::AddVertexStorageBuffers() {
// TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
// The array inside the struct definition
auto* internal_array_type = out->create<type::Array>(
auto* internal_array_type = ctx.dst->create<type::Array>(
GetU32Type(), 0,
ast::ArrayDecorationList{
out->create<ast::StrideDecoration>(Source{}, 4u),
ctx.dst->create<ast::StrideDecoration>(Source{}, 4u),
});
// Creating the struct type
ast::StructMemberList members;
ast::StructMemberDecorationList member_dec;
member_dec.push_back(
out->create<ast::StructMemberOffsetDecoration>(Source{}, 0u));
ctx.dst->create<ast::StructMemberOffsetDecoration>(Source{}, 0u));
members.push_back(out->create<ast::StructMember>(
Source{}, out->Symbols().Register(kStructBufferName), internal_array_type,
std::move(member_dec)));
members.push_back(ctx.dst->create<ast::StructMember>(
Source{}, ctx.dst->Symbols().Register(kStructBufferName),
internal_array_type, std::move(member_dec)));
ast::StructDecorationList decos;
decos.push_back(out->create<ast::StructBlockDecoration>(Source{}));
decos.push_back(ctx.dst->create<ast::StructBlockDecoration>(Source{}));
auto* struct_type = out->create<type::Struct>(
out->Symbols().Register(kStructName),
out->create<ast::Struct>(Source{}, std::move(members), std::move(decos)));
auto* struct_type = ctx.dst->create<type::Struct>(
ctx.dst->Symbols().Register(kStructName),
ctx.dst->create<ast::Struct>(Source{}, std::move(members),
std::move(decos)));
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
// The decorated variable with struct type
std::string name = GetVertexBufferName(i);
auto* var = out->create<ast::Variable>(
Source{}, // source
out->Symbols().Register(name), // symbol
ast::StorageClass::kStorage, // storage_class
struct_type, // type
false, // is_const
nullptr, // constructor
auto* var = ctx.dst->create<ast::Variable>(
Source{}, // source
ctx.dst->Symbols().Register(name), // symbol
ast::StorageClass::kStorage, // storage_class
struct_type, // type
false, // is_const
nullptr, // constructor
ast::VariableDecorationList{
// decorations
out->create<ast::BindingDecoration>(Source{}, i),
out->create<ast::GroupDecoration>(Source{}, cfg.pulling_group),
ctx.dst->create<ast::BindingDecoration>(Source{}, i),
ctx.dst->create<ast::GroupDecoration>(Source{}, cfg.pulling_group),
});
out->AST().AddGlobalVariable(var);
ctx.dst->AST().AddGlobalVariable(var);
}
out->AST().AddConstructedType(struct_type);
ctx.dst->AST().AddConstructedType(struct_type);
}
ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
@ -311,10 +313,10 @@ ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
ast::StatementList stmts;
// Declare the |kPullingPosVarName| variable in the shader
auto* pos_declaration = out->create<ast::VariableDeclStatement>(
Source{}, out->create<ast::Variable>(
Source{}, // source
out->Symbols().Register(kPullingPosVarName), // symbol
auto* pos_declaration = ctx.dst->create<ast::VariableDeclStatement>(
Source{}, ctx.dst->create<ast::Variable>(
Source{}, // source
ctx.dst->Symbols().Register(kPullingPosVarName), // symbol
ast::StorageClass::kFunction, // storage_class
GetI32Type(), // type
false, // is_const
@ -341,42 +343,41 @@ ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
? vertex_index_name
: instance_index_name;
// Identifier to index by
auto* index_identifier = out->create<ast::IdentifierExpression>(
Source{}, out->Symbols().Register(name));
auto* index_identifier = ctx.dst->create<ast::IdentifierExpression>(
Source{}, ctx.dst->Symbols().Register(name));
// An expression for the start of the read in the buffer in bytes
auto* pos_value = out->create<ast::BinaryExpression>(
auto* pos_value = ctx.dst->create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kAdd,
out->create<ast::BinaryExpression>(
ctx.dst->create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kMultiply, index_identifier,
GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
GenUint(static_cast<uint32_t>(attribute_desc.offset)));
// Update position of the read
auto* set_pos_expr = out->create<ast::AssignmentStatement>(
auto* set_pos_expr = ctx.dst->create<ast::AssignmentStatement>(
Source{}, CreatePullingPositionIdent(), pos_value);
stmts.emplace_back(set_pos_expr);
auto ident_name = in->Symbols().NameFor(v->symbol());
stmts.emplace_back(out->create<ast::AssignmentStatement>(
stmts.emplace_back(ctx.dst->create<ast::AssignmentStatement>(
Source{},
out->create<ast::IdentifierExpression>(
Source{}, out->Symbols().Register(ident_name)),
ctx.dst->create<ast::IdentifierExpression>(Source{}, v->symbol()),
AccessByFormat(i, attribute_desc.format)));
}
}
return out->create<ast::BlockStatement>(Source{}, stmts);
return ctx.dst->create<ast::BlockStatement>(Source{}, stmts);
}
ast::Expression* VertexPulling::State::GenUint(uint32_t value) const {
return out->create<ast::ScalarConstructorExpression>(
Source{}, out->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
return ctx.dst->create<ast::ScalarConstructorExpression>(
Source{},
ctx.dst->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
}
ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const {
return out->create<ast::IdentifierExpression>(
Source{}, out->Symbols().Register(kPullingPosVarName));
return ctx.dst->create<ast::IdentifierExpression>(
Source{}, ctx.dst->Symbols().Register(kPullingPosVarName));
}
ast::Expression* VertexPulling::State::AccessByFormat(
@ -415,30 +416,30 @@ ast::Expression* VertexPulling::State::AccessU32(uint32_t buffer,
// unpacked into an appropriate variable. All reads should end up here as a
// base case.
auto vbuf_name = GetVertexBufferName(buffer);
return out->create<ast::ArrayAccessorExpression>(
return ctx.dst->create<ast::ArrayAccessorExpression>(
Source{},
out->create<ast::MemberAccessorExpression>(
ctx.dst->create<ast::MemberAccessorExpression>(
Source{},
out->create<ast::IdentifierExpression>(
Source{}, out->Symbols().Register(vbuf_name)),
out->create<ast::IdentifierExpression>(
Source{}, out->Symbols().Register(kStructBufferName))),
out->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide, pos,
GenUint(4)));
ctx.dst->create<ast::IdentifierExpression>(
Source{}, ctx.dst->Symbols().Register(vbuf_name)),
ctx.dst->create<ast::IdentifierExpression>(
Source{}, ctx.dst->Symbols().Register(kStructBufferName))),
ctx.dst->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide,
pos, GenUint(4)));
}
ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer,
ast::Expression* pos) const {
// as<T> reinterprets bits
return out->create<ast::BitcastExpression>(Source{}, GetI32Type(),
AccessU32(buffer, pos));
return ctx.dst->create<ast::BitcastExpression>(Source{}, GetI32Type(),
AccessU32(buffer, pos));
}
ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer,
ast::Expression* pos) const {
// as<T> reinterprets bits
return out->create<ast::BitcastExpression>(Source{}, GetF32Type(),
AccessU32(buffer, pos));
return ctx.dst->create<ast::BitcastExpression>(Source{}, GetF32Type(),
AccessU32(buffer, pos));
}
ast::Expression* VertexPulling::State::AccessPrimitive(
@ -469,27 +470,27 @@ ast::Expression* VertexPulling::State::AccessVec(uint32_t buffer,
ast::ExpressionList expr_list;
for (uint32_t i = 0; i < count; ++i) {
// Offset read position by element_stride for each component
auto* cur_pos = out->create<ast::BinaryExpression>(
auto* cur_pos = ctx.dst->create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(),
GenUint(element_stride * i));
expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
}
return out->create<ast::TypeConstructorExpression>(
Source{}, out->create<type::Vector>(base_type, count),
return ctx.dst->create<ast::TypeConstructorExpression>(
Source{}, ctx.dst->create<type::Vector>(base_type, count),
std::move(expr_list));
}
type::Type* VertexPulling::State::GetU32Type() const {
return out->create<type::U32>();
return ctx.dst->create<type::U32>();
}
type::Type* VertexPulling::State::GetI32Type() const {
return out->create<type::I32>();
return ctx.dst->create<type::I32>();
}
type::Type* VertexPulling::State::GetF32Type() const {
return out->create<type::F32>();
return ctx.dst->create<type::F32>();
}
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;

View File

@ -183,7 +183,7 @@ class VertexPulling : public Transform {
Config cfg;
struct State {
State(const Program* in, Program* out, const Config& c);
State(CloneContext& ctx, const Config& c);
explicit State(const State&);
~State();
@ -263,11 +263,20 @@ class VertexPulling : public Transform {
type::Type* GetI32Type() const;
type::Type* GetF32Type() const;
const Program* const in;
Program* const out;
CloneContext& ctx;
Config const cfg;
/// LocationReplacement describes an ast::Variable replacement for a
/// location input.
struct LocationReplacement {
/// The variable to replace in the source Program
ast::Variable* from;
/// The replacement to use in the target ProgramBuilder
ast::Variable* to;
};
std::unordered_map<uint32_t, ast::Variable*> location_to_var;
std::vector<LocationReplacement> location_replacements;
std::string vertex_index_name;
std::string instance_index_name;
};