spirv-reader: Set workgroup size, but not specializable

The WorkgroupSize builtin decoration applies to a composite constant.
Because WGSL does not yet support specializable constants for this,
use the *default* values for that SPIR-V spec constant.

Update end-to-end test expectations.

Fixed: tint:503
Change-Id: I012b316d13544ab9282e3276b58906327adab133
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/41960
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: Alan Baker <alanbaker@google.com>
This commit is contained in:
David Neto 2021-06-17 22:40:43 +00:00 committed by Tint LUCI CQ
parent 53829a5e42
commit 17287fcf1a
26 changed files with 617 additions and 39 deletions

View File

@ -25,13 +25,15 @@ EntryPointInfo::EntryPointInfo(std::string the_name,
bool the_owns_inner_implementation, bool the_owns_inner_implementation,
std::string the_inner_name, std::string the_inner_name,
std::vector<uint32_t>&& the_inputs, std::vector<uint32_t>&& the_inputs,
std::vector<uint32_t>&& the_outputs) std::vector<uint32_t>&& the_outputs,
GridSize the_wg_size)
: name(the_name), : name(the_name),
stage(the_stage), stage(the_stage),
owns_inner_implementation(the_owns_inner_implementation), owns_inner_implementation(the_owns_inner_implementation),
inner_name(std::move(the_inner_name)), inner_name(std::move(the_inner_name)),
inputs(std::move(the_inputs)), inputs(std::move(the_inputs)),
outputs(std::move(the_outputs)) {} outputs(std::move(the_outputs)),
workgroup_size(the_wg_size) {}
EntryPointInfo::EntryPointInfo(const EntryPointInfo&) = default; EntryPointInfo::EntryPointInfo(const EntryPointInfo&) = default;

View File

