transform: Fix multiple mutability issues with VertexPulling
ConvertVertexInputVariablesToPrivate() mutated the source program global variables, and copied them into the destination program. Symbols and types were assigned across the program boundary without cloning. Bug: tint:390 Change-Id: I03c8924e6ba94b745e74de0ab57f8a489e85cc50 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/38554 Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
parent
5d9b1c38c2
commit
deb02019d5
|
@ -103,21 +103,24 @@ Transform::Output VertexPulling::Run(const Program* in) {
|
|||
// following stages will pass
|
||||
Output out;
|
||||
|
||||
State state{in, &out.program, cfg};
|
||||
CloneContext ctx(&out.program, in);
|
||||
State state{ctx, cfg};
|
||||
state.FindOrInsertVertexIndexIfUsed();
|
||||
state.FindOrInsertInstanceIndexIfUsed();
|
||||
state.ConvertVertexInputVariablesToPrivate();
|
||||
state.AddVertexStorageBuffers();
|
||||
|
||||
CloneContext(&out.program, in)
|
||||
.ReplaceAll([&](CloneContext* ctx, ast::Function* f) -> ast::Function* {
|
||||
if (f == func) {
|
||||
return CloneWithStatementsAtStart(
|
||||
ctx, f, {state.CreateVertexPullingPreamble()});
|
||||
}
|
||||
return nullptr; // Just clone func
|
||||
})
|
||||
.Clone();
|
||||
for (auto& replacement : state.location_replacements) {
|
||||
ctx.Replace(replacement.from, replacement.to);
|
||||
}
|
||||
ctx.ReplaceAll([&](CloneContext*, ast::Function* f) -> ast::Function* {
|
||||
if (f == func) {
|
||||
return CloneWithStatementsAtStart(&ctx, f,
|
||||
{state.CreateVertexPullingPreamble()});
|
||||
}
|
||||
return nullptr; // Just clone func
|
||||
});
|
||||
ctx.Clone();
|
||||
|
||||
return out;
|
||||
}
|
||||
|
@ -126,8 +129,8 @@ VertexPulling::Config::Config() = default;
|
|||
VertexPulling::Config::Config(const Config&) = default;
|
||||
VertexPulling::Config::~Config() = default;
|
||||
|
||||
VertexPulling::State::State(const Program* i, Program* o, const Config& c)
|
||||
: in(i), out(o), cfg(c) {}
|
||||
VertexPulling::State::State(CloneContext& context, const Config& c)
|
||||
: ctx(context), cfg(c) {}
|
||||
|
||||
VertexPulling::State::State(const State&) = default;
|
||||
|
||||
|
@ -150,7 +153,7 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
|
|||
}
|
||||
|
||||
// Look for an existing vertex index builtin
|
||||
for (auto* v : in->AST().GlobalVariables()) {
|
||||
for (auto* v : ctx.src->AST().GlobalVariables()) {
|
||||
if (v->storage_class() != ast::StorageClass::kInput) {
|
||||
continue;
|
||||
}
|
||||
|
@ -158,7 +161,7 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
|
|||
for (auto* d : v->decorations()) {
|
||||
if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
|
||||
if (builtin->value() == ast::Builtin::kVertexIndex) {
|
||||
vertex_index_name = in->Symbols().NameFor(v->symbol());
|
||||
vertex_index_name = ctx.src->Symbols().NameFor(v->symbol());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -168,20 +171,19 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
|
|||
// We didn't find a vertex index builtin, so create one
|
||||
vertex_index_name = kDefaultVertexIndexName;
|
||||
|
||||
auto* var = out->create<ast::Variable>(
|
||||
Source{}, // source
|
||||
out->Symbols().Register(vertex_index_name), // symbol
|
||||
ast::StorageClass::kInput, // storage_class
|
||||
GetI32Type(), // type
|
||||
false, // is_const
|
||||
nullptr, // constructor
|
||||
auto* var = ctx.dst->create<ast::Variable>(
|
||||
Source{}, // source
|
||||
ctx.dst->Symbols().Register(vertex_index_name), // symbol
|
||||
ast::StorageClass::kInput, // storage_class
|
||||
GetI32Type(), // type
|
||||
false, // is_const
|
||||
nullptr, // constructor
|
||||
ast::VariableDecorationList{
|
||||
// decorations
|
||||
out->create<ast::BuiltinDecoration>(Source{},
|
||||
ast::Builtin::kVertexIndex),
|
||||
ctx.dst->create<ast::BuiltinDecoration>(Source{},
|
||||
ast::Builtin::kVertexIndex),
|
||||
});
|
||||
|
||||
out->AST().AddGlobalVariable(var);
|
||||
ctx.dst->AST().AddGlobalVariable(var);
|
||||
}
|
||||
|
||||
void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
|
||||
|
@ -197,7 +199,7 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
|
|||
}
|
||||
|
||||
// Look for an existing instance index builtin
|
||||
for (auto* v : in->AST().GlobalVariables()) {
|
||||
for (auto* v : ctx.src->AST().GlobalVariables()) {
|
||||
if (v->storage_class() != ast::StorageClass::kInput) {
|
||||
continue;
|
||||
}
|
||||
|
@ -205,7 +207,7 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
|
|||
for (auto* d : v->decorations()) {
|
||||
if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
|
||||
if (builtin->value() == ast::Builtin::kInstanceIndex) {
|
||||
instance_index_name = in->Symbols().NameFor(v->symbol());
|
||||
instance_index_name = ctx.src->Symbols().NameFor(v->symbol());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -215,24 +217,22 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
|
|||
// We didn't find an instance index builtin, so create one
|
||||
instance_index_name = kDefaultInstanceIndexName;
|
||||
|
||||
auto* var = out->create<ast::Variable>(
|
||||
Source{}, // source
|
||||
out->Symbols().Register(instance_index_name), // symbol
|
||||
ast::StorageClass::kInput, // storage_class
|
||||
GetI32Type(), // type
|
||||
false, // is_const
|
||||
nullptr, // constructor
|
||||
auto* var = ctx.dst->create<ast::Variable>(
|
||||
Source{}, // source
|
||||
ctx.dst->Symbols().Register(instance_index_name), // symbol
|
||||
ast::StorageClass::kInput, // storage_class
|
||||
GetI32Type(), // type
|
||||
false, // is_const
|
||||
nullptr, // constructor
|
||||
ast::VariableDecorationList{
|
||||
// decorations
|
||||
out->create<ast::BuiltinDecoration>(Source{},
|
||||
ast::Builtin::kInstanceIndex),
|
||||
ctx.dst->create<ast::BuiltinDecoration>(Source{},
|
||||
ast::Builtin::kInstanceIndex),
|
||||
});
|
||||
out->AST().AddGlobalVariable(var);
|
||||
ctx.dst->AST().AddGlobalVariable(var);
|
||||
}
|
||||
|
||||
void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
|
||||
// TODO(https://crbug.com/tint/390): Remove this const_cast hack!
|
||||
for (auto*& v : const_cast<Program*>(in)->AST().GlobalVariables()) {
|
||||
for (auto* v : ctx.src->AST().GlobalVariables()) {
|
||||
if (v->storage_class() != ast::StorageClass::kInput) {
|
||||
continue;
|
||||
}
|
||||
|
@ -240,18 +240,20 @@ void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
|
|||
for (auto* d : v->decorations()) {
|
||||
if (auto* l = d->As<ast::LocationDecoration>()) {
|
||||
uint32_t location = l->value();
|
||||
// This is where the replacement happens. Expressions use identifier
|
||||
// This is where the replacement is created. Expressions use identifier
|
||||
// strings instead of pointers, so we don't need to update any other
|
||||
// place in the AST.
|
||||
v = out->create<ast::Variable>(
|
||||
Source{}, // source
|
||||
v->symbol(), // symbol
|
||||
ast::StorageClass::kPrivate, // storage_class
|
||||
v->type(), // type
|
||||
false, // is_const
|
||||
nullptr, // constructor
|
||||
ast::VariableDecorationList{}); // decorations
|
||||
location_to_var[location] = v;
|
||||
auto name = ctx.src->Symbols().NameFor(v->symbol());
|
||||
auto* replacement = ctx.dst->create<ast::Variable>(
|
||||
Source{}, // source
|
||||
ctx.dst->Symbols().Register(name), // symbol
|
||||
ast::StorageClass::kPrivate, // storage_class
|
||||
ctx.Clone(v->type()), // type
|
||||
false, // is_const
|
||||
nullptr, // constructor
|
||||
ast::VariableDecorationList{}); // decorations
|
||||
location_to_var[location] = replacement;
|
||||
location_replacements.emplace_back(LocationReplacement{v, replacement});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -261,47 +263,47 @@ void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
|
|||
void VertexPulling::State::AddVertexStorageBuffers() {
|
||||
// TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
|
||||
// The array inside the struct definition
|
||||
auto* internal_array_type = out->create<type::Array>(
|
||||
auto* internal_array_type = ctx.dst->create<type::Array>(
|
||||
GetU32Type(), 0,
|
||||
ast::ArrayDecorationList{
|
||||
out->create<ast::StrideDecoration>(Source{}, 4u),
|
||||
ctx.dst->create<ast::StrideDecoration>(Source{}, 4u),
|
||||
});
|
||||
|
||||
// Creating the struct type
|
||||
ast::StructMemberList members;
|
||||
ast::StructMemberDecorationList member_dec;
|
||||
member_dec.push_back(
|
||||
out->create<ast::StructMemberOffsetDecoration>(Source{}, 0u));
|
||||
ctx.dst->create<ast::StructMemberOffsetDecoration>(Source{}, 0u));
|
||||
|
||||
members.push_back(out->create<ast::StructMember>(
|
||||
Source{}, out->Symbols().Register(kStructBufferName), internal_array_type,
|
||||
std::move(member_dec)));
|
||||
members.push_back(ctx.dst->create<ast::StructMember>(
|
||||
Source{}, ctx.dst->Symbols().Register(kStructBufferName),
|
||||
internal_array_type, std::move(member_dec)));
|
||||
|
||||
ast::StructDecorationList decos;
|
||||
decos.push_back(out->create<ast::StructBlockDecoration>(Source{}));
|
||||
decos.push_back(ctx.dst->create<ast::StructBlockDecoration>(Source{}));
|
||||
|
||||
auto* struct_type = out->create<type::Struct>(
|
||||
out->Symbols().Register(kStructName),
|
||||
out->create<ast::Struct>(Source{}, std::move(members), std::move(decos)));
|
||||
auto* struct_type = ctx.dst->create<type::Struct>(
|
||||
ctx.dst->Symbols().Register(kStructName),
|
||||
ctx.dst->create<ast::Struct>(Source{}, std::move(members),
|
||||
std::move(decos)));
|
||||
|
||||
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
|
||||
// The decorated variable with struct type
|
||||
std::string name = GetVertexBufferName(i);
|
||||
auto* var = out->create<ast::Variable>(
|
||||
Source{}, // source
|
||||
out->Symbols().Register(name), // symbol
|
||||
ast::StorageClass::kStorage, // storage_class
|
||||
struct_type, // type
|
||||
false, // is_const
|
||||
nullptr, // constructor
|
||||
auto* var = ctx.dst->create<ast::Variable>(
|
||||
Source{}, // source
|
||||
ctx.dst->Symbols().Register(name), // symbol
|
||||
ast::StorageClass::kStorage, // storage_class
|
||||
struct_type, // type
|
||||
false, // is_const
|
||||
nullptr, // constructor
|
||||
ast::VariableDecorationList{
|
||||
// decorations
|
||||
out->create<ast::BindingDecoration>(Source{}, i),
|
||||
out->create<ast::GroupDecoration>(Source{}, cfg.pulling_group),
|
||||
ctx.dst->create<ast::BindingDecoration>(Source{}, i),
|
||||
ctx.dst->create<ast::GroupDecoration>(Source{}, cfg.pulling_group),
|
||||
});
|
||||
out->AST().AddGlobalVariable(var);
|
||||
ctx.dst->AST().AddGlobalVariable(var);
|
||||
}
|
||||
out->AST().AddConstructedType(struct_type);
|
||||
ctx.dst->AST().AddConstructedType(struct_type);
|
||||
}
|
||||
|
||||
ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
|
||||
|
@ -311,10 +313,10 @@ ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
|
|||
ast::StatementList stmts;
|
||||
|
||||
// Declare the |kPullingPosVarName| variable in the shader
|
||||
auto* pos_declaration = out->create<ast::VariableDeclStatement>(
|
||||
Source{}, out->create<ast::Variable>(
|
||||
Source{}, // source
|
||||
out->Symbols().Register(kPullingPosVarName), // symbol
|
||||
auto* pos_declaration = ctx.dst->create<ast::VariableDeclStatement>(
|
||||
Source{}, ctx.dst->create<ast::Variable>(
|
||||
Source{}, // source
|
||||
ctx.dst->Symbols().Register(kPullingPosVarName), // symbol
|
||||
ast::StorageClass::kFunction, // storage_class
|
||||
GetI32Type(), // type
|
||||
false, // is_const
|
||||
|
@ -341,42 +343,41 @@ ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
|
|||
? vertex_index_name
|
||||
: instance_index_name;
|
||||
// Identifier to index by
|
||||
auto* index_identifier = out->create<ast::IdentifierExpression>(
|
||||
Source{}, out->Symbols().Register(name));
|
||||
auto* index_identifier = ctx.dst->create<ast::IdentifierExpression>(
|
||||
Source{}, ctx.dst->Symbols().Register(name));
|
||||
|
||||
// An expression for the start of the read in the buffer in bytes
|
||||
auto* pos_value = out->create<ast::BinaryExpression>(
|
||||
auto* pos_value = ctx.dst->create<ast::BinaryExpression>(
|
||||
Source{}, ast::BinaryOp::kAdd,
|
||||
out->create<ast::BinaryExpression>(
|
||||
ctx.dst->create<ast::BinaryExpression>(
|
||||
Source{}, ast::BinaryOp::kMultiply, index_identifier,
|
||||
GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
|
||||
GenUint(static_cast<uint32_t>(attribute_desc.offset)));
|
||||
|
||||
// Update position of the read
|
||||
auto* set_pos_expr = out->create<ast::AssignmentStatement>(
|
||||
auto* set_pos_expr = ctx.dst->create<ast::AssignmentStatement>(
|
||||
Source{}, CreatePullingPositionIdent(), pos_value);
|
||||
stmts.emplace_back(set_pos_expr);
|
||||
|
||||
auto ident_name = in->Symbols().NameFor(v->symbol());
|
||||
stmts.emplace_back(out->create<ast::AssignmentStatement>(
|
||||
stmts.emplace_back(ctx.dst->create<ast::AssignmentStatement>(
|
||||
Source{},
|
||||
out->create<ast::IdentifierExpression>(
|
||||
Source{}, out->Symbols().Register(ident_name)),
|
||||
ctx.dst->create<ast::IdentifierExpression>(Source{}, v->symbol()),
|
||||
AccessByFormat(i, attribute_desc.format)));
|
||||
}
|
||||
}
|
||||
|
||||
return out->create<ast::BlockStatement>(Source{}, stmts);
|
||||
return ctx.dst->create<ast::BlockStatement>(Source{}, stmts);
|
||||
}
|
||||
|
||||
ast::Expression* VertexPulling::State::GenUint(uint32_t value) const {
|
||||
return out->create<ast::ScalarConstructorExpression>(
|
||||
Source{}, out->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
|
||||
return ctx.dst->create<ast::ScalarConstructorExpression>(
|
||||
Source{},
|
||||
ctx.dst->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
|
||||
}
|
||||
|
||||
ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const {
|
||||
return out->create<ast::IdentifierExpression>(
|
||||
Source{}, out->Symbols().Register(kPullingPosVarName));
|
||||
return ctx.dst->create<ast::IdentifierExpression>(
|
||||
Source{}, ctx.dst->Symbols().Register(kPullingPosVarName));
|
||||
}
|
||||
|
||||
ast::Expression* VertexPulling::State::AccessByFormat(
|
||||
|
@ -415,30 +416,30 @@ ast::Expression* VertexPulling::State::AccessU32(uint32_t buffer,
|
|||
// unpacked into an appropriate variable. All reads should end up here as a
|
||||
// base case.
|
||||
auto vbuf_name = GetVertexBufferName(buffer);
|
||||
return out->create<ast::ArrayAccessorExpression>(
|
||||
return ctx.dst->create<ast::ArrayAccessorExpression>(
|
||||
Source{},
|
||||
out->create<ast::MemberAccessorExpression>(
|
||||
ctx.dst->create<ast::MemberAccessorExpression>(
|
||||
Source{},
|
||||
out->create<ast::IdentifierExpression>(
|
||||
Source{}, out->Symbols().Register(vbuf_name)),
|
||||
out->create<ast::IdentifierExpression>(
|
||||
Source{}, out->Symbols().Register(kStructBufferName))),
|
||||
out->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide, pos,
|
||||
GenUint(4)));
|
||||
ctx.dst->create<ast::IdentifierExpression>(
|
||||
Source{}, ctx.dst->Symbols().Register(vbuf_name)),
|
||||
ctx.dst->create<ast::IdentifierExpression>(
|
||||
Source{}, ctx.dst->Symbols().Register(kStructBufferName))),
|
||||
ctx.dst->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide,
|
||||
pos, GenUint(4)));
|
||||
}
|
||||
|
||||
ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer,
|
||||
ast::Expression* pos) const {
|
||||
// as<T> reinterprets bits
|
||||
return out->create<ast::BitcastExpression>(Source{}, GetI32Type(),
|
||||
AccessU32(buffer, pos));
|
||||
return ctx.dst->create<ast::BitcastExpression>(Source{}, GetI32Type(),
|
||||
AccessU32(buffer, pos));
|
||||
}
|
||||
|
||||
ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer,
|
||||
ast::Expression* pos) const {
|
||||
// as<T> reinterprets bits
|
||||
return out->create<ast::BitcastExpression>(Source{}, GetF32Type(),
|
||||
AccessU32(buffer, pos));
|
||||
return ctx.dst->create<ast::BitcastExpression>(Source{}, GetF32Type(),
|
||||
AccessU32(buffer, pos));
|
||||
}
|
||||
|
||||
ast::Expression* VertexPulling::State::AccessPrimitive(
|
||||
|
@ -469,27 +470,27 @@ ast::Expression* VertexPulling::State::AccessVec(uint32_t buffer,
|
|||
ast::ExpressionList expr_list;
|
||||
for (uint32_t i = 0; i < count; ++i) {
|
||||
// Offset read position by element_stride for each component
|
||||
auto* cur_pos = out->create<ast::BinaryExpression>(
|
||||
auto* cur_pos = ctx.dst->create<ast::BinaryExpression>(
|
||||
Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(),
|
||||
GenUint(element_stride * i));
|
||||
expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
|
||||
}
|
||||
|
||||
return out->create<ast::TypeConstructorExpression>(
|
||||
Source{}, out->create<type::Vector>(base_type, count),
|
||||
return ctx.dst->create<ast::TypeConstructorExpression>(
|
||||
Source{}, ctx.dst->create<type::Vector>(base_type, count),
|
||||
std::move(expr_list));
|
||||
}
|
||||
|
||||
type::Type* VertexPulling::State::GetU32Type() const {
|
||||
return out->create<type::U32>();
|
||||
return ctx.dst->create<type::U32>();
|
||||
}
|
||||
|
||||
type::Type* VertexPulling::State::GetI32Type() const {
|
||||
return out->create<type::I32>();
|
||||
return ctx.dst->create<type::I32>();
|
||||
}
|
||||
|
||||
type::Type* VertexPulling::State::GetF32Type() const {
|
||||
return out->create<type::F32>();
|
||||
return ctx.dst->create<type::F32>();
|
||||
}
|
||||
|
||||
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
|
||||
|
|
|
@ -183,7 +183,7 @@ class VertexPulling : public Transform {
|
|||
Config cfg;
|
||||
|
||||
struct State {
|
||||
State(const Program* in, Program* out, const Config& c);
|
||||
State(CloneContext& ctx, const Config& c);
|
||||
explicit State(const State&);
|
||||
~State();
|
||||
|
||||
|
@ -263,11 +263,20 @@ class VertexPulling : public Transform {
|
|||
type::Type* GetI32Type() const;
|
||||
type::Type* GetF32Type() const;
|
||||
|
||||
const Program* const in;
|
||||
Program* const out;
|
||||
CloneContext& ctx;
|
||||
Config const cfg;
|
||||
|
||||
/// LocationReplacement describes an ast::Variable replacement for a
|
||||
/// location input.
|
||||
struct LocationReplacement {
|
||||
/// The variable to replace in the source Program
|
||||
ast::Variable* from;
|
||||
/// The replacement to use in the target ProgramBuilder
|
||||
ast::Variable* to;
|
||||
};
|
||||
|
||||
std::unordered_map<uint32_t, ast::Variable*> location_to_var;
|
||||
std::vector<LocationReplacement> location_replacements;
|
||||
std::string vertex_index_name;
|
||||
std::string instance_index_name;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue