mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-13 07:06:11 +00:00
tint/writer/msl: Inline constant expressions
This is required to handle materialized values, and for constant expressions. Bug: tint:1504 Change-Id: Ic3ac62317241fa6f7009360128f222aeb56f62e4 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/92083 Reviewed-by: Corentin Wallez <cwallez@chromium.org> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
2e22d9285c
commit
cb6ddd2aa6
@@ -36,12 +36,12 @@
|
||||
#include "src/tint/sem/atomic.h"
|
||||
#include "src/tint/sem/bool.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/constant.h"
|
||||
#include "src/tint/sem/depth_multisampled_texture.h"
|
||||
#include "src/tint/sem/depth_texture.h"
|
||||
#include "src/tint/sem/f32.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/i32.h"
|
||||
#include "src/tint/sem/materialize.h"
|
||||
#include "src/tint/sem/matrix.h"
|
||||
#include "src/tint/sem/member_accessor_expression.h"
|
||||
#include "src/tint/sem/module.h"
|
||||
@@ -86,6 +86,31 @@ bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
|
||||
return IsAnyOf<ast::BreakStatement, ast::FallthroughStatement>(stmts->Last());
|
||||
}
|
||||
|
||||
void PrintF32(std::ostream& out, float value) {
|
||||
// Note: Currently inf and nan should not be constructable, but this is implemented for the day
|
||||
// we support them.
|
||||
if (std::isinf(value)) {
|
||||
out << (value >= 0 ? "INFINITY" : "-INFINITY");
|
||||
} else if (std::isnan(value)) {
|
||||
out << "NAN";
|
||||
} else {
|
||||
out << FloatToString(value) << "f";
|
||||
}
|
||||
}
|
||||
|
||||
void PrintI32(std::ostream& out, int32_t value) {
|
||||
// MSL (and C++) parse `-2147483648` as a `long` because it parses unary minus and `2147483648`
|
||||
// as separate tokens, and the latter doesn't fit into an (32-bit) `int`.
|
||||
// WGSL, on the other hand, parses this as an `i32`.
|
||||
// To avoid issues with `long` to `int` casts, emit `(-2147483647 - 1)` instead, which ensures
|
||||
// the expression type is `int`.
|
||||
if (auto int_min = std::numeric_limits<int32_t>::min(); value == int_min) {
|
||||
out << "(" << int_min + 1 << " - 1)";
|
||||
} else {
|
||||
out << value;
|
||||
}
|
||||
}
|
||||
|
||||
class ScopedBitCast {
|
||||
public:
|
||||
ScopedBitCast(GeneratorImpl* generator,
|
||||
@@ -551,12 +576,7 @@ bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) {
|
||||
auto* sem = program_->Sem().Get(expr);
|
||||
if (auto* m = sem->As<sem::Materialize>()) {
|
||||
// TODO(crbug.com/tint/1504): Just emit the constant value.
|
||||
sem = m->Expr();
|
||||
}
|
||||
auto* call = sem->As<sem::Call>();
|
||||
auto* call = program_->Sem().Get<sem::Call>(expr);
|
||||
auto* target = call->Target();
|
||||
return Switch(
|
||||
target, [&](const sem::Function* func) { return EmitFunctionCall(out, call, func); },
|
||||
@@ -1522,8 +1542,7 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
|
||||
if (!EmitType(out, mat, "")) {
|
||||
return false;
|
||||
}
|
||||
out << "(";
|
||||
TINT_DEFER(out << ")");
|
||||
ScopedParen sp(out);
|
||||
return EmitZeroValue(out, mat->type());
|
||||
},
|
||||
[&](const sem::Array* arr) {
|
||||
@@ -1543,6 +1562,92 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant& constant) {
|
||||
auto emit_bool = [&](size_t element_idx) {
|
||||
out << (constant.Element<AInt>(element_idx) ? "true" : "false");
|
||||
return true;
|
||||
};
|
||||
auto emit_f32 = [&](size_t element_idx) {
|
||||
PrintF32(out, static_cast<float>(constant.Element<AFloat>(element_idx)));
|
||||
return true;
|
||||
};
|
||||
auto emit_i32 = [&](size_t element_idx) {
|
||||
PrintI32(out, static_cast<int32_t>(constant.Element<AInt>(element_idx).value));
|
||||
return true;
|
||||
};
|
||||
auto emit_u32 = [&](size_t element_idx) {
|
||||
out << constant.Element<AInt>(element_idx).value << "u";
|
||||
return true;
|
||||
};
|
||||
auto emit_vector = [&](const sem::Vector* vec_ty, size_t start, size_t end) {
|
||||
if (!EmitType(out, vec_ty, "")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ScopedParen sp(out);
|
||||
|
||||
auto emit_els = [&](auto emit_el) {
|
||||
if (constant.AllEqual(start, end)) {
|
||||
return emit_el(start);
|
||||
}
|
||||
for (size_t i = start; i < end; i++) {
|
||||
if (i > start) {
|
||||
out << ", ";
|
||||
}
|
||||
if (!emit_el(i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
return Switch(
|
||||
vec_ty->type(), //
|
||||
[&](const sem::Bool*) { return emit_els(emit_bool); }, //
|
||||
[&](const sem::F32*) { return emit_els(emit_f32); }, //
|
||||
[&](const sem::I32*) { return emit_els(emit_i32); }, //
|
||||
[&](const sem::U32*) { return emit_els(emit_u32); }, //
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(diag::System::Writer,
|
||||
"unhandled constant vector element type: " +
|
||||
builder_.FriendlyName(vec_ty->type()));
|
||||
return false;
|
||||
});
|
||||
};
|
||||
auto emit_matrix = [&](const sem::Matrix* m) {
|
||||
if (!EmitType(out, constant.Type(), "")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ScopedParen sp(out);
|
||||
|
||||
for (size_t column_idx = 0; column_idx < m->columns(); column_idx++) {
|
||||
if (column_idx > 0) {
|
||||
out << ", ";
|
||||
}
|
||||
size_t start = m->rows() * column_idx;
|
||||
size_t end = m->rows() * (column_idx + 1);
|
||||
if (!emit_vector(m->ColumnType(), start, end)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
return Switch(
|
||||
constant.Type(), //
|
||||
[&](const sem::Bool*) { return emit_bool(0); }, //
|
||||
[&](const sem::F32*) { return emit_f32(0); }, //
|
||||
[&](const sem::I32*) { return emit_i32(0); }, //
|
||||
[&](const sem::U32*) { return emit_u32(0); }, //
|
||||
[&](const sem::Vector* v) { return emit_vector(v, 0, constant.ElementCount()); }, //
|
||||
[&](const sem::Matrix* m) { return emit_matrix(m); }, //
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"unhandled constant type: " + builder_.FriendlyName(constant.Type()));
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression* lit) {
|
||||
return Switch(
|
||||
lit,
|
||||
@@ -1551,32 +1656,14 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression*
|
||||
return true;
|
||||
},
|
||||
[&](const ast::FloatLiteralExpression* l) {
|
||||
auto f32 = static_cast<float>(l->value);
|
||||
if (std::isinf(f32)) {
|
||||
out << (f32 >= 0 ? "INFINITY" : "-INFINITY");
|
||||
} else if (std::isnan(f32)) {
|
||||
out << "NAN";
|
||||
} else {
|
||||
out << FloatToString(f32) << "f";
|
||||
}
|
||||
PrintF32(out, static_cast<float>(l->value));
|
||||
return true;
|
||||
},
|
||||
[&](const ast::IntLiteralExpression* i) {
|
||||
switch (i->suffix) {
|
||||
case ast::IntLiteralExpression::Suffix::kNone:
|
||||
case ast::IntLiteralExpression::Suffix::kI: {
|
||||
// MSL (and C++) parse `-2147483648` as a `long` because it parses
|
||||
// unary minus and `2147483648` as separate tokens, and the latter
|
||||
// doesn't fit into an (32-bit) `int`. WGSL, OTOH, parses this as an
|
||||
// `i32`. To avoid issues with `long` to `int` casts, emit
|
||||
// `(2147483647 - 1)` instead, which ensures the expression type is
|
||||
// `int`.
|
||||
const auto int_min = std::numeric_limits<int32_t>::min();
|
||||
if (i->value == int_min) {
|
||||
out << "(" << int_min + 1 << " - 1)";
|
||||
} else {
|
||||
out << i->value;
|
||||
}
|
||||
PrintI32(out, static_cast<int32_t>(i->value));
|
||||
return true;
|
||||
}
|
||||
case ast::IntLiteralExpression::Suffix::kU: {
|
||||
@@ -1594,6 +1681,11 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression*
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitExpression(std::ostream& out, const ast::Expression* expr) {
|
||||
if (auto* sem = builder_.Sem().Get(expr)) {
|
||||
if (auto constant = sem->ConstantValue()) {
|
||||
return EmitConstant(out, constant);
|
||||
}
|
||||
}
|
||||
return Switch(
|
||||
expr,
|
||||
[&](const ast::IndexAccessorExpression* a) { //
|
||||
|
||||
@@ -45,6 +45,7 @@
|
||||
// Forward declarations
|
||||
namespace tint::sem {
|
||||
class Call;
|
||||
class Constant;
|
||||
class Builtin;
|
||||
class TypeConstructor;
|
||||
class TypeConversion;
|
||||
@@ -250,6 +251,11 @@ class GeneratorImpl : public TextGenerator {
|
||||
/// @param stmt the statement to emit
|
||||
/// @returns true if the statement was successfully emitted
|
||||
bool EmitIf(const ast::IfStatement* stmt);
|
||||
/// Handles a constant value
|
||||
/// @param out the output stream
|
||||
/// @param constant the constant value to emit
|
||||
/// @returns true if the constant value was successfully emitted
|
||||
bool EmitConstant(std::ostream& out, const sem::Constant& constant);
|
||||
/// Handles a literal
|
||||
/// @param out the output of the expression stream
|
||||
/// @param lit the literal to emit
|
||||
|
||||
@@ -184,7 +184,7 @@ std::string expected_texture_overload(ast::builtin::test::ValidTextureOverload o
|
||||
case ValidTextureOverload::kSampleGrad2dF32:
|
||||
return R"(texture.sample(sampler, float2(1.0f, 2.0f), gradient2d(float2(3.0f, 4.0f), float2(5.0f, 6.0f))))";
|
||||
case ValidTextureOverload::kSampleGrad2dOffsetF32:
|
||||
return R"(texture.sample(sampler, float2(1.0f, 2.0f), gradient2d(float2(3.0f, 4.0f), float2(5.0f, 6.0f)), int2(7, 7)))";
|
||||
return R"(texture.sample(sampler, float2(1.0f, 2.0f), gradient2d(float2(3.0f, 4.0f), float2(5.0f, 6.0f)), int2(7)))";
|
||||
case ValidTextureOverload::kSampleGrad2dArrayF32:
|
||||
return R"(texture.sample(sampler, float2(1.0f, 2.0f), 3, gradient2d(float2(4.0f, 5.0f), float2(6.0f, 7.0f))))";
|
||||
case ValidTextureOverload::kSampleGrad2dArrayOffsetF32:
|
||||
|
||||
@@ -29,7 +29,7 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Cast_Scalar) {
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "float(1)");
|
||||
EXPECT_EQ(out.str(), "1.0f");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitExpression_Cast_Vector) {
|
||||
@@ -40,7 +40,7 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Cast_Vector) {
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "float3(int3(1, 2, 3))");
|
||||
EXPECT_EQ(out.str(), "float3(1.0f, 2.0f, 3.0f)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitExpression_Cast_IntMin) {
|
||||
@@ -51,7 +51,7 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Cast_IntMin) {
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "uint((-2147483647 - 1))");
|
||||
EXPECT_EQ(out.str(), "0u");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -67,7 +67,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) {
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("float(-0.000012f)"));
|
||||
EXPECT_THAT(gen.result(), HasSubstr("-0.000012f"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Bool) {
|
||||
@@ -76,7 +76,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Bool) {
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("bool(true)"));
|
||||
EXPECT_THAT(gen.result(), HasSubstr("true"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Int) {
|
||||
@@ -85,7 +85,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Int) {
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("int(-12345)"));
|
||||
EXPECT_THAT(gen.result(), HasSubstr("-12345"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Uint) {
|
||||
@@ -94,7 +94,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Uint) {
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("uint(12345u)"));
|
||||
EXPECT_THAT(gen.result(), HasSubstr("12345u"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) {
|
||||
@@ -112,7 +112,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) {
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("float3()"));
|
||||
EXPECT_THAT(gen.result(), HasSubstr("float3(0.0f)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat) {
|
||||
@@ -134,7 +134,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty) {
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("float4x4()"));
|
||||
EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(0.0f), float4(0.0f)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Array) {
|
||||
|
||||
@@ -207,7 +207,7 @@ struct tint_symbol {
|
||||
};
|
||||
|
||||
Interface vert_main_inner() {
|
||||
Interface const tint_symbol_3 = {.col1=0.5f, .col2=0.25f, .pos=float4()};
|
||||
Interface const tint_symbol_3 = {.col1=0.5f, .col2=0.25f, .pos=float4(0.0f)};
|
||||
return tint_symbol_3;
|
||||
}
|
||||
|
||||
|
||||
@@ -157,7 +157,7 @@ struct tint_symbol_3 {
|
||||
|
||||
void comp_main_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol) {
|
||||
{
|
||||
*(tint_symbol) = float2x2();
|
||||
*(tint_symbol) = float2x2(float2(0.0f), float2(0.0f));
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float2x2 const x = *(tint_symbol);
|
||||
@@ -199,7 +199,7 @@ struct tint_symbol_3 {
|
||||
void comp_main_inner(uint local_invocation_index, threadgroup tint_array_wrapper* const tint_symbol) {
|
||||
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
|
||||
uint const i = idx;
|
||||
(*(tint_symbol)).arr[i] = float2x2();
|
||||
(*(tint_symbol)).arr[i] = float2x2(float2(0.0f), float2(0.0f));
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
tint_array_wrapper const x = *(tint_symbol);
|
||||
@@ -333,9 +333,9 @@ struct tint_symbol_23 {
|
||||
|
||||
void main1_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol, threadgroup float2x3* const tint_symbol_1, threadgroup float2x4* const tint_symbol_2) {
|
||||
{
|
||||
*(tint_symbol) = float2x2();
|
||||
*(tint_symbol_1) = float2x3();
|
||||
*(tint_symbol_2) = float2x4();
|
||||
*(tint_symbol) = float2x2(float2(0.0f), float2(0.0f));
|
||||
*(tint_symbol_1) = float2x3(float3(0.0f), float3(0.0f));
|
||||
*(tint_symbol_2) = float2x4(float4(0.0f), float4(0.0f));
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float2x2 const a1 = *(tint_symbol);
|
||||
@@ -353,9 +353,9 @@ kernel void main1(threadgroup tint_symbol_7* tint_symbol_4 [[threadgroup(0)]], u
|
||||
|
||||
void main2_inner(uint local_invocation_index_1, threadgroup float3x2* const tint_symbol_8, threadgroup float3x3* const tint_symbol_9, threadgroup float3x4* const tint_symbol_10) {
|
||||
{
|
||||
*(tint_symbol_8) = float3x2();
|
||||
*(tint_symbol_9) = float3x3();
|
||||
*(tint_symbol_10) = float3x4();
|
||||
*(tint_symbol_8) = float3x2(float2(0.0f), float2(0.0f), float2(0.0f));
|
||||
*(tint_symbol_9) = float3x3(float3(0.0f), float3(0.0f), float3(0.0f));
|
||||
*(tint_symbol_10) = float3x4(float4(0.0f), float4(0.0f), float4(0.0f));
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float3x2 const a1 = *(tint_symbol_8);
|
||||
@@ -373,9 +373,9 @@ kernel void main2(threadgroup tint_symbol_15* tint_symbol_12 [[threadgroup(0)]],
|
||||
|
||||
void main3_inner(uint local_invocation_index_2, threadgroup float4x2* const tint_symbol_16, threadgroup float4x3* const tint_symbol_17, threadgroup float4x4* const tint_symbol_18) {
|
||||
{
|
||||
*(tint_symbol_16) = float4x2();
|
||||
*(tint_symbol_17) = float4x3();
|
||||
*(tint_symbol_18) = float4x4();
|
||||
*(tint_symbol_16) = float4x2(float2(0.0f), float2(0.0f), float2(0.0f), float2(0.0f));
|
||||
*(tint_symbol_17) = float4x3(float3(0.0f), float3(0.0f), float3(0.0f), float3(0.0f));
|
||||
*(tint_symbol_18) = float4x4(float4(0.0f), float4(0.0f), float4(0.0f), float4(0.0f));
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float4x2 const a1 = *(tint_symbol_16);
|
||||
|
||||
@@ -48,7 +48,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const) {
|
||||
gen.increment_indent();
|
||||
|
||||
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
|
||||
EXPECT_EQ(gen.result(), " float const a = float();\n");
|
||||
EXPECT_EQ(gen.result(), " float const a = 0.0f;\n");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Array) {
|
||||
@@ -132,7 +132,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_Private) {
|
||||
GeneratorImpl& gen = SanitizeAndBuild();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("thread float tint_symbol_1 = initializer;\n"));
|
||||
EXPECT_THAT(gen.result(), HasSubstr("thread float tint_symbol_1 = 0.0f;\n float const tint_symbol = tint_symbol_1;\n return;\n"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Workgroup) {
|
||||
@@ -158,7 +158,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec) {
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
|
||||
EXPECT_EQ(gen.result(), R"(float3 a = float3();
|
||||
EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f);
|
||||
)");
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user