@ -24,6 +24,13 @@ namespace tint {
namespace reader { namespace reader {
namespace spirv { namespace spirv {
/// The size of an integer-coordinate grid, in the x, y, and z dimensions.
struct GridSize {
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
};
/// Entry point information for a function /// Entry point information for a function
struct EntryPointInfo { struct EntryPointInfo {
/// Constructor. /// Constructor.
@ -35,12 +42,14 @@ struct EntryPointInfo {
/// entry point /// entry point
/// @param the_inputs list of IDs for Input variables used by the shader /// @param the_inputs list of IDs for Input variables used by the shader
/// @param the_outputs list of IDs for Output variables used by the shader /// @param the_outputs list of IDs for Output variables used by the shader
/// @param the_wg_size the workgroup_size, for a compute shader
EntryPointInfo(std::string the_name, EntryPointInfo(std::string the_name,
ast::PipelineStage the_stage, ast::PipelineStage the_stage,
bool the_owns_inner_implementation, bool the_owns_inner_implementation,
std::string the_inner_name, std::string the_inner_name,
std::vector<uint32_t>&& the_inputs, std::vector<uint32_t>&& the_inputs,
std::vector<uint32_t>&& the_outputs); std::vector<uint32_t>&& the_outputs,
GridSize the_wg_size);
/// Copy constructor /// Copy constructor
/// @param other the other entry point info to be built from /// @param other the other entry point info to be built from
EntryPointInfo(const EntryPointInfo& other); EntryPointInfo(const EntryPointInfo& other);
@ -55,6 +64,7 @@ struct EntryPointInfo {
std::string name; std::string name;
/// The entry point stage /// The entry point stage
ast::PipelineStage stage = ast::PipelineStage::kNone; ast::PipelineStage stage = ast::PipelineStage::kNone;
/// True when this entry point is responsible for generating the /// True when this entry point is responsible for generating the
/// inner implementation function. False when this is the second entry /// inner implementation function. False when this is the second entry
/// point encountered for the same function in SPIR-V. It's unusual, but /// point encountered for the same function in SPIR-V. It's unusual, but
@ -67,6 +77,12 @@ struct EntryPointInfo {
std::vector<uint32_t> inputs; std::vector<uint32_t> inputs;
/// IDs of pipeline output variables, sorted and without duplicates. /// IDs of pipeline output variables, sorted and without duplicates.
std::vector<uint32_t> outputs; std::vector<uint32_t> outputs;
/// If this is a compute shader, this is the workgroup size in the x, y,
/// and z dimensions set via LocalSize, or via the composite value
/// decorated as the WorkgroupSize BuiltIn. The WorkgroupSize builtin
/// takes priority.
GridSize workgroup_size;
}; };
} // namespace spirv } // namespace spirv

View File

@ -1131,6 +1131,19 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() {
ast::DecorationList fn_decos; ast::DecorationList fn_decos;
fn_decos.emplace_back(create<ast::StageDecoration>(source, ep_info_->stage)); fn_decos.emplace_back(create<ast::StageDecoration>(source, ep_info_->stage));
if (ep_info_->stage == ast::PipelineStage::kCompute) {
auto& size = ep_info_->workgroup_size;
if (size.x != 0 && size.y != 0 && size.z != 0) {
ast::Expression* x = builder_.Expr(static_cast<int>(size.x));
ast::Expression* y =
size.y ? builder_.Expr(static_cast<int>(size.y)) : nullptr;
ast::Expression* z =
size.z ? builder_.Expr(static_cast<int>(size.z)) : nullptr;
fn_decos.emplace_back(
create<ast::WorkgroupDecoration>(Source{}, x, y, z));
}
}
builder_.AST().AddFunction( builder_.AST().AddFunction(
create<ast::Function>(source, builder_.Symbols().Register(ep_info_->name), create<ast::Function>(source, builder_.Symbols().Register(ep_info_->name),
std::move(decl.params), return_type, body, std::move(decl.params), return_type, body,
@ -2261,6 +2274,9 @@ bool FunctionEmitter::EmitFunctionVariables() {
constructor = constructor =
parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1)) parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1))
.expr; .expr;
if (!constructor) {
return false;
}
} }
auto* var = parser_impl_.MakeVariable( auto* var = parser_impl_.MakeVariable(
inst.result_id(), ast::StorageClass::kNone, var_store_type, false, inst.result_id(), ast::StorageClass::kNone, var_store_type, false,

View File

@ -1462,7 +1462,7 @@ TEST_F(SpvParserTest_VectorExtractDynamic, UnsignedIndex) {
using SpvParserTest_VectorInsertDynamic = SpvParserTest; using SpvParserTest_VectorInsertDynamic = SpvParserTest;
TEST_F(SpvParserTest_VectorExtractDynamic, Sample) { TEST_F(SpvParserTest_VectorInsertDynamic, Sample) {
const auto assembly = Preamble() + R"( const auto assembly = Preamble() + R"(
%100 = OpFunction %void None %voidfn %100 = OpFunction %void None %voidfn
%entry = OpLabel %entry = OpLabel
@ -1512,6 +1512,125 @@ VariableDeclStatement{
<< assembly; << assembly;
} }
TEST_F(SpvParserTest, DISABLED_WorkgroupSize_Overridable) {
// TODO(dneto): Support specializable workgroup size. crbug.com/tint/504
const auto* assembly = R"(
OpCapability Shader
OpMemoryModel Logical Simple
OpEntryPoint GLCompute %100 "main"
OpDecorate %1 BuiltIn WorkgroupSize
OpDecorate %uint_2 SpecId 0
OpDecorate %uint_4 SpecId 1
OpDecorate %uint_8 SpecId 2
%uint = OpTypeInt 32 0
%uint_2 = OpSpecConstant %uint 2
%uint_4 = OpSpecConstant %uint 4
%uint_8 = OpSpecConstant %uint 8
%v3uint = OpTypeVector %uint 3
%1 = OpSpecConstantComposite %v3uint %uint_2 %uint_4 %uint_8
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%10 = OpCopyObject %v3uint %1
%11 = OpCopyObject %uint %uint_2
%12 = OpCopyObject %uint %uint_4
%13 = OpCopyObject %uint %uint_8
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.Emit()) << p->error();
const auto got = p->program().to_str();
EXPECT_THAT(got, HasSubstr(R"(
VariableConst{
Decorations{
OverrideDecoration{0}
}
x_2
none
__u32
{
ScalarConstructor[not set]{2}
}
}
VariableConst{
Decorations{
OverrideDecoration{1}
}
x_3
none
__u32
{
ScalarConstructor[not set]{4}
}
}
VariableConst{
Decorations{
OverrideDecoration{2}
}
x_4
none
__u32
{
ScalarConstructor[not set]{8}
}
}
)")) << got;
EXPECT_THAT(got, HasSubstr(R"(
VariableDeclStatement{
VariableConst{
x_10
none
__vec_3__u32
{
TypeConstructor[not set]{
__vec_3__u32
ScalarConstructor[not set]{2}
ScalarConstructor[not set]{4}
ScalarConstructor[not set]{8}
}
}
}
}
VariableDeclStatement{
VariableConst{
x_11
none
__u32
{
Identifier[not set]{x_2}
}
}
}
VariableDeclStatement{
VariableConst{
x_12
none
__u32
{
Identifier[not set]{x_3}
}
}
}
VariableDeclStatement{
VariableConst{
x_13
none
__u32
{
Identifier[not set]{x_4}
}
}
})"))
<< got << assembly;
}
} // namespace } // namespace
} // namespace spirv } // namespace spirv
} // namespace reader } // namespace reader

View File

@ -580,6 +580,9 @@ bool ParserImpl::ParseInternalModuleExceptFunctions() {
if (!RegisterUserAndStructMemberNames()) { if (!RegisterUserAndStructMemberNames()) {
return false; return false;
} }
if (!RegisterWorkgroupSizeBuiltin()) {
return false;
}
if (!RegisterEntryPoints()) { if (!RegisterEntryPoints()) {
return false; return false;
} }
@ -720,13 +723,102 @@ bool ParserImpl::IsValidIdentifier(const std::string& str) {
return true; return true;
} }
bool ParserImpl::RegisterWorkgroupSizeBuiltin() {
WorkgroupSizeInfo& info = workgroup_size_builtin_;
for (const spvtools::opt::Instruction& inst : module_->annotations()) {
if (inst.opcode() != SpvOpDecorate) {
continue;
}
if (inst.GetSingleWordInOperand(1) != SpvDecorationBuiltIn) {
continue;
}
if (inst.GetSingleWordInOperand(2) != SpvBuiltInWorkgroupSize) {
continue;
}
info.id = inst.GetSingleWordInOperand(0);
}
if (info.id == 0) {
return true;
}
// Gather the values.
const spvtools::opt::Instruction* composite_def =
def_use_mgr_->GetDef(info.id);
if (!composite_def) {
return Fail() << "Invalid WorkgroupSize builtin value";
}
// SPIR-V validation checks that the result is a 3-element vector of 32-bit
// integer scalars (signed or unsigned). Rely on validation to check the
// type. In theory the instruction could be OpConstantNull and still
// pass validation, but that would be non-sensical. Be a little more
// stringent here and check for specific opcodes. WGSL does not support
// const-expr yet, so avoid supporting OpSpecConstantOp here.
// TODO(dneto): See https://github.com/gpuweb/gpuweb/issues/1272 for WGSL
// const_expr proposals.
if ((composite_def->opcode() != SpvOpSpecConstantComposite &&
composite_def->opcode() != SpvOpConstantComposite)) {
return Fail() << "Invalid WorkgroupSize builtin. Expected 3-element "
"OpSpecConstantComposite or OpConstantComposite: "
<< composite_def->PrettyPrint();
}
info.type_id = composite_def->type_id();
// Extract the component type from the vector type.
info.component_type_id =
def_use_mgr_->GetDef(info.type_id)->GetSingleWordInOperand(0);
/// Sets the ID and value of the index'th member of the composite constant.
/// Returns false and emits a diagnostic on error.
auto set_param = [this, composite_def](uint32_t* id_ptr, uint32_t* value_ptr,
int index) -> bool {
const auto id = composite_def->GetSingleWordInOperand(index);
const auto* def = def_use_mgr_->GetDef(id);
if (!def ||
(def->opcode() != SpvOpSpecConstant &&
def->opcode() != SpvOpConstant) ||
(def->NumInOperands() != 1)) {
return Fail() << "invalid component " << index << " of workgroupsize "
<< (def ? def->PrettyPrint()
: std::string("no definition"));
}
*id_ptr = id;
// Use the default value of a spec constant.
*value_ptr = def->GetSingleWordInOperand(0);
return true;
};
return set_param(&info.x_id, &info.x_value, 0) &&
set_param(&info.y_id, &info.y_value, 1) &&
set_param(&info.z_id, &info.z_value, 2);
}
bool ParserImpl::RegisterEntryPoints() { bool ParserImpl::RegisterEntryPoints() {
// Mapping from entry point ID to GridSize computed from LocalSize
// decorations.
std::unordered_map<uint32_t, GridSize> local_size;
for (const spvtools::opt::Instruction& inst : module_->execution_modes()) {
auto mode = static_cast<SpvExecutionMode>(inst.GetSingleWordInOperand(1));
if (mode == SpvExecutionModeLocalSize) {
if (inst.NumInOperands() != 5) {
// This won't even get past SPIR-V binary parsing.
return Fail() << "invalid LocalSize execution mode: "
<< inst.PrettyPrint();
}
uint32_t function_id = inst.GetSingleWordInOperand(0);
local_size[function_id] = GridSize{inst.GetSingleWordInOperand(2),
inst.GetSingleWordInOperand(3),
inst.GetSingleWordInOperand(4)};
}
}
for (const spvtools::opt::Instruction& entry_point : for (const spvtools::opt::Instruction& entry_point :
module_->entry_points()) { module_->entry_points()) {
const auto stage = SpvExecutionModel(entry_point.GetSingleWordInOperand(0)); const auto stage = SpvExecutionModel(entry_point.GetSingleWordInOperand(0));
const uint32_t function_id = entry_point.GetSingleWordInOperand(1); const uint32_t function_id = entry_point.GetSingleWordInOperand(1);
const std::string ep_name = entry_point.GetOperand(2).AsString(); const std::string ep_name = entry_point.GetOperand(2).AsString();
if (!IsValidIdentifier(ep_name)) {
return Fail() << "entry point name is not a valid WGSL identifier: "
<< ep_name;
}
bool owns_inner_implementation = false; bool owns_inner_implementation = false;
std::string inner_implementation_name; std::string inner_implementation_name;
@ -769,11 +861,30 @@ bool ParserImpl::RegisterEntryPoints() {
std::vector<uint32_t> sorted_outputs(outputs.begin(), outputs.end()); std::vector<uint32_t> sorted_outputs(outputs.begin(), outputs.end());
std::sort(sorted_inputs.begin(), sorted_inputs.end()); std::sort(sorted_inputs.begin(), sorted_inputs.end());
function_to_ep_info_[function_id].emplace_back( const auto ast_stage = enum_converter_.ToPipelineStage(stage);
ep_name, enum_converter_.ToPipelineStage(stage), GridSize wgsize;
owns_inner_implementation, inner_implementation_name, if (ast_stage == ast::PipelineStage::kCompute) {
std::move(sorted_inputs), std::move(sorted_outputs)); if (workgroup_size_builtin_.id) {
// Store the default values.
// WGSL allows specializing these, but this code doesn't support that
// yet. https://github.com/gpuweb/gpuweb/issues/1442
wgsize = GridSize{workgroup_size_builtin_.x_value,
workgroup_size_builtin_.y_value,
workgroup_size_builtin_.z_value};
} else {
// Use the LocalSize execution mode. This is the second choice.
auto where = local_size.find(function_id);
if (where != local_size.end()) {
wgsize = where->second;
} }
}
}
function_to_ep_info_[function_id].emplace_back(
ep_name, ast_stage, owns_inner_implementation,
inner_implementation_name, std::move(sorted_inputs),
std::move(sorted_outputs), wgsize);
}
// The enum conversion could have failed, so return the existing status value. // The enum conversion could have failed, so return the existing status value.
return success_; return success_;
} }
@ -1521,10 +1632,59 @@ bool ParserImpl::ConvertDecorationsForVariable(uint32_t id,
return success(); return success();
} }
bool ParserImpl::CanMakeConstantExpression(uint32_t id) {
if ((id == workgroup_size_builtin_.id) ||
(id == workgroup_size_builtin_.x_id) ||
(id == workgroup_size_builtin_.y_id) ||
(id == workgroup_size_builtin_.z_id)) {
return true;
}
const auto* inst = def_use_mgr_->GetDef(id);
if (!inst) {
return false;
}
if (inst->opcode() == SpvOpUndef) {
return true;
}
return nullptr != constant_mgr_->FindDeclaredConstant(id);
}
TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
if (!success_) { if (!success_) {
return {}; return {};
} }
// Handle the special cases for workgroup sizing.
if (id == workgroup_size_builtin_.id) {
auto x = MakeConstantExpression(workgroup_size_builtin_.x_id);
auto y = MakeConstantExpression(workgroup_size_builtin_.y_id);
auto z = MakeConstantExpression(workgroup_size_builtin_.z_id);
auto* ast_type = ty_.Vector(x.type, 3);
return {ast_type, create<ast::TypeConstructorExpression>(
Source{}, ast_type->Build(builder_),
ast::ExpressionList{x.expr, y.expr, z.expr})};
} else if (id == workgroup_size_builtin_.x_id) {
return MakeConstantExpressionForSpirvConstant(
Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
constant_mgr_->GetConstant(
type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
{workgroup_size_builtin_.x_value}));
} else if (id == workgroup_size_builtin_.y_id) {
return MakeConstantExpressionForSpirvConstant(
Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
constant_mgr_->GetConstant(
type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
{workgroup_size_builtin_.y_value}));
} else if (id == workgroup_size_builtin_.z_id) {
return MakeConstantExpressionForSpirvConstant(
Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
constant_mgr_->GetConstant(
type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
{workgroup_size_builtin_.z_value}));
}
// Handle the general case where a constant is already registered
// with the SPIR-V optimizer's analysis framework.
const auto* inst = def_use_mgr_->GetDef(id); const auto* inst = def_use_mgr_->GetDef(id);
if (inst == nullptr) { if (inst == nullptr) {
Fail() << "ID " << id << " is not a registered instruction"; Fail() << "ID " << id << " is not a registered instruction";
@ -1548,6 +1708,14 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
} }
auto source = GetSourceForInst(inst); auto source = GetSourceForInst(inst);
return MakeConstantExpressionForSpirvConstant(source, original_ast_type,
spirv_const);
}
TypedExpression ParserImpl::MakeConstantExpressionForSpirvConstant(
Source source,
const Type* original_ast_type,
const spvtools::opt::analysis::Constant* spirv_const) {
auto* ast_type = original_ast_type->UnwrapAlias(); auto* ast_type = original_ast_type->UnwrapAlias();
// TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0. // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0.
@ -1603,12 +1771,10 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
Source{}, original_ast_type->Build(builder_), Source{}, original_ast_type->Build(builder_),
std::move(ast_components))}; std::move(ast_components))};
} }
auto* spirv_null_const = spirv_const->AsNullConstant(); if (spirv_const->AsNullConstant()) {
if (spirv_null_const != nullptr) {
return {original_ast_type, MakeNullValue(original_ast_type)}; return {original_ast_type, MakeNullValue(original_ast_type)};
} }
Fail() << "Unhandled constant type " << inst->type_id() << " for value ID " Fail() << "Unhandled constant type ";
<< id;
return {}; return {};
} }
@ -2464,6 +2630,10 @@ const spvtools::opt::Instruction* ParserImpl::GetInstructionForTest(
return def_use_mgr_ ? def_use_mgr_->GetDef(id) : nullptr; return def_use_mgr_ ? def_use_mgr_->GetDef(id) : nullptr;
} }
WorkgroupSizeInfo::WorkgroupSizeInfo() = default;
WorkgroupSizeInfo::~WorkgroupSizeInfo() = default;
} // namespace spirv } // namespace spirv
} // namespace reader } // namespace reader
} // namespace tint } // namespace tint

