Tint&Dawn: Enable f16 override
This CL enable using f16 override, and also fix related tests in Dawn and Tint. Bug: tint:1473, tint:1502 Change-Id: I8336770e8a73e5023c1aba224b7b5f21692fbaa6 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/124544 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com> Reviewed-by: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
parent
a0c34124d1
commit
6af073cecc
|
@ -25,6 +25,14 @@
|
||||||
#include "dawn/native/PipelineLayout.h"
|
#include "dawn/native/PipelineLayout.h"
|
||||||
#include "dawn/native/ShaderModule.h"
|
#include "dawn/native/ShaderModule.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
bool IsDoubleValueRepresentableAsF16(double value) {
|
||||||
|
constexpr double kLowestF16 = -65504.0;
|
||||||
|
constexpr double kMaxF16 = 65504.0;
|
||||||
|
return kLowestF16 <= value && value <= kMaxF16;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace dawn::native {
|
namespace dawn::native {
|
||||||
MaybeError ValidateProgrammableStage(DeviceBase* device,
|
MaybeError ValidateProgrammableStage(DeviceBase* device,
|
||||||
const ShaderModuleBase* module,
|
const ShaderModuleBase* module,
|
||||||
|
@ -80,6 +88,12 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
|
||||||
"representable in type (%s)",
|
"representable in type (%s)",
|
||||||
constants[i].key, constants[i].value, "f32");
|
constants[i].key, constants[i].value, "f32");
|
||||||
break;
|
break;
|
||||||
|
case EntryPointMetadata::Override::Type::Float16:
|
||||||
|
DAWN_INVALID_IF(!IsDoubleValueRepresentableAsF16(constants[i].value),
|
||||||
|
"Pipeline overridable constant \"%s\" with value (%f) is not "
|
||||||
|
"representable in type (%s)",
|
||||||
|
constants[i].key, constants[i].value, "f16");
|
||||||
|
break;
|
||||||
case EntryPointMetadata::Override::Type::Int32:
|
case EntryPointMetadata::Override::Type::Int32:
|
||||||
DAWN_INVALID_IF(!IsDoubleValueRepresentable<int32_t>(constants[i].value),
|
DAWN_INVALID_IF(!IsDoubleValueRepresentable<int32_t>(constants[i].value),
|
||||||
"Pipeline overridable constant \"%s\" with value (%f) is not "
|
"Pipeline overridable constant \"%s\" with value (%f) is not "
|
||||||
|
|
|
@ -288,6 +288,8 @@ EntryPointMetadata::Override::Type FromTintOverrideType(tint::inspector::Overrid
|
||||||
return EntryPointMetadata::Override::Type::Boolean;
|
return EntryPointMetadata::Override::Type::Boolean;
|
||||||
case tint::inspector::Override::Type::kFloat32:
|
case tint::inspector::Override::Type::kFloat32:
|
||||||
return EntryPointMetadata::Override::Type::Float32;
|
return EntryPointMetadata::Override::Type::Float32;
|
||||||
|
case tint::inspector::Override::Type::kFloat16:
|
||||||
|
return EntryPointMetadata::Override::Type::Float16;
|
||||||
case tint::inspector::Override::Type::kInt32:
|
case tint::inspector::Override::Type::kInt32:
|
||||||
return EntryPointMetadata::Override::Type::Int32;
|
return EntryPointMetadata::Override::Type::Int32;
|
||||||
case tint::inspector::Override::Type::kUint32:
|
case tint::inspector::Override::Type::kUint32:
|
||||||
|
|
|
@ -221,7 +221,7 @@ struct EntryPointMetadata {
|
||||||
|
|
||||||
// Match tint::inspector::Override::Type
|
// Match tint::inspector::Override::Type
|
||||||
// Bool is defined as a macro on linux X11 and cannot compile
|
// Bool is defined as a macro on linux X11 and cannot compile
|
||||||
enum class Type { Boolean, Float32, Uint32, Int32 } type;
|
enum class Type { Boolean, Float32, Uint32, Int32, Float16 } type;
|
||||||
|
|
||||||
// If the constant doesn't not have an initializer in the shader
|
// If the constant doesn't not have an initializer in the shader
|
||||||
// Then it is required for the pipeline stage to have a constant record to initialize a
|
// Then it is required for the pipeline stage to have a constant record to initialize a
|
||||||
|
|
|
@ -21,19 +21,46 @@
|
||||||
|
|
||||||
class ComputePipelineOverridableConstantsValidationTest : public ValidationTest {
|
class ComputePipelineOverridableConstantsValidationTest : public ValidationTest {
|
||||||
protected:
|
protected:
|
||||||
|
WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override {
|
||||||
|
std::vector<const char*> enabledToggles;
|
||||||
|
std::vector<const char*> disabledToggles;
|
||||||
|
|
||||||
|
disabledToggles.push_back("disallow_unsafe_apis");
|
||||||
|
|
||||||
|
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
||||||
|
deviceTogglesDesc.enabledToggles = enabledToggles.data();
|
||||||
|
deviceTogglesDesc.enabledTogglesCount = enabledToggles.size();
|
||||||
|
deviceTogglesDesc.disabledToggles = disabledToggles.data();
|
||||||
|
deviceTogglesDesc.disabledTogglesCount = disabledToggles.size();
|
||||||
|
|
||||||
|
const wgpu::FeatureName requiredFeatures[] = {wgpu::FeatureName::ShaderF16};
|
||||||
|
|
||||||
|
wgpu::DeviceDescriptor deviceDescriptor;
|
||||||
|
deviceDescriptor.nextInChain = &deviceTogglesDesc;
|
||||||
|
deviceDescriptor.requiredFeatures = requiredFeatures;
|
||||||
|
deviceDescriptor.requiredFeaturesCount = 1;
|
||||||
|
|
||||||
|
return dawnAdapter.CreateDevice(&deviceDescriptor);
|
||||||
|
}
|
||||||
|
|
||||||
void SetUpShadersWithDefaultValueConstants() {
|
void SetUpShadersWithDefaultValueConstants() {
|
||||||
computeModule = utils::CreateShaderModule(device, R"(
|
computeModule = utils::CreateShaderModule(device, R"(
|
||||||
override c0: bool = true; // type: bool
|
enable f16;
|
||||||
override c1: bool = false; // default override
|
|
||||||
override c2: f32 = 0.0; // type: float32
|
override c0: bool = true; // type: bool
|
||||||
override c3: f32 = 0.0; // default override
|
override c1: bool = false; // default override
|
||||||
override c4: f32 = 4.0; // default
|
override c2: f32 = 0.0; // type: float32
|
||||||
override c5: i32 = 0; // type: int32
|
override c3: f32 = 0.0; // default override
|
||||||
override c6: i32 = 0; // default override
|
override c4: f32 = 4.0; // default
|
||||||
override c7: i32 = 7; // default
|
override c5: i32 = 0; // type: int32
|
||||||
override c8: u32 = 0u; // type: uint32
|
override c6: i32 = 0; // default override
|
||||||
override c9: u32 = 0u; // default override
|
override c7: i32 = 7; // default
|
||||||
@id(1000) override c10: u32 = 10u; // default
|
override c8: u32 = 0u; // type: uint32
|
||||||
|
override c9: u32 = 0u; // default override
|
||||||
|
override c10: u32 = 10u; // default
|
||||||
|
override c11: f16 = 0.0h; // type: float16
|
||||||
|
override c12: f16 = 0.0h; // default override
|
||||||
|
@id(1000) override c13: f16 = 4.0h; // default
|
||||||
|
|
||||||
@compute @workgroup_size(1) fn main() {
|
@compute @workgroup_size(1) fn main() {
|
||||||
// make sure the overridable constants are not optimized out
|
// make sure the overridable constants are not optimized out
|
||||||
|
@ -48,22 +75,30 @@ override c9: u32 = 0u; // default override
|
||||||
_ = u32(c8);
|
_ = u32(c8);
|
||||||
_ = u32(c9);
|
_ = u32(c9);
|
||||||
_ = u32(c10);
|
_ = u32(c10);
|
||||||
|
_ = u32(c11);
|
||||||
|
_ = u32(c12);
|
||||||
|
_ = u32(c13);
|
||||||
})");
|
})");
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetUpShadersWithUninitializedConstants() {
|
void SetUpShadersWithUninitializedConstants() {
|
||||||
computeModule = utils::CreateShaderModule(device, R"(
|
computeModule = utils::CreateShaderModule(device, R"(
|
||||||
override c0: bool; // type: bool
|
enable f16;
|
||||||
override c1: bool = false; // default override
|
|
||||||
override c2: f32; // type: float32
|
override c0: bool; // type: bool
|
||||||
override c3: f32 = 0.0; // default override
|
override c1: bool = false; // default override
|
||||||
override c4: f32 = 4.0; // default
|
override c2: f32; // type: float32
|
||||||
override c5: i32; // type: int32
|
override c3: f32 = 0.0; // default override
|
||||||
override c6: i32 = 0; // default override
|
override c4: f32 = 4.0; // default
|
||||||
override c7: i32 = 7; // default
|
override c5: i32; // type: int32
|
||||||
override c8: u32; // type: uint32
|
override c6: i32 = 0; // default override
|
||||||
override c9: u32 = 0u; // default override
|
override c7: i32 = 7; // default
|
||||||
@id(1000) override c10: u32 = 10u; // default
|
override c8: u32; // type: uint32
|
||||||
|
override c9: u32 = 0u; // default override
|
||||||
|
override c10: u32 = 10u; // default
|
||||||
|
override c11: f16; // type: float16
|
||||||
|
override c12: f16 = 0.0h; // default override
|
||||||
|
@id(1000) override c13: f16 = 4.0h; // default
|
||||||
|
|
||||||
@compute @workgroup_size(1) fn main() {
|
@compute @workgroup_size(1) fn main() {
|
||||||
// make sure the overridable constants are not optimized out
|
// make sure the overridable constants are not optimized out
|
||||||
|
@ -78,6 +113,9 @@ override c9: u32 = 0u; // default override
|
||||||
_ = u32(c8);
|
_ = u32(c8);
|
||||||
_ = u32(c9);
|
_ = u32(c9);
|
||||||
_ = u32(c10);
|
_ = u32(c10);
|
||||||
|
_ = u32(c11);
|
||||||
|
_ = u32(c12);
|
||||||
|
_ = u32(c13);
|
||||||
})");
|
})");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,6 +132,11 @@ override c9: u32 = 0u; // default override
|
||||||
wgpu::Buffer buffer;
|
wgpu::Buffer buffer;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Basic constants lookup tests
|
||||||
|
TEST_F(ComputePipelineOverridableConstantsValidationTest, CreateShaderWithOverride) {
|
||||||
|
SetUpShadersWithUninitializedConstants();
|
||||||
|
}
|
||||||
|
|
||||||
// Basic constants lookup tests
|
// Basic constants lookup tests
|
||||||
TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierLookUp) {
|
TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierLookUp) {
|
||||||
SetUpShadersWithDefaultValueConstants();
|
SetUpShadersWithDefaultValueConstants();
|
||||||
|
@ -122,7 +165,7 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierLoo
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Error: c10 already has a constant numeric id specified
|
// Error: c10 already has a constant numeric id specified
|
||||||
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c10", 0}};
|
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c13", 0}};
|
||||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
|
@ -152,24 +195,23 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, UninitializedConstants
|
||||||
{nullptr, "c2", 1},
|
{nullptr, "c2", 1},
|
||||||
// c5 is missing
|
// c5 is missing
|
||||||
{nullptr, "c8", 1},
|
{nullptr, "c8", 1},
|
||||||
|
{nullptr, "c11", 1},
|
||||||
};
|
};
|
||||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Valid: all constants initialized
|
// Valid: all constants initialized
|
||||||
std::vector<wgpu::ConstantEntry> constants{
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
{nullptr, "c0", false},
|
{nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1},
|
||||||
{nullptr, "c2", 1},
|
{nullptr, "c8", 1}, {nullptr, "c11", 1},
|
||||||
{nullptr, "c5", 1},
|
|
||||||
{nullptr, "c8", 1},
|
|
||||||
};
|
};
|
||||||
TestCreatePipeline(constants);
|
TestCreatePipeline(constants);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Error: duplicate initializations
|
// Error: duplicate initializations
|
||||||
std::vector<wgpu::ConstantEntry> constants{
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
{nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1},
|
{nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1},
|
||||||
{nullptr, "c8", 1}, {nullptr, "c2", 2},
|
{nullptr, "c8", 1}, {nullptr, "c11", 1}, {nullptr, "c2", 2},
|
||||||
};
|
};
|
||||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||||
}
|
}
|
||||||
|
@ -214,7 +256,7 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierUni
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Error: constant with numeric id cannot be referenced with variable name
|
// Error: constant with numeric id cannot be referenced with variable name
|
||||||
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c10", 0}};
|
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c13", 0}};
|
||||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -267,6 +309,34 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, OutofRangeValue) {
|
||||||
{nullptr, "c3", std::numeric_limits<double>::max()}};
|
{nullptr, "c3", std::numeric_limits<double>::max()}};
|
||||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
// Valid: max f32 representable value
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "c3", std::numeric_limits<float>::max()}};
|
||||||
|
TestCreatePipeline(constants);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: one ULP higher than max f32 representable value
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "c3",
|
||||||
|
std::nextafter<double>(std::numeric_limits<float>::max(),
|
||||||
|
std::numeric_limits<double>::max())}};
|
||||||
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Valid: lowest f32 representable value
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "c3", std::numeric_limits<float>::lowest()}};
|
||||||
|
TestCreatePipeline(constants);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: one ULP lower than lowest f32 representable value
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "c3",
|
||||||
|
std::nextafter<double>(std::numeric_limits<float>::lowest(),
|
||||||
|
std::numeric_limits<double>::lowest())}};
|
||||||
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||||
|
}
|
||||||
{
|
{
|
||||||
// Error: i32 out of range
|
// Error: i32 out of range
|
||||||
std::vector<wgpu::ConstantEntry> constants{
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
@ -291,4 +361,27 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, OutofRangeValue) {
|
||||||
{nullptr, "c0", static_cast<double>(std::numeric_limits<int32_t>::max()) + 1.0}};
|
{nullptr, "c0", static_cast<double>(std::numeric_limits<int32_t>::max()) + 1.0}};
|
||||||
TestCreatePipeline(constants);
|
TestCreatePipeline(constants);
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
// Valid: max f16 representable value
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c11", 65504.0}};
|
||||||
|
TestCreatePipeline(constants);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: one ULP higher than max f16 representable value
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "c11", std::nextafter<double>(65504.0, std::numeric_limits<double>::max())}};
|
||||||
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Valid: lowest f16 representable value
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c11", -65504.0}};
|
||||||
|
TestCreatePipeline(constants);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Error: one ULP lower than lowest f16 representable value
|
||||||
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
|
{nullptr, "c11",
|
||||||
|
std::nextafter<double>(-65504.0, std::numeric_limits<double>::lowest())}};
|
||||||
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -505,6 +505,8 @@ std::string OverrideTypeToString(tint::inspector::Override::Type type) {
|
||||||
return "bool";
|
return "bool";
|
||||||
case tint::inspector::Override::Type::kFloat32:
|
case tint::inspector::Override::Type::kFloat32:
|
||||||
return "f32";
|
return "f32";
|
||||||
|
case tint::inspector::Override::Type::kFloat16:
|
||||||
|
return "f16";
|
||||||
case tint::inspector::Override::Type::kUint32:
|
case tint::inspector::Override::Type::kUint32:
|
||||||
return "u32";
|
return "u32";
|
||||||
case tint::inspector::Override::Type::kInt32:
|
case tint::inspector::Override::Type::kInt32:
|
||||||
|
|
|
@ -92,6 +92,7 @@ struct Override {
|
||||||
kFloat32,
|
kFloat32,
|
||||||
kUint32,
|
kUint32,
|
||||||
kInt32,
|
kInt32,
|
||||||
|
kFloat16,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Type of the scalar
|
/// Type of the scalar
|
||||||
|
|
|
@ -204,7 +204,11 @@ EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) {
|
||||||
if (type->is_bool_scalar_or_vector()) {
|
if (type->is_bool_scalar_or_vector()) {
|
||||||
override.type = Override::Type::kBool;
|
override.type = Override::Type::kBool;
|
||||||
} else if (type->is_float_scalar()) {
|
} else if (type->is_float_scalar()) {
|
||||||
override.type = Override::Type::kFloat32;
|
if (type->Is<type::F16>()) {
|
||||||
|
override.type = Override::Type::kFloat16;
|
||||||
|
} else {
|
||||||
|
override.type = Override::Type::kFloat32;
|
||||||
|
}
|
||||||
} else if (type->is_signed_integer_scalar()) {
|
} else if (type->is_signed_integer_scalar()) {
|
||||||
override.type = Override::Type::kInt32;
|
override.type = Override::Type::kInt32;
|
||||||
} else if (type->is_unsigned_integer_scalar()) {
|
} else if (type->is_unsigned_integer_scalar()) {
|
||||||
|
@ -270,6 +274,10 @@ std::map<OverrideId, Scalar> Inspector::GetOverrideDefaultValues() {
|
||||||
[&](const type::I32*) { return Scalar(value->ValueAs<i32>()); },
|
[&](const type::I32*) { return Scalar(value->ValueAs<i32>()); },
|
||||||
[&](const type::U32*) { return Scalar(value->ValueAs<u32>()); },
|
[&](const type::U32*) { return Scalar(value->ValueAs<u32>()); },
|
||||||
[&](const type::F32*) { return Scalar(value->ValueAs<f32>()); },
|
[&](const type::F32*) { return Scalar(value->ValueAs<f32>()); },
|
||||||
|
[&](const type::F16*) {
|
||||||
|
// Default value of f16 override is also stored as float scalar.
|
||||||
|
return Scalar(static_cast<float>(value->ValueAs<f16>()));
|
||||||
|
},
|
||||||
[&](const type::Bool*) { return Scalar(value->ValueAs<bool>()); });
|
[&](const type::Bool*) { return Scalar(value->ValueAs<bool>()); });
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -908,18 +908,23 @@ TEST_F(InspectorGetEntryPointTest, OverrideReferencedByArraySizeViaAlias) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(InspectorGetEntryPointTest, OverrideTypes) {
|
TEST_F(InspectorGetEntryPointTest, OverrideTypes) {
|
||||||
|
Enable(builtin::Extension::kF16);
|
||||||
|
|
||||||
Override("bool_var", ty.bool_());
|
Override("bool_var", ty.bool_());
|
||||||
Override("float_var", ty.f32());
|
Override("float_var", ty.f32());
|
||||||
Override("u32_var", ty.u32());
|
Override("u32_var", ty.u32());
|
||||||
Override("i32_var", ty.i32());
|
Override("i32_var", ty.i32());
|
||||||
|
Override("f16_var", ty.f16());
|
||||||
|
|
||||||
MakePlainGlobalReferenceBodyFunction("bool_func", "bool_var", ty.bool_(), utils::Empty);
|
MakePlainGlobalReferenceBodyFunction("bool_func", "bool_var", ty.bool_(), utils::Empty);
|
||||||
MakePlainGlobalReferenceBodyFunction("float_func", "float_var", ty.f32(), utils::Empty);
|
MakePlainGlobalReferenceBodyFunction("float_func", "float_var", ty.f32(), utils::Empty);
|
||||||
MakePlainGlobalReferenceBodyFunction("u32_func", "u32_var", ty.u32(), utils::Empty);
|
MakePlainGlobalReferenceBodyFunction("u32_func", "u32_var", ty.u32(), utils::Empty);
|
||||||
MakePlainGlobalReferenceBodyFunction("i32_func", "i32_var", ty.i32(), utils::Empty);
|
MakePlainGlobalReferenceBodyFunction("i32_func", "i32_var", ty.i32(), utils::Empty);
|
||||||
|
MakePlainGlobalReferenceBodyFunction("f16_func", "f16_var", ty.f16(), utils::Empty);
|
||||||
|
|
||||||
MakeCallerBodyFunction(
|
MakeCallerBodyFunction(
|
||||||
"ep_func", utils::Vector{std::string("bool_func"), "float_func", "u32_func", "i32_func"},
|
"ep_func",
|
||||||
|
utils::Vector{std::string("bool_func"), "float_func", "u32_func", "i32_func", "f16_func"},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
Stage(ast::PipelineStage::kCompute),
|
Stage(ast::PipelineStage::kCompute),
|
||||||
WorkgroupSize(1_i),
|
WorkgroupSize(1_i),
|
||||||
|
@ -930,7 +935,7 @@ TEST_F(InspectorGetEntryPointTest, OverrideTypes) {
|
||||||
auto result = inspector.GetEntryPoints();
|
auto result = inspector.GetEntryPoints();
|
||||||
|
|
||||||
ASSERT_EQ(1u, result.size());
|
ASSERT_EQ(1u, result.size());
|
||||||
ASSERT_EQ(4u, result[0].overrides.size());
|
ASSERT_EQ(5u, result[0].overrides.size());
|
||||||
EXPECT_EQ("bool_var", result[0].overrides[0].name);
|
EXPECT_EQ("bool_var", result[0].overrides[0].name);
|
||||||
EXPECT_EQ(inspector::Override::Type::kBool, result[0].overrides[0].type);
|
EXPECT_EQ(inspector::Override::Type::kBool, result[0].overrides[0].type);
|
||||||
EXPECT_EQ("float_var", result[0].overrides[1].name);
|
EXPECT_EQ("float_var", result[0].overrides[1].name);
|
||||||
|
@ -939,6 +944,8 @@ TEST_F(InspectorGetEntryPointTest, OverrideTypes) {
|
||||||
EXPECT_EQ(inspector::Override::Type::kUint32, result[0].overrides[2].type);
|
EXPECT_EQ(inspector::Override::Type::kUint32, result[0].overrides[2].type);
|
||||||
EXPECT_EQ("i32_var", result[0].overrides[3].name);
|
EXPECT_EQ("i32_var", result[0].overrides[3].name);
|
||||||
EXPECT_EQ(inspector::Override::Type::kInt32, result[0].overrides[3].type);
|
EXPECT_EQ(inspector::Override::Type::kInt32, result[0].overrides[3].type);
|
||||||
|
EXPECT_EQ("f16_var", result[0].overrides[4].name);
|
||||||
|
EXPECT_EQ(inspector::Override::Type::kFloat16, result[0].overrides[4].type);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(InspectorGetEntryPointTest, OverrideInitialized) {
|
TEST_F(InspectorGetEntryPointTest, OverrideInitialized) {
|
||||||
|
@ -1572,7 +1579,7 @@ TEST_F(InspectorGetOverrideDefaultValuesTest, I32) {
|
||||||
EXPECT_EQ(100, result[OverrideId{6000}].AsI32());
|
EXPECT_EQ(100, result[OverrideId{6000}].AsI32());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(InspectorGetOverrideDefaultValuesTest, Float) {
|
TEST_F(InspectorGetOverrideDefaultValuesTest, F32) {
|
||||||
Override("a", ty.f32(), Id(1_a));
|
Override("a", ty.f32(), Id(1_a));
|
||||||
Override("b", ty.f32(), Expr(0_f), Id(20_a));
|
Override("b", ty.f32(), Expr(0_f), Id(20_a));
|
||||||
Override("c", ty.f32(), Expr(-10_f), Id(300_a));
|
Override("c", ty.f32(), Expr(-10_f), Id(300_a));
|
||||||
|
@ -1609,6 +1616,46 @@ TEST_F(InspectorGetOverrideDefaultValuesTest, Float) {
|
||||||
EXPECT_FLOAT_EQ(150.0f, result[OverrideId{6000}].AsFloat());
|
EXPECT_FLOAT_EQ(150.0f, result[OverrideId{6000}].AsFloat());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(InspectorGetOverrideDefaultValuesTest, F16) {
|
||||||
|
Enable(builtin::Extension::kF16);
|
||||||
|
|
||||||
|
Override("a", ty.f16(), Id(1_a));
|
||||||
|
Override("b", ty.f16(), Expr(0_h), Id(20_a));
|
||||||
|
Override("c", ty.f16(), Expr(-10_h), Id(300_a));
|
||||||
|
Override("d", Expr(15_h), Id(4000_a));
|
||||||
|
Override("3", Expr(42.0_h), Id(5000_a));
|
||||||
|
Override("e", ty.f16(), Mul(15_h, 10_a), Id(6000_a));
|
||||||
|
|
||||||
|
Inspector& inspector = Build();
|
||||||
|
|
||||||
|
auto result = inspector.GetOverrideDefaultValues();
|
||||||
|
ASSERT_EQ(6u, result.size());
|
||||||
|
|
||||||
|
ASSERT_TRUE(result.find(OverrideId{1}) != result.end());
|
||||||
|
EXPECT_TRUE(result[OverrideId{1}].IsNull());
|
||||||
|
|
||||||
|
ASSERT_TRUE(result.find(OverrideId{20}) != result.end());
|
||||||
|
// Default value of f16 override is also stored as float scalar.
|
||||||
|
EXPECT_TRUE(result[OverrideId{20}].IsFloat());
|
||||||
|
EXPECT_FLOAT_EQ(0.0f, result[OverrideId{20}].AsFloat());
|
||||||
|
|
||||||
|
ASSERT_TRUE(result.find(OverrideId{300}) != result.end());
|
||||||
|
EXPECT_TRUE(result[OverrideId{300}].IsFloat());
|
||||||
|
EXPECT_FLOAT_EQ(-10.0f, result[OverrideId{300}].AsFloat());
|
||||||
|
|
||||||
|
ASSERT_TRUE(result.find(OverrideId{4000}) != result.end());
|
||||||
|
EXPECT_TRUE(result[OverrideId{4000}].IsFloat());
|
||||||
|
EXPECT_FLOAT_EQ(15.0f, result[OverrideId{4000}].AsFloat());
|
||||||
|
|
||||||
|
ASSERT_TRUE(result.find(OverrideId{5000}) != result.end());
|
||||||
|
EXPECT_TRUE(result[OverrideId{5000}].IsFloat());
|
||||||
|
EXPECT_FLOAT_EQ(42.0f, result[OverrideId{5000}].AsFloat());
|
||||||
|
|
||||||
|
ASSERT_TRUE(result.find(OverrideId{6000}) != result.end());
|
||||||
|
EXPECT_TRUE(result[OverrideId{6000}].IsFloat());
|
||||||
|
EXPECT_FLOAT_EQ(150.0f, result[OverrideId{6000}].AsFloat());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) {
|
TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) {
|
||||||
Override("v1", ty.f32(), Id(1_a));
|
Override("v1", ty.f32(), Id(1_a));
|
||||||
Override("v20", ty.f32(), Id(20_a));
|
Override("v20", ty.f32(), Id(20_a));
|
||||||
|
|
|
@ -66,10 +66,12 @@ TEST_F(ResolverOverrideTest, WithoutId) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverOverrideTest, WithAndWithoutIds) {
|
TEST_F(ResolverOverrideTest, WithAndWithoutIds) {
|
||||||
|
Enable(builtin::Extension::kF16);
|
||||||
|
|
||||||
auto* a = Override("a", ty.f32(), Expr(1_f));
|
auto* a = Override("a", ty.f32(), Expr(1_f));
|
||||||
auto* b = Override("b", ty.f32(), Expr(1_f));
|
auto* b = Override("b", ty.f16(), Expr(1_h));
|
||||||
auto* c = Override("c", ty.f32(), Expr(1_f), Id(2_u));
|
auto* c = Override("c", ty.i32(), Expr(1_i), Id(2_u));
|
||||||
auto* d = Override("d", ty.f32(), Expr(1_f), Id(4_u));
|
auto* d = Override("d", ty.u32(), Expr(1_u), Id(4_u));
|
||||||
auto* e = Override("e", ty.f32(), Expr(1_f));
|
auto* e = Override("e", ty.f32(), Expr(1_f));
|
||||||
auto* f = Override("f", ty.f32(), Expr(1_f), Id(1_u));
|
auto* f = Override("f", ty.f32(), Expr(1_f), Id(1_u));
|
||||||
|
|
||||||
|
@ -102,16 +104,6 @@ TEST_F(ResolverOverrideTest, IdTooLarge) {
|
||||||
EXPECT_EQ(r()->error(), "12:34 error: @id value must be between 0 and 65535");
|
EXPECT_EQ(r()->error(), "12:34 error: @id value must be between 0 and 65535");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverOverrideTest, F16_TemporallyBan) {
|
|
||||||
Enable(builtin::Extension::kF16);
|
|
||||||
|
|
||||||
Override(Source{{12, 34}}, "a", ty.f16(), Expr(1_h), Id(1_u));
|
|
||||||
|
|
||||||
EXPECT_FALSE(r()->Resolve());
|
|
||||||
|
|
||||||
EXPECT_EQ(r()->error(), "12:34 error: 'override' of type f16 is not implemented yet");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ResolverOverrideTest, TransitiveReferences_DirectUse) {
|
TEST_F(ResolverOverrideTest, TransitiveReferences_DirectUse) {
|
||||||
auto* a = Override("a", ty.f32());
|
auto* a = Override("a", ty.f32());
|
||||||
auto* b = Override("b", ty.f32(), Expr(1_f));
|
auto* b = Override("b", ty.f32(), Expr(1_f));
|
||||||
|
|
|
@ -788,11 +788,6 @@ bool Validator::Override(
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (storage_ty->Is<type::F16>()) {
|
|
||||||
AddError("'override' of type f16 is not implemented yet", decl->source);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -84,15 +84,16 @@ fn main() -> @builtin(position) vec4<f32> {
|
||||||
|
|
||||||
TEST_F(SubstituteOverrideTest, ImplicitId) {
|
TEST_F(SubstituteOverrideTest, ImplicitId) {
|
||||||
auto* src = R"(
|
auto* src = R"(
|
||||||
|
enable f16;
|
||||||
|
|
||||||
override i_width: i32;
|
override i_width: i32;
|
||||||
override i_height = 1i;
|
override i_height = 1i;
|
||||||
|
|
||||||
override f_width: f32;
|
override f_width: f32;
|
||||||
override f_height = 1.f;
|
override f_height = 1.f;
|
||||||
|
|
||||||
// TODO(crbug.com/tint/1473)
|
override h_width: f16;
|
||||||
// override h_width: f16;
|
override h_height = 1.h;
|
||||||
// override h_height = 1.h;
|
|
||||||
|
|
||||||
override b_width: bool;
|
override b_width: bool;
|
||||||
override b_height = true;
|
override b_height = true;
|
||||||
|
@ -106,6 +107,8 @@ fn main() -> @builtin(position) vec4<f32> {
|
||||||
)";
|
)";
|
||||||
|
|
||||||
auto* expect = R"(
|
auto* expect = R"(
|
||||||
|
enable f16;
|
||||||
|
|
||||||
const i_width : i32 = 42i;
|
const i_width : i32 = 42i;
|
||||||
|
|
||||||
const i_height = 11i;
|
const i_height = 11i;
|
||||||
|
@ -114,6 +117,10 @@ const f_width : f32 = 22.299999237060546875f;
|
||||||
|
|
||||||
const f_height = 12.3999996185302734375f;
|
const f_height = 12.3999996185302734375f;
|
||||||
|
|
||||||
|
const h_width : f16 = 9.3984375h;
|
||||||
|
|
||||||
|
const h_height = 3.3984375h;
|
||||||
|
|
||||||
const b_width : bool = true;
|
const b_width : bool = true;
|
||||||
|
|
||||||
const b_height = false;
|
const b_height = false;
|
||||||
|
@ -131,10 +138,10 @@ fn main() -> @builtin(position) vec4<f32> {
|
||||||
cfg.map.insert({OverrideId{1}, 11.0});
|
cfg.map.insert({OverrideId{1}, 11.0});
|
||||||
cfg.map.insert({OverrideId{2}, 22.3});
|
cfg.map.insert({OverrideId{2}, 22.3});
|
||||||
cfg.map.insert({OverrideId{3}, 12.4});
|
cfg.map.insert({OverrideId{3}, 12.4});
|
||||||
// cfg.map.insert({OverrideId{4}, 9.4});
|
cfg.map.insert({OverrideId{4}, 9.4});
|
||||||
// cfg.map.insert({OverrideId{5}, 3.4});
|
cfg.map.insert({OverrideId{5}, 3.4});
|
||||||
cfg.map.insert({OverrideId{4}, 1.0});
|
cfg.map.insert({OverrideId{6}, 1.0});
|
||||||
cfg.map.insert({OverrideId{5}, 0.0});
|
cfg.map.insert({OverrideId{7}, 0.0});
|
||||||
|
|
||||||
DataMap data;
|
DataMap data;
|
||||||
data.Add<SubstituteOverride::Config>(cfg);
|
data.Add<SubstituteOverride::Config>(cfg);
|
||||||
|
@ -153,9 +160,8 @@ enable f16;
|
||||||
@id(1) override f_width: f32;
|
@id(1) override f_width: f32;
|
||||||
@id(9) override f_height = 1.f;
|
@id(9) override f_height = 1.f;
|
||||||
|
|
||||||
// TODO(crbug.com/tint/1473)
|
@id(2) override h_width: f16;
|
||||||
// @id(2) override h_width: f16;
|
@id(8) override h_height = 1.h;
|
||||||
// @id(8) override h_height = 1.h;
|
|
||||||
|
|
||||||
@id(3) override b_width: bool;
|
@id(3) override b_width: bool;
|
||||||
@id(7) override b_height = true;
|
@id(7) override b_height = true;
|
||||||
|
@ -179,6 +185,10 @@ const f_width : f32 = 22.299999237060546875f;
|
||||||
|
|
||||||
const f_height = 12.3999996185302734375f;
|
const f_height = 12.3999996185302734375f;
|
||||||
|
|
||||||
|
const h_width : f16 = 9.3984375h;
|
||||||
|
|
||||||
|
const h_height = 3.3984375h;
|
||||||
|
|
||||||
const b_width : bool = true;
|
const b_width : bool = true;
|
||||||
|
|
||||||
const b_height = false;
|
const b_height = false;
|
||||||
|
|
Loading…
Reference in New Issue