transform/hlsl: Hoist structure constructors to new var

HLSL has some pecular rules around structure constructors.
`S s = S(1,2,3)` is not valid, but `S s = {1,2,3}` is.

This matches the quirkiness with array initializers, so adjust the array
hoisting logic to also support structures.

Fixed: tint:702
Change-Id: Ifdcafd98292715ae2482f72ec06c87842176d270
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46875
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2021-04-07 17:29:31 +00:00 committed by Commit Bot service account
parent 8f42be3f60
commit 933d44a2c8
7 changed files with 305 additions and 74 deletions

View File

@ -32,24 +32,24 @@ Hlsl::~Hlsl() = default;
Transform::Output Hlsl::Run(const Program* in, const DataMap&) { Transform::Output Hlsl::Run(const Program* in, const DataMap&) {
ProgramBuilder out; ProgramBuilder out;
CloneContext ctx(&out, in); CloneContext ctx(&out, in);
PromoteArrayInitializerToConstVar(ctx); PromoteInitializersToConstVar(ctx);
AddEmptyEntryPoint(ctx); AddEmptyEntryPoint(ctx);
ctx.Clone(); ctx.Clone();
return Output{Program(std::move(out))}; return Output{Program(std::move(out))};
} }
void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const { void Hlsl::PromoteInitializersToConstVar(CloneContext& ctx) const {
// Scan the AST nodes for array initializers which need to be promoted to // Scan the AST nodes for array and structure initializers which
// their own constant declaration. // need to be promoted to their own constant declaration.
// Note: Correct handling of arrays-of-arrays is guaranteed due to the // Note: Correct handling of nested expressions is guaranteed due to the
// depth-first traversal of the ast::Node::Clone() methods: // depth-first traversal of the ast::Node::Clone() methods:
// //
// The inner-most array initializers are traversed first, and they are hoisted // The inner-most initializers are traversed first, and they are hoisted
// to const variables declared just above the statement of use. The outer // to const variables declared just above the statement of use. The outer
// array initializer will then be hoisted, inserting themselves between the // initializer will then be hoisted, inserting themselves between the
// inner array declaration and the statement of use. This pattern applies // inner declaration and the statement of use. This pattern applies correctly
// correctly to any nested depth. // to any nested depth.
// //
// Depth-first traversal of the AST is guaranteed because AST nodes are fully // Depth-first traversal of the AST is guaranteed because AST nodes are fully
// immutable and require their children to be constructed first so their // immutable and require their children to be constructed first so their
@ -75,22 +75,23 @@ void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const {
if (auto* src_var_decl = src_stmt->As<ast::VariableDeclStatement>()) { if (auto* src_var_decl = src_stmt->As<ast::VariableDeclStatement>()) {
if (src_var_decl->variable()->constructor() == src_init) { if (src_var_decl->variable()->constructor() == src_init) {
// This statement is just a variable declaration with the array // This statement is just a variable declaration with the initializer
// initializer as the constructor value. This is what we're // as the constructor value. This is what we're attempting to
// attempting to transform to, and so ignore. // transform to, and so ignore.
continue; continue;
} }
} }
if (auto* src_array_ty = src_sem_expr->Type()->As<type::Array>()) { auto* src_ty = src_sem_expr->Type();
if (src_ty->IsAnyOf<type::Array, type::Struct>()) {
// Create a new symbol for the constant // Create a new symbol for the constant
auto dst_symbol = ctx.dst->Symbols().New(); auto dst_symbol = ctx.dst->Symbols().New();
// Clone the array type // Clone the type
auto* dst_array_ty = ctx.Clone(src_array_ty); auto* dst_ty = ctx.Clone(src_ty);
// Clone the array initializer // Clone the initializer
auto* dst_init = ctx.Clone(src_init); auto* dst_init = ctx.Clone(src_init);
// Construct the constant that holds the array // Construct the constant that holds the hoisted initializer
auto* dst_var = ctx.dst->Const(dst_symbol, dst_array_ty, dst_init); auto* dst_var = ctx.dst->Const(dst_symbol, dst_ty, dst_init);
// Construct the variable declaration statement // Construct the variable declaration statement
auto* dst_var_decl = auto* dst_var_decl =
ctx.dst->create<ast::VariableDeclStatement>(dst_var); ctx.dst->create<ast::VariableDeclStatement>(dst_var);
@ -100,7 +101,7 @@ void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const {
// Insert the constant before the usage // Insert the constant before the usage
ctx.InsertBefore(src_sem_stmt->Block()->statements(), src_stmt, ctx.InsertBefore(src_sem_stmt->Block()->statements(), src_stmt,
dst_var_decl); dst_var_decl);
// Replace the inlined array with a reference to the constant // Replace the inlined initializer with a reference to the constant
ctx.Replace(src_init, dst_ident); ctx.Replace(src_init, dst_ident);
} }
} }

View File

@ -40,10 +40,10 @@ class Hlsl : public Transform {
Output Run(const Program* program, const DataMap& data = {}) override; Output Run(const Program* program, const DataMap& data = {}) override;
private: private:
/// Hoists the array initializer to a constant variable, declared just before /// Hoists the array and structure initializers to a constant variable,
/// the array usage statement. /// declared just before the statement of usage. See crbug.com/tint/406 for
/// See crbug.com/tint/406 for more details /// more details
void PromoteArrayInitializerToConstVar(CloneContext& ctx) const; void PromoteInitializersToConstVar(CloneContext& ctx) const;
/// Add an empty shader entry point if none exist in the module. /// Add an empty shader entry point if none exist in the module.
void AddEmptyEntryPoint(CloneContext& ctx) const; void AddEmptyEntryPoint(CloneContext& ctx) const;
}; };

View File

@ -51,6 +51,39 @@ fn main() -> void {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
TEST_F(HlslTest, PromoteStructureInitializerToConstVar_Basic) {
auto* src = R"(
struct S {
a : i32;
b : f32;
c : vec3<f32>;
};
[[stage(vertex)]]
fn main() -> void {
var x : f32 = S(1, 2.0, vec3<f32>()).b;
}
)";
auto* expect = R"(
struct S {
a : i32;
b : f32;
c : vec3<f32>;
};
[[stage(vertex)]]
fn main() -> void {
const tint_symbol_1 : S = S(1, 2.0, vec3<f32>());
var x : f32 = tint_symbol_1.b;
}
)";
auto got = Run<Hlsl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(HlslTest, PromoteArrayInitializerToConstVar_ArrayInArray) { TEST_F(HlslTest, PromoteArrayInitializerToConstVar_ArrayInArray) {
auto* src = R"( auto* src = R"(
[[stage(vertex)]] [[stage(vertex)]]
@ -74,14 +107,115 @@ fn main() -> void {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
TEST_F(HlslTest, PromoteArrayInitializerToConstVar_NoChangeOnArrayVarDecl) { TEST_F(HlslTest, PromoteStructureInitializerToConstVar_Nested) {
auto* src = R"( auto* src = R"(
struct S1 {
a : i32;
};
struct S2 {
a : i32;
b : S1;
c : i32;
};
struct S3 {
a : S2;
};
[[stage(vertex)]]
fn main() -> void {
var x : i32 = S3(S2(1, S1(2), 3)).a.b.a;
}
)";
auto* expect = R"(
struct S1 {
a : i32;
};
struct S2 {
a : i32;
b : S1;
c : i32;
};
struct S3 {
a : S2;
};
[[stage(vertex)]]
fn main() -> void {
const tint_symbol_1 : S1 = S1(2);
const tint_symbol_4 : S2 = S2(1, tint_symbol_1, 3);
const tint_symbol_8 : S3 = S3(tint_symbol_4);
var x : i32 = tint_symbol_8.a.b.a;
}
)";
auto got = Run<Hlsl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(HlslTest, PromoteInitializerToConstVar_Mixed) {
auto* src = R"(
struct S1 {
a : i32;
};
struct S2 {
a : array<S1, 3>;
};
[[stage(vertex)]]
fn main() -> void {
var x : i32 = S2(array<S1, 3>(S1(1), S1(2), S1(3))).a[1].a;
}
)";
auto* expect = R"(
struct S1 {
a : i32;
};
struct S2 {
a : array<S1, 3>;
};
[[stage(vertex)]]
fn main() -> void {
const tint_symbol_1 : S1 = S1(1);
const tint_symbol_4 : S1 = S1(2);
const tint_symbol_5 : S1 = S1(3);
const tint_symbol_6 : array<S1, 3> = array<S1, 3>(tint_symbol_1, tint_symbol_4, tint_symbol_5);
const tint_symbol_7 : S2 = S2(tint_symbol_6);
var x : i32 = tint_symbol_7.a[1].a;
}
)";
auto got = Run<Hlsl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(HlslTest, PromoteInitializerToConstVar_NoChangeOnVarDecl) {
auto* src = R"(
struct S {
a : i32;
b : f32;
c : i32;
};
[[stage(vertex)]] [[stage(vertex)]]
fn main() -> void { fn main() -> void {
var local_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0); var local_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0);
var local_str : S = S(1, 2.0, 3);
} }
const module_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0); const module_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0);
const module_str : S = S(1, 2.0, 3);
)"; )";
auto* expect = src; auto* expect = src;

View File

@ -1229,7 +1229,16 @@ bool GeneratorImpl::EmitScalarConstructor(
bool GeneratorImpl::EmitTypeConstructor(std::ostream& pre, bool GeneratorImpl::EmitTypeConstructor(std::ostream& pre,
std::ostream& out, std::ostream& out,
ast::TypeConstructorExpression* expr) { ast::TypeConstructorExpression* expr) {
if (expr->type()->Is<type::Array>()) { // If the type constructor is empty then we need to construct with the zero
// value for all components.
if (expr->values().empty()) {
return EmitZeroValue(out, expr->type());
}
bool brackets =
expr->type()->UnwrapAliasIfNeeded()->IsAnyOf<type::Array, type::Struct>();
if (brackets) {
out << "{"; out << "{";
} else { } else {
if (!EmitType(out, expr->type(), "")) { if (!EmitType(out, expr->type(), "")) {
@ -1238,31 +1247,19 @@ bool GeneratorImpl::EmitTypeConstructor(std::ostream& pre,
out << "("; out << "(";
} }
// If the type constructor is empty then we need to construct with the zero bool first = true;
// value for all components. for (auto* e : expr->values()) {
if (expr->values().empty()) { if (!first) {
if (!EmitZeroValue(out, expr->type())) { out << ", ";
}
first = false;
if (!EmitExpression(pre, out, e)) {
return false; return false;
} }
} else {
bool first = true;
for (auto* e : expr->values()) {
if (!first) {
out << ", ";
}
first = false;
if (!EmitExpression(pre, out, e)) {
return false;
}
}
} }
if (expr->type()->Is<type::Array>()) { out << (brackets ? "}" : ")");
out << "}";
} else {
out << ")";
}
return true; return true;
} }
@ -1994,6 +1991,10 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, type::Type* type) {
} else if (type->Is<type::U32>()) { } else if (type->Is<type::U32>()) {
out << "0u"; out << "0u";
} else if (auto* vec = type->As<type::Vector>()) { } else if (auto* vec = type->As<type::Vector>()) {
if (!EmitType(out, type, "")) {
return false;
}
ScopedParen sp(out);
for (uint32_t i = 0; i < vec->size(); i++) { for (uint32_t i = 0; i < vec->size(); i++) {
if (i != 0) { if (i != 0) {
out << ", "; out << ", ";
@ -2003,6 +2004,10 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, type::Type* type) {
} }
} }
} else if (auto* mat = type->As<type::Matrix>()) { } else if (auto* mat = type->As<type::Matrix>()) {
if (!EmitType(out, type, "")) {
return false;
}
ScopedParen sp(out);
for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) { for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
if (i != 0) { if (i != 0) {
out << ", "; out << ", ";
@ -2011,6 +2016,19 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, type::Type* type) {
return false; return false;
} }
} }
} else if (auto* str = type->As<type::Struct>()) {
out << "{";
bool first = true;
for (auto* member : str->impl()->members()) {
if (!first) {
out << ", ";
}
first = false;
if (!EmitZeroValue(out, member->type())) {
return false;
}
}
out << "}";
} else { } else {
diagnostics_.add_error("Invalid type for zero emission: " + diagnostics_.add_error("Invalid type for zero emission: " +
type->type_name()); type->type_name());

View File

@ -194,6 +194,40 @@ TEST_F(HlslGeneratorImplTest_Constructor,
Validate(); Validate();
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Struct) {
auto* str = Structure("S", {
Member("a", ty.i32()),
Member("b", ty.f32()),
Member("c", ty.vec3<i32>()),
});
WrapInFunction(Construct(str, 1, 2.0f, vec3<i32>(3, 4, 5)));
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_THAT(result(), HasSubstr("{1, 2.0f, int3(3, 4, 5)}"));
Validate();
}
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Struct_Empty) {
auto* str = Structure("S", {
Member("a", ty.i32()),
Member("b", ty.f32()),
Member("c", ty.vec3<i32>()),
});
WrapInFunction(Construct(str));
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_THAT(result(), HasSubstr("{0, 0.0f, int3(0, 0, 0)}"));
Validate();
}
} // namespace } // namespace
} // namespace hlsl } // namespace hlsl
} // namespace writer } // namespace writer