View File

@ -85,6 +85,29 @@ struct TypedExpression {
ast::Expression* expr = nullptr; ast::Expression* expr = nullptr;
}; };
/// Info about the WorkgroupSize builtin.
struct WorkgroupSizeInfo {
/// Constructor
WorkgroupSizeInfo();
/// Destructor
~WorkgroupSizeInfo();
/// The SPIR-V ID of the WorkgroupSize builtin, if any.
uint32_t id = 0u;
/// The SPIR-V type ID of the WorkgroupSize builtin, if any.
uint32_t type_id = 0u;
/// The SPIR-V type IDs of the x, y, and z components.
uint32_t component_type_id = 0u;
/// The SPIR-V IDs of the X, Y, and Z components of the workgroup size
/// builtin.
uint32_t x_id = 0u;
uint32_t y_id = 0u;
uint32_t z_id = 0u;
/// The effective workgroup size, if this is a compute shader.
uint32_t x_value = 0u;
uint32_t y_value = 0u;
uint32_t z_value = 0u;
};
/// Parser implementation for SPIR-V. /// Parser implementation for SPIR-V.
class ParserImpl : Reader { class ParserImpl : Reader {
public: public:
@ -306,6 +329,14 @@ class ParserImpl : Reader {
/// @returns true if parser is still successful. /// @returns true if parser is still successful.
bool RegisterUserAndStructMemberNames(); bool RegisterUserAndStructMemberNames();
/// Register the WorkgroupSize builtin and its associated constant value.
/// @returns true if parser is still successful.
bool RegisterWorkgroupSizeBuiltin();
const WorkgroupSizeInfo& workgroup_size_builtin() {
return workgroup_size_builtin_;
}
/// Register entry point information. /// Register entry point information.
/// This is a no-op if the parser has already failed. /// This is a no-op if the parser has already failed.
/// @returns true if parser is still successful. /// @returns true if parser is still successful.
@ -366,11 +397,27 @@ class ParserImpl : Reader {
ast::Expression* constructor, ast::Expression* constructor,
ast::DecorationList decorations); ast::DecorationList decorations);
/// Creates an AST expression node for a SPIR-V constant. /// Returns true if a constant expression can be generated.
/// @param id the SPIR-V ID of the value
/// @returns true if a constant expression can be generated
bool CanMakeConstantExpression(uint32_t id);
/// Creates an AST expression node for a SPIR-V ID. This is valid to call
/// when `CanMakeConstantExpression` returns true.
/// @param id the SPIR-V ID of the constant /// @param id the SPIR-V ID of the constant
/// @returns a new expression /// @returns a new expression
TypedExpression MakeConstantExpression(uint32_t id); TypedExpression MakeConstantExpression(uint32_t id);
/// Creates an AST expression node for a SPIR-V constant.
/// @param source the source location
/// @param ast_type the AST type for the value
/// @param spirv_const the internal representation of the SPIR-V constant.
/// @returns a new expression
TypedExpression MakeConstantExpressionForSpirvConstant(
Source source,
const Type* ast_type,
const spvtools::opt::analysis::Constant* spirv_const);
/// Creates an AST expression node for the null value for the given type. /// Creates an AST expression node for the null value for the given type.
/// @param type the AST type /// @param type the AST type
/// @returns a new expression /// @returns a new expression
@ -778,6 +825,11 @@ class ParserImpl : Reader {
/// This is temporary while this module is converted to use the new style /// This is temporary while this module is converted to use the new style
/// of pipeline IO. /// of pipeline IO.
bool hlsl_style_pipeline_io_ = false; bool hlsl_style_pipeline_io_ = false;
/// Info about the WorkgroupSize builtin. If it's not present, then the 'id'
/// field will be 0. Sadly, in SPIR-V right now, there's only one workgroup
/// size object in the module.
WorkgroupSizeInfo workgroup_size_builtin_;
}; };
} // namespace spirv } // namespace spirv

