tint/writer/msl: Inline constant expressions

This is required to handle materialized values, and for constant
expressions.

Bug: tint:1504
Change-Id: Ic3ac62317241fa6f7009360128f222aeb56f62e4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/92083
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton
2022-06-01 10:08:29 +00:00
committed by Dawn LUCI CQ
parent 2e22d9285c
commit cb6ddd2aa6
1012 changed files with 1950 additions and 1852 deletions

View File

@@ -36,12 +36,12 @@
#include "src/tint/sem/atomic.h"
#include "src/tint/sem/bool.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/constant.h"
#include "src/tint/sem/depth_multisampled_texture.h"
#include "src/tint/sem/depth_texture.h"
#include "src/tint/sem/f32.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/i32.h"
#include "src/tint/sem/materialize.h"
#include "src/tint/sem/matrix.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/module.h"
@@ -86,6 +86,31 @@ bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
return IsAnyOf<ast::BreakStatement, ast::FallthroughStatement>(stmts->Last());
}
void PrintF32(std::ostream& out, float value) {
// Note: Currently inf and nan should not be constructable, but this is implemented for the day
// we support them.
if (std::isinf(value)) {
out << (value >= 0 ? "INFINITY" : "-INFINITY");
} else if (std::isnan(value)) {
out << "NAN";
} else {
out << FloatToString(value) << "f";
}
}
void PrintI32(std::ostream& out, int32_t value) {
// MSL (and C++) parse `-2147483648` as a `long` because it parses unary minus and `2147483648`
// as separate tokens, and the latter doesn't fit into an (32-bit) `int`.
// WGSL, on the other hand, parses this as an `i32`.
// To avoid issues with `long` to `int` casts, emit `(-2147483647 - 1)` instead, which ensures
// the expression type is `int`.
if (auto int_min = std::numeric_limits<int32_t>::min(); value == int_min) {
out << "(" << int_min + 1 << " - 1)";
} else {
out << value;
}
}
class ScopedBitCast {
public:
ScopedBitCast(GeneratorImpl* generator,
@@ -551,12 +576,7 @@ bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
}
bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) {
auto* sem = program_->Sem().Get(expr);
if (auto* m = sem->As<sem::Materialize>()) {
// TODO(crbug.com/tint/1504): Just emit the constant value.
sem = m->Expr();
}
auto* call = sem->As<sem::Call>();
auto* call = program_->Sem().Get<sem::Call>(expr);
auto* target = call->Target();
return Switch(
target, [&](const sem::Function* func) { return EmitFunctionCall(out, call, func); },
@@ -1522,8 +1542,7 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
if (!EmitType(out, mat, "")) {
return false;
}
out << "(";
TINT_DEFER(out << ")");
ScopedParen sp(out);
return EmitZeroValue(out, mat->type());
},
[&](const sem::Array* arr) {
@@ -1543,6 +1562,92 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
});
}
bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant& constant) {
auto emit_bool = [&](size_t element_idx) {
out << (constant.Element<AInt>(element_idx) ? "true" : "false");
return true;
};
auto emit_f32 = [&](size_t element_idx) {
PrintF32(out, static_cast<float>(constant.Element<AFloat>(element_idx)));
return true;
};
auto emit_i32 = [&](size_t element_idx) {
PrintI32(out, static_cast<int32_t>(constant.Element<AInt>(element_idx).value));
return true;
};
auto emit_u32 = [&](size_t element_idx) {
out << constant.Element<AInt>(element_idx).value << "u";
return true;
};
auto emit_vector = [&](const sem::Vector* vec_ty, size_t start, size_t end) {
if (!EmitType(out, vec_ty, "")) {
return false;
}
ScopedParen sp(out);
auto emit_els = [&](auto emit_el) {
if (constant.AllEqual(start, end)) {
return emit_el(start);
}
for (size_t i = start; i < end; i++) {
if (i > start) {
out << ", ";
}
if (!emit_el(i)) {
return false;
}
}
return true;
};
return Switch(
vec_ty->type(), //
[&](const sem::Bool*) { return emit_els(emit_bool); }, //
[&](const sem::F32*) { return emit_els(emit_f32); }, //
[&](const sem::I32*) { return emit_els(emit_i32); }, //
[&](const sem::U32*) { return emit_els(emit_u32); }, //
[&](Default) {
diagnostics_.add_error(diag::System::Writer,
"unhandled constant vector element type: " +
builder_.FriendlyName(vec_ty->type()));
return false;
});
};
auto emit_matrix = [&](const sem::Matrix* m) {
if (!EmitType(out, constant.Type(), "")) {
return false;
}
ScopedParen sp(out);
for (size_t column_idx = 0; column_idx < m->columns(); column_idx++) {
if (column_idx > 0) {
out << ", ";
}
size_t start = m->rows() * column_idx;
size_t end = m->rows() * (column_idx + 1);
if (!emit_vector(m->ColumnType(), start, end)) {
return false;
}
}
return true;
};
return Switch(
constant.Type(), //
[&](const sem::Bool*) { return emit_bool(0); }, //
[&](const sem::F32*) { return emit_f32(0); }, //
[&](const sem::I32*) { return emit_i32(0); }, //
[&](const sem::U32*) { return emit_u32(0); }, //
[&](const sem::Vector* v) { return emit_vector(v, 0, constant.ElementCount()); }, //
[&](const sem::Matrix* m) { return emit_matrix(m); }, //
[&](Default) {
diagnostics_.add_error(
diag::System::Writer,
"unhandled constant type: " + builder_.FriendlyName(constant.Type()));
return false;
});
}
bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression* lit) {
return Switch(
lit,
@@ -1551,32 +1656,14 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression*
return true;
},
[&](const ast::FloatLiteralExpression* l) {
auto f32 = static_cast<float>(l->value);
if (std::isinf(f32)) {
out << (f32 >= 0 ? "INFINITY" : "-INFINITY");
} else if (std::isnan(f32)) {
out << "NAN";
} else {
out << FloatToString(f32) << "f";
}
PrintF32(out, static_cast<float>(l->value));
return true;
},
[&](const ast::IntLiteralExpression* i) {
switch (i->suffix) {
case ast::IntLiteralExpression::Suffix::kNone:
case ast::IntLiteralExpression::Suffix::kI: {
// MSL (and C++) parse `-2147483648` as a `long` because it parses
// unary minus and `2147483648` as separate tokens, and the latter
// doesn't fit into an (32-bit) `int`. WGSL, OTOH, parses this as an
// `i32`. To avoid issues with `long` to `int` casts, emit
// `(2147483647 - 1)` instead, which ensures the expression type is
// `int`.
const auto int_min = std::numeric_limits<int32_t>::min();
if (i->value == int_min) {
out << "(" << int_min + 1 << " - 1)";
} else {
out << i->value;
}
PrintI32(out, static_cast<int32_t>(i->value));
return true;
}
case ast::IntLiteralExpression::Suffix::kU: {
@@ -1594,6 +1681,11 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression*
}
bool GeneratorImpl::EmitExpression(std::ostream& out, const ast::Expression* expr) {
if (auto* sem = builder_.Sem().Get(expr)) {
if (auto constant = sem->ConstantValue()) {
return EmitConstant(out, constant);
}
}
return Switch(
expr,
[&](const ast::IndexAccessorExpression* a) { //

View File

@@ -45,6 +45,7 @@
// Forward declarations
namespace tint::sem {
class Call;
class Constant;
class Builtin;
class TypeConstructor;
class TypeConversion;
@@ -250,6 +251,11 @@ class GeneratorImpl : public TextGenerator {
/// @param stmt the statement to emit
/// @returns true if the statement was successfully emitted
bool EmitIf(const ast::IfStatement* stmt);
/// Handles a constant value
/// @param out the output stream
/// @param constant the constant value to emit
/// @returns true if the constant value was successfully emitted
bool EmitConstant(std::ostream& out, const sem::Constant& constant);
/// Handles a literal
/// @param out the output of the expression stream
/// @param lit the literal to emit

View File

@@ -184,7 +184,7 @@ std::string expected_texture_overload(ast::builtin::test::ValidTextureOverload o
case ValidTextureOverload::kSampleGrad2dF32:
return R"(texture.sample(sampler, float2(1.0f, 2.0f), gradient2d(float2(3.0f, 4.0f), float2(5.0f, 6.0f))))";
case ValidTextureOverload::kSampleGrad2dOffsetF32:
return R"(texture.sample(sampler, float2(1.0f, 2.0f), gradient2d(float2(3.0f, 4.0f), float2(5.0f, 6.0f)), int2(7, 7)))";
return R"(texture.sample(sampler, float2(1.0f, 2.0f), gradient2d(float2(3.0f, 4.0f), float2(5.0f, 6.0f)), int2(7)))";
case ValidTextureOverload::kSampleGrad2dArrayF32:
return R"(texture.sample(sampler, float2(1.0f, 2.0f), 3, gradient2d(float2(4.0f, 5.0f), float2(6.0f, 7.0f))))";
case ValidTextureOverload::kSampleGrad2dArrayOffsetF32:

View File

@@ -29,7 +29,7 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Cast_Scalar) {
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error();
EXPECT_EQ(out.str(), "float(1)");
EXPECT_EQ(out.str(), "1.0f");
}
TEST_F(MslGeneratorImplTest, EmitExpression_Cast_Vector) {
@@ -40,7 +40,7 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Cast_Vector) {
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error();
EXPECT_EQ(out.str(), "float3(int3(1, 2, 3))");
EXPECT_EQ(out.str(), "float3(1.0f, 2.0f, 3.0f)");
}
TEST_F(MslGeneratorImplTest, EmitExpression_Cast_IntMin) {
@@ -51,7 +51,7 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Cast_IntMin) {
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error();
EXPECT_EQ(out.str(), "uint((-2147483647 - 1))");
EXPECT_EQ(out.str(), "0u");
}
} // namespace

View File

@@ -67,7 +67,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("float(-0.000012f)"));
EXPECT_THAT(gen.result(), HasSubstr("-0.000012f"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Bool) {
@@ -76,7 +76,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Bool) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("bool(true)"));
EXPECT_THAT(gen.result(), HasSubstr("true"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Int) {
@@ -85,7 +85,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Int) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("int(-12345)"));
EXPECT_THAT(gen.result(), HasSubstr("-12345"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Uint) {
@@ -94,7 +94,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Uint) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("uint(12345u)"));
EXPECT_THAT(gen.result(), HasSubstr("12345u"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) {
@@ -112,7 +112,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("float3()"));
EXPECT_THAT(gen.result(), HasSubstr("float3(0.0f)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat) {
@@ -134,7 +134,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("float4x4()"));
EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(0.0f), float4(0.0f)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Array) {

View File

@@ -207,7 +207,7 @@ struct tint_symbol {
};
Interface vert_main_inner() {
Interface const tint_symbol_3 = {.col1=0.5f, .col2=0.25f, .pos=float4()};
Interface const tint_symbol_3 = {.col1=0.5f, .col2=0.25f, .pos=float4(0.0f)};
return tint_symbol_3;
}

View File

@@ -157,7 +157,7 @@ struct tint_symbol_3 {
void comp_main_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol) {
{
*(tint_symbol) = float2x2();
*(tint_symbol) = float2x2(float2(0.0f), float2(0.0f));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float2x2 const x = *(tint_symbol);
@@ -199,7 +199,7 @@ struct tint_symbol_3 {
void comp_main_inner(uint local_invocation_index, threadgroup tint_array_wrapper* const tint_symbol) {
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
uint const i = idx;
(*(tint_symbol)).arr[i] = float2x2();
(*(tint_symbol)).arr[i] = float2x2(float2(0.0f), float2(0.0f));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tint_array_wrapper const x = *(tint_symbol);
@@ -333,9 +333,9 @@ struct tint_symbol_23 {
void main1_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol, threadgroup float2x3* const tint_symbol_1, threadgroup float2x4* const tint_symbol_2) {
{
*(tint_symbol) = float2x2();
*(tint_symbol_1) = float2x3();
*(tint_symbol_2) = float2x4();
*(tint_symbol) = float2x2(float2(0.0f), float2(0.0f));
*(tint_symbol_1) = float2x3(float3(0.0f), float3(0.0f));
*(tint_symbol_2) = float2x4(float4(0.0f), float4(0.0f));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float2x2 const a1 = *(tint_symbol);
@@ -353,9 +353,9 @@ kernel void main1(threadgroup tint_symbol_7* tint_symbol_4 [[threadgroup(0)]], u
void main2_inner(uint local_invocation_index_1, threadgroup float3x2* const tint_symbol_8, threadgroup float3x3* const tint_symbol_9, threadgroup float3x4* const tint_symbol_10) {
{
*(tint_symbol_8) = float3x2();
*(tint_symbol_9) = float3x3();
*(tint_symbol_10) = float3x4();
*(tint_symbol_8) = float3x2(float2(0.0f), float2(0.0f), float2(0.0f));
*(tint_symbol_9) = float3x3(float3(0.0f), float3(0.0f), float3(0.0f));
*(tint_symbol_10) = float3x4(float4(0.0f), float4(0.0f), float4(0.0f));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float3x2 const a1 = *(tint_symbol_8);
@@ -373,9 +373,9 @@ kernel void main2(threadgroup tint_symbol_15* tint_symbol_12 [[threadgroup(0)]],
void main3_inner(uint local_invocation_index_2, threadgroup float4x2* const tint_symbol_16, threadgroup float4x3* const tint_symbol_17, threadgroup float4x4* const tint_symbol_18) {
{
*(tint_symbol_16) = float4x2();
*(tint_symbol_17) = float4x3();
*(tint_symbol_18) = float4x4();
*(tint_symbol_16) = float4x2(float2(0.0f), float2(0.0f), float2(0.0f), float2(0.0f));
*(tint_symbol_17) = float4x3(float3(0.0f), float3(0.0f), float3(0.0f), float3(0.0f));
*(tint_symbol_18) = float4x4(float4(0.0f), float4(0.0f), float4(0.0f), float4(0.0f));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float4x2 const a1 = *(tint_symbol_16);

View File

@@ -48,7 +48,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const) {
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(gen.result(), " float const a = float();\n");
EXPECT_EQ(gen.result(), " float const a = 0.0f;\n");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Array) {
@@ -132,7 +132,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_Private) {
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("thread float tint_symbol_1 = initializer;\n"));
EXPECT_THAT(gen.result(), HasSubstr("thread float tint_symbol_1 = 0.0f;\n float const tint_symbol = tint_symbol_1;\n return;\n"));
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Workgroup) {
@@ -158,7 +158,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(gen.result(), R"(float3 a = float3();
EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f);
)");
}