View File

@ -124,16 +124,17 @@ TEST_F(HlslGeneratorImplTest_Function,
GeneratorImpl& gen = SanitizeAndBuild(); GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error(); ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_EQ(result(), R"(struct tint_symbol_1 { EXPECT_EQ(result(), R"(struct tint_symbol_5 {
float foo : TEXCOORD0; float foo : TEXCOORD0;
}; };
struct tint_symbol_3 { struct tint_symbol_2 {
float value : SV_Target1; float value : SV_Target1;
}; };
tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) { tint_symbol_2 frag_main(tint_symbol_5 tint_symbol_7) {
const float foo = tint_symbol_6.foo; const float foo = tint_symbol_7.foo;
return tint_symbol_3(foo); const tint_symbol_2 tint_symbol_1 = {foo};
return tint_symbol_1;
} }
)"); )");
@ -157,16 +158,17 @@ TEST_F(HlslGeneratorImplTest_Function,
GeneratorImpl& gen = SanitizeAndBuild(); GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error(); ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_EQ(result(), R"(struct tint_symbol_1 { EXPECT_EQ(result(), R"(struct tint_symbol_6 {
float4 coord : SV_Position; float4 coord : SV_Position;
}; };
struct tint_symbol_3 { struct tint_symbol_2 {
float value : SV_Depth; float value : SV_Depth;
}; };
tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) { tint_symbol_2 frag_main(tint_symbol_6 tint_symbol_8) {
const float4 coord = tint_symbol_6.coord; const float4 coord = tint_symbol_8.coord;
return tint_symbol_3(coord.x); const tint_symbol_2 tint_symbol_1 = {coord.x};
return tint_symbol_1;
} }
)"); )");
@ -213,22 +215,23 @@ TEST_F(HlslGeneratorImplTest_Function,
float col1; float col1;
float col2; float col2;
}; };
struct tint_symbol_4 { struct tint_symbol_2 {
float col1 : TEXCOORD1; float col1 : TEXCOORD1;
float col2 : TEXCOORD2; float col2 : TEXCOORD2;
}; };
struct tint_symbol_7 { struct tint_symbol_8 {
float col1 : TEXCOORD1; float col1 : TEXCOORD1;
float col2 : TEXCOORD2; float col2 : TEXCOORD2;
}; };
tint_symbol_4 vert_main() { tint_symbol_2 vert_main() {
const Interface tint_symbol_6 = Interface(0.5f, 0.25f); const Interface tint_symbol_5 = {0.5f, 0.25f};
return tint_symbol_4(tint_symbol_6.col1, tint_symbol_6.col2); const tint_symbol_2 tint_symbol_1 = {tint_symbol_5.col1, tint_symbol_5.col2};
return tint_symbol_1;
} }
void frag_main(tint_symbol_7 tint_symbol_9) { void frag_main(tint_symbol_8 tint_symbol_10) {
const Interface colors = Interface(tint_symbol_9.col1, tint_symbol_9.col2); const Interface colors = {tint_symbol_10.col1, tint_symbol_10.col2};
const float r = colors.col1; const float r = colors.col1;
const float g = colors.col2; const float g = colors.col2;
return; return;
@ -236,8 +239,7 @@ void frag_main(tint_symbol_7 tint_symbol_9) {
)"); )");
// TODO(crbug.com/tint/702): This is not legal HLSL Validate();
// Validate();
} }
TEST_F(HlslGeneratorImplTest_Function, TEST_F(HlslGeneratorImplTest_Function,
@ -281,25 +283,28 @@ TEST_F(HlslGeneratorImplTest_Function,
EXPECT_EQ(result(), R"(struct VertexOutput { EXPECT_EQ(result(), R"(struct VertexOutput {
float4 pos; float4 pos;
}; };
struct tint_symbol_5 { struct tint_symbol_2 {
float4 pos : SV_Position; float4 pos : SV_Position;
}; };
struct tint_symbol_8 { struct tint_symbol_6 {
float4 pos : SV_Position; float4 pos : SV_Position;
}; };
VertexOutput foo(float x) { VertexOutput foo(float x) {
return VertexOutput(float4(x, x, x, 1.0f)); const VertexOutput tint_symbol_8 = {float4(x, x, x, 1.0f)};
return tint_symbol_8;
} }
tint_symbol_5 vert_main1() { tint_symbol_2 vert_main1() {
const VertexOutput tint_symbol_7 = VertexOutput(foo(0.5f)); const VertexOutput tint_symbol_4 = {foo(0.5f)};
return tint_symbol_5(tint_symbol_7.pos); const tint_symbol_2 tint_symbol_1 = {tint_symbol_4.pos};
return tint_symbol_1;
} }
tint_symbol_8 vert_main2() { tint_symbol_6 vert_main2() {
const VertexOutput tint_symbol_10 = VertexOutput(foo(0.25f)); const VertexOutput tint_symbol_7 = {foo(0.25f)};
return tint_symbol_8(tint_symbol_10.pos); const tint_symbol_6 tint_symbol_5 = {tint_symbol_7.pos};
return tint_symbol_5;
} }
)"); )");

View File

@ -51,6 +51,45 @@ TEST_F(HlslSanitizerTest, PromoteArrayInitializerToConstVar) {
EXPECT_EQ(expect, got); EXPECT_EQ(expect, got);
} }
TEST_F(HlslSanitizerTest, PromoteStructInitializerToConstVar) {
auto* str = Structure("S", {
Member("a", ty.i32()),
Member("b", ty.vec3<f32>()),
Member("c", ty.i32()),
});
auto* struct_init = Construct(str, 1, vec3<f32>(2.f, 3.f, 4.f), 4);
auto* struct_access = MemberAccessor(struct_init, "b");
auto* pos =
Var("pos", ty.vec3<f32>(), ast::StorageClass::kFunction, struct_access);
Func("main", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::VariableDeclStatement>(pos),
},
ast::DecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex),
});
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
auto got = result();
auto* expect = R"(struct S {
int a;
float3 b;
int c;
};
void main() {
const S tint_symbol_1 = {1, float3(2.0f, 3.0f, 4.0f), 4};
float3 pos = tint_symbol_1.b;
return;
}
)";
EXPECT_EQ(expect, got);
}
} // namespace } // namespace
} // namespace hlsl } // namespace hlsl
} // namespace writer } // namespace writer