View File

@ -151,6 +151,11 @@ TEST_F(SpvParserTest, EmitFunctions_Function_EntryPoint_GLCompute) {
Function )" + program.Symbols().Get("main").to_str() + Function )" + program.Symbols().Get("main").to_str() +
R"( -> __void R"( -> __void
StageDecoration{compute} StageDecoration{compute}
WorkgroupDecoration{
ScalarConstructor[not set]{1}
ScalarConstructor[not set]{1}
ScalarConstructor[not set]{1}
}
() ()
{)")); {)"));
} }
@ -182,6 +187,189 @@ OpEntryPoint Vertex %main "second_shader"
{)")); {)"));
} }
TEST_F(SpvParserTest,
EmitFunctions_Function_EntryPoint_GLCompute_LocalSize_Only) {
std::string input = Caps() + Names({"main"}) +
R"(OpEntryPoint GLCompute %main "comp_main"
OpExecutionMode %main LocalSize 2 4 8
)" + CommonTypes() + R"(
%main = OpFunction %void None %voidfn
%entry = OpLabel
OpReturn
OpFunctionEnd)";
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
Program program = p->program();
const auto program_ast = program.to_str(false);
EXPECT_THAT(program_ast, HasSubstr(R"(
Function )" + program.Symbols().Get("comp_main").to_str() +
R"( -> __void
StageDecoration{compute}
WorkgroupDecoration{
ScalarConstructor[not set]{2}
ScalarConstructor[not set]{4}
ScalarConstructor[not set]{8}
}
()
{)"))
<< program_ast;
}
TEST_F(SpvParserTest,
EmitFunctions_Function_EntryPoint_WorkgroupSizeBuiltin_Constant_Only) {
std::string input = Caps() + R"(OpEntryPoint GLCompute %main "comp_main"
OpDecorate %wgsize BuiltIn WorkgroupSize
)" + CommonTypes() + R"(
%uvec3 = OpTypeVector %uint 3
%uint_3 = OpConstant %uint 3
%uint_5 = OpConstant %uint 5
%uint_7 = OpConstant %uint 7
%wgsize = OpConstantComposite %uvec3 %uint_3 %uint_5 %uint_7
%main = OpFunction %void None %voidfn
%entry = OpLabel
OpReturn
OpFunctionEnd)";
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
Program program = p->program();
const auto program_ast = program.to_str(false);
EXPECT_THAT(program_ast, HasSubstr(R"(
Function )" + program.Symbols().Get("comp_main").to_str() +
R"( -> __void
StageDecoration{compute}
WorkgroupDecoration{
ScalarConstructor[not set]{3}
ScalarConstructor[not set]{5}
ScalarConstructor[not set]{7}
}
()
{)"))
<< program_ast;
}
TEST_F(
SpvParserTest,
EmitFunctions_Function_EntryPoint_WorkgroupSizeBuiltin_SpecConstant_Only) {
std::string input = Caps() +
R"(OpEntryPoint GLCompute %main "comp_main"
OpDecorate %wgsize BuiltIn WorkgroupSize
OpDecorate %uint_3 SpecId 0
OpDecorate %uint_5 SpecId 1
OpDecorate %uint_7 SpecId 2
)" + CommonTypes() + R"(
%uvec3 = OpTypeVector %uint 3
%uint_3 = OpSpecConstant %uint 3
%uint_5 = OpSpecConstant %uint 5
%uint_7 = OpSpecConstant %uint 7
%wgsize = OpSpecConstantComposite %uvec3 %uint_3 %uint_5 %uint_7
%main = OpFunction %void None %voidfn
%entry = OpLabel
OpReturn
OpFunctionEnd)";
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
Program program = p->program();
const auto program_ast = program.to_str(false);
EXPECT_THAT(program_ast, HasSubstr(R"(
Function )" + program.Symbols().Get("comp_main").to_str() +
R"( -> __void
StageDecoration{compute}
WorkgroupDecoration{
ScalarConstructor[not set]{3}
ScalarConstructor[not set]{5}
ScalarConstructor[not set]{7}
}
()
{)"))
<< program_ast;
}
TEST_F(
SpvParserTest,
EmitFunctions_Function_EntryPoint_WorkgroupSize_MixedConstantSpecConstant) {
std::string input = Caps() +
R"(OpEntryPoint GLCompute %main "comp_main"
OpDecorate %wgsize BuiltIn WorkgroupSize
OpDecorate %uint_3 SpecId 0
OpDecorate %uint_7 SpecId 2
)" + CommonTypes() + R"(
%uvec3 = OpTypeVector %uint 3
%uint_3 = OpSpecConstant %uint 3
%uint_5 = OpConstant %uint 5
%uint_7 = OpSpecConstant %uint 7
%wgsize = OpSpecConstantComposite %uvec3 %uint_3 %uint_5 %uint_7
%main = OpFunction %void None %voidfn
%entry = OpLabel
OpReturn
OpFunctionEnd)";
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
Program program = p->program();
const auto program_ast = program.to_str(false);
EXPECT_THAT(program_ast, HasSubstr(R"(
Function )" + program.Symbols().Get("comp_main").to_str() +
R"( -> __void
StageDecoration{compute}
WorkgroupDecoration{
ScalarConstructor[not set]{3}
ScalarConstructor[not set]{5}
ScalarConstructor[not set]{7}
}
()
{)"))
<< program_ast;
}
TEST_F(
SpvParserTest,
// I had to shorten the name to pass the linter.
EmitFunctions_Function_EntryPoint_LocalSize_And_WGSBuiltin_SpecConstant) {
// WorkgroupSize builtin wins.
std::string input = Caps() +
R"(OpEntryPoint GLCompute %main "comp_main"
OpExecutionMode %main LocalSize 2 4 8
OpDecorate %wgsize BuiltIn WorkgroupSize
OpDecorate %uint_3 SpecId 0
OpDecorate %uint_5 SpecId 1
OpDecorate %uint_7 SpecId 2
)" + CommonTypes() + R"(
%uvec3 = OpTypeVector %uint 3
%uint_3 = OpSpecConstant %uint 3
%uint_5 = OpSpecConstant %uint 5
%uint_7 = OpSpecConstant %uint 7
%wgsize = OpSpecConstantComposite %uvec3 %uint_3 %uint_5 %uint_7
%main = OpFunction %void None %voidfn
%entry = OpLabel
OpReturn
OpFunctionEnd)";
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
Program program = p->program();
const auto program_ast = program.to_str(false);
EXPECT_THAT(program_ast, HasSubstr(R"(
Function )" + program.Symbols().Get("comp_main").to_str() +
R"( -> __void
StageDecoration{compute}
WorkgroupDecoration{
ScalarConstructor[not set]{3}
ScalarConstructor[not set]{5}
ScalarConstructor[not set]{7}
}
()
{)"))
<< program_ast;
}
TEST_F(SpvParserTest, EmitFunctions_VoidFunctionWithoutParams) { TEST_F(SpvParserTest, EmitFunctions_VoidFunctionWithoutParams) {
auto p = parser(test::Assemble(Preamble() + Names({"another_function"}) + auto p = parser(test::Assemble(Preamble() + Names({"another_function"}) +
CommonTypes() + R"( CommonTypes() + R"(

View File

@ -4903,6 +4903,11 @@ TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltin, Load_Direct) {
} }
Function main -> __void Function main -> __void
StageDecoration{compute} StageDecoration{compute}
WorkgroupDecoration{
ScalarConstructor[not set]{1}
ScalarConstructor[not set]{1}
ScalarConstructor[not set]{1}
}
( (
VariableConst{ VariableConst{
Decorations{ Decorations{
@ -5001,6 +5006,11 @@ TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltin, Load_CopyObject) {
} }
Function main -> __void Function main -> __void
StageDecoration{compute} StageDecoration{compute}
WorkgroupDecoration{
ScalarConstructor[not set]{1}
ScalarConstructor[not set]{1}
ScalarConstructor[not set]{1}
}
( (
VariableConst{ VariableConst{
Decorations{ Decorations{
@ -5081,6 +5091,11 @@ TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltin, Load_AccessChain) {
} }
Function main -> __void Function main -> __void
StageDecoration{compute} StageDecoration{compute}
WorkgroupDecoration{
ScalarConstructor[not set]{1}
ScalarConstructor[not set]{1}
ScalarConstructor[not set]{1}
}
( (
VariableConst{ VariableConst{
Decorations{ Decorations{

View File

@ -3,7 +3,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -5,7 +5,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -5,7 +5,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -8,7 +8,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -14,7 +14,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -5,7 +5,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -5,7 +5,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -5,7 +5,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -6,7 +6,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -11,7 +11,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -6,7 +6,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -10,7 +10,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -11,7 +11,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -6,7 +6,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -9,7 +9,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -6,7 +6,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -8,7 +8,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }

View File

@ -10,7 +10,7 @@ fn main_1() {
return; return;
} }
[[stage(compute)]] [[stage(compute), workgroup_size(1, 1, 1)]]
fn main() { fn main() {
main_1(); main_1();